benchmarks: v2

Minor refactor
This commit is contained in:
Per Stark
2025-11-18 22:51:06 +01:00
parent f535df7e61
commit bd519ab269
22 changed files with 2794 additions and 2035 deletions

View File

@@ -8,7 +8,7 @@ anyhow = { workspace = true }
async-openai = { workspace = true }
chrono = { workspace = true }
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }
retrieval-pipeline = { path = "../retrieval-pipeline" }
ingestion-pipeline = { path = "../ingestion-pipeline" }
futures = { workspace = true }
fastembed = { workspace = true }

View File

@@ -4,6 +4,7 @@ use std::{
};
use anyhow::{anyhow, Context, Result};
use retrieval_pipeline::RetrievalStrategy;
use crate::datasets::DatasetKind;
@@ -35,6 +36,41 @@ impl std::str::FromStr for EmbeddingBackend {
}
}
#[derive(Debug, Clone)]
pub struct RetrievalSettings {
pub chunk_min_chars: usize,
pub chunk_max_chars: usize,
pub chunk_vector_take: Option<usize>,
pub chunk_fts_take: Option<usize>,
pub chunk_token_budget: Option<usize>,
pub chunk_avg_chars_per_token: Option<usize>,
pub max_chunks_per_entity: Option<usize>,
pub rerank: bool,
pub rerank_pool_size: usize,
pub rerank_keep_top: usize,
pub require_verified_chunks: bool,
pub strategy: RetrievalStrategy,
}
impl Default for RetrievalSettings {
fn default() -> Self {
Self {
chunk_min_chars: 500,
chunk_max_chars: 2_000,
chunk_vector_take: None,
chunk_fts_take: None,
chunk_token_budget: None,
chunk_avg_chars_per_token: None,
max_chunks_per_entity: None,
rerank: true,
rerank_pool_size: 16,
rerank_keep_top: 10,
require_verified_chunks: true,
strategy: RetrievalStrategy::Initial,
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub convert_only: bool,
@@ -49,21 +85,14 @@ pub struct Config {
pub limit: Option<usize>,
pub summary_sample: usize,
pub full_context: bool,
pub chunk_min_chars: usize,
pub chunk_max_chars: usize,
pub chunk_vector_take: Option<usize>,
pub chunk_fts_take: Option<usize>,
pub chunk_token_budget: Option<usize>,
pub chunk_avg_chars_per_token: Option<usize>,
pub max_chunks_per_entity: Option<usize>,
pub rerank: bool,
pub rerank_pool_size: usize,
pub rerank_keep_top: usize,
pub retrieval: RetrievalSettings,
pub concurrency: usize,
pub embedding_backend: EmbeddingBackend,
pub embedding_model: Option<String>,
pub cache_dir: PathBuf,
pub ingestion_cache_dir: PathBuf,
pub ingestion_batch_size: usize,
pub ingestion_max_retries: usize,
pub refresh_embeddings_only: bool,
pub detailed_report: bool,
pub slice: Option<String>,
@@ -105,21 +134,14 @@ impl Default for Config {
limit: Some(200),
summary_sample: 5,
full_context: false,
chunk_min_chars: 500,
chunk_max_chars: 2_000,
chunk_vector_take: None,
chunk_fts_take: None,
chunk_token_budget: None,
chunk_avg_chars_per_token: None,
max_chunks_per_entity: None,
rerank: true,
rerank_pool_size: 16,
rerank_keep_top: 10,
retrieval: RetrievalSettings::default(),
concurrency: 4,
embedding_backend: EmbeddingBackend::FastEmbed,
embedding_model: None,
cache_dir: PathBuf::from("eval/cache"),
ingestion_cache_dir: PathBuf::from("eval/cache/ingested"),
ingestion_batch_size: 5,
ingestion_max_retries: 3,
refresh_embeddings_only: false,
detailed_report: false,
slice: None,
@@ -176,6 +198,7 @@ pub fn parse() -> Result<ParsedArgs> {
"--force" | "--refresh" => config.force_convert = true,
"--llm-mode" => {
config.llm_mode = true;
config.retrieval.require_verified_chunks = false;
}
"--dataset" => {
let value = take_value("--dataset", &mut args)?;
@@ -279,14 +302,14 @@ pub fn parse() -> Result<ParsedArgs> {
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --chunk-min value '{value}' as usize")
})?;
config.chunk_min_chars = parsed.max(1);
config.retrieval.chunk_min_chars = parsed.max(1);
}
"--chunk-max" => {
let value = take_value("--chunk-max", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --chunk-max value '{value}' as usize")
})?;
config.chunk_max_chars = parsed.max(1);
config.retrieval.chunk_max_chars = parsed.max(1);
}
"--chunk-vector-take" => {
let value = take_value("--chunk-vector-take", &mut args)?;
@@ -296,7 +319,7 @@ pub fn parse() -> Result<ParsedArgs> {
if parsed == 0 {
return Err(anyhow!("--chunk-vector-take must be greater than zero"));
}
config.chunk_vector_take = Some(parsed);
config.retrieval.chunk_vector_take = Some(parsed);
}
"--chunk-fts-take" => {
let value = take_value("--chunk-fts-take", &mut args)?;
@@ -306,7 +329,7 @@ pub fn parse() -> Result<ParsedArgs> {
if parsed == 0 {
return Err(anyhow!("--chunk-fts-take must be greater than zero"));
}
config.chunk_fts_take = Some(parsed);
config.retrieval.chunk_fts_take = Some(parsed);
}
"--chunk-token-budget" => {
let value = take_value("--chunk-token-budget", &mut args)?;
@@ -316,7 +339,7 @@ pub fn parse() -> Result<ParsedArgs> {
if parsed == 0 {
return Err(anyhow!("--chunk-token-budget must be greater than zero"));
}
config.chunk_token_budget = Some(parsed);
config.retrieval.chunk_token_budget = Some(parsed);
}
"--chunk-token-chars" => {
let value = take_value("--chunk-token-chars", &mut args)?;
@@ -326,7 +349,14 @@ pub fn parse() -> Result<ParsedArgs> {
if parsed == 0 {
return Err(anyhow!("--chunk-token-chars must be greater than zero"));
}
config.chunk_avg_chars_per_token = Some(parsed);
config.retrieval.chunk_avg_chars_per_token = Some(parsed);
}
"--retrieval-strategy" => {
let value = take_value("--retrieval-strategy", &mut args)?;
let parsed = value.parse::<RetrievalStrategy>().map_err(|err| {
anyhow!("failed to parse --retrieval-strategy value '{value}': {err}")
})?;
config.retrieval.strategy = parsed;
}
"--max-chunks-per-entity" => {
let value = take_value("--max-chunks-per-entity", &mut args)?;
@@ -336,7 +366,7 @@ pub fn parse() -> Result<ParsedArgs> {
if parsed == 0 {
return Err(anyhow!("--max-chunks-per-entity must be greater than zero"));
}
config.max_chunks_per_entity = Some(parsed);
config.retrieval.max_chunks_per_entity = Some(parsed);
}
"--embedding" => {
let value = take_value("--embedding", &mut args)?;
@@ -354,6 +384,23 @@ pub fn parse() -> Result<ParsedArgs> {
let value = take_value("--ingestion-cache-dir", &mut args)?;
config.ingestion_cache_dir = PathBuf::from(value);
}
"--ingestion-batch-size" => {
let value = take_value("--ingestion-batch-size", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --ingestion-batch-size value '{value}' as usize")
})?;
if parsed == 0 {
return Err(anyhow!("--ingestion-batch-size must be greater than zero"));
}
config.ingestion_batch_size = parsed;
}
"--ingestion-max-retries" => {
let value = take_value("--ingestion-max-retries", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --ingestion-max-retries value '{value}' as usize")
})?;
config.ingestion_max_retries = parsed;
}
"--negative-multiplier" => {
let value = take_value("--negative-multiplier", &mut args)?;
let parsed = value.parse::<f32>().with_context(|| {
@@ -367,21 +414,21 @@ pub fn parse() -> Result<ParsedArgs> {
config.negative_multiplier = parsed;
}
"--no-rerank" => {
config.rerank = false;
config.retrieval.rerank = false;
}
"--rerank-pool" => {
let value = take_value("--rerank-pool", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --rerank-pool value '{value}' as usize")
})?;
config.rerank_pool_size = parsed.max(1);
config.retrieval.rerank_pool_size = parsed.max(1);
}
"--rerank-keep" => {
let value = take_value("--rerank-keep", &mut args)?;
let parsed = value.parse::<usize>().with_context(|| {
format!("failed to parse --rerank-keep value '{value}' as usize")
})?;
config.rerank_keep_top = parsed.max(1);
config.retrieval.rerank_keep_top = parsed.max(1);
}
"--concurrency" => {
let value = take_value("--concurrency", &mut args)?;
@@ -451,15 +498,15 @@ pub fn parse() -> Result<ParsedArgs> {
}
}
if config.chunk_min_chars >= config.chunk_max_chars {
if config.retrieval.chunk_min_chars >= config.retrieval.chunk_max_chars {
return Err(anyhow!(
"--chunk-min must be less than --chunk-max (got {} >= {})",
config.chunk_min_chars,
config.chunk_max_chars
config.retrieval.chunk_min_chars,
config.retrieval.chunk_max_chars
));
}
if config.rerank && config.rerank_pool_size == 0 {
if config.retrieval.rerank && config.retrieval.rerank_pool_size == 0 {
return Err(anyhow!(
"--rerank-pool must be greater than zero when reranking is enabled"
));
@@ -578,14 +625,20 @@ OPTIONS:
Override chunk token budget estimate for assembly (default: 10000).
--chunk-token-chars <int>
Override average characters per token used for budgeting (default: 4).
--retrieval-strategy <initial|revised>
Select the retrieval pipeline strategy (default: initial).
--max-chunks-per-entity <int>
Override maximum chunks attached per entity (default: 4).
--embedding <name> Embedding backend: 'fastembed' (default) or 'hashed'.
--embedding-model <code>
FastEmbed model code (defaults to crate preset when omitted).
--cache-dir <path> Directory for embedding caches (default: eval/cache).
--ingestion-cache-dir <path>
Directory for ingestion corpora caches (default: eval/cache/ingested).
--ingestion-cache-dir <path>
Directory for ingestion corpora caches (default: eval/cache/ingested).
--ingestion-batch-size <int>
Number of paragraphs to ingest concurrently (default: 5).
--ingestion-max-retries <int>
Maximum retries for ingestion failures per paragraph (default: 3).
--negative-multiplier <float>
Target negative-to-positive paragraph ratio for slice growth (default: 4.0).
--refresh-embeddings Recompute embeddings for cached corpora without re-running ingestion.

File diff suppressed because it is too large Load Diff

493
eval/src/datasets/mod.rs Normal file
View File

@@ -0,0 +1,493 @@
mod nq;
mod squad;
use std::{
collections::{BTreeMap, HashMap},
fs::{self},
path::{Path, PathBuf},
str::FromStr,
};
use anyhow::{anyhow, bail, Context, Result};
use chrono::{DateTime, TimeZone, Utc};
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use tracing::warn;
const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml");
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DatasetCatalog {
datasets: BTreeMap<String, DatasetEntry>,
slices: HashMap<String, SliceLocation>,
default_dataset: String,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DatasetEntry {
pub metadata: DatasetMetadata,
pub raw_path: PathBuf,
pub converted_path: PathBuf,
pub include_unanswerable: bool,
pub slices: Vec<SliceEntry>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SliceEntry {
pub id: String,
pub dataset_id: String,
pub label: String,
pub description: Option<String>,
pub limit: Option<usize>,
pub corpus_limit: Option<usize>,
pub include_unanswerable: Option<bool>,
pub seed: Option<u64>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct SliceLocation {
dataset_id: String,
slice_index: usize,
}
#[derive(Debug, Deserialize)]
struct ManifestFile {
default_dataset: Option<String>,
datasets: Vec<ManifestDataset>,
}
#[derive(Debug, Deserialize)]
struct ManifestDataset {
id: String,
label: String,
category: String,
#[serde(default)]
entity_suffix: Option<String>,
#[serde(default)]
source_prefix: Option<String>,
raw: String,
converted: String,
#[serde(default)]
include_unanswerable: bool,
#[serde(default)]
slices: Vec<ManifestSlice>,
}
#[derive(Debug, Deserialize)]
struct ManifestSlice {
id: String,
label: String,
#[serde(default)]
description: Option<String>,
#[serde(default)]
limit: Option<usize>,
#[serde(default)]
corpus_limit: Option<usize>,
#[serde(default)]
include_unanswerable: Option<bool>,
#[serde(default)]
seed: Option<u64>,
}
impl DatasetCatalog {
pub fn load() -> Result<Self> {
let manifest_raw = fs::read_to_string(MANIFEST_PATH)
.with_context(|| format!("reading dataset manifest at {}", MANIFEST_PATH))?;
let manifest: ManifestFile = serde_yaml::from_str(&manifest_raw)
.with_context(|| format!("parsing dataset manifest at {}", MANIFEST_PATH))?;
let root = Path::new(env!("CARGO_MANIFEST_DIR"));
let mut datasets = BTreeMap::new();
let mut slices = HashMap::new();
for dataset in manifest.datasets {
let raw_path = resolve_path(root, &dataset.raw);
let converted_path = resolve_path(root, &dataset.converted);
if !raw_path.exists() {
bail!(
"dataset '{}' raw file missing at {}",
dataset.id,
raw_path.display()
);
}
if !converted_path.exists() {
warn!(
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
dataset.id,
converted_path.display()
);
}
let metadata = DatasetMetadata {
id: dataset.id.clone(),
label: dataset.label.clone(),
category: dataset.category.clone(),
entity_suffix: dataset
.entity_suffix
.clone()
.unwrap_or_else(|| dataset.label.clone()),
source_prefix: dataset
.source_prefix
.clone()
.unwrap_or_else(|| dataset.id.clone()),
include_unanswerable: dataset.include_unanswerable,
context_token_limit: None,
};
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
for (index, manifest_slice) in dataset.slices.into_iter().enumerate() {
if slices.contains_key(&manifest_slice.id) {
bail!(
"slice '{}' defined multiple times in manifest",
manifest_slice.id
);
}
entry_slices.push(SliceEntry {
id: manifest_slice.id.clone(),
dataset_id: dataset.id.clone(),
label: manifest_slice.label,
description: manifest_slice.description,
limit: manifest_slice.limit,
corpus_limit: manifest_slice.corpus_limit,
include_unanswerable: manifest_slice.include_unanswerable,
seed: manifest_slice.seed,
});
slices.insert(
manifest_slice.id,
SliceLocation {
dataset_id: dataset.id.clone(),
slice_index: index,
},
);
}
datasets.insert(
metadata.id.clone(),
DatasetEntry {
metadata,
raw_path,
converted_path,
include_unanswerable: dataset.include_unanswerable,
slices: entry_slices,
},
);
}
let default_dataset = manifest
.default_dataset
.or_else(|| datasets.keys().next().cloned())
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
Ok(Self {
datasets,
slices,
default_dataset,
})
}
pub fn global() -> Result<&'static Self> {
DATASET_CATALOG.get_or_try_init(Self::load)
}
pub fn dataset(&self, id: &str) -> Result<&DatasetEntry> {
self.datasets
.get(id)
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
}
#[allow(dead_code)]
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
self.dataset(&self.default_dataset)
}
#[allow(dead_code)]
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
let location = self
.slices
.get(slice_id)
.ok_or_else(|| anyhow!("unknown slice '{slice_id}' in manifest"))?;
let dataset = self
.datasets
.get(&location.dataset_id)
.ok_or_else(|| anyhow!("slice '{slice_id}' references missing dataset"))?;
let slice = dataset
.slices
.get(location.slice_index)
.ok_or_else(|| anyhow!("slice index out of bounds for '{slice_id}'"))?;
Ok((dataset, slice))
}
}
fn resolve_path(root: &Path, value: &str) -> PathBuf {
let path = Path::new(value);
if path.is_absolute() {
path.to_path_buf()
} else {
root.join(path)
}
}
pub fn catalog() -> Result<&'static DatasetCatalog> {
DatasetCatalog::global()
}
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
let catalog = catalog()?;
catalog.dataset(kind.id())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DatasetKind {
SquadV2,
NaturalQuestions,
}
impl DatasetKind {
pub fn id(self) -> &'static str {
match self {
Self::SquadV2 => "squad-v2",
Self::NaturalQuestions => "natural-questions-dev",
}
}
pub fn label(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD v2.0",
Self::NaturalQuestions => "Natural Questions (dev)",
}
}
pub fn category(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD v2.0",
Self::NaturalQuestions => "Natural Questions",
}
}
pub fn entity_suffix(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD",
Self::NaturalQuestions => "Natural Questions",
}
}
pub fn source_prefix(self) -> &'static str {
match self {
Self::SquadV2 => "squad",
Self::NaturalQuestions => "nq",
}
}
pub fn default_raw_path(self) -> PathBuf {
dataset_entry_for_kind(self)
.map(|entry| entry.raw_path.clone())
.unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self))
}
pub fn default_converted_path(self) -> PathBuf {
dataset_entry_for_kind(self)
.map(|entry| entry.converted_path.clone())
.unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self))
}
}
impl Default for DatasetKind {
fn default() -> Self {
Self::SquadV2
}
}
impl FromStr for DatasetKind {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"squad" | "squad-v2" | "squad_v2" => Ok(Self::SquadV2),
"nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => {
Ok(Self::NaturalQuestions)
}
other => {
anyhow::bail!("unknown dataset '{other}'. Expected 'squad' or 'natural-questions'.")
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetMetadata {
pub id: String,
pub label: String,
pub category: String,
pub entity_suffix: String,
pub source_prefix: String,
#[serde(default)]
pub include_unanswerable: bool,
#[serde(default)]
pub context_token_limit: Option<usize>,
}
impl DatasetMetadata {
pub fn for_kind(
kind: DatasetKind,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Self {
if let Ok(entry) = dataset_entry_for_kind(kind) {
return Self {
id: entry.metadata.id.clone(),
label: entry.metadata.label.clone(),
category: entry.metadata.category.clone(),
entity_suffix: entry.metadata.entity_suffix.clone(),
source_prefix: entry.metadata.source_prefix.clone(),
include_unanswerable,
context_token_limit,
};
}
Self {
id: kind.id().to_string(),
label: kind.label().to_string(),
category: kind.category().to_string(),
entity_suffix: kind.entity_suffix().to_string(),
source_prefix: kind.source_prefix().to_string(),
include_unanswerable,
context_token_limit,
}
}
}
fn default_metadata() -> DatasetMetadata {
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedDataset {
pub generated_at: DateTime<Utc>,
#[serde(default = "default_metadata")]
pub metadata: DatasetMetadata,
pub source: String,
pub paragraphs: Vec<ConvertedParagraph>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedParagraph {
pub id: String,
pub title: String,
pub context: String,
pub questions: Vec<ConvertedQuestion>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedQuestion {
pub id: String,
pub question: String,
pub answers: Vec<String>,
pub is_impossible: bool,
}
pub fn convert(
raw_path: &Path,
dataset: DatasetKind,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Result<ConvertedDataset> {
let paragraphs = match dataset {
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
DatasetKind::NaturalQuestions => {
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
}
};
let metadata_limit = match dataset {
DatasetKind::NaturalQuestions => None,
_ => context_token_limit,
};
Ok(ConvertedDataset {
generated_at: Utc::now(),
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
source: raw_path.display().to_string(),
paragraphs,
})
}
fn ensure_parent(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating parent directory for {}", path.display()))?;
}
Ok(())
}
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
ensure_parent(converted_path)?;
let json =
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
fs::write(converted_path, json)
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
}
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
let raw = fs::read_to_string(converted_path)
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
if dataset.metadata.id.trim().is_empty() {
dataset.metadata = default_metadata();
}
if dataset.source.is_empty() {
dataset.source = converted_path.display().to_string();
}
Ok(dataset)
}
pub fn ensure_converted(
dataset_kind: DatasetKind,
raw_path: &Path,
converted_path: &Path,
force: bool,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Result<ConvertedDataset> {
if force || !converted_path.exists() {
let dataset = convert(
raw_path,
dataset_kind,
include_unanswerable,
context_token_limit,
)?;
write_converted(&dataset, converted_path)?;
return Ok(dataset);
}
match read_converted(converted_path) {
Ok(dataset)
if dataset.metadata.id == dataset_kind.id()
&& dataset.metadata.include_unanswerable == include_unanswerable
&& dataset.metadata.context_token_limit == context_token_limit =>
{
Ok(dataset)
}
_ => {
let dataset = convert(
raw_path,
dataset_kind,
include_unanswerable,
context_token_limit,
)?;
write_converted(&dataset, converted_path)?;
Ok(dataset)
}
}
}
pub fn base_timestamp() -> DateTime<Utc> {
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
}

234
eval/src/datasets/nq.rs Normal file
View File

@@ -0,0 +1,234 @@
use std::{
collections::BTreeSet,
fs::File,
io::{BufRead, BufReader},
path::Path,
};
use anyhow::{Context, Result};
use serde::Deserialize;
use tracing::warn;
use super::{ConvertedParagraph, ConvertedQuestion};
pub fn convert_nq(
raw_path: &Path,
include_unanswerable: bool,
_context_token_limit: Option<usize>,
) -> Result<Vec<ConvertedParagraph>> {
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqExample {
question_text: String,
document_title: String,
example_id: i64,
document_tokens: Vec<NqToken>,
long_answer_candidates: Vec<NqLongAnswerCandidate>,
annotations: Vec<NqAnnotation>,
}
#[derive(Debug, Deserialize)]
struct NqToken {
token: String,
#[serde(default)]
html_token: bool,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqLongAnswerCandidate {
start_token: i32,
end_token: i32,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqAnnotation {
short_answers: Vec<NqShortAnswer>,
#[serde(default)]
yes_no_answer: String,
long_answer: NqLongAnswer,
}
#[derive(Debug, Deserialize)]
struct NqShortAnswer {
start_token: i32,
end_token: i32,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqLongAnswer {
candidate_index: i32,
}
fn join_tokens(tokens: &[NqToken], start: usize, end: usize) -> String {
let mut buffer = String::new();
let end = end.min(tokens.len());
for token in tokens.iter().skip(start).take(end.saturating_sub(start)) {
if token.html_token {
continue;
}
let text = token.token.trim();
if text.is_empty() {
continue;
}
let attach = matches!(
text,
"," | "." | "!" | "?" | ";" | ":" | ")" | "]" | "}" | "%" | "" | "..."
) || text.starts_with('\'')
|| text == "n't"
|| text == "'s"
|| text == "'re"
|| text == "'ve"
|| text == "'d"
|| text == "'ll";
if buffer.is_empty() || attach {
buffer.push_str(text);
} else {
buffer.push(' ');
buffer.push_str(text);
}
}
buffer.trim().to_string()
}
let file = File::open(raw_path).with_context(|| {
format!(
"opening Natural Questions dataset at {}",
raw_path.display()
)
})?;
let reader = BufReader::new(file);
let mut paragraphs = Vec::new();
for (line_idx, line) in reader.lines().enumerate() {
let line = line.with_context(|| {
format!(
"reading Natural Questions line {} from {}",
line_idx + 1,
raw_path.display()
)
})?;
if line.trim().is_empty() {
continue;
}
let example: NqExample = serde_json::from_str(&line).with_context(|| {
format!(
"parsing Natural Questions JSON (line {}) at {}",
line_idx + 1,
raw_path.display()
)
})?;
let mut answer_texts: Vec<String> = Vec::new();
let mut short_answer_texts: Vec<String> = Vec::new();
let mut has_short_or_yesno = false;
let mut has_short_answer = false;
for annotation in &example.annotations {
for short in &annotation.short_answers {
if short.start_token < 0 || short.end_token <= short.start_token {
continue;
}
let start = short.start_token as usize;
let end = short.end_token as usize;
if start >= example.document_tokens.len() || end > example.document_tokens.len() {
continue;
}
let text = join_tokens(&example.document_tokens, start, end);
if !text.is_empty() {
answer_texts.push(text.clone());
short_answer_texts.push(text);
has_short_or_yesno = true;
has_short_answer = true;
}
}
match annotation
.yes_no_answer
.trim()
.to_ascii_lowercase()
.as_str()
{
"yes" => {
answer_texts.push("yes".to_string());
has_short_or_yesno = true;
}
"no" => {
answer_texts.push("no".to_string());
has_short_or_yesno = true;
}
_ => {}
}
}
let mut answers = dedupe_strings(answer_texts);
let is_unanswerable = !has_short_or_yesno || answers.is_empty();
if is_unanswerable {
if !include_unanswerable {
continue;
}
answers.clear();
}
let paragraph_id = format!("nq-{}", example.example_id);
let question_id = format!("nq-{}", example.example_id);
let context = join_tokens(&example.document_tokens, 0, example.document_tokens.len());
if context.is_empty() {
continue;
}
if has_short_answer && !short_answer_texts.is_empty() {
let normalized_context = context.to_ascii_lowercase();
let missing_answer = short_answer_texts.iter().any(|answer| {
let needle = answer.trim().to_ascii_lowercase();
!needle.is_empty() && !normalized_context.contains(&needle)
});
if missing_answer {
warn!(
question_id = %question_id,
"Skipping Natural Questions example because answers were not found in the assembled context"
);
continue;
}
}
if !include_unanswerable && (!has_short_answer || short_answer_texts.is_empty()) {
// yes/no-only questions are excluded by default unless --llm-mode is used
continue;
}
let question = ConvertedQuestion {
id: question_id,
question: example.question_text.trim().to_string(),
answers,
is_impossible: is_unanswerable,
};
paragraphs.push(ConvertedParagraph {
id: paragraph_id,
title: example.document_title.trim().to_string(),
context,
questions: vec![question],
});
}
Ok(paragraphs)
}
fn dedupe_strings<I>(values: I) -> Vec<String>
where
I: IntoIterator<Item = String>,
{
let mut set = BTreeSet::new();
for value in values {
let trimmed = value.trim();
if !trimmed.is_empty() {
set.insert(trimmed.to_string());
}
}
set.into_iter().collect()
}

107
eval/src/datasets/squad.rs Normal file
View File

@@ -0,0 +1,107 @@
use std::{collections::BTreeSet, fs, path::Path};
use anyhow::{Context, Result};
use serde::Deserialize;
use super::{ConvertedParagraph, ConvertedQuestion};
pub fn convert_squad(raw_path: &Path) -> Result<Vec<ConvertedParagraph>> {
#[derive(Debug, Deserialize)]
struct SquadDataset {
data: Vec<SquadArticle>,
}
#[derive(Debug, Deserialize)]
struct SquadArticle {
title: String,
paragraphs: Vec<SquadParagraph>,
}
#[derive(Debug, Deserialize)]
struct SquadParagraph {
context: String,
qas: Vec<SquadQuestion>,
}
#[derive(Debug, Deserialize)]
struct SquadQuestion {
id: String,
question: String,
answers: Vec<SquadAnswer>,
#[serde(default)]
is_impossible: bool,
}
#[derive(Debug, Deserialize)]
struct SquadAnswer {
text: String,
}
let raw = fs::read_to_string(raw_path)
.with_context(|| format!("reading raw SQuAD dataset at {}", raw_path.display()))?;
let parsed: SquadDataset = serde_json::from_str(&raw)
.with_context(|| format!("parsing SQuAD dataset at {}", raw_path.display()))?;
let mut paragraphs = Vec::new();
for (article_idx, article) in parsed.data.into_iter().enumerate() {
for (paragraph_idx, paragraph) in article.paragraphs.into_iter().enumerate() {
let mut questions = Vec::new();
for qa in paragraph.qas {
let answers = dedupe_strings(qa.answers.into_iter().map(|answer| answer.text));
questions.push(ConvertedQuestion {
id: qa.id,
question: qa.question.trim().to_string(),
answers,
is_impossible: qa.is_impossible,
});
}
let paragraph_id =
format!("{}-{}", slugify(&article.title, article_idx), paragraph_idx);
paragraphs.push(ConvertedParagraph {
id: paragraph_id,
title: article.title.trim().to_string(),
context: paragraph.context.trim().to_string(),
questions,
});
}
}
Ok(paragraphs)
}
fn dedupe_strings<I>(values: I) -> Vec<String>
where
I: IntoIterator<Item = String>,
{
let mut set = BTreeSet::new();
for value in values {
let trimmed = value.trim();
if !trimmed.is_empty() {
set.insert(trimmed.to_string());
}
}
set.into_iter().collect()
}
fn slugify(input: &str, fallback_idx: usize) -> String {
let mut slug = String::new();
let mut last_dash = false;
for ch in input.chars() {
let c = ch.to_ascii_lowercase();
if c.is_ascii_alphanumeric() {
slug.push(c);
last_dash = false;
} else if !last_dash {
slug.push('-');
last_dash = true;
}
}
slug = slug.trim_matches('-').to_string();
if slug.is_empty() {
slug = format!("article-{fallback_idx}");
}
slug
}

View File

@@ -1,12 +1,10 @@
mod pipeline;
mod types;
pub use pipeline::run_evaluation;
pub use types::*;
use std::{
collections::{HashMap, HashSet},
path::Path,
time::Duration,
};
use std::{collections::HashMap, path::Path, time::Duration};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, SecondsFormat, Utc};
@@ -17,10 +15,8 @@ use common::{
types::{system_settings::SystemSettings, user::User},
},
};
use composite_retrieval::pipeline as retrieval_pipeline;
use composite_retrieval::pipeline::PipelineStageTimings;
use composite_retrieval::pipeline::RetrievalTuning;
use serde::{Deserialize, Serialize};
use retrieval_pipeline::RetrievalTuning;
use serde::Deserialize;
use tokio::io::AsyncWriteExt;
use tracing::{info, warn};
@@ -33,178 +29,6 @@ use crate::{
snapshot::{self, DbSnapshotState},
};
#[derive(Debug, Serialize)]
pub struct EvaluationSummary {
pub generated_at: DateTime<Utc>,
pub k: usize,
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub run_label: Option<String>,
pub total_cases: usize,
pub correct: usize,
pub precision: f64,
pub correct_at_1: usize,
pub correct_at_2: usize,
pub correct_at_3: usize,
pub precision_at_1: f64,
pub precision_at_2: f64,
pub precision_at_3: f64,
pub duration_ms: u128,
pub dataset_id: String,
pub dataset_label: String,
pub dataset_includes_unanswerable: bool,
pub dataset_source: String,
pub slice_id: String,
pub slice_seed: u64,
pub slice_total_cases: usize,
pub slice_window_offset: usize,
pub slice_window_length: usize,
pub slice_cases: usize,
pub slice_positive_paragraphs: usize,
pub slice_negative_paragraphs: usize,
pub slice_total_paragraphs: usize,
pub slice_negative_multiplier: f32,
pub namespace_reused: bool,
pub corpus_paragraphs: usize,
pub ingestion_cache_path: String,
pub ingestion_reused: bool,
pub ingestion_embeddings_reused: bool,
pub ingestion_fingerprint: String,
pub positive_paragraphs_reused: usize,
pub negative_paragraphs_reused: usize,
pub latency_ms: LatencyStats,
pub perf: PerformanceTimings,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub rerank_enabled: bool,
pub rerank_pool_size: Option<usize>,
pub rerank_keep_top: usize,
pub concurrency: usize,
pub detailed_report: bool,
pub chunk_vector_take: usize,
pub chunk_fts_take: usize,
pub chunk_token_budget: usize,
pub chunk_avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub cases: Vec<CaseSummary>,
}
#[derive(Debug, Serialize)]
pub struct CaseSummary {
pub question_id: String,
pub question: String,
pub paragraph_id: String,
pub paragraph_title: String,
pub expected_source: String,
pub answers: Vec<String>,
pub matched: bool,
pub entity_match: bool,
pub chunk_text_match: bool,
pub chunk_id_match: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub match_rank: Option<usize>,
pub latency_ms: u128,
pub retrieved: Vec<RetrievedSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyStats {
pub avg: f64,
pub p50: u128,
pub p95: u128,
}
#[derive(Debug, Clone, Serialize)]
pub struct StageLatencyBreakdown {
pub collect_candidates: LatencyStats,
pub graph_expansion: LatencyStats,
pub chunk_attach: LatencyStats,
pub rerank: LatencyStats,
pub assemble: LatencyStats,
}
#[derive(Debug, Default, Clone, Serialize)]
pub struct EvaluationStageTimings {
pub prepare_slice_ms: u128,
pub prepare_db_ms: u128,
pub prepare_corpus_ms: u128,
pub prepare_namespace_ms: u128,
pub run_queries_ms: u128,
pub summarize_ms: u128,
pub finalize_ms: u128,
}
#[derive(Debug, Serialize)]
pub struct PerformanceTimings {
pub openai_base_url: String,
pub ingestion_ms: u128,
#[serde(skip_serializing_if = "Option::is_none")]
pub namespace_seed_ms: Option<u128>,
pub evaluation_stage_ms: EvaluationStageTimings,
pub stage_latency: StageLatencyBreakdown,
}
#[derive(Debug, Serialize)]
pub struct RetrievedSummary {
pub rank: usize,
pub entity_id: String,
pub source_id: String,
pub entity_name: String,
pub score: f32,
pub matched: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_category: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk_text_match: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk_id_match: Option<bool>,
}
#[derive(Debug, Serialize)]
pub(crate) struct CaseDiagnostics {
question_id: String,
question: String,
paragraph_id: String,
paragraph_title: String,
expected_source: String,
expected_chunk_ids: Vec<String>,
answers: Vec<String>,
entity_match: bool,
chunk_text_match: bool,
chunk_id_match: bool,
failure_reasons: Vec<String>,
missing_expected_chunk_ids: Vec<String>,
attached_chunk_ids: Vec<String>,
retrieved: Vec<EntityDiagnostics>,
#[serde(skip_serializing_if = "Option::is_none")]
pipeline: Option<retrieval_pipeline::PipelineDiagnostics>,
}
#[derive(Debug, Serialize)]
struct EntityDiagnostics {
rank: usize,
entity_id: String,
source_id: String,
name: String,
score: f32,
entity_match: bool,
chunk_text_match: bool,
chunk_id_match: bool,
chunks: Vec<ChunkDiagnosticsEntry>,
}
#[derive(Debug, Serialize)]
struct ChunkDiagnosticsEntry {
chunk_id: String,
score: f32,
contains_answer: bool,
expected_chunk: bool,
snippet: String,
}
pub(crate) struct SeededCase {
question_id: String,
question: String,
@@ -213,6 +37,8 @@ pub(crate) struct SeededCase {
paragraph_id: String,
paragraph_title: String,
expected_chunk_ids: Vec<String>,
is_impossible: bool,
has_verified_chunks: bool,
}
pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec<SeededCase> {
@@ -221,10 +47,15 @@ pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec<Seed
title_map.insert(paragraph.paragraph_id.as_str(), paragraph.title.clone());
}
let include_impossible = manifest.metadata.include_unanswerable;
let require_verified_chunks = manifest.metadata.require_verified_chunks;
manifest
.questions
.iter()
.filter(|question| !question.is_impossible)
.filter(|question| {
should_include_question(question, include_impossible, require_verified_chunks)
})
.map(|question| {
let title = title_map
.get(question.paragraph_id.as_str())
@@ -238,66 +69,25 @@ pub(crate) fn cases_from_manifest(manifest: &ingest::CorpusManifest) -> Vec<Seed
paragraph_id: question.paragraph_id.clone(),
paragraph_title: title,
expected_chunk_ids: question.matching_chunk_ids.clone(),
is_impossible: question.is_impossible,
has_verified_chunks: !question.matching_chunk_ids.is_empty(),
}
})
.collect()
}
pub(crate) fn text_contains_answer(text: &str, answers: &[String]) -> bool {
if answers.is_empty() {
return true;
fn should_include_question(
question: &ingest::CorpusQuestion,
include_impossible: bool,
require_verified_chunks: bool,
) -> bool {
if !include_impossible && question.is_impossible {
return false;
}
let haystack = text.to_ascii_lowercase();
answers.iter().any(|needle| haystack.contains(needle))
}
pub(crate) fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
if latencies.is_empty() {
return LatencyStats {
avg: 0.0,
p50: 0,
p95: 0,
};
if require_verified_chunks && question.matching_chunk_ids.is_empty() {
return false;
}
let mut sorted = latencies.to_vec();
sorted.sort_unstable();
let sum: u128 = sorted.iter().copied().sum();
let avg = sum as f64 / (sorted.len() as f64);
let p50 = percentile(&sorted, 0.50);
let p95 = percentile(&sorted, 0.95);
LatencyStats { avg, p50, p95 }
}
pub(crate) fn build_stage_latency_breakdown(
samples: &[PipelineStageTimings],
) -> StageLatencyBreakdown {
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
where
F: Fn(&PipelineStageTimings) -> u128,
{
samples.iter().map(selector).collect()
}
StageLatencyBreakdown {
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
entry.collect_candidates_ms
})),
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
entry.graph_expansion_ms
})),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)),
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)),
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)),
}
}
fn percentile(sorted: &[u128], fraction: f64) -> u128 {
if sorted.is_empty() {
return 0;
}
let clamped = fraction.clamp(0.0, 1.0);
let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize;
sorted[idx.min(sorted.len() - 1)]
true
}
pub async fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
@@ -345,16 +135,16 @@ pub(crate) fn apply_dataset_tuning_overrides(
return;
}
if config.chunk_vector_take.is_none() {
if config.retrieval.chunk_vector_take.is_none() {
tuning.chunk_vector_take = tuning.chunk_vector_take.max(80);
}
if config.chunk_fts_take.is_none() {
if config.retrieval.chunk_fts_take.is_none() {
tuning.chunk_fts_take = tuning.chunk_fts_take.max(80);
}
if config.chunk_token_budget.is_none() {
if config.retrieval.chunk_token_budget.is_none() {
tuning.token_budget_estimate = tuning.token_budget_estimate.max(20_000);
}
if config.max_chunks_per_entity.is_none() {
if config.retrieval.max_chunks_per_entity.is_none() {
tuning.max_chunks_per_entity = tuning.max_chunks_per_entity.max(12);
}
if tuning.lexical_match_weight < 0.25 {
@@ -362,92 +152,6 @@ pub(crate) fn apply_dataset_tuning_overrides(
}
}
pub(crate) fn build_case_diagnostics(
summary: &CaseSummary,
expected_chunk_ids: &[String],
answers_lower: &[String],
entities: &[composite_retrieval::RetrievedEntity],
pipeline_stats: Option<retrieval_pipeline::PipelineDiagnostics>,
) -> CaseDiagnostics {
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect();
let mut seen_chunks: HashSet<String> = HashSet::new();
let mut attached_chunk_ids = Vec::new();
let mut entity_diagnostics = Vec::new();
for (idx, entity) in entities.iter().enumerate() {
let mut chunk_entries = Vec::new();
for chunk in &entity.chunks {
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
let expected_chunk = expected_set.contains(chunk.chunk.id.as_str());
seen_chunks.insert(chunk.chunk.id.clone());
attached_chunk_ids.push(chunk.chunk.id.clone());
chunk_entries.push(ChunkDiagnosticsEntry {
chunk_id: chunk.chunk.id.clone(),
score: chunk.score,
contains_answer,
expected_chunk,
snippet: chunk_preview(&chunk.chunk.chunk),
});
}
entity_diagnostics.push(EntityDiagnostics {
rank: idx + 1,
entity_id: entity.entity.id.clone(),
source_id: entity.entity.source_id.clone(),
name: entity.entity.name.clone(),
score: entity.score,
entity_match: entity.entity.source_id == summary.expected_source,
chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer),
chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk),
chunks: chunk_entries,
});
}
let missing_expected_chunk_ids = expected_chunk_ids
.iter()
.filter(|id| !seen_chunks.contains(id.as_str()))
.cloned()
.collect::<Vec<_>>();
let mut failure_reasons = Vec::new();
if !summary.entity_match {
failure_reasons.push("entity_miss".to_string());
}
if !summary.chunk_text_match {
failure_reasons.push("chunk_text_missing".to_string());
}
if !summary.chunk_id_match {
failure_reasons.push("chunk_id_missing".to_string());
}
if !missing_expected_chunk_ids.is_empty() {
failure_reasons.push("expected_chunk_absent".to_string());
}
CaseDiagnostics {
question_id: summary.question_id.clone(),
question: summary.question.clone(),
paragraph_id: summary.paragraph_id.clone(),
paragraph_title: summary.paragraph_title.clone(),
expected_source: summary.expected_source.clone(),
expected_chunk_ids: expected_chunk_ids.to_vec(),
answers: summary.answers.clone(),
entity_match: summary.entity_match,
chunk_text_match: summary.chunk_text_match,
chunk_id_match: summary.chunk_id_match,
failure_reasons,
missing_expected_chunk_ids,
attached_chunk_ids,
retrieved: entity_diagnostics,
pipeline: pipeline_stats,
}
}
fn chunk_preview(text: &str) -> String {
text.chars()
.take(200)
.collect::<String>()
.replace('\n', " ")
}
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
args::ensure_parent(path)?;
let mut file = tokio::fs::File::create(path)
@@ -765,3 +469,118 @@ pub(crate) async fn load_or_init_system_settings(
Err(err) => Err(err).context("loading system settings"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusParagraph, CorpusQuestion};
use chrono::Utc;
use common::storage::types::text_content::TextContent;
fn sample_manifest() -> CorpusManifest {
let paragraphs = vec![
CorpusParagraph {
paragraph_id: "p1".to_string(),
title: "Alpha".to_string(),
text_content: TextContent::new(
"alpha context".to_string(),
None,
"test".to_string(),
None,
None,
"user".to_string(),
),
entities: Vec::new(),
relationships: Vec::new(),
chunks: Vec::new(),
},
CorpusParagraph {
paragraph_id: "p2".to_string(),
title: "Beta".to_string(),
text_content: TextContent::new(
"beta context".to_string(),
None,
"test".to_string(),
None,
None,
"user".to_string(),
),
entities: Vec::new(),
relationships: Vec::new(),
chunks: Vec::new(),
},
];
let questions = vec![
CorpusQuestion {
question_id: "q1".to_string(),
paragraph_id: "p1".to_string(),
text_content_id: "tc-alpha".to_string(),
question_text: "What is Alpha?".to_string(),
answers: vec!["Alpha".to_string()],
is_impossible: false,
matching_chunk_ids: vec!["chunk-alpha".to_string()],
},
CorpusQuestion {
question_id: "q2".to_string(),
paragraph_id: "p1".to_string(),
text_content_id: "tc-alpha".to_string(),
question_text: "Unanswerable?".to_string(),
answers: Vec::new(),
is_impossible: true,
matching_chunk_ids: Vec::new(),
},
CorpusQuestion {
question_id: "q3".to_string(),
paragraph_id: "p2".to_string(),
text_content_id: "tc-beta".to_string(),
question_text: "Where is Beta?".to_string(),
answers: vec!["Beta".to_string()],
is_impossible: false,
matching_chunk_ids: Vec::new(),
},
];
CorpusManifest {
version: 1,
metadata: CorpusMetadata {
dataset_id: "ds".to_string(),
dataset_label: "Dataset".to_string(),
slice_id: "slice".to_string(),
include_unanswerable: true,
require_verified_chunks: true,
ingestion_fingerprint: "fp".to_string(),
embedding_backend: "test".to_string(),
embedding_model: None,
embedding_dimension: 3,
converted_checksum: "chk".to_string(),
generated_at: Utc::now(),
paragraph_count: paragraphs.len(),
question_count: questions.len(),
},
paragraphs,
questions,
}
}
#[test]
fn cases_respect_mode_filters() {
let mut manifest = sample_manifest();
manifest.metadata.include_unanswerable = false;
manifest.metadata.require_verified_chunks = true;
let strict_cases = cases_from_manifest(&manifest);
assert_eq!(strict_cases.len(), 1);
assert_eq!(strict_cases[0].question_id, "q1");
assert_eq!(strict_cases[0].paragraph_title, "Alpha");
let mut llm_manifest = manifest.clone();
llm_manifest.metadata.include_unanswerable = true;
llm_manifest.metadata.require_verified_chunks = false;
let llm_cases = cases_from_manifest(&llm_manifest);
let ids: Vec<_> = llm_cases
.iter()
.map(|case| case.question_id.as_str())
.collect();
assert_eq!(ids, vec!["q1", "q2", "q3"]);
}
}

View File

@@ -9,7 +9,7 @@ use common::storage::{
db::SurrealDbClient,
types::{system_settings::SystemSettings, user::User},
};
use composite_retrieval::{
use retrieval_pipeline::{
pipeline::{PipelineStageTimings, RetrievalConfig},
reranking::RerankerPool,
};
@@ -52,6 +52,7 @@ pub(super) struct EvaluationContext<'a> {
pub eval_user: Option<User>,
pub corpus_handle: Option<ingest::CorpusHandle>,
pub cases: Vec<SeededCase>,
pub filtered_questions: usize,
pub stage_latency_samples: Vec<PipelineStageTimings>,
pub latencies: Vec<u128>,
pub diagnostics_output: Vec<CaseDiagnostics>,
@@ -94,6 +95,7 @@ impl<'a> EvaluationContext<'a> {
eval_user: None,
corpus_handle: None,
cases: Vec::new(),
filtered_questions: 0,
stage_latency_samples: Vec::new(),
latencies: Vec::new(),
diagnostics_output: Vec::new(),

View File

@@ -128,13 +128,28 @@ pub(crate) async fn prepare_namespace(
let user = ensure_eval_user(ctx.db()).await?;
ctx.eval_user = Some(user);
let cases = cases_from_manifest(&ctx.corpus_handle().manifest);
let corpus_handle = ctx.corpus_handle();
let total_manifest_questions = corpus_handle.manifest.questions.len();
let cases = cases_from_manifest(&corpus_handle.manifest);
let include_impossible = corpus_handle.manifest.metadata.include_unanswerable;
let require_verified_chunks = corpus_handle.manifest.metadata.require_verified_chunks;
let filtered = total_manifest_questions.saturating_sub(cases.len());
if filtered > 0 {
info!(
filtered_questions = filtered,
total_questions = total_manifest_questions,
includes_impossible = include_impossible,
require_verified_chunks = require_verified_chunks,
"Filtered questions not eligible for this evaluation mode (impossible or unverifiable)"
);
}
if cases.is_empty() {
return Err(anyhow!(
"no answerable questions found in converted dataset for evaluation"
"no eligible questions found in converted dataset for evaluation (consider --llm-mode or refreshing ingestion data)"
));
}
ctx.cases = cases;
ctx.filtered_questions = filtered;
ctx.namespace_reused = namespace_reused;
ctx.namespace_seed_ms = namespace_seed_ms;

View File

@@ -1,15 +1,17 @@
use std::{collections::HashSet, sync::Arc, time::Instant};
use anyhow::Context;
use futures::stream::{self, StreamExt, TryStreamExt};
use futures::stream::{self, StreamExt};
use tracing::{debug, info};
use crate::eval::{
apply_dataset_tuning_overrides, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
CaseSummary, RetrievedSummary,
adapt_strategy_output, apply_dataset_tuning_overrides, build_case_diagnostics,
text_contains_answer, CaseDiagnostics, CaseSummary, RetrievedSummary,
};
use retrieval_pipeline::{
pipeline::{self, PipelineStageTimings, RetrievalConfig},
reranking::RerankerPool,
};
use composite_retrieval::pipeline::{self, PipelineStageTimings, RetrievalConfig};
use composite_retrieval::reranking::RerankerPool;
use tokio::sync::Semaphore;
use super::super::{
@@ -38,30 +40,34 @@ pub(crate) async fn run_queries(
let total_cases = ctx.cases.len();
let cases_iter = std::mem::take(&mut ctx.cases).into_iter().enumerate();
let rerank_pool = if config.rerank {
Some(RerankerPool::new(config.rerank_pool_size).context("initialising reranker pool")?)
let rerank_pool = if config.retrieval.rerank {
Some(
RerankerPool::new(config.retrieval.rerank_pool_size)
.context("initialising reranker pool")?,
)
} else {
None
};
let mut retrieval_config = RetrievalConfig::default();
retrieval_config.tuning.rerank_keep_top = config.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.rerank_keep_top;
retrieval_config.strategy = config.retrieval.strategy;
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
}
if let Some(value) = config.chunk_vector_take {
if let Some(value) = config.retrieval.chunk_vector_take {
retrieval_config.tuning.chunk_vector_take = value;
}
if let Some(value) = config.chunk_fts_take {
if let Some(value) = config.retrieval.chunk_fts_take {
retrieval_config.tuning.chunk_fts_take = value;
}
if let Some(value) = config.chunk_token_budget {
if let Some(value) = config.retrieval.chunk_token_budget {
retrieval_config.tuning.token_budget_estimate = value;
}
if let Some(value) = config.chunk_avg_chars_per_token {
if let Some(value) = config.retrieval.chunk_avg_chars_per_token {
retrieval_config.tuning.avg_chars_per_token = value;
}
if let Some(value) = config.max_chunks_per_entity {
if let Some(value) = config.retrieval.max_chunks_per_entity {
retrieval_config.tuning.max_chunks_per_entity = value;
}
@@ -69,9 +75,11 @@ pub(crate) async fn run_queries(
let active_tuning = retrieval_config.tuning.clone();
let effective_chunk_vector = config
.retrieval
.chunk_vector_take
.unwrap_or(active_tuning.chunk_vector_take);
let effective_chunk_fts = config
.retrieval
.chunk_fts_take
.unwrap_or(active_tuning.chunk_fts_take);
@@ -83,11 +91,11 @@ pub(crate) async fn run_queries(
.limit
.unwrap_or(ctx.window_total_cases),
negative_multiplier = %slice_settings.negative_multiplier,
rerank_enabled = config.rerank,
rerank_pool_size = config.rerank_pool_size,
rerank_keep_top = config.rerank_keep_top,
chunk_min = config.chunk_min_chars,
chunk_max = config.chunk_max_chars,
rerank_enabled = config.retrieval.rerank,
rerank_pool_size = config.retrieval.rerank_pool_size,
rerank_keep_top = config.retrieval.rerank_keep_top,
chunk_min = config.retrieval.chunk_min_chars,
chunk_max = config.retrieval.chunk_max_chars,
chunk_vector_take = effective_chunk_vector,
chunk_fts_take = effective_chunk_fts,
chunk_token_budget = active_tuning.token_budget_estimate,
@@ -122,12 +130,7 @@ pub(crate) async fn run_queries(
let db = ctx.db().clone();
let openai_client = ctx.openai_client();
let results: Vec<(
usize,
CaseSummary,
Option<CaseDiagnostics>,
PipelineStageTimings,
)> = stream::iter(cases_iter)
let raw_results = stream::iter(cases_iter)
.map(move |(idx, case)| {
let db = db.clone();
let openai_client = openai_client.clone();
@@ -152,6 +155,8 @@ pub(crate) async fn run_queries(
paragraph_id,
paragraph_title,
expected_chunk_ids,
is_impossible,
has_verified_chunks,
} = case;
let query_start = Instant::now();
@@ -165,7 +170,7 @@ pub(crate) async fn run_queries(
None => None,
};
let (results, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
let (result_output, pipeline_diagnostics, stage_timings) = if diagnostics_enabled {
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
&db,
&openai_client,
@@ -194,26 +199,27 @@ pub(crate) async fn run_queries(
};
let query_latency = query_start.elapsed().as_millis() as u128;
let candidates = adapt_strategy_output(result_output);
let mut retrieved = Vec::new();
let mut match_rank = None;
let answers_lower: Vec<String> =
answers.iter().map(|ans| ans.to_ascii_lowercase()).collect();
let expected_chunk_ids_set: HashSet<&str> =
expected_chunk_ids.iter().map(|id| id.as_str()).collect();
let chunk_id_required = !expected_chunk_ids_set.is_empty();
let chunk_id_required = has_verified_chunks;
let mut entity_hit = false;
let mut chunk_text_hit = false;
let mut chunk_id_hit = !chunk_id_required;
for (idx_entity, entity) in results.iter().enumerate() {
for (idx_entity, candidate) in candidates.iter().enumerate() {
if idx_entity >= config.k {
break;
}
let entity_match = entity.entity.source_id == expected_source;
let entity_match = candidate.source_id == expected_source;
if entity_match {
entity_hit = true;
}
let chunk_text_for_entity = entity
let chunk_text_for_entity = candidate
.chunks
.iter()
.any(|chunk| text_contains_answer(&chunk.chunk.chunk, &answers_lower));
@@ -221,8 +227,8 @@ pub(crate) async fn run_queries(
chunk_text_hit = true;
}
let chunk_id_for_entity = if chunk_id_required {
expected_chunk_ids_set.contains(entity.entity.source_id.as_str())
|| entity.chunks.iter().any(|chunk| {
expected_chunk_ids_set.contains(candidate.source_id.as_str())
|| candidate.chunks.iter().any(|chunk| {
expected_chunk_ids_set.contains(chunk.chunk.id.as_str())
})
} else {
@@ -236,9 +242,11 @@ pub(crate) async fn run_queries(
match_rank = Some(idx_entity + 1);
}
let detail_fields = if config.detailed_report {
let description = candidate.entity_description.clone();
let category = candidate.entity_category.clone();
(
Some(entity.entity.description.clone()),
Some(format!("{:?}", entity.entity.entity_type)),
description,
category,
Some(chunk_text_for_entity),
Some(chunk_id_for_entity),
)
@@ -247,10 +255,10 @@ pub(crate) async fn run_queries(
};
retrieved.push(RetrievedSummary {
rank: idx_entity + 1,
entity_id: entity.entity.id.clone(),
source_id: entity.entity.source_id.clone(),
entity_name: entity.entity.name.clone(),
score: entity.score,
entity_id: candidate.entity_id.clone(),
source_id: candidate.source_id.clone(),
entity_name: candidate.entity_name.clone(),
score: candidate.score,
matched: success,
entity_description: detail_fields.0,
entity_category: detail_fields.1,
@@ -271,6 +279,8 @@ pub(crate) async fn run_queries(
entity_match: entity_hit,
chunk_text_match: chunk_text_hit,
chunk_id_match: chunk_id_hit,
is_impossible,
has_verified_chunks,
match_rank,
latency_ms: query_latency,
retrieved,
@@ -281,7 +291,7 @@ pub(crate) async fn run_queries(
&summary,
&expected_chunk_ids,
&answers_lower,
&results,
&candidates,
pipeline_diagnostics,
))
} else {
@@ -300,8 +310,18 @@ pub(crate) async fn run_queries(
}
})
.buffer_unordered(concurrency)
.try_collect()
.await?;
.collect::<Vec<_>>()
.await;
let mut results = Vec::with_capacity(raw_results.len());
for result in raw_results {
match result {
Ok(val) => results.push(val),
Err(err) => {
tracing::error!(error = ?err, "Query execution failed");
}
}
}
let mut ordered = results;
ordered.sort_by_key(|(idx, ..)| *idx);

View File

@@ -42,7 +42,18 @@ pub(crate) async fn summarize(
let mut correct_at_1 = 0usize;
let mut correct_at_2 = 0usize;
let mut correct_at_3 = 0usize;
let mut retrieval_cases = 0usize;
let mut llm_cases = 0usize;
let mut llm_answered = 0usize;
for summary in &summaries {
if summary.is_impossible {
llm_cases += 1;
if summary.matched {
llm_answered += 1;
}
continue;
}
retrieval_cases += 1;
if summary.matched {
correct += 1;
if let Some(rank) = summary.match_rank {
@@ -62,25 +73,31 @@ pub(crate) async fn summarize(
let latency_stats = compute_latency_stats(&latencies);
let stage_latency = build_stage_latency_breakdown(&stage_latency_samples);
let precision = if total_cases == 0 {
let retrieval_precision = if retrieval_cases == 0 {
0.0
} else {
(correct as f64) / (total_cases as f64)
(correct as f64) / (retrieval_cases as f64)
};
let precision_at_1 = if total_cases == 0 {
let llm_precision = if llm_cases == 0 {
0.0
} else {
(correct_at_1 as f64) / (total_cases as f64)
(llm_answered as f64) / (llm_cases as f64)
};
let precision_at_2 = if total_cases == 0 {
let precision = retrieval_precision;
let precision_at_1 = if retrieval_cases == 0 {
0.0
} else {
(correct_at_2 as f64) / (total_cases as f64)
(correct_at_1 as f64) / (retrieval_cases as f64)
};
let precision_at_3 = if total_cases == 0 {
let precision_at_2 = if retrieval_cases == 0 {
0.0
} else {
(correct_at_3 as f64) / (total_cases as f64)
(correct_at_2 as f64) / (retrieval_cases as f64)
};
let precision_at_3 = if retrieval_cases == 0 {
0.0
} else {
(correct_at_3 as f64) / (retrieval_cases as f64)
};
let active_tuning = ctx
@@ -119,6 +136,15 @@ pub(crate) async fn summarize(
dataset_label: dataset.metadata.label.clone(),
dataset_includes_unanswerable: dataset.metadata.include_unanswerable,
dataset_source: dataset.source.clone(),
includes_impossible_cases: slice.manifest.includes_unanswerable,
require_verified_chunks: slice.manifest.require_verified_chunks,
filtered_questions: ctx.filtered_questions,
retrieval_cases,
retrieval_correct: correct,
retrieval_precision,
llm_cases,
llm_answered,
llm_precision,
slice_id: slice.manifest.slice_id.clone(),
slice_seed: slice.manifest.seed,
slice_total_cases: slice.manifest.case_count,
@@ -146,11 +172,15 @@ pub(crate) async fn summarize(
embedding_backend: ctx.embedding_provider().backend_label().to_string(),
embedding_model: ctx.embedding_provider().model_code(),
embedding_dimension: ctx.embedding_provider().dimension(),
rerank_enabled: config.rerank,
rerank_pool_size: ctx.rerank_pool.as_ref().map(|_| config.rerank_pool_size),
rerank_keep_top: config.rerank_keep_top,
rerank_enabled: config.retrieval.rerank,
rerank_pool_size: ctx
.rerank_pool
.as_ref()
.map(|_| config.retrieval.rerank_pool_size),
rerank_keep_top: config.retrieval.rerank_keep_top,
concurrency: config.concurrency.max(1),
detailed_report: config.detailed_report,
retrieval_strategy: config.retrieval.strategy.to_string(),
chunk_vector_take: active_tuning.chunk_vector_take,
chunk_fts_take: active_tuning.chunk_fts_take,
chunk_token_budget: active_tuning.token_budget_estimate,

396
eval/src/eval/types.rs Normal file
View File

@@ -0,0 +1,396 @@
use std::collections::HashSet;
use chrono::{DateTime, Utc};
use retrieval_pipeline::{
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize)]
pub struct EvaluationSummary {
pub generated_at: DateTime<Utc>,
pub k: usize,
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub run_label: Option<String>,
pub total_cases: usize,
pub correct: usize,
pub precision: f64,
pub correct_at_1: usize,
pub correct_at_2: usize,
pub correct_at_3: usize,
pub precision_at_1: f64,
pub precision_at_2: f64,
pub precision_at_3: f64,
pub duration_ms: u128,
pub dataset_id: String,
pub dataset_label: String,
pub dataset_includes_unanswerable: bool,
pub dataset_source: String,
pub includes_impossible_cases: bool,
pub require_verified_chunks: bool,
pub filtered_questions: usize,
pub retrieval_cases: usize,
pub retrieval_correct: usize,
pub retrieval_precision: f64,
pub llm_cases: usize,
pub llm_answered: usize,
pub llm_precision: f64,
pub slice_id: String,
pub slice_seed: u64,
pub slice_total_cases: usize,
pub slice_window_offset: usize,
pub slice_window_length: usize,
pub slice_cases: usize,
pub slice_positive_paragraphs: usize,
pub slice_negative_paragraphs: usize,
pub slice_total_paragraphs: usize,
pub slice_negative_multiplier: f32,
pub namespace_reused: bool,
pub corpus_paragraphs: usize,
pub ingestion_cache_path: String,
pub ingestion_reused: bool,
pub ingestion_embeddings_reused: bool,
pub ingestion_fingerprint: String,
pub positive_paragraphs_reused: usize,
pub negative_paragraphs_reused: usize,
pub latency_ms: LatencyStats,
pub perf: PerformanceTimings,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub rerank_enabled: bool,
pub rerank_pool_size: Option<usize>,
pub rerank_keep_top: usize,
pub concurrency: usize,
pub detailed_report: bool,
pub retrieval_strategy: String,
pub chunk_vector_take: usize,
pub chunk_fts_take: usize,
pub chunk_token_budget: usize,
pub chunk_avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub cases: Vec<CaseSummary>,
}
#[derive(Debug, Serialize)]
pub struct CaseSummary {
pub question_id: String,
pub question: String,
pub paragraph_id: String,
pub paragraph_title: String,
pub expected_source: String,
pub answers: Vec<String>,
pub matched: bool,
pub entity_match: bool,
pub chunk_text_match: bool,
pub chunk_id_match: bool,
pub is_impossible: bool,
pub has_verified_chunks: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub match_rank: Option<usize>,
pub latency_ms: u128,
pub retrieved: Vec<RetrievedSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyStats {
pub avg: f64,
pub p50: u128,
pub p95: u128,
}
#[derive(Debug, Clone, Serialize)]
pub struct StageLatencyBreakdown {
pub embed: LatencyStats,
pub collect_candidates: LatencyStats,
pub graph_expansion: LatencyStats,
pub chunk_attach: LatencyStats,
pub rerank: LatencyStats,
pub assemble: LatencyStats,
}
#[derive(Debug, Default, Clone, Serialize)]
pub struct EvaluationStageTimings {
pub prepare_slice_ms: u128,
pub prepare_db_ms: u128,
pub prepare_corpus_ms: u128,
pub prepare_namespace_ms: u128,
pub run_queries_ms: u128,
pub summarize_ms: u128,
pub finalize_ms: u128,
}
#[derive(Debug, Serialize)]
pub struct PerformanceTimings {
pub openai_base_url: String,
pub ingestion_ms: u128,
#[serde(skip_serializing_if = "Option::is_none")]
pub namespace_seed_ms: Option<u128>,
pub evaluation_stage_ms: EvaluationStageTimings,
pub stage_latency: StageLatencyBreakdown,
}
#[derive(Debug, Serialize)]
pub struct RetrievedSummary {
pub rank: usize,
pub entity_id: String,
pub source_id: String,
pub entity_name: String,
pub score: f32,
pub matched: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_category: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk_text_match: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chunk_id_match: Option<bool>,
}
#[derive(Debug, Clone)]
pub struct EvaluationCandidate {
pub entity_id: String,
pub source_id: String,
pub entity_name: String,
pub entity_description: Option<String>,
pub entity_category: Option<String>,
pub score: f32,
pub chunks: Vec<RetrievedChunk>,
}
impl EvaluationCandidate {
fn from_entity(entity: RetrievedEntity) -> Self {
let entity_category = Some(format!("{:?}", entity.entity.entity_type));
Self {
entity_id: entity.entity.id.clone(),
source_id: entity.entity.source_id.clone(),
entity_name: entity.entity.name.clone(),
entity_description: Some(entity.entity.description.clone()),
entity_category,
score: entity.score,
chunks: entity.chunks,
}
}
fn from_chunk(chunk: RetrievedChunk) -> Self {
let snippet = chunk_snippet(&chunk.chunk.chunk);
Self {
entity_id: chunk.chunk.id.clone(),
source_id: chunk.chunk.source_id.clone(),
entity_name: chunk.chunk.source_id.clone(),
entity_description: Some(snippet),
entity_category: Some("Chunk".to_string()),
score: chunk.score,
chunks: vec![chunk],
}
}
}
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
match output {
StrategyOutput::Entities(entities) => entities
.into_iter()
.map(EvaluationCandidate::from_entity)
.collect(),
StrategyOutput::Chunks(chunks) => chunks
.into_iter()
.map(EvaluationCandidate::from_chunk)
.collect(),
}
}
#[derive(Debug, Serialize)]
pub struct CaseDiagnostics {
pub question_id: String,
pub question: String,
pub paragraph_id: String,
pub paragraph_title: String,
pub expected_source: String,
pub expected_chunk_ids: Vec<String>,
pub answers: Vec<String>,
pub entity_match: bool,
pub chunk_text_match: bool,
pub chunk_id_match: bool,
pub failure_reasons: Vec<String>,
pub missing_expected_chunk_ids: Vec<String>,
pub attached_chunk_ids: Vec<String>,
pub retrieved: Vec<EntityDiagnostics>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pipeline: Option<PipelineDiagnostics>,
}
#[derive(Debug, Serialize)]
pub struct EntityDiagnostics {
pub rank: usize,
pub entity_id: String,
pub source_id: String,
pub name: String,
pub score: f32,
pub entity_match: bool,
pub chunk_text_match: bool,
pub chunk_id_match: bool,
pub chunks: Vec<ChunkDiagnosticsEntry>,
}
#[derive(Debug, Serialize)]
pub struct ChunkDiagnosticsEntry {
pub chunk_id: String,
pub score: f32,
pub contains_answer: bool,
pub expected_chunk: bool,
pub snippet: String,
}
pub fn text_contains_answer(text: &str, answers: &[String]) -> bool {
if answers.is_empty() {
return true;
}
let haystack = text.to_ascii_lowercase();
answers.iter().any(|needle| haystack.contains(needle))
}
fn chunk_snippet(text: &str) -> String {
const MAX_CHARS: usize = 160;
let trimmed = text.trim();
if trimmed.chars().count() <= MAX_CHARS {
return trimmed.to_string();
}
let mut acc = String::with_capacity(MAX_CHARS + 3);
for (idx, ch) in trimmed.chars().enumerate() {
if idx >= MAX_CHARS {
break;
}
acc.push(ch);
}
acc.push_str("...");
acc
}
pub fn compute_latency_stats(latencies: &[u128]) -> LatencyStats {
if latencies.is_empty() {
return LatencyStats {
avg: 0.0,
p50: 0,
p95: 0,
};
}
let mut sorted = latencies.to_vec();
sorted.sort_unstable();
let sum: u128 = sorted.iter().copied().sum();
let avg = sum as f64 / (sorted.len() as f64);
let p50 = percentile(&sorted, 0.50);
let p95 = percentile(&sorted, 0.95);
LatencyStats { avg, p50, p95 }
}
pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageLatencyBreakdown {
fn collect_stage<F>(samples: &[PipelineStageTimings], selector: F) -> Vec<u128>
where
F: Fn(&PipelineStageTimings) -> u128,
{
samples.iter().map(selector).collect()
}
StageLatencyBreakdown {
embed: compute_latency_stats(&collect_stage(samples, |entry| entry.embed_ms)),
collect_candidates: compute_latency_stats(&collect_stage(samples, |entry| {
entry.collect_candidates_ms
})),
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
entry.graph_expansion_ms
})),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms)),
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms)),
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms)),
}
}
fn percentile(sorted: &[u128], fraction: f64) -> u128 {
if sorted.is_empty() {
return 0;
}
let clamped = fraction.clamp(0.0, 1.0);
let idx = (clamped * (sorted.len() as f64 - 1.0)).round() as usize;
sorted[idx.min(sorted.len() - 1)]
}
pub fn build_case_diagnostics(
summary: &CaseSummary,
expected_chunk_ids: &[String],
answers_lower: &[String],
candidates: &[EvaluationCandidate],
pipeline_stats: Option<PipelineDiagnostics>,
) -> CaseDiagnostics {
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(|id| id.as_str()).collect();
let mut seen_chunks: HashSet<String> = HashSet::new();
let mut attached_chunk_ids = Vec::new();
let mut entity_diagnostics = Vec::new();
for (idx, candidate) in candidates.iter().enumerate() {
let mut chunk_entries = Vec::new();
for chunk in &candidate.chunks {
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
let expected_chunk = expected_set.contains(chunk.chunk.id.as_str());
seen_chunks.insert(chunk.chunk.id.clone());
attached_chunk_ids.push(chunk.chunk.id.clone());
chunk_entries.push(ChunkDiagnosticsEntry {
chunk_id: chunk.chunk.id.clone(),
score: chunk.score,
contains_answer,
expected_chunk,
snippet: chunk_snippet(&chunk.chunk.chunk),
});
}
entity_diagnostics.push(EntityDiagnostics {
rank: idx + 1,
entity_id: candidate.entity_id.clone(),
source_id: candidate.source_id.clone(),
name: candidate.entity_name.clone(),
score: candidate.score,
entity_match: candidate.source_id == summary.expected_source,
chunk_text_match: chunk_entries.iter().any(|entry| entry.contains_answer),
chunk_id_match: chunk_entries.iter().any(|entry| entry.expected_chunk),
chunks: chunk_entries,
});
}
let missing_expected_chunk_ids = expected_chunk_ids
.iter()
.filter(|id| !seen_chunks.contains(id.as_str()))
.cloned()
.collect::<Vec<_>>();
let mut failure_reasons = Vec::new();
if !summary.entity_match {
failure_reasons.push("entity_miss".to_string());
}
if !summary.chunk_text_match {
failure_reasons.push("chunk_text_missing".to_string());
}
if !summary.chunk_id_match {
failure_reasons.push("chunk_id_missing".to_string());
}
if !missing_expected_chunk_ids.is_empty() {
failure_reasons.push("expected_chunk_absent".to_string());
}
CaseDiagnostics {
question_id: summary.question_id.clone(),
question: summary.question.clone(),
paragraph_id: summary.paragraph_id.clone(),
paragraph_title: summary.paragraph_title.clone(),
expected_source: summary.expected_source.clone(),
expected_chunk_ids: expected_chunk_ids.to_vec(),
answers: summary.answers.clone(),
entity_match: summary.entity_match,
chunk_text_match: summary.chunk_text_match,
chunk_id_match: summary.chunk_id_match,
failure_reasons,
missing_expected_chunk_ids,
attached_chunk_ids,
retrieved: entity_diagnostics,
pipeline: pipeline_stats,
}
}

72
eval/src/ingest/config.rs Normal file
View File

@@ -0,0 +1,72 @@
use std::path::PathBuf;
use anyhow::Result;
use async_trait::async_trait;
use crate::{args::Config, embedding::EmbeddingProvider};
#[derive(Debug, Clone)]
pub struct CorpusCacheConfig {
pub ingestion_cache_dir: PathBuf,
pub force_refresh: bool,
pub refresh_embeddings_only: bool,
pub ingestion_batch_size: usize,
pub ingestion_max_retries: usize,
}
impl CorpusCacheConfig {
pub fn new(
ingestion_cache_dir: impl Into<PathBuf>,
force_refresh: bool,
refresh_embeddings_only: bool,
ingestion_batch_size: usize,
ingestion_max_retries: usize,
) -> Self {
Self {
ingestion_cache_dir: ingestion_cache_dir.into(),
force_refresh,
refresh_embeddings_only,
ingestion_batch_size,
ingestion_max_retries,
}
}
}
#[async_trait]
pub trait CorpusEmbeddingProvider: Send + Sync {
fn backend_label(&self) -> &str;
fn model_code(&self) -> Option<String>;
fn dimension(&self) -> usize;
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
}
#[async_trait]
impl CorpusEmbeddingProvider for EmbeddingProvider {
fn backend_label(&self) -> &str {
EmbeddingProvider::backend_label(self)
}
fn model_code(&self) -> Option<String> {
EmbeddingProvider::model_code(self)
}
fn dimension(&self) -> usize {
EmbeddingProvider::dimension(self)
}
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
EmbeddingProvider::embed_batch(self, texts).await
}
}
impl From<&Config> for CorpusCacheConfig {
fn from(config: &Config) -> Self {
CorpusCacheConfig::new(
config.ingestion_cache_dir.clone(),
config.force_convert || config.slice_reset_ingestion,
config.refresh_embeddings_only,
config.ingestion_batch_size,
config.ingestion_max_retries,
)
}
}

10
eval/src/ingest/mod.rs Normal file
View File

@@ -0,0 +1,10 @@
mod config;
mod orchestrator;
mod store;
pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider};
pub use orchestrator::ensure_corpus;
pub use store::{
seed_manifest_into_db, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
};

View File

@@ -1,23 +1,21 @@
use std::{
collections::{HashMap, HashSet},
fs,
io::{BufReader, Read},
path::{Path, PathBuf},
io::Read,
path::Path,
sync::Arc,
};
use anyhow::{anyhow, Context, Result};
use async_openai::Client;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use chrono::Utc;
use common::{
storage::{
db::SurrealDbClient,
store::{DynStore, StorageManager},
types::{
ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk, text_content::TextContent,
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk,
},
},
utils::config::{AppConfig, StorageKind},
@@ -30,274 +28,19 @@ use tracing::{info, warn};
use uuid::Uuid;
use crate::{
args::Config,
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion},
embedding::EmbeddingProvider,
slices::{self, ResolvedSlice, SliceParagraphKind},
};
const MANIFEST_VERSION: u32 = 1;
use crate::ingest::{
CorpusCacheConfig, CorpusEmbeddingProvider, CorpusHandle, CorpusManifest, CorpusMetadata,
CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
};
const INGESTION_SPEC_VERSION: u32 = 1;
const INGESTION_MAX_RETRIES: usize = 3;
const INGESTION_BATCH_SIZE: usize = 5;
const PARAGRAPH_SHARD_VERSION: u32 = 1;
#[derive(Debug, Clone)]
pub struct CorpusCacheConfig {
pub ingestion_cache_dir: PathBuf,
pub force_refresh: bool,
pub refresh_embeddings_only: bool,
}
impl CorpusCacheConfig {
pub fn new(
ingestion_cache_dir: impl Into<PathBuf>,
force_refresh: bool,
refresh_embeddings_only: bool,
) -> Self {
Self {
ingestion_cache_dir: ingestion_cache_dir.into(),
force_refresh,
refresh_embeddings_only,
}
}
}
#[async_trait]
pub trait CorpusEmbeddingProvider: Send + Sync {
fn backend_label(&self) -> &str;
fn model_code(&self) -> Option<String>;
fn dimension(&self) -> usize;
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
}
type OpenAIClient = Client<async_openai::config::OpenAIConfig>;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusManifest {
pub version: u32,
pub metadata: CorpusMetadata,
pub paragraphs: Vec<CorpusParagraph>,
pub questions: Vec<CorpusQuestion>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusMetadata {
pub dataset_id: String,
pub dataset_label: String,
pub slice_id: String,
pub include_unanswerable: bool,
pub ingestion_fingerprint: String,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub converted_checksum: String,
pub generated_at: DateTime<Utc>,
pub paragraph_count: usize,
pub question_count: usize,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusParagraph {
pub paragraph_id: String,
pub title: String,
pub text_content: TextContent,
pub entities: Vec<KnowledgeEntity>,
pub relationships: Vec<KnowledgeRelationship>,
pub chunks: Vec<TextChunk>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusQuestion {
pub question_id: String,
pub paragraph_id: String,
pub text_content_id: String,
pub question_text: String,
pub answers: Vec<String>,
pub is_impossible: bool,
pub matching_chunk_ids: Vec<String>,
}
pub struct CorpusHandle {
pub manifest: CorpusManifest,
pub path: PathBuf,
pub reused_ingestion: bool,
pub reused_embeddings: bool,
pub positive_reused: usize,
pub positive_ingested: usize,
pub negative_reused: usize,
pub negative_ingested: usize,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct ParagraphShard {
version: u32,
paragraph_id: String,
shard_path: String,
ingestion_fingerprint: String,
ingested_at: DateTime<Utc>,
title: String,
text_content: TextContent,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
chunks: Vec<TextChunk>,
#[serde(default)]
question_bindings: HashMap<String, Vec<String>>,
#[serde(default)]
embedding_backend: String,
#[serde(default)]
embedding_model: Option<String>,
#[serde(default)]
embedding_dimension: usize,
}
struct ParagraphShardStore {
base_dir: PathBuf,
}
impl ParagraphShardStore {
fn new(base_dir: PathBuf) -> Self {
Self { base_dir }
}
fn ensure_base_dir(&self) -> Result<()> {
fs::create_dir_all(&self.base_dir)
.with_context(|| format!("creating shard base dir {}", self.base_dir.display()))
}
fn resolve(&self, relative: &str) -> PathBuf {
self.base_dir.join(relative)
}
fn load(&self, relative: &str, fingerprint: &str) -> Result<Option<ParagraphShard>> {
let path = self.resolve(relative);
let file = match fs::File::open(&path) {
Ok(file) => file,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(err) => {
return Err(err).with_context(|| format!("opening shard {}", path.display()))
}
};
let reader = BufReader::new(file);
let mut shard: ParagraphShard = serde_json::from_reader(reader)
.with_context(|| format!("parsing shard {}", path.display()))?;
if shard.version != PARAGRAPH_SHARD_VERSION {
warn!(
path = %path.display(),
version = shard.version,
expected = PARAGRAPH_SHARD_VERSION,
"Skipping shard due to version mismatch"
);
return Ok(None);
}
if shard.ingestion_fingerprint != fingerprint {
return Ok(None);
}
shard.shard_path = relative.to_string();
Ok(Some(shard))
}
fn persist(&self, shard: &ParagraphShard) -> Result<()> {
let path = self.resolve(&shard.shard_path);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating shard dir {}", parent.display()))?;
}
let tmp_path = path.with_extension("json.tmp");
let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?;
fs::write(&tmp_path, &body)
.with_context(|| format!("writing shard tmp {}", tmp_path.display()))?;
fs::rename(&tmp_path, &path)
.with_context(|| format!("renaming shard tmp {}", path.display()))?;
Ok(())
}
}
#[async_trait]
impl CorpusEmbeddingProvider for EmbeddingProvider {
fn backend_label(&self) -> &str {
EmbeddingProvider::backend_label(self)
}
fn model_code(&self) -> Option<String> {
EmbeddingProvider::model_code(self)
}
fn dimension(&self) -> usize {
EmbeddingProvider::dimension(self)
}
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
EmbeddingProvider::embed_batch(self, texts).await
}
}
impl From<&Config> for CorpusCacheConfig {
fn from(config: &Config) -> Self {
CorpusCacheConfig::new(
config.ingestion_cache_dir.clone(),
config.force_convert || config.slice_reset_ingestion,
config.refresh_embeddings_only,
)
}
}
impl ParagraphShard {
fn new(
paragraph: &ConvertedParagraph,
shard_path: String,
ingestion_fingerprint: &str,
text_content: TextContent,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
chunks: Vec<TextChunk>,
embedding_backend: &str,
embedding_model: Option<String>,
embedding_dimension: usize,
) -> Self {
Self {
version: PARAGRAPH_SHARD_VERSION,
paragraph_id: paragraph.id.clone(),
shard_path,
ingestion_fingerprint: ingestion_fingerprint.to_string(),
ingested_at: Utc::now(),
title: paragraph.title.clone(),
text_content,
entities,
relationships,
chunks,
question_bindings: HashMap::new(),
embedding_backend: embedding_backend.to_string(),
embedding_model,
embedding_dimension,
}
}
fn to_corpus_paragraph(&self) -> CorpusParagraph {
CorpusParagraph {
paragraph_id: self.paragraph_id.clone(),
title: self.title.clone(),
text_content: self.text_content.clone(),
entities: self.entities.clone(),
relationships: self.relationships.clone(),
chunks: self.chunks.clone(),
}
}
fn ensure_question_binding(
&mut self,
question: &ConvertedQuestion,
) -> Result<(Vec<String>, bool)> {
if let Some(existing) = self.question_bindings.get(&question.id) {
return Ok((existing.clone(), false));
}
let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?;
self.question_bindings
.insert(question.id.clone(), chunk_ids.clone());
Ok((chunk_ids, true))
}
}
#[derive(Clone)]
struct ParagraphShardRecord {
shard: ParagraphShard,
@@ -390,6 +133,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
store.ensure_base_dir()?;
let positive_set: HashSet<&str> = window.positive_ids().collect();
let require_verified_chunks = slice.manifest.require_verified_chunks;
let embedding_backend_label = embedding.backend_label().to_string();
let embedding_model_code = embedding.model_code();
let embedding_dimension = embedding.dimension();
@@ -487,6 +231,8 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
&embedding_backend_label,
embedding_model_code.clone(),
embedding_dimension,
cache.ingestion_batch_size,
cache.ingestion_max_retries,
)
.await
.context("ingesting missing slice paragraphs")?;
@@ -548,6 +294,12 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
let (chunk_ids, updated) = match record.shard.ensure_question_binding(case.question) {
Ok(result) => result,
Err(err) => {
if require_verified_chunks {
return Err(err).context(format!(
"locating answer text for question '{}' in paragraph '{}'",
case.question.id, case.paragraph.id
));
}
warn!(
question_id = %case.question.id,
paragraph_id = %case.paragraph.id,
@@ -591,6 +343,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
dataset_label: dataset.metadata.label.clone(),
slice_id: slice.manifest.slice_id.clone(),
include_unanswerable: slice.manifest.includes_unanswerable,
require_verified_chunks: slice.manifest.require_verified_chunks,
ingestion_fingerprint: ingestion_fingerprint.clone(),
embedding_backend: embedding.backend_label().to_string(),
embedding_model: embedding.model_code(),
@@ -681,6 +434,8 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
embedding_backend: &str,
embedding_model: Option<String>,
embedding_dimension: usize,
batch_size: usize,
max_retries: usize,
) -> Result<Vec<ParagraphShard>> {
if targets.is_empty() {
return Ok(Vec::new());
@@ -704,7 +459,7 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
db,
openai.clone(),
app_config,
None::<Arc<composite_retrieval::reranking::RerankerPool>>,
None::<Arc<retrieval_pipeline::reranking::RerankerPool>>,
storage,
)
.await?;
@@ -712,11 +467,11 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
let mut shards = Vec::with_capacity(targets.len());
let category = dataset.metadata.category.clone();
for (batch_index, batch) in targets.chunks(INGESTION_BATCH_SIZE).enumerate() {
for (batch_index, batch) in targets.chunks(batch_size).enumerate() {
info!(
batch = batch_index,
batch_size = batch.len(),
total_batches = (targets.len() + INGESTION_BATCH_SIZE - 1) / INGESTION_BATCH_SIZE,
total_batches = (targets.len() + batch_size - 1) / batch_size,
"Ingesting paragraph batch"
);
let model_clone = embedding_model.clone();
@@ -734,6 +489,7 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
backend_clone.clone(),
model_clone.clone(),
embedding_dimension,
max_retries,
)
});
let batch_results: Vec<ParagraphShard> = try_join_all(tasks)
@@ -755,10 +511,11 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
embedding_backend: String,
embedding_model: Option<String>,
embedding_dimension: usize,
max_retries: usize,
) -> Result<ParagraphShard> {
let paragraph = request.paragraph;
let mut last_err: Option<anyhow::Error> = None;
for attempt in 1..=INGESTION_MAX_RETRIES {
for attempt in 1..=max_retries {
let payload = IngestionPayload::Text {
text: paragraph.context.clone(),
context: paragraph.title.clone(),
@@ -801,7 +558,7 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
warn!(
paragraph_id = %paragraph.id,
attempt,
max_attempts = INGESTION_MAX_RETRIES,
max_attempts = max_retries,
error = ?err,
"ingestion attempt failed for paragraph; retrying"
);
@@ -815,49 +572,6 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
.context(format!("running ingestion for paragraph {}", paragraph.id)))
}
fn validate_answers(
content: &TextContent,
chunks: &[TextChunk],
question: &ConvertedQuestion,
) -> Result<Vec<String>> {
if question.is_impossible || question.answers.is_empty() {
return Ok(Vec::new());
}
let mut matches = std::collections::BTreeSet::new();
let mut found_any = false;
let haystack = content.text.to_ascii_lowercase();
let haystack_norm = normalize_answer_text(&haystack);
for answer in &question.answers {
let needle: String = answer.to_ascii_lowercase();
let needle_norm = normalize_answer_text(&needle);
let text_match = haystack.contains(&needle)
|| (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm));
if text_match {
found_any = true;
}
for chunk in chunks {
let chunk_text = chunk.chunk.to_ascii_lowercase();
let chunk_norm = normalize_answer_text(&chunk_text);
if chunk_text.contains(&needle)
|| (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm))
{
matches.insert(chunk.id.clone());
found_any = true;
}
}
}
if !found_any {
Err(anyhow!(
"expected answer for question '{}' was not found in ingested content",
question.id
))
} else {
Ok(matches.into_iter().collect())
}
}
fn build_ingestion_fingerprint(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
@@ -894,107 +608,3 @@ fn compute_file_checksum(path: &Path) -> Result<String> {
}
Ok(format!("{:x}", hasher.finalize()))
}
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
for paragraph in &manifest.paragraphs {
db.store_item(paragraph.text_content.clone())
.await
.context("storing text_content from manifest")?;
for entity in &paragraph.entities {
db.store_item(entity.clone())
.await
.context("storing knowledge_entity from manifest")?;
}
for relationship in &paragraph.relationships {
relationship
.store_relationship(db)
.await
.context("storing knowledge_relationship from manifest")?;
}
for chunk in &paragraph.chunks {
db.store_item(chunk.clone())
.await
.context("storing text_chunk from manifest")?;
}
}
Ok(())
}
fn normalize_answer_text(text: &str) -> String {
text.chars()
.map(|ch| {
if ch.is_alphanumeric() || ch.is_whitespace() {
ch.to_ascii_lowercase()
} else {
' '
}
})
.collect::<String>()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasets::ConvertedQuestion;
fn mock_text_content() -> TextContent {
TextContent {
id: "tc1".into(),
created_at: Utc::now(),
updated_at: Utc::now(),
text: "alpha beta gamma".into(),
file_info: None,
url_info: None,
context: Some("ctx".into()),
category: "cat".into(),
user_id: "user".into(),
}
}
fn mock_chunk(id: &str, text: &str) -> TextChunk {
TextChunk {
id: id.into(),
created_at: Utc::now(),
updated_at: Utc::now(),
source_id: "src".into(),
chunk: text.into(),
embedding: vec![],
user_id: "user".into(),
}
}
#[test]
fn validate_answers_passes_when_present() {
let content = mock_text_content();
let chunk = mock_chunk("chunk1", "alpha chunk");
let question = ConvertedQuestion {
id: "q1".into(),
question: "?".into(),
answers: vec!["Alpha".into()],
is_impossible: false,
};
let matches = validate_answers(&content, &[chunk], &question).expect("answers match");
assert_eq!(matches, vec!["chunk1".to_string()]);
}
#[test]
fn validate_answers_fails_when_missing() {
let question = ConvertedQuestion {
id: "q1".into(),
question: "?".into(),
answers: vec!["delta".into()],
is_impossible: false,
};
let err = validate_answers(
&mock_text_content(),
&[mock_chunk("chunk", "alpha")],
&question,
)
.expect_err("missing answer should fail");
assert!(err.to_string().contains("not found"));
}
}

299
eval/src/ingest/store.rs Normal file
View File

@@ -0,0 +1,299 @@
use std::{collections::HashMap, fs, io::BufReader, path::PathBuf};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk, text_content::TextContent,
},
};
use tracing::warn;
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
pub const MANIFEST_VERSION: u32 = 1;
pub const PARAGRAPH_SHARD_VERSION: u32 = 1;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusManifest {
pub version: u32,
pub metadata: CorpusMetadata,
pub paragraphs: Vec<CorpusParagraph>,
pub questions: Vec<CorpusQuestion>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusMetadata {
pub dataset_id: String,
pub dataset_label: String,
pub slice_id: String,
pub include_unanswerable: bool,
#[serde(default)]
pub require_verified_chunks: bool,
pub ingestion_fingerprint: String,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub converted_checksum: String,
pub generated_at: DateTime<Utc>,
pub paragraph_count: usize,
pub question_count: usize,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusParagraph {
pub paragraph_id: String,
pub title: String,
pub text_content: TextContent,
pub entities: Vec<KnowledgeEntity>,
pub relationships: Vec<KnowledgeRelationship>,
pub chunks: Vec<TextChunk>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusQuestion {
pub question_id: String,
pub paragraph_id: String,
pub text_content_id: String,
pub question_text: String,
pub answers: Vec<String>,
pub is_impossible: bool,
pub matching_chunk_ids: Vec<String>,
}
pub struct CorpusHandle {
pub manifest: CorpusManifest,
pub path: PathBuf,
pub reused_ingestion: bool,
pub reused_embeddings: bool,
pub positive_reused: usize,
pub positive_ingested: usize,
pub negative_reused: usize,
pub negative_ingested: usize,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ParagraphShard {
pub version: u32,
pub paragraph_id: String,
pub shard_path: String,
pub ingestion_fingerprint: String,
pub ingested_at: DateTime<Utc>,
pub title: String,
pub text_content: TextContent,
pub entities: Vec<KnowledgeEntity>,
pub relationships: Vec<KnowledgeRelationship>,
pub chunks: Vec<TextChunk>,
#[serde(default)]
pub question_bindings: HashMap<String, Vec<String>>,
#[serde(default)]
pub embedding_backend: String,
#[serde(default)]
pub embedding_model: Option<String>,
#[serde(default)]
pub embedding_dimension: usize,
}
pub struct ParagraphShardStore {
base_dir: PathBuf,
}
impl ParagraphShardStore {
pub fn new(base_dir: PathBuf) -> Self {
Self { base_dir }
}
pub fn ensure_base_dir(&self) -> Result<()> {
fs::create_dir_all(&self.base_dir)
.with_context(|| format!("creating shard base dir {}", self.base_dir.display()))
}
fn resolve(&self, relative: &str) -> PathBuf {
self.base_dir.join(relative)
}
pub fn load(&self, relative: &str, fingerprint: &str) -> Result<Option<ParagraphShard>> {
let path = self.resolve(relative);
let file = match fs::File::open(&path) {
Ok(file) => file,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(err) => {
return Err(err).with_context(|| format!("opening shard {}", path.display()))
}
};
let reader = BufReader::new(file);
let mut shard: ParagraphShard = serde_json::from_reader(reader)
.with_context(|| format!("parsing shard {}", path.display()))?;
if shard.version != PARAGRAPH_SHARD_VERSION {
warn!(
path = %path.display(),
version = shard.version,
expected = PARAGRAPH_SHARD_VERSION,
"Skipping shard due to version mismatch"
);
return Ok(None);
}
if shard.ingestion_fingerprint != fingerprint {
return Ok(None);
}
shard.shard_path = relative.to_string();
Ok(Some(shard))
}
pub fn persist(&self, shard: &ParagraphShard) -> Result<()> {
let path = self.resolve(&shard.shard_path);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating shard dir {}", parent.display()))?;
}
let tmp_path = path.with_extension("json.tmp");
let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?;
fs::write(&tmp_path, &body)
.with_context(|| format!("writing shard tmp {}", tmp_path.display()))?;
fs::rename(&tmp_path, &path)
.with_context(|| format!("renaming shard tmp {}", path.display()))?;
Ok(())
}
}
impl ParagraphShard {
pub fn new(
paragraph: &ConvertedParagraph,
shard_path: String,
ingestion_fingerprint: &str,
text_content: TextContent,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
chunks: Vec<TextChunk>,
embedding_backend: &str,
embedding_model: Option<String>,
embedding_dimension: usize,
) -> Self {
Self {
version: PARAGRAPH_SHARD_VERSION,
paragraph_id: paragraph.id.clone(),
shard_path,
ingestion_fingerprint: ingestion_fingerprint.to_string(),
ingested_at: Utc::now(),
title: paragraph.title.clone(),
text_content,
entities,
relationships,
chunks,
question_bindings: HashMap::new(),
embedding_backend: embedding_backend.to_string(),
embedding_model,
embedding_dimension,
}
}
pub fn to_corpus_paragraph(&self) -> CorpusParagraph {
CorpusParagraph {
paragraph_id: self.paragraph_id.clone(),
title: self.title.clone(),
text_content: self.text_content.clone(),
entities: self.entities.clone(),
relationships: self.relationships.clone(),
chunks: self.chunks.clone(),
}
}
pub fn ensure_question_binding(
&mut self,
question: &ConvertedQuestion,
) -> Result<(Vec<String>, bool)> {
if let Some(existing) = self.question_bindings.get(&question.id) {
return Ok((existing.clone(), false));
}
let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?;
self.question_bindings
.insert(question.id.clone(), chunk_ids.clone());
Ok((chunk_ids, true))
}
}
fn validate_answers(
content: &TextContent,
chunks: &[TextChunk],
question: &ConvertedQuestion,
) -> Result<Vec<String>> {
if question.is_impossible || question.answers.is_empty() {
return Ok(Vec::new());
}
let mut matches = std::collections::BTreeSet::new();
let mut found_any = false;
let haystack = content.text.to_ascii_lowercase();
let haystack_norm = normalize_answer_text(&haystack);
for answer in &question.answers {
let needle: String = answer.to_ascii_lowercase();
let needle_norm = normalize_answer_text(&needle);
let text_match = haystack.contains(&needle)
|| (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm));
if text_match {
found_any = true;
}
for chunk in chunks {
let chunk_text = chunk.chunk.to_ascii_lowercase();
let chunk_norm = normalize_answer_text(&chunk_text);
if chunk_text.contains(&needle)
|| (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm))
{
matches.insert(chunk.id.clone());
found_any = true;
}
}
}
if !found_any {
Err(anyhow!(
"expected answer for question '{}' was not found in ingested content",
question.id
))
} else {
Ok(matches.into_iter().collect())
}
}
fn normalize_answer_text(text: &str) -> String {
text.chars()
.map(|ch| {
if ch.is_alphanumeric() || ch.is_whitespace() {
ch.to_ascii_lowercase()
} else {
' '
}
})
.collect::<String>()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
for paragraph in &manifest.paragraphs {
db.upsert_item(paragraph.text_content.clone())
.await
.context("storing text_content from manifest")?;
for entity in &paragraph.entities {
db.upsert_item(entity.clone())
.await
.context("storing knowledge_entity from manifest")?;
}
for relationship in &paragraph.relationships {
relationship
.store_relationship(db)
.await
.context("storing knowledge_relationship from manifest")?;
}
for chunk in &paragraph.chunks {
db.upsert_item(chunk.clone())
.await
.context("storing text_chunk from manifest")?;
}
}
Ok(())
}

View File

@@ -194,17 +194,34 @@ async fn async_main() -> anyhow::Result<()> {
)
})?;
println!(
"[{}] Precision@{k}: {precision:.3} ({correct}/{total}) → JSON: {json} | Markdown: {md} | Perf: {perf}",
summary.dataset_label,
k = summary.k,
precision = summary.precision,
correct = summary.correct,
total = summary.total_cases,
json = report_paths.json.display(),
md = report_paths.markdown.display(),
perf = perf_log_path.display()
);
if summary.llm_cases > 0 {
println!(
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) + LLM: {llm_answered}/{llm_total} ({llm_precision:.3}) → JSON: {json} | Markdown: {md} | Perf: {perf}",
summary.dataset_label,
k = summary.k,
precision = summary.precision,
correct = summary.correct,
retrieval_total = summary.retrieval_cases,
llm_answered = summary.llm_answered,
llm_total = summary.llm_cases,
llm_precision = summary.llm_precision,
json = report_paths.json.display(),
md = report_paths.markdown.display(),
perf = perf_log_path.display()
);
} else {
println!(
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) → JSON: {json} | Markdown: {md} | Perf: {perf}",
summary.dataset_label,
k = summary.k,
precision = summary.precision,
correct = summary.correct,
retrieval_total = summary.retrieval_cases,
json = report_paths.json.display(),
md = report_paths.markdown.display(),
perf = perf_log_path.display()
);
}
if parsed.config.perf_log_console {
perf::print_console_summary(&summary);

View File

@@ -19,6 +19,7 @@ struct PerformanceLogEntry {
dataset_id: String,
dataset_label: String,
run_label: Option<String>,
retrieval_strategy: String,
slice_id: String,
slice_seed: u64,
slice_window_offset: usize,
@@ -27,6 +28,10 @@ struct PerformanceLogEntry {
total_cases: usize,
correct: usize,
precision: f64,
retrieval_cases: usize,
llm_cases: usize,
llm_answered: usize,
llm_precision: f64,
k: usize,
openai_base_url: String,
ingestion: IngestionPerf,
@@ -87,7 +92,7 @@ impl PerformanceLogEntry {
rerank_enabled: summary.rerank_enabled,
rerank_pool_size: summary.rerank_pool_size,
rerank_keep_top: summary.rerank_keep_top,
evaluated_cases: summary.total_cases,
evaluated_cases: summary.retrieval_cases,
};
Self {
@@ -95,6 +100,7 @@ impl PerformanceLogEntry {
dataset_id: summary.dataset_id.clone(),
dataset_label: summary.dataset_label.clone(),
run_label: summary.run_label.clone(),
retrieval_strategy: summary.retrieval_strategy.clone(),
slice_id: summary.slice_id.clone(),
slice_seed: summary.slice_seed,
slice_window_offset: summary.slice_window_offset,
@@ -103,6 +109,10 @@ impl PerformanceLogEntry {
total_cases: summary.total_cases,
correct: summary.correct,
precision: summary.precision,
retrieval_cases: summary.retrieval_cases,
llm_cases: summary.llm_cases,
llm_answered: summary.llm_answered,
llm_precision: summary.llm_precision,
k: summary.k,
openai_base_url: summary.perf.openai_base_url.clone(),
ingestion,
@@ -162,6 +172,13 @@ pub fn write_perf_logs(
pub fn print_console_summary(summary: &EvaluationSummary) {
let perf = &summary.perf;
println!(
"[perf] retrieval strategy={} | rerank={} (pool {:?}, keep {})",
summary.retrieval_strategy,
summary.rerank_enabled,
summary.rerank_pool_size,
summary.rerank_keep_top
);
println!(
"[perf] ingestion={}ms | namespace_seed={}",
perf.ingestion_ms,
@@ -169,7 +186,8 @@ pub fn print_console_summary(summary: &EvaluationSummary) {
);
let stage = &perf.stage_latency;
println!(
"[perf] stage avg ms → collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
stage.embed.avg,
stage.collect_candidates.avg,
stage.graph_expansion.avg,
stage.chunk_attach.avg,
@@ -212,6 +230,7 @@ mod tests {
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
crate::eval::StageLatencyBreakdown {
embed: sample_latency(),
collect_candidates: sample_latency(),
graph_expansion: sample_latency(),
chunk_attach: sample_latency(),
@@ -252,6 +271,15 @@ mod tests {
dataset_label: "SQuAD v2".into(),
dataset_includes_unanswerable: false,
dataset_source: "dev".into(),
includes_impossible_cases: false,
require_verified_chunks: true,
filtered_questions: 0,
retrieval_cases: 2,
retrieval_correct: 1,
retrieval_precision: 0.5,
llm_cases: 0,
llm_answered: 0,
llm_precision: 0.0,
slice_id: "slice123".into(),
slice_seed: 42,
slice_total_cases: 400,
@@ -285,6 +313,7 @@ mod tests {
rerank_pool_size: Some(4),
rerank_keep_top: 10,
concurrency: 2,
retrieval_strategy: "initial".into(),
detailed_report: false,
chunk_vector_take: 20,
chunk_fts_take: 20,

File diff suppressed because it is too large Load Diff

View File

@@ -23,5 +23,6 @@ pub fn slice_config_with_limit<'a>(
slice_seed: config.slice_seed,
llm_mode: config.llm_mode,
negative_multiplier: config.negative_multiplier,
require_verified_chunks: config.retrieval.require_verified_chunks,
}
}

View File

@@ -26,6 +26,7 @@ pub struct SliceConfig<'a> {
pub slice_seed: u64,
pub llm_mode: bool,
pub negative_multiplier: f32,
pub require_verified_chunks: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -36,6 +37,8 @@ pub struct SliceManifest {
pub dataset_label: String,
pub dataset_source: String,
pub includes_unanswerable: bool,
#[serde(default = "default_require_verified_chunks")]
pub require_verified_chunks: bool,
pub seed: u64,
pub requested_limit: Option<usize>,
pub requested_corpus: usize,
@@ -49,6 +52,10 @@ pub struct SliceManifest {
pub paragraphs: Vec<SliceParagraphEntry>,
}
fn default_require_verified_chunks() -> bool {
false
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SliceCaseEntry {
pub question_id: String,
@@ -184,6 +191,7 @@ impl DatasetIndex {
struct SliceKey<'a> {
dataset_id: &'a str,
includes_unanswerable: bool,
require_verified_chunks: bool,
requested_corpus: usize,
seed: u64,
}
@@ -222,7 +230,8 @@ pub fn resolve_slice<'a>(
.max(1);
let key = SliceKey {
dataset_id: dataset.metadata.id.as_str(),
includes_unanswerable: dataset.metadata.include_unanswerable,
includes_unanswerable: config.llm_mode,
require_verified_chunks: config.require_verified_chunks,
requested_corpus,
seed: config.slice_seed,
};
@@ -248,13 +257,24 @@ pub fn resolve_slice<'a>(
let mut manifest = if !config.force_convert && path.exists() {
match read_manifest(&path) {
Ok(manifest) if manifest.dataset_id == dataset.metadata.id => {
if manifest.includes_unanswerable != dataset.metadata.include_unanswerable {
if manifest.includes_unanswerable != config.llm_mode {
warn!(
slice = manifest.slice_id,
path = %path.display(),
expected = config.llm_mode,
found = manifest.includes_unanswerable,
"Slice manifest includes_unanswerable mismatch; regenerating"
);
None
} else if manifest.require_verified_chunks != config.require_verified_chunks {
warn!(
slice = manifest.slice_id,
path = %path.display(),
expected = config.require_verified_chunks,
found = manifest.require_verified_chunks,
"Slice manifest verified-chunk requirement mismatch; regenerating"
);
None
} else {
Some(manifest)
}
@@ -312,6 +332,7 @@ pub fn resolve_slice<'a>(
&params,
requested_corpus,
config.negative_multiplier,
config.require_verified_chunks,
config.limit,
)
});
@@ -319,6 +340,8 @@ pub fn resolve_slice<'a>(
manifest.requested_limit = config.limit;
manifest.requested_corpus = requested_corpus;
manifest.negative_multiplier = config.negative_multiplier;
manifest.includes_unanswerable = config.llm_mode;
manifest.require_verified_chunks = config.require_verified_chunks;
let mut changed = ensure_shard_paths(&mut manifest);
@@ -439,6 +462,22 @@ fn load_explicit_slice<'a>(
dataset.metadata.id
));
}
if manifest.includes_unanswerable != config.llm_mode {
return Err(anyhow!(
"slice '{}' includes_unanswerable mismatch (expected {}, found {})",
manifest.slice_id,
config.llm_mode,
manifest.includes_unanswerable
));
}
if manifest.require_verified_chunks != config.require_verified_chunks {
return Err(anyhow!(
"slice '{}' verified-chunk requirement mismatch (expected {}, found {})",
manifest.slice_id,
config.require_verified_chunks,
manifest.require_verified_chunks
));
}
// Validate the manifest before returning.
manifest_to_resolved(dataset, index, manifest.clone(), candidate_path.clone())?;
@@ -452,6 +491,7 @@ fn empty_manifest(
params: &BuildParams,
requested_corpus: usize,
negative_multiplier: f32,
require_verified_chunks: bool,
requested_limit: Option<usize>,
) -> SliceManifest {
SliceManifest {
@@ -460,7 +500,8 @@ fn empty_manifest(
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
dataset_source: dataset.source.clone(),
includes_unanswerable: dataset.metadata.include_unanswerable,
includes_unanswerable: params.include_impossible,
require_verified_chunks,
seed: params.base_seed,
requested_limit,
requested_corpus,
@@ -891,6 +932,7 @@ mod tests {
slice_seed: 0x5eed_2025,
llm_mode: false,
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
require_verified_chunks: true,
};
let first = resolve_slice(&dataset, &config)?;
@@ -922,6 +964,7 @@ mod tests {
slice_seed: 0x5eed_2025,
llm_mode: false,
negative_multiplier: DEFAULT_NEGATIVE_MULTIPLIER,
require_verified_chunks: true,
};
let resolved = resolve_slice(&dataset, &config)?;
let window = select_window(&resolved, 1, Some(1))?;

View File

@@ -54,9 +54,9 @@ impl Descriptor {
embedding_backend: embedding_provider.backend_label().to_string(),
embedding_model: embedding_provider.model_code(),
embedding_dimension: embedding_provider.dimension(),
chunk_min_chars: config.chunk_min_chars,
chunk_max_chars: config.chunk_max_chars,
rerank_enabled: config.rerank,
chunk_min_chars: config.retrieval.chunk_min_chars,
chunk_max_chars: config.retrieval.chunk_max_chars,
rerank_enabled: config.retrieval.rerank,
};
let dir = config