evals: v3, ebeddings at the side

additional indexes
This commit is contained in:
Per Stark
2025-11-26 15:00:55 +01:00
parent 226b2db43a
commit 030f0fc17d
63 changed files with 3859 additions and 1124 deletions

View File

@@ -8,6 +8,23 @@ use retrieval_pipeline::RetrievalStrategy;
use crate::datasets::DatasetKind;
fn workspace_root() -> PathBuf {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
manifest_dir.parent().unwrap_or(&manifest_dir).to_path_buf()
}
fn default_report_dir() -> PathBuf {
workspace_root().join("eval/reports")
}
fn default_cache_dir() -> PathBuf {
workspace_root().join("eval/cache")
}
fn default_ingestion_cache_dir() -> PathBuf {
workspace_root().join("eval/cache/ingested")
}
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -129,7 +146,7 @@ impl Default for Config {
corpus_limit: None,
raw_dataset_path: dataset.default_raw_path(),
converted_dataset_path: dataset.default_converted_path(),
report_dir: PathBuf::from("eval/reports"),
report_dir: default_report_dir(),
k: 5,
limit: Some(200),
summary_sample: 5,
@@ -138,8 +155,8 @@ impl Default for Config {
concurrency: 4,
embedding_backend: EmbeddingBackend::FastEmbed,
embedding_model: None,
cache_dir: PathBuf::from("eval/cache"),
ingestion_cache_dir: PathBuf::from("eval/cache/ingested"),
cache_dir: default_cache_dir(),
ingestion_cache_dir: default_ingestion_cache_dir(),
ingestion_batch_size: 5,
ingestion_max_retries: 3,
refresh_embeddings_only: false,
@@ -585,6 +602,13 @@ where
}
pub fn print_help() {
let report_default = default_report_dir();
let cache_default = default_cache_dir();
let ingestion_cache_default = default_ingestion_cache_dir();
let report_default_display = report_default.display();
let cache_default_display = cache_default.display();
let ingestion_cache_default_display = ingestion_cache_default.display();
println!(
"\
eval — dataset conversion, ingestion, and retrieval evaluation CLI
@@ -610,7 +634,7 @@ OPTIONS:
--corpus-limit <int> Cap the slice corpus size (positives + negatives). Defaults to ~10× --limit, capped at 1000.
--raw <path> Path to the raw dataset (defaults per dataset).
--converted <path> Path to write/read the converted dataset (defaults per dataset).
--report-dir <path> Directory to write evaluation reports (default: eval/reports).
--report-dir <path> Directory to write evaluation reports (default: {report_default_display}).
--k <int> Precision@k cutoff (default: 5).
--limit <int> Limit the number of questions evaluated (default: 200, 0 = all).
--sample <int> Number of mismatches to surface in the Markdown summary (default: 5).
@@ -632,9 +656,9 @@ OPTIONS:
--embedding <name> Embedding backend: 'fastembed' (default) or 'hashed'.
--embedding-model <code>
FastEmbed model code (defaults to crate preset when omitted).
--cache-dir <path> Directory for embedding caches (default: eval/cache).
--cache-dir <path> Directory for embedding caches (default: {cache_default_display}).
--ingestion-cache-dir <path>
Directory for ingestion corpora caches (default: eval/cache/ingested).
Directory for ingestion corpora caches (default: {ingestion_cache_default_display}).
--ingestion-batch-size <int>
Number of paragraphs to ingest concurrently (default: 5).
--ingestion-max-retries <int>

View File

@@ -1,187 +1,30 @@
use anyhow::{Context, Result};
use common::storage::db::SurrealDbClient;
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes};
use serde::Deserialize;
use tracing::info;
// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings
// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings.
pub async fn change_embedding_length_in_hnsw_indexes(
db: &SurrealDbClient,
dimension: usize,
) -> Result<()> {
tracing::info!("Changing embedding length in HNSW indexes");
let query = format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim};
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim};
COMMIT TRANSACTION;",
dim = dimension
);
db.client
.query(query)
.await
.context("changing HNSW indexes")?;
tracing::info!("HNSW indexes successfully changed");
// No-op for now; runtime indexes are created after ingestion with the correct dimension.
let _ = (db, dimension);
Ok(())
}
// Helper functions for index management during namespace reseed
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
tracing::info!("Removing ALL indexes before namespace reseed (aggressive approach)");
// Remove ALL indexes from ALL tables to ensure no cache access
db.client
.query(
"BEGIN TRANSACTION;
-- HNSW indexes
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
-- FTS indexes on text_content (remove ALL of them)
REMOVE INDEX IF EXISTS text_content_fts_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_text_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_category_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_context_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_url_idx ON TABLE text_content;
REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON TABLE text_content;
-- FTS indexes on knowledge_entity
REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity;
REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity;
-- FTS indexes on text_chunk
REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk;
COMMIT TRANSACTION;",
)
.await
.context("removing all indexes before namespace reseed")?;
tracing::info!("All indexes removed before namespace reseed");
Ok(())
}
async fn create_tokenizer(db: &SurrealDbClient) -> Result<()> {
tracing::info!("Creating FTS analyzers for namespace reseed");
let res = db
.client
.query(
"BEGIN TRANSACTION;
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);
COMMIT TRANSACTION;",
)
.await
.context("creating FTS analyzers for namespace reseed")?;
res.check().context("failed to create the tokenizer")?;
let _ = db;
info!("Removing ALL indexes before namespace reseed (no-op placeholder)");
Ok(())
}
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
tracing::info!("Recreating ALL indexes after namespace reseed (SEQUENTIAL approach)");
let total_start = std::time::Instant::now();
create_tokenizer(db)
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
ensure_runtime_indexes(db, dimension)
.await
.context("creating FTS analyzer")?;
// For now we dont remove these plain indexes, we could if they prove negatively impacting performance
// create_regular_indexes_for_snapshot(db)
// .await
// .context("creating regular indexes for namespace reseed")?;
let fts_start = std::time::Instant::now();
create_fts_indexes_for_snapshot(db)
.await
.context("creating FTS indexes for namespace reseed")?;
tracing::info!(duration = ?fts_start.elapsed(), "FTS indexes created");
let hnsw_start = std::time::Instant::now();
create_hnsw_indexes_for_snapshot(db, dimension)
.await
.context("creating HNSW indexes for namespace reseed")?;
tracing::info!(duration = ?hnsw_start.elapsed(), "HNSW indexes created");
tracing::info!(duration = ?total_start.elapsed(), "All index groups recreated successfully in sequence");
Ok(())
}
#[allow(dead_code)] // For now we dont do this. We could
async fn create_regular_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> {
tracing::info!("Creating regular indexes for namespace reseed (parallel group 1)");
let res = db
.client
.query(
"BEGIN TRANSACTION;
DEFINE INDEX text_content_user_id_idx ON text_content FIELDS user_id;
DEFINE INDEX text_content_created_at_idx ON text_content FIELDS created_at;
DEFINE INDEX text_content_category_idx ON text_content FIELDS category;
DEFINE INDEX text_chunk_source_id_idx ON text_chunk FIELDS source_id;
DEFINE INDEX text_chunk_user_id_idx ON text_chunk FIELDS user_id;
DEFINE INDEX knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
DEFINE INDEX knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
DEFINE INDEX knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
DEFINE INDEX knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
COMMIT TRANSACTION;",
)
.await
.context("creating regular indexes for namespace reseed")?;
res.check().context("one of the regular indexes failed")?;
tracing::info!("Regular indexes for namespace reseed created");
Ok(())
}
async fn create_fts_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> {
tracing::info!("Creating FTS indexes for namespace reseed (group 2)");
let res = db.client
.query(
"BEGIN TRANSACTION;
DEFINE INDEX text_content_fts_idx ON TABLE text_content FIELDS text;
DEFINE INDEX knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name
SEARCH ANALYZER app_en_fts_analyzer BM25;
DEFINE INDEX knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description
SEARCH ANALYZER app_en_fts_analyzer BM25;
DEFINE INDEX text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk
SEARCH ANALYZER app_en_fts_analyzer BM25;
COMMIT TRANSACTION;",
)
.await
.context("sending FTS index creation query")?;
// This actually surfaces statement-level errors
res.check()
.context("one or more FTS index statements failed")?;
tracing::info!("FTS indexes for namespace reseed created");
Ok(())
}
async fn create_hnsw_indexes_for_snapshot(db: &SurrealDbClient, dimension: usize) -> Result<()> {
tracing::info!("Creating HNSW indexes for namespace reseed (group 3)");
let query = format!(
"BEGIN TRANSACTION;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim};
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim};
COMMIT TRANSACTION;",
dim = dimension
);
let res = db
.client
.query(query)
.await
.context("creating HNSW indexes for namespace reseed")?;
res.check()
.context("one or more HNSW index statements failed")?;
tracing::info!("HNSW indexes for namespace reseed created");
Ok(())
.context("creating runtime indexes")
}
pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &str) -> Result<()> {
@@ -207,7 +50,6 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use uuid::Uuid;
#[derive(Debug, Deserialize)]

View File

@@ -12,7 +12,7 @@ use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{system_settings::SystemSettings, user::User},
types::{system_settings::SystemSettings, user::User, StoredObject},
},
};
use retrieval_pipeline::RetrievalTuning;
@@ -172,18 +172,26 @@ pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> R
info!("Warming HNSW caches with sample queries");
// Warm up chunk index
// Warm up chunk embedding index - just query the embedding table to load HNSW index
let _ = db
.client
.query("SELECT * FROM text_chunk WHERE embedding <|1,1|> $embedding LIMIT 5")
.query(
"SELECT chunk_id \
FROM text_chunk_embedding \
WHERE embedding <|1,1|> $embedding LIMIT 5",
)
.bind(("embedding", dummy_embedding.clone()))
.await
.context("warming text chunk HNSW cache")?;
// Warm up entity index
// Warm up entity embedding index
let _ = db
.client
.query("SELECT * FROM knowledge_entity WHERE embedding <|1,1|> $embedding LIMIT 5")
.query(
"SELECT entity_id \
FROM knowledge_entity_embedding \
WHERE embedding <|1,1|> $embedding LIMIT 5",
)
.bind(("embedding", dummy_embedding))
.await
.context("warming knowledge entity HNSW cache")?;
@@ -206,7 +214,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
timezone: "UTC".to_string(),
};
if let Some(existing) = db.get_item::<User>(&user.id).await? {
if let Some(existing) = db.get_item::<User>(&user.get_id()).await? {
return Ok(existing);
}
@@ -321,11 +329,11 @@ pub(crate) async fn can_reuse_namespace(
}
};
if state.slice_case_count < slice_case_count {
if state.slice_case_count != slice_case_count {
info!(
requested_cases = slice_case_count,
stored_cases = state.slice_case_count,
"Skipping live namespace reuse; ledger grew beyond cached state"
"Skipping live namespace reuse; cached state does not match requested window"
);
return Ok(false);
}
@@ -420,12 +428,12 @@ pub(crate) async fn enforce_system_settings(
) -> Result<SystemSettings> {
let mut updated_settings = settings.clone();
let mut needs_settings_update = false;
let mut embedding_dimension_changed = false;
// let mut embedding_dimension_changed = false;
if provider_dimension != settings.embedding_dimensions as usize {
updated_settings.embedding_dimensions = provider_dimension as u32;
needs_settings_update = true;
embedding_dimension_changed = true;
// embedding_dimension_changed = true;
}
if let Some(query_override) = config.query_model.as_deref() {
if settings.query_model != query_override {
@@ -442,16 +450,18 @@ pub(crate) async fn enforce_system_settings(
.await
.context("updating system settings overrides")?;
}
if embedding_dimension_changed {
change_embedding_length_in_hnsw_indexes(db, provider_dimension)
.await
.context("redefining HNSW indexes for new embedding dimension")?;
}
// We dont need to do this, we've changed from default settings already
// if embedding_dimension_changed {
// change_embedding_length_in_hnsw_indexes(db, provider_dimension)
// .await
// .context("redefining HNSW indexes for new embedding dimension")?;
// }
Ok(settings)
}
pub(crate) async fn load_or_init_system_settings(
db: &SurrealDbClient,
dimension: usize,
) -> Result<(SystemSettings, bool)> {
match SystemSettings::get_current(db).await {
Ok(settings) => Ok((settings, false)),
@@ -460,7 +470,6 @@ pub(crate) async fn load_or_init_system_settings(
db.apply_migrations()
.await
.context("applying database migrations after missing system settings")?;
tokio::time::sleep(Duration::from_millis(50)).await;
let settings = SystemSettings::get_current(db)
.await
.context("loading system settings after migrations")?;
@@ -473,8 +482,8 @@ pub(crate) async fn load_or_init_system_settings(
#[cfg(test)]
mod tests {
use super::*;
use crate::ingest::store::CorpusParagraph;
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion};
use crate::ingest::store::{CorpusParagraph, EmbeddedKnowledgeEntity, EmbeddedTextChunk};
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion, MANIFEST_VERSION};
use chrono::Utc;
use common::storage::types::text_content::TextContent;
@@ -491,9 +500,9 @@ mod tests {
None,
"user".to_string(),
),
entities: Vec::new(),
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
relationships: Vec::new(),
chunks: Vec::new(),
chunks: Vec::<EmbeddedTextChunk>::new(),
},
CorpusParagraph {
paragraph_id: "p2".to_string(),
@@ -506,9 +515,9 @@ mod tests {
None,
"user".to_string(),
),
entities: Vec::new(),
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
relationships: Vec::new(),
chunks: Vec::new(),
chunks: Vec::<EmbeddedTextChunk>::new(),
},
];
let questions = vec![
@@ -541,7 +550,7 @@ mod tests {
},
];
CorpusManifest {
version: 1,
version: MANIFEST_VERSION,
metadata: CorpusMetadata {
dataset_id: "ds".to_string(),
dataset_label: "Dataset".to_string(),

View File

@@ -5,9 +5,12 @@ use std::{
};
use async_openai::Client;
use common::storage::{
db::SurrealDbClient,
types::{system_settings::SystemSettings, user::User},
use common::{
storage::{
db::SurrealDbClient,
types::{system_settings::SystemSettings, user::User},
},
utils::embedding::EmbeddingProvider,
};
use retrieval_pipeline::{
pipeline::{PipelineStageTimings, RetrievalConfig},
@@ -18,7 +21,6 @@ use crate::{
args::Config,
cache::EmbeddingCache,
datasets::ConvertedDataset,
embedding::EmbeddingProvider,
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
ingest, slice, snapshot,
};

View File

@@ -3,7 +3,7 @@ use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::{ingest, slice, snapshot};
use crate::{eval::can_reuse_namespace, ingest, slice, snapshot};
use super::super::{
context::{EvalStage, EvaluationContext},
@@ -26,19 +26,78 @@ pub(crate) async fn prepare_corpus(
let cache_settings = ingest::CorpusCacheConfig::from(config);
let embedding_provider = ctx.embedding_provider().clone();
let openai_client = ctx.openai_client();
let slice = ctx.slice();
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
.context("selecting slice window for corpus preparation")?;
let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider());
let expected_fingerprint = ingest::compute_ingestion_fingerprint(
ctx.dataset(),
slice,
config.converted_dataset_path.as_path(),
)?;
let base_dir = ingest::cached_corpus_dir(
&cache_settings,
ctx.dataset().metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
);
if !config.reseed_slice {
let requested_cases = window.cases.len();
if can_reuse_namespace(
ctx.db(),
&descriptor,
&ctx.namespace,
&ctx.database,
ctx.dataset().metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
requested_cases,
)
.await?
{
if let Some(manifest) = ingest::load_cached_manifest(&base_dir)? {
info!(
cache = %base_dir.display(),
namespace = ctx.namespace.as_str(),
database = ctx.database.as_str(),
"Namespace already seeded; reusing cached corpus manifest"
);
let corpus_handle = ingest::corpus_handle_from_manifest(manifest, base_dir);
ctx.corpus_handle = Some(corpus_handle);
ctx.expected_fingerprint = Some(expected_fingerprint);
ctx.ingestion_duration_ms = 0;
ctx.descriptor = Some(descriptor);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
return machine
.prepare_corpus()
.map_err(|(_, guard)| map_guard_error("prepare_corpus", guard));
} else {
info!(
cache = %base_dir.display(),
"Namespace reusable but cached manifest missing; regenerating corpus"
);
}
}
}
let eval_user_id = "eval-user".to_string();
let ingestion_timer = Instant::now();
let corpus_handle = {
let slice = ctx.slice();
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
.context("selecting slice window for corpus preparation")?;
ingest::ensure_corpus(
ctx.dataset(),
slice,
&window,
&cache_settings,
&embedding_provider,
embedding_provider.clone().into(),
openai_client,
&eval_user_id,
config.converted_dataset_path.as_path(),
@@ -64,11 +123,7 @@ pub(crate) async fn prepare_corpus(
ctx.corpus_handle = Some(corpus_handle);
ctx.expected_fingerprint = Some(expected_fingerprint);
ctx.ingestion_duration_ms = ingestion_duration_ms;
ctx.descriptor = Some(snapshot::Descriptor::new(
config,
ctx.slice(),
ctx.embedding_provider(),
));
ctx.descriptor = Some(descriptor);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);

View File

@@ -6,12 +6,12 @@ use tracing::info;
use crate::{
args::EmbeddingBackend,
cache::EmbeddingCache,
embedding,
eval::{
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
},
openai,
};
use common::utils::embedding::EmbeddingProvider;
use super::super::{
context::{EvalStage, EvaluationContext},
@@ -35,15 +35,22 @@ pub(crate) async fn prepare_db(
let config = ctx.config();
let db = connect_eval_db(config, &namespace, &database).await?;
let (mut settings, settings_missing) = load_or_init_system_settings(&db).await?;
let embedding_provider =
embedding::build_provider(config, settings.embedding_dimensions as usize)
.await
.context("building embedding provider")?;
let (raw_openai_client, openai_base_url) =
openai::build_client_from_env().context("building OpenAI client")?;
let openai_client = Arc::new(raw_openai_client);
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
let embedding_provider = match config.embedding_backend {
crate::args::EmbeddingBackend::FastEmbed => {
EmbeddingProvider::new_fastembed(config.embedding_model.clone())
.await
.context("creating FastEmbed provider")?
}
crate::args::EmbeddingBackend::Hashed => {
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
}
};
let provider_dimension = embedding_provider.dimension();
if provider_dimension == 0 {
return Err(anyhow!(
@@ -62,6 +69,9 @@ pub(crate) async fn prepare_db(
);
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
let (mut settings, settings_missing) =
load_or_init_system_settings(&db, provider_dimension).await?;
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
if let Some(model_code) = embedding_provider.model_code() {
let sanitized = sanitize_model_code(&model_code);

View File

@@ -41,6 +41,22 @@ pub(crate) async fn prepare_namespace(
let database = ctx.database.clone();
let embedding_provider = ctx.embedding_provider().clone();
let corpus_handle = ctx.corpus_handle();
let base_manifest = &corpus_handle.manifest;
let manifest_for_seed =
if ctx.window_offset == 0 && ctx.window_length >= base_manifest.questions.len() {
base_manifest.clone()
} else {
ingest::window_manifest(
base_manifest,
ctx.window_offset,
ctx.window_length,
ctx.config().negative_multiplier,
)
.context("selecting manifest window for seeding")?
};
let requested_cases = manifest_for_seed.questions.len();
let mut namespace_reused = false;
if !config.reseed_slice {
namespace_reused = {
@@ -53,7 +69,7 @@ pub(crate) async fn prepare_namespace(
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
slice.manifest.case_count,
requested_cases,
)
.await?
};
@@ -79,25 +95,39 @@ pub(crate) async fn prepare_namespace(
slice = slice.manifest.slice_id.as_str(),
window_offset = ctx.window_offset,
window_length = ctx.window_length,
positives = slice.manifest.positive_paragraphs,
negatives = slice.manifest.negative_paragraphs,
total = slice.manifest.total_paragraphs,
positives = manifest_for_seed
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect::<std::collections::HashSet<_>>()
.len(),
negatives = manifest_for_seed.paragraphs.len().saturating_sub(
manifest_for_seed
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect::<std::collections::HashSet<_>>()
.len(),
),
total = manifest_for_seed.paragraphs.len(),
"Seeding ingestion corpus into SurrealDB"
);
}
let indexes_disabled = remove_all_indexes(ctx.db()).await.is_ok();
let seed_start = Instant::now();
ingest::seed_manifest_into_db(ctx.db(), &ctx.corpus_handle().manifest)
ingest::seed_manifest_into_db(ctx.db(), &manifest_for_seed)
.await
.context("seeding ingestion corpus from manifest")?;
namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128);
// Recreate indexes AFTER data is loaded (correct bulk loading pattern)
if indexes_disabled {
info!("Recreating indexes after namespace reset");
if let Err(err) = recreate_indexes(ctx.db(), embedding_provider.dimension()).await {
warn!(error = %err, "failed to restore indexes after namespace reset");
} else {
warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?;
}
info!("Recreating indexes after seeding data");
recreate_indexes(ctx.db(), embedding_provider.dimension())
.await
.context("recreating indexes with correct dimension")?;
warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?;
}
{
let slice = ctx.slice();
@@ -108,7 +138,7 @@ pub(crate) async fn prepare_namespace(
expected_fingerprint.as_str(),
&namespace,
&database,
slice.manifest.case_count,
requested_cases,
)
.await;
}
@@ -128,11 +158,10 @@ pub(crate) async fn prepare_namespace(
let user = ensure_eval_user(ctx.db()).await?;
ctx.eval_user = Some(user);
let corpus_handle = ctx.corpus_handle();
let total_manifest_questions = corpus_handle.manifest.questions.len();
let cases = cases_from_manifest(&corpus_handle.manifest);
let include_impossible = corpus_handle.manifest.metadata.include_unanswerable;
let require_verified_chunks = corpus_handle.manifest.metadata.require_verified_chunks;
let total_manifest_questions = manifest_for_seed.questions.len();
let cases = cases_from_manifest(&manifest_for_seed);
let include_impossible = manifest_for_seed.metadata.include_unanswerable;
let require_verified_chunks = manifest_for_seed.metadata.require_verified_chunks;
let filtered = total_manifest_questions.saturating_sub(cases.len());
if filtered > 0 {
info!(

View File

@@ -1,6 +1,7 @@
use std::{collections::HashSet, sync::Arc, time::Instant};
use anyhow::Context;
use common::storage::types::StoredObject;
use futures::stream::{self, StreamExt};
use tracing::{debug, info};
@@ -174,6 +175,7 @@ pub(crate) async fn run_queries(
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
&db,
&openai_client,
Some(&embedding_provider),
query_embedding,
&question,
&user_id,
@@ -187,6 +189,7 @@ pub(crate) async fn run_queries(
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
&db,
&openai_client,
Some(&embedding_provider),
query_embedding,
&question,
&user_id,
@@ -228,9 +231,10 @@ pub(crate) async fn run_queries(
}
let chunk_id_for_entity = if chunk_id_required {
expected_chunk_ids_set.contains(candidate.source_id.as_str())
|| candidate.chunks.iter().any(|chunk| {
expected_chunk_ids_set.contains(chunk.chunk.id.as_str())
})
|| candidate
.chunks
.iter()
.any(|chunk| expected_chunk_ids_set.contains(&chunk.chunk.get_id()))
} else {
true
};

View File

@@ -1,6 +1,7 @@
use std::collections::HashSet;
use chrono::{DateTime, Utc};
use common::storage::types::StoredObject;
use retrieval_pipeline::{
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
};
@@ -164,7 +165,7 @@ impl EvaluationCandidate {
fn from_entity(entity: RetrievedEntity) -> Self {
let entity_category = Some(format!("{:?}", entity.entity.entity_type));
Self {
entity_id: entity.entity.id.clone(),
entity_id: entity.entity.get_id().to_string(),
source_id: entity.entity.source_id.clone(),
entity_name: entity.entity.name.clone(),
entity_description: Some(entity.entity.description.clone()),
@@ -177,7 +178,7 @@ impl EvaluationCandidate {
fn from_chunk(chunk: RetrievedChunk) -> Self {
let snippet = chunk_snippet(&chunk.chunk.chunk);
Self {
entity_id: chunk.chunk.id.clone(),
entity_id: chunk.chunk.get_id().to_string(),
source_id: chunk.chunk.source_id.clone(),
entity_name: chunk.chunk.source_id.clone(),
entity_description: Some(snippet),
@@ -301,7 +302,9 @@ pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageL
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
entry.graph_expansion_ms()
})),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms())),
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| {
entry.chunk_attach_ms()
})),
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms())),
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms())),
}
@@ -332,11 +335,11 @@ pub fn build_case_diagnostics(
let mut chunk_entries = Vec::new();
for chunk in &candidate.chunks {
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
let expected_chunk = expected_set.contains(chunk.chunk.id.as_str());
seen_chunks.insert(chunk.chunk.id.clone());
attached_chunk_ids.push(chunk.chunk.id.clone());
let expected_chunk = expected_set.contains(chunk.chunk.get_id());
seen_chunks.insert(chunk.chunk.get_id().to_string());
attached_chunk_ids.push(chunk.chunk.get_id().to_string());
chunk_entries.push(ChunkDiagnosticsEntry {
chunk_id: chunk.chunk.id.clone(),
chunk_id: chunk.chunk.get_id().to_string(),
score: chunk.score,
contains_answer,
expected_chunk,

View File

@@ -59,6 +59,25 @@ impl CorpusEmbeddingProvider for EmbeddingProvider {
}
}
#[async_trait]
impl CorpusEmbeddingProvider for common::utils::embedding::EmbeddingProvider {
fn backend_label(&self) -> &str {
common::utils::embedding::EmbeddingProvider::backend_label(self)
}
fn model_code(&self) -> Option<String> {
common::utils::embedding::EmbeddingProvider::model_code(self)
}
fn dimension(&self) -> usize {
common::utils::embedding::EmbeddingProvider::dimension(self)
}
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
common::utils::embedding::EmbeddingProvider::embed_batch(self, texts).await
}
}
impl From<&Config> for CorpusCacheConfig {
fn from(config: &Config) -> Self {
CorpusCacheConfig::new(

View File

@@ -2,9 +2,13 @@ mod config;
mod orchestrator;
pub(crate) mod store;
pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider};
pub use orchestrator::ensure_corpus;
pub use store::{
seed_manifest_into_db, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
pub use config::CorpusCacheConfig;
pub use orchestrator::{
cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus,
load_cached_manifest,
};
pub use store::{
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
CorpusQuestion, EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard,
ParagraphShardStore, MANIFEST_VERSION,
};

View File

@@ -2,7 +2,7 @@ use std::{
collections::{HashMap, HashSet},
fs,
io::Read,
path::Path,
path::{Path, PathBuf},
sync::Arc,
};
@@ -13,10 +13,7 @@ use common::{
storage::{
db::SurrealDbClient,
store::{DynStore, StorageManager},
types::{
ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk,
},
types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject},
},
utils::config::{AppConfig, StorageKind},
};
@@ -29,12 +26,14 @@ use uuid::Uuid;
use crate::{
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion},
db_helpers::change_embedding_length_in_hnsw_indexes,
slices::{self, ResolvedSlice, SliceParagraphKind},
};
use crate::ingest::{
CorpusCacheConfig, CorpusEmbeddingProvider, CorpusHandle, CorpusManifest, CorpusMetadata,
CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
CorpusCacheConfig, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, ParagraphShardStore,
MANIFEST_VERSION,
};
const INGESTION_SPEC_VERSION: u32 = 1;
@@ -108,12 +107,12 @@ struct IngestionStats {
negative_ingested: usize,
}
pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
pub async fn ensure_corpus(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
window: &slices::SliceWindow<'_>,
cache: &CorpusCacheConfig,
embedding: &E,
embedding: Arc<common::utils::embedding::EmbeddingProvider>,
openai: Arc<OpenAIClient>,
user_id: &str,
converted_path: &Path,
@@ -122,10 +121,11 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
.with_context(|| format!("computing checksum for {}", converted_path.display()))?;
let ingestion_fingerprint = build_ingestion_fingerprint(dataset, slice, &checksum);
let base_dir = cache
.ingestion_cache_dir
.join(dataset.metadata.id.as_str())
.join(slice.manifest.slice_id.as_str());
let base_dir = cached_corpus_dir(
cache,
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
);
if cache.force_refresh && !cache.refresh_embeddings_only {
let _ = fs::remove_dir_all(&base_dir);
}
@@ -144,11 +144,19 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
));
}
let desired_negatives =
((positive_set.len() as f32) * slice.manifest.negative_multiplier).ceil() as usize;
let mut plan = Vec::new();
let mut negatives_added = 0usize;
for (idx, entry) in slice.manifest.paragraphs.iter().enumerate() {
let include = match &entry.kind {
SliceParagraphKind::Positive { .. } => positive_set.contains(entry.id.as_str()),
SliceParagraphKind::Negative => true,
SliceParagraphKind::Negative => {
negatives_added < desired_negatives && {
negatives_added += 1;
true
}
}
};
if include {
let paragraph = slice
@@ -224,7 +232,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
let new_shards = ingest_paragraph_batch(
dataset,
&ingest_requests,
embedding,
embedding.clone(),
openai.clone(),
user_id,
&ingestion_fingerprint,
@@ -251,8 +259,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
.as_mut()
.context("shard record missing after ingestion run")?;
if cache.refresh_embeddings_only || shard_record.needs_reembed {
reembed_entities(&mut shard_record.shard.entities, embedding).await?;
reembed_chunks(&mut shard_record.shard.chunks, embedding).await?;
// Embeddings are now generated by the pipeline using FastEmbed - no need to re-embed
shard_record.shard.ingestion_fingerprint = ingestion_fingerprint.clone();
shard_record.shard.ingested_at = Utc::now();
shard_record.shard.embedding_backend = embedding_backend_label.clone();
@@ -320,7 +327,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
corpus_questions.push(CorpusQuestion {
question_id: case.question.id.clone(),
paragraph_id: case.paragraph.id.clone(),
text_content_id: record.shard.text_content.id.clone(),
text_content_id: record.shard.text_content.get_id().to_string(),
question_text: case.question.question.clone(),
answers: case.question.answers.clone(),
is_impossible: case.question.is_impossible,
@@ -361,7 +368,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
let reused_ingestion = ingested_count == 0 && !cache.force_refresh;
let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only;
Ok(CorpusHandle {
let handle = CorpusHandle {
manifest,
path: base_dir,
reused_ingestion,
@@ -370,64 +377,17 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
positive_ingested: stats.positive_ingested,
negative_reused: stats.negative_reused,
negative_ingested: stats.negative_ingested,
})
};
persist_manifest(&handle).context("persisting corpus manifest")?;
Ok(handle)
}
async fn reembed_entities<E: CorpusEmbeddingProvider>(
entities: &mut [KnowledgeEntity],
embedding: &E,
) -> Result<()> {
if entities.is_empty() {
return Ok(());
}
let payloads: Vec<String> = entities.iter().map(entity_embedding_text).collect();
let vectors = embedding.embed_batch(payloads).await?;
if vectors.len() != entities.len() {
return Err(anyhow!(
"entity embedding batch mismatch (expected {}, got {})",
entities.len(),
vectors.len()
));
}
for (entity, vector) in entities.iter_mut().zip(vectors.into_iter()) {
entity.embedding = vector;
}
Ok(())
}
async fn reembed_chunks<E: CorpusEmbeddingProvider>(
chunks: &mut [TextChunk],
embedding: &E,
) -> Result<()> {
if chunks.is_empty() {
return Ok(());
}
let payloads: Vec<String> = chunks.iter().map(|chunk| chunk.chunk.clone()).collect();
let vectors = embedding.embed_batch(payloads).await?;
if vectors.len() != chunks.len() {
return Err(anyhow!(
"chunk embedding batch mismatch (expected {}, got {})",
chunks.len(),
vectors.len()
));
}
for (chunk, vector) in chunks.iter_mut().zip(vectors.into_iter()) {
chunk.embedding = vector;
}
Ok(())
}
fn entity_embedding_text(entity: &KnowledgeEntity) -> String {
format!(
"name: {}\ndescription: {}\ntype: {:?}",
entity.name, entity.description, entity.entity_type
)
}
async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
async fn ingest_paragraph_batch(
dataset: &ConvertedDataset,
targets: &[IngestRequest<'_>],
embedding: &E,
embedding: Arc<common::utils::embedding::EmbeddingProvider>,
openai: Arc<OpenAIClient>,
user_id: &str,
ingestion_fingerprint: &str,
@@ -444,12 +404,16 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
let db = Arc::new(
SurrealDbClient::memory(&namespace, "corpus")
.await
.context("creating ingestion SurrealDB instance")?,
.context("creating in-memory surrealdb for ingestion")?,
);
db.apply_migrations()
.await
.context("applying migrations for ingestion")?;
change_embedding_length_in_hnsw_indexes(&db, embedding_dimension)
.await
.context("failed setting new hnsw length")?;
let mut app_config = AppConfig::default();
app_config.storage = StorageKind::Memory;
let backend: DynStore = Arc::new(InMemory::new());
@@ -461,6 +425,7 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
app_config,
None::<Arc<retrieval_pipeline::reranking::RerankerPool>>,
storage,
embedding.clone(),
)
.await?;
let pipeline = Arc::new(pipeline);
@@ -483,7 +448,6 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
pipeline_clone.clone(),
request,
category_clone.clone(),
embedding,
user_id,
ingestion_fingerprint,
backend_clone.clone(),
@@ -501,11 +465,10 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
Ok(shards)
}
async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
async fn ingest_single_paragraph(
pipeline: Arc<IngestionPipeline>,
request: IngestRequest<'_>,
category: String,
embedding: &E,
user_id: &str,
ingestion_fingerprint: &str,
embedding_backend: String,
@@ -524,17 +487,32 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
};
let task = IngestionTask::new(payload, user_id.to_string());
match pipeline.produce_artifacts(&task).await {
Ok(mut artifacts) => {
reembed_entities(&mut artifacts.entities, embedding).await?;
reembed_chunks(&mut artifacts.chunks, embedding).await?;
Ok(artifacts) => {
let entities: Vec<EmbeddedKnowledgeEntity> = artifacts
.entities
.into_iter()
.map(|e| EmbeddedKnowledgeEntity {
entity: e.entity,
embedding: e.embedding,
})
.collect();
let chunks: Vec<EmbeddedTextChunk> = artifacts
.chunks
.into_iter()
.map(|c| EmbeddedTextChunk {
chunk: c.chunk,
embedding: c.embedding,
})
.collect();
// No need to reembed - pipeline now uses FastEmbed internally
let mut shard = ParagraphShard::new(
paragraph,
request.shard_path,
ingestion_fingerprint,
artifacts.text_content,
artifacts.entities,
entities,
artifacts.relationships,
artifacts.chunks,
chunks,
&embedding_backend,
embedding_model.clone(),
embedding_dimension,
@@ -572,7 +550,11 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
.context(format!("running ingestion for paragraph {}", paragraph.id)))
}
fn build_ingestion_fingerprint(
pub fn cached_corpus_dir(cache: &CorpusCacheConfig, dataset_id: &str, slice_id: &str) -> PathBuf {
cache.ingestion_cache_dir.join(dataset_id).join(slice_id)
}
pub fn build_ingestion_fingerprint(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
checksum: &str,
@@ -592,6 +574,59 @@ fn build_ingestion_fingerprint(
)
}
pub fn compute_ingestion_fingerprint(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
converted_path: &Path,
) -> Result<String> {
let checksum = compute_file_checksum(converted_path)?;
Ok(build_ingestion_fingerprint(dataset, slice, &checksum))
}
pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
let path = base_dir.join("manifest.json");
if !path.exists() {
return Ok(None);
}
let mut file = fs::File::open(&path)
.with_context(|| format!("opening cached manifest {}", path.display()))?;
let mut buf = Vec::new();
file.read_to_end(&mut buf)
.with_context(|| format!("reading cached manifest {}", path.display()))?;
let manifest: CorpusManifest = serde_json::from_slice(&buf)
.with_context(|| format!("deserialising cached manifest {}", path.display()))?;
Ok(Some(manifest))
}
fn persist_manifest(handle: &CorpusHandle) -> Result<()> {
let path = handle.path.join("manifest.json");
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating manifest directory {}", parent.display()))?;
}
let tmp_path = path.with_extension("json.tmp");
let blob =
serde_json::to_vec_pretty(&handle.manifest).context("serialising corpus manifest")?;
fs::write(&tmp_path, &blob)
.with_context(|| format!("writing temporary manifest {}", tmp_path.display()))?;
fs::rename(&tmp_path, &path)
.with_context(|| format!("replacing manifest {}", path.display()))?;
Ok(())
}
pub fn corpus_handle_from_manifest(manifest: CorpusManifest, base_dir: PathBuf) -> CorpusHandle {
CorpusHandle {
manifest,
path: base_dir,
reused_ingestion: true,
reused_embeddings: true,
positive_reused: 0,
positive_ingested: 0,
negative_reused: 0,
negative_ingested: 0,
}
}
fn compute_file_checksum(path: &Path) -> Result<String> {
let mut file = fs::File::open(path)
.with_context(|| format!("opening file {} for checksum", path.display()))?;

View File

@@ -1,23 +1,126 @@
use std::{collections::HashMap, fs, io::BufReader, path::PathBuf};
use std::{
collections::{HashMap, HashSet},
fs,
io::BufReader,
path::PathBuf,
};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use common::storage::types::StoredObject;
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk, text_content::TextContent,
knowledge_entity::KnowledgeEntity,
knowledge_entity_embedding::KnowledgeEntityEmbedding,
knowledge_relationship::{KnowledgeRelationship, RelationshipMetadata},
text_chunk::TextChunk,
text_chunk_embedding::TextChunkEmbedding,
text_content::TextContent,
},
};
use serde::Deserialize;
use serde::Serialize;
use surrealdb::sql::Thing;
use tracing::warn;
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
pub const MANIFEST_VERSION: u32 = 1;
pub const PARAGRAPH_SHARD_VERSION: u32 = 1;
pub const MANIFEST_VERSION: u32 = 2;
pub const PARAGRAPH_SHARD_VERSION: u32 = 2;
const MANIFEST_BATCH_SIZE: usize = 100;
const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches
const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively
const MAX_BATCHES_PER_REQUEST: usize = 24;
const REQUEST_MAX_BYTES: usize = 800_000; // total payload cap per Surreal query request
fn current_manifest_version() -> u32 {
MANIFEST_VERSION
}
fn current_paragraph_shard_version() -> u32 {
PARAGRAPH_SHARD_VERSION
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EmbeddedKnowledgeEntity {
pub entity: KnowledgeEntity,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EmbeddedTextChunk {
pub chunk: TextChunk,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct LegacyKnowledgeEntity {
#[serde(flatten)]
pub entity: KnowledgeEntity,
#[serde(default)]
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct LegacyTextChunk {
#[serde(flatten)]
pub chunk: TextChunk,
#[serde(default)]
pub embedding: Vec<f32>,
}
fn deserialize_embedded_entities<'de, D>(
deserializer: D,
) -> Result<Vec<EmbeddedKnowledgeEntity>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum EntityInput {
Embedded(Vec<EmbeddedKnowledgeEntity>),
Legacy(Vec<LegacyKnowledgeEntity>),
}
match EntityInput::deserialize(deserializer)? {
EntityInput::Embedded(items) => Ok(items),
EntityInput::Legacy(items) => Ok(items
.into_iter()
.map(|legacy| EmbeddedKnowledgeEntity {
entity: legacy.entity,
embedding: legacy.embedding,
})
.collect()),
}
}
fn deserialize_embedded_chunks<'de, D>(deserializer: D) -> Result<Vec<EmbeddedTextChunk>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum ChunkInput {
Embedded(Vec<EmbeddedTextChunk>),
Legacy(Vec<LegacyTextChunk>),
}
match ChunkInput::deserialize(deserializer)? {
ChunkInput::Embedded(items) => Ok(items),
ChunkInput::Legacy(items) => Ok(items
.into_iter()
.map(|legacy| EmbeddedTextChunk {
chunk: legacy.chunk,
embedding: legacy.embedding,
})
.collect()),
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusManifest {
#[serde(default = "current_manifest_version")]
pub version: u32,
pub metadata: CorpusMetadata,
pub paragraphs: Vec<CorpusParagraph>,
@@ -47,9 +150,11 @@ pub struct CorpusParagraph {
pub paragraph_id: String,
pub title: String,
pub text_content: TextContent,
pub entities: Vec<KnowledgeEntity>,
#[serde(deserialize_with = "deserialize_embedded_entities")]
pub entities: Vec<EmbeddedKnowledgeEntity>,
pub relationships: Vec<KnowledgeRelationship>,
pub chunks: Vec<TextChunk>,
#[serde(deserialize_with = "deserialize_embedded_chunks")]
pub chunks: Vec<EmbeddedTextChunk>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
@@ -74,8 +179,189 @@ pub struct CorpusHandle {
pub negative_ingested: usize,
}
pub fn window_manifest(
manifest: &CorpusManifest,
offset: usize,
length: usize,
negative_multiplier: f32,
) -> Result<CorpusManifest> {
let total = manifest.questions.len();
if total == 0 {
return Err(anyhow!(
"manifest contains no questions; cannot select a window"
));
}
if offset >= total {
return Err(anyhow!(
"window offset {} exceeds manifest questions ({})",
offset,
total
));
}
let end = (offset + length).min(total);
let questions = manifest.questions[offset..end].to_vec();
let selected_positive_ids: HashSet<_> =
questions.iter().map(|q| q.paragraph_id.clone()).collect();
let positives_all: HashSet<_> = manifest
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect();
let available_negatives = manifest
.paragraphs
.len()
.saturating_sub(positives_all.len());
let desired_negatives =
((selected_positive_ids.len() as f32) * negative_multiplier).ceil() as usize;
let desired_negatives = desired_negatives.min(available_negatives);
let mut paragraphs = Vec::new();
let mut negative_count = 0usize;
for paragraph in &manifest.paragraphs {
if selected_positive_ids.contains(&paragraph.paragraph_id) {
paragraphs.push(paragraph.clone());
} else if negative_count < desired_negatives {
paragraphs.push(paragraph.clone());
negative_count += 1;
}
}
let mut narrowed = manifest.clone();
narrowed.questions = questions;
narrowed.paragraphs = paragraphs;
narrowed.metadata.paragraph_count = narrowed.paragraphs.len();
narrowed.metadata.question_count = narrowed.questions.len();
Ok(narrowed)
}
#[derive(Debug, Clone, Serialize)]
struct RelationInsert {
#[serde(rename = "in")]
pub in_: Thing,
#[serde(rename = "out")]
pub out: Thing,
pub id: String,
pub metadata: RelationshipMetadata,
}
#[derive(Debug)]
struct SizedBatch<T> {
approx_bytes: usize,
items: Vec<T>,
}
struct ManifestBatches {
text_contents: Vec<SizedBatch<TextContent>>,
entities: Vec<SizedBatch<KnowledgeEntity>>,
entity_embeddings: Vec<SizedBatch<KnowledgeEntityEmbedding>>,
relationships: Vec<SizedBatch<RelationInsert>>,
chunks: Vec<SizedBatch<TextChunk>>,
chunk_embeddings: Vec<SizedBatch<TextChunkEmbedding>>,
}
fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches> {
let mut text_contents = Vec::new();
let mut entities = Vec::new();
let mut entity_embeddings = Vec::new();
let mut relationships = Vec::new();
let mut chunks = Vec::new();
let mut chunk_embeddings = Vec::new();
let mut seen_text_content = HashSet::new();
let mut seen_entities = HashSet::new();
let mut seen_relationships = HashSet::new();
let mut seen_chunks = HashSet::new();
for paragraph in &manifest.paragraphs {
if seen_text_content.insert(paragraph.text_content.id.clone()) {
text_contents.push(paragraph.text_content.clone());
}
for embedded_entity in &paragraph.entities {
if seen_entities.insert(embedded_entity.entity.id.clone()) {
let entity = embedded_entity.entity.clone();
entities.push(entity.clone());
entity_embeddings.push(KnowledgeEntityEmbedding::new(
&entity.id,
embedded_entity.embedding.clone(),
entity.user_id.clone(),
));
}
}
for relationship in &paragraph.relationships {
if seen_relationships.insert(relationship.id.clone()) {
let table = KnowledgeEntity::table_name();
let in_id = relationship
.in_
.strip_prefix(&format!("{table}:"))
.unwrap_or(&relationship.in_);
let out_id = relationship
.out
.strip_prefix(&format!("{table}:"))
.unwrap_or(&relationship.out);
let in_thing = Thing::from((table, in_id));
let out_thing = Thing::from((table, out_id));
relationships.push(RelationInsert {
in_: in_thing,
out: out_thing,
id: relationship.id.clone(),
metadata: relationship.metadata.clone(),
});
}
}
for embedded_chunk in &paragraph.chunks {
if seen_chunks.insert(embedded_chunk.chunk.id.clone()) {
let chunk = embedded_chunk.chunk.clone();
chunks.push(chunk.clone());
chunk_embeddings.push(TextChunkEmbedding::new(
&chunk.id,
chunk.source_id.clone(),
embedded_chunk.embedding.clone(),
chunk.user_id.clone(),
));
}
}
}
Ok(ManifestBatches {
text_contents: chunk_items(
&text_contents,
MANIFEST_BATCH_SIZE,
TEXT_CONTENT_MAX_BYTES_PER_BATCH,
)
.context("chunking text_content payloads")?,
entities: chunk_items(&entities, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
.context("chunking knowledge_entity payloads")?,
entity_embeddings: chunk_items(
&entity_embeddings,
MANIFEST_BATCH_SIZE,
MANIFEST_MAX_BYTES_PER_BATCH,
)
.context("chunking knowledge_entity_embedding payloads")?,
relationships: chunk_items(
&relationships,
MANIFEST_BATCH_SIZE,
MANIFEST_MAX_BYTES_PER_BATCH,
)
.context("chunking relationship payloads")?,
chunks: chunk_items(&chunks, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
.context("chunking text_chunk payloads")?,
chunk_embeddings: chunk_items(
&chunk_embeddings,
MANIFEST_BATCH_SIZE,
MANIFEST_MAX_BYTES_PER_BATCH,
)
.context("chunking text_chunk_embedding payloads")?,
})
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ParagraphShard {
#[serde(default = "current_paragraph_shard_version")]
pub version: u32,
pub paragraph_id: String,
pub shard_path: String,
@@ -83,9 +369,11 @@ pub struct ParagraphShard {
pub ingested_at: DateTime<Utc>,
pub title: String,
pub text_content: TextContent,
pub entities: Vec<KnowledgeEntity>,
#[serde(deserialize_with = "deserialize_embedded_entities")]
pub entities: Vec<EmbeddedKnowledgeEntity>,
pub relationships: Vec<KnowledgeRelationship>,
pub chunks: Vec<TextChunk>,
#[serde(deserialize_with = "deserialize_embedded_chunks")]
pub chunks: Vec<EmbeddedTextChunk>,
#[serde(default)]
pub question_bindings: HashMap<String, Vec<String>>,
#[serde(default)]
@@ -126,30 +414,34 @@ impl ParagraphShardStore {
let reader = BufReader::new(file);
let mut shard: ParagraphShard = serde_json::from_reader(reader)
.with_context(|| format!("parsing shard {}", path.display()))?;
if shard.ingestion_fingerprint != fingerprint {
return Ok(None);
}
if shard.version != PARAGRAPH_SHARD_VERSION {
warn!(
path = %path.display(),
version = shard.version,
expected = PARAGRAPH_SHARD_VERSION,
"Skipping shard due to version mismatch"
"Upgrading shard to current version"
);
return Ok(None);
}
if shard.ingestion_fingerprint != fingerprint {
return Ok(None);
shard.version = PARAGRAPH_SHARD_VERSION;
}
shard.shard_path = relative.to_string();
Ok(Some(shard))
}
pub fn persist(&self, shard: &ParagraphShard) -> Result<()> {
let mut shard = shard.clone();
shard.version = PARAGRAPH_SHARD_VERSION;
let path = self.resolve(&shard.shard_path);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating shard dir {}", parent.display()))?;
}
let tmp_path = path.with_extension("json.tmp");
let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?;
let body = serde_json::to_vec_pretty(&shard).context("serialising paragraph shard")?;
fs::write(&tmp_path, &body)
.with_context(|| format!("writing shard tmp {}", tmp_path.display()))?;
fs::rename(&tmp_path, &path)
@@ -164,9 +456,9 @@ impl ParagraphShard {
shard_path: String,
ingestion_fingerprint: &str,
text_content: TextContent,
entities: Vec<KnowledgeEntity>,
entities: Vec<EmbeddedKnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
chunks: Vec<TextChunk>,
chunks: Vec<EmbeddedTextChunk>,
embedding_backend: &str,
embedding_model: Option<String>,
embedding_dimension: usize,
@@ -216,7 +508,7 @@ impl ParagraphShard {
fn validate_answers(
content: &TextContent,
chunks: &[TextChunk],
chunks: &[EmbeddedTextChunk],
question: &ConvertedQuestion,
) -> Result<Vec<String>> {
if question.is_impossible || question.answers.is_empty() {
@@ -236,12 +528,12 @@ fn validate_answers(
found_any = true;
}
for chunk in chunks {
let chunk_text = chunk.chunk.to_ascii_lowercase();
let chunk_text = chunk.chunk.chunk.to_ascii_lowercase();
let chunk_norm = normalize_answer_text(&chunk_text);
if chunk_text.contains(&needle)
|| (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm))
{
matches.insert(chunk.id.clone());
matches.insert(chunk.chunk.get_id().to_string());
found_any = true;
}
}
@@ -272,28 +564,492 @@ fn normalize_answer_text(text: &str) -> String {
.join(" ")
}
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
for paragraph in &manifest.paragraphs {
db.upsert_item(paragraph.text_content.clone())
fn chunk_items<T: Clone + Serialize>(
items: &[T],
max_items: usize,
max_bytes: usize,
) -> Result<Vec<SizedBatch<T>>> {
if items.is_empty() {
return Ok(Vec::new());
}
let mut batches = Vec::new();
let mut current = Vec::new();
let mut current_bytes = 0usize;
for item in items {
let size = serde_json::to_vec(item)
.map(|buf| buf.len())
.context("serialising batch item for sizing")?;
let would_overflow_items = !current.is_empty() && current.len() >= max_items;
let would_overflow_bytes = !current.is_empty() && current_bytes + size > max_bytes;
if would_overflow_items || would_overflow_bytes {
batches.push(SizedBatch {
approx_bytes: current_bytes.max(1),
items: std::mem::take(&mut current),
});
current_bytes = 0;
}
current_bytes += size;
current.push(item.clone());
}
if !current.is_empty() {
batches.push(SizedBatch {
approx_bytes: current_bytes.max(1),
items: current,
});
}
Ok(batches)
}
async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
db: &SurrealDbClient,
statement: impl AsRef<str>,
prefix: &str,
batches: &[SizedBatch<T>],
) -> Result<()> {
if batches.is_empty() {
return Ok(());
}
let mut start = 0;
while start < batches.len() {
let mut group_bytes = 0usize;
let mut group_end = start;
let mut group_count = 0usize;
while group_end < batches.len() {
let batch_bytes = batches[group_end].approx_bytes.max(1);
if group_count > 0
&& (group_bytes + batch_bytes > REQUEST_MAX_BYTES
|| group_count >= MAX_BATCHES_PER_REQUEST)
{
break;
}
group_bytes += batch_bytes;
group_end += 1;
group_count += 1;
}
let slice = &batches[start..group_end];
let mut query = db.client.query("BEGIN TRANSACTION;");
let mut bind_index = 0usize;
for batch in slice {
let name = format!("{prefix}{bind_index}");
bind_index += 1;
query = query
.query(format!("{} ${};", statement.as_ref(), name))
.bind((name, batch.items.clone()));
}
let response = query
.query("COMMIT TRANSACTION;")
.await
.context("storing text_content from manifest")?;
for entity in &paragraph.entities {
db.upsert_item(entity.clone())
.await
.context("storing knowledge_entity from manifest")?;
}
for relationship in &paragraph.relationships {
relationship
.store_relationship(db)
.await
.context("storing knowledge_relationship from manifest")?;
}
for chunk in &paragraph.chunks {
db.upsert_item(chunk.clone())
.await
.context("storing text_chunk from manifest")?;
.context("executing batched insert transaction")?;
if let Err(err) = response.check() {
return Err(anyhow!(
"batched insert failed for statement '{}': {err:?}",
statement.as_ref()
));
}
start = group_end;
}
Ok(())
}
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
let result = (|| async {
execute_batched_inserts(
db,
format!("INSERT INTO {}", TextContent::table_name()),
"tc",
&batches.text_contents,
)
.await?;
execute_batched_inserts(
db,
format!("INSERT INTO {}", KnowledgeEntity::table_name()),
"ke",
&batches.entities,
)
.await?;
execute_batched_inserts(
db,
format!("INSERT INTO {}", TextChunk::table_name()),
"ch",
&batches.chunks,
)
.await?;
execute_batched_inserts(
db,
"INSERT RELATION INTO relates_to",
"rel",
&batches.relationships,
)
.await?;
execute_batched_inserts(
db,
format!("INSERT INTO {}", KnowledgeEntityEmbedding::table_name()),
"kee",
&batches.entity_embeddings,
)
.await?;
execute_batched_inserts(
db,
format!("INSERT INTO {}", TextChunkEmbedding::table_name()),
"tce",
&batches.chunk_embeddings,
)
.await?;
Ok(())
})()
.await;
if result.is_err() {
// Best-effort cleanup to avoid leaving partial manifest data behind.
let _ = db
.client
.query(
"BEGIN TRANSACTION;
DELETE text_chunk_embedding;
DELETE knowledge_entity_embedding;
DELETE relates_to;
DELETE text_chunk;
DELETE knowledge_entity;
DELETE text_content;
COMMIT TRANSACTION;",
)
.await;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db_helpers::change_embedding_length_in_hnsw_indexes;
use chrono::Utc;
use common::storage::types::knowledge_entity::KnowledgeEntityType;
use uuid::Uuid;
fn build_manifest() -> CorpusManifest {
let user_id = "user-1".to_string();
let source_id = "source-1".to_string();
let now = Utc::now();
let text_content_id = Uuid::new_v4().to_string();
let text_content = TextContent {
id: text_content_id.clone(),
created_at: now,
updated_at: now,
text: "Hello world".to_string(),
file_info: None,
url_info: None,
context: None,
category: "test".to_string(),
user_id: user_id.clone(),
};
let entity = KnowledgeEntity {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id: source_id.clone(),
name: "Entity".to_string(),
description: "A test entity".to_string(),
entity_type: KnowledgeEntityType::Document,
metadata: None,
user_id: user_id.clone(),
};
let relationship = KnowledgeRelationship::new(
format!("knowledge_entity:{}", entity.id),
format!("knowledge_entity:{}", entity.id),
user_id.clone(),
source_id.clone(),
"related".to_string(),
);
let chunk = TextChunk {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id: source_id.clone(),
chunk: "chunk text".to_string(),
user_id: user_id.clone(),
};
let paragraph_one = CorpusParagraph {
paragraph_id: "p1".to_string(),
title: "Paragraph 1".to_string(),
text_content: text_content.clone(),
entities: vec![EmbeddedKnowledgeEntity {
entity: entity.clone(),
embedding: vec![0.1, 0.2, 0.3],
}],
relationships: vec![relationship],
chunks: vec![EmbeddedTextChunk {
chunk: chunk.clone(),
embedding: vec![0.3, 0.2, 0.1],
}],
};
// Duplicate content/entities should be de-duplicated by the loader.
let paragraph_two = CorpusParagraph {
paragraph_id: "p2".to_string(),
title: "Paragraph 2".to_string(),
text_content: text_content.clone(),
entities: vec![EmbeddedKnowledgeEntity {
entity: entity.clone(),
embedding: vec![0.1, 0.2, 0.3],
}],
relationships: Vec::new(),
chunks: vec![EmbeddedTextChunk {
chunk: chunk.clone(),
embedding: vec![0.3, 0.2, 0.1],
}],
};
let question = CorpusQuestion {
question_id: "q1".to_string(),
paragraph_id: paragraph_one.paragraph_id.clone(),
text_content_id: text_content_id,
question_text: "What is this?".to_string(),
answers: vec!["Hello".to_string()],
is_impossible: false,
matching_chunk_ids: vec![chunk.id.clone()],
};
CorpusManifest {
version: current_manifest_version(),
metadata: CorpusMetadata {
dataset_id: "dataset".to_string(),
dataset_label: "Dataset".to_string(),
slice_id: "slice".to_string(),
include_unanswerable: false,
require_verified_chunks: false,
ingestion_fingerprint: "fp".to_string(),
embedding_backend: "test".to_string(),
embedding_model: Some("model".to_string()),
embedding_dimension: 3,
converted_checksum: "checksum".to_string(),
generated_at: now,
paragraph_count: 2,
question_count: 1,
},
paragraphs: vec![paragraph_one, paragraph_two],
questions: vec![question],
}
}
#[tokio::test]
async fn seeds_manifest_with_transactional_batches() {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("memory db");
db.apply_migrations()
.await
.expect("apply migrations for memory db");
change_embedding_length_in_hnsw_indexes(&db, 3)
.await
.expect("set embedding index dimension for test");
let manifest = build_manifest();
seed_manifest_into_db(&db, &manifest)
.await
.expect("manifest seed should succeed");
let text_contents: Vec<TextContent> = db
.client
.query(format!("SELECT * FROM {};", TextContent::table_name()))
.await
.expect("select text_content")
.take(0)
.unwrap_or_default();
assert_eq!(text_contents.len(), 1);
let entities: Vec<KnowledgeEntity> = db
.client
.query(format!("SELECT * FROM {};", KnowledgeEntity::table_name()))
.await
.expect("select knowledge_entity")
.take(0)
.unwrap_or_default();
assert_eq!(entities.len(), 1);
let chunks: Vec<TextChunk> = db
.client
.query(format!("SELECT * FROM {};", TextChunk::table_name()))
.await
.expect("select text_chunk")
.take(0)
.unwrap_or_default();
assert_eq!(chunks.len(), 1);
let relationships: Vec<KnowledgeRelationship> = db
.client
.query("SELECT * FROM relates_to;")
.await
.expect("select relates_to")
.take(0)
.unwrap_or_default();
assert_eq!(relationships.len(), 1);
let entity_embeddings: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {};",
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("select knowledge_entity_embedding")
.take(0)
.unwrap_or_default();
assert_eq!(entity_embeddings.len(), 1);
let chunk_embeddings: Vec<TextChunkEmbedding> = db
.client
.query(format!(
"SELECT * FROM {};",
TextChunkEmbedding::table_name()
))
.await
.expect("select text_chunk_embedding")
.take(0)
.unwrap_or_default();
assert_eq!(chunk_embeddings.len(), 1);
}
#[tokio::test]
async fn rolls_back_when_embeddings_mismatch_index_dimension() {
let namespace = "test_ns_rollback";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("memory db");
db.apply_migrations()
.await
.expect("apply migrations for memory db");
let manifest = build_manifest();
let result = seed_manifest_into_db(&db, &manifest).await;
assert!(
result.is_err(),
"expected embedding dimension mismatch to fail"
);
let text_contents: Vec<TextContent> = db
.client
.query(format!("SELECT * FROM {};", TextContent::table_name()))
.await
.expect("select text_content")
.take(0)
.unwrap_or_default();
let entities: Vec<KnowledgeEntity> = db
.client
.query(format!("SELECT * FROM {};", KnowledgeEntity::table_name()))
.await
.expect("select knowledge_entity")
.take(0)
.unwrap_or_default();
let chunks: Vec<TextChunk> = db
.client
.query(format!("SELECT * FROM {};", TextChunk::table_name()))
.await
.expect("select text_chunk")
.take(0)
.unwrap_or_default();
let relationships: Vec<KnowledgeRelationship> = db
.client
.query("SELECT * FROM relates_to;")
.await
.expect("select relates_to")
.take(0)
.unwrap_or_default();
let entity_embeddings: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {};",
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("select knowledge_entity_embedding")
.take(0)
.unwrap_or_default();
let chunk_embeddings: Vec<TextChunkEmbedding> = db
.client
.query(format!(
"SELECT * FROM {};",
TextChunkEmbedding::table_name()
))
.await
.expect("select text_chunk_embedding")
.take(0)
.unwrap_or_default();
assert!(
text_contents.is_empty()
&& entities.is_empty()
&& chunks.is_empty()
&& relationships.is_empty()
&& entity_embeddings.is_empty()
&& chunk_embeddings.is_empty(),
"no rows should be inserted when transaction fails"
);
}
#[test]
fn window_manifest_trims_questions_and_negatives() {
let manifest = build_manifest();
// Add extra negatives to simulate multiplier ~4x
let mut manifest = manifest;
let mut extra_paragraphs = Vec::new();
for _ in 0..8 {
let mut p = manifest.paragraphs[0].clone();
p.paragraph_id = Uuid::new_v4().to_string();
p.entities.clear();
p.relationships.clear();
p.chunks.clear();
extra_paragraphs.push(p);
}
manifest.paragraphs.extend(extra_paragraphs);
manifest.metadata.paragraph_count = manifest.paragraphs.len();
let windowed = window_manifest(&manifest, 0, 1, 4.0).expect("window manifest");
assert_eq!(windowed.questions.len(), 1);
// Expect roughly 4x negatives (bounded by available paragraphs)
assert!(
windowed.paragraphs.len() <= manifest.paragraphs.len(),
"windowed paragraphs should never exceed original"
);
let positive_set: std::collections::HashSet<_> = windowed
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect();
let positives = windowed
.paragraphs
.iter()
.filter(|p| positive_set.contains(p.paragraph_id.as_str()))
.count();
let negatives = windowed.paragraphs.len().saturating_sub(positives);
assert_eq!(positives, 1);
assert!(negatives >= 1, "should include some negatives");
}
}

View File

@@ -121,13 +121,14 @@ fn build_chunk_lookup(manifest: &ingest::CorpusManifest) -> HashMap<String, Chun
for paragraph in &manifest.paragraphs {
for chunk in &paragraph.chunks {
let snippet = chunk
.chunk
.chunk
.chars()
.take(160)
.collect::<String>()
.replace('\n', " ");
lookup.insert(
chunk.id.clone(),
chunk.chunk.id.clone(),
ChunkEntry {
paragraph_title: paragraph.title.clone(),
snippet,

View File

@@ -6,7 +6,8 @@ use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::fs;
use crate::{args::Config, embedding::EmbeddingProvider, slice};
use crate::{args::Config, slice};
use common::utils::embedding::EmbeddingProvider;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SnapshotMetadata {