mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-18 15:34:16 +01:00
evals: v3, ebeddings at the side
additional indexes
This commit is contained in:
@@ -8,6 +8,23 @@ use retrieval_pipeline::RetrievalStrategy;
|
||||
|
||||
use crate::datasets::DatasetKind;
|
||||
|
||||
fn workspace_root() -> PathBuf {
|
||||
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
manifest_dir.parent().unwrap_or(&manifest_dir).to_path_buf()
|
||||
}
|
||||
|
||||
fn default_report_dir() -> PathBuf {
|
||||
workspace_root().join("eval/reports")
|
||||
}
|
||||
|
||||
fn default_cache_dir() -> PathBuf {
|
||||
workspace_root().join("eval/cache")
|
||||
}
|
||||
|
||||
fn default_ingestion_cache_dir() -> PathBuf {
|
||||
workspace_root().join("eval/cache/ingested")
|
||||
}
|
||||
|
||||
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -129,7 +146,7 @@ impl Default for Config {
|
||||
corpus_limit: None,
|
||||
raw_dataset_path: dataset.default_raw_path(),
|
||||
converted_dataset_path: dataset.default_converted_path(),
|
||||
report_dir: PathBuf::from("eval/reports"),
|
||||
report_dir: default_report_dir(),
|
||||
k: 5,
|
||||
limit: Some(200),
|
||||
summary_sample: 5,
|
||||
@@ -138,8 +155,8 @@ impl Default for Config {
|
||||
concurrency: 4,
|
||||
embedding_backend: EmbeddingBackend::FastEmbed,
|
||||
embedding_model: None,
|
||||
cache_dir: PathBuf::from("eval/cache"),
|
||||
ingestion_cache_dir: PathBuf::from("eval/cache/ingested"),
|
||||
cache_dir: default_cache_dir(),
|
||||
ingestion_cache_dir: default_ingestion_cache_dir(),
|
||||
ingestion_batch_size: 5,
|
||||
ingestion_max_retries: 3,
|
||||
refresh_embeddings_only: false,
|
||||
@@ -585,6 +602,13 @@ where
|
||||
}
|
||||
|
||||
pub fn print_help() {
|
||||
let report_default = default_report_dir();
|
||||
let cache_default = default_cache_dir();
|
||||
let ingestion_cache_default = default_ingestion_cache_dir();
|
||||
let report_default_display = report_default.display();
|
||||
let cache_default_display = cache_default.display();
|
||||
let ingestion_cache_default_display = ingestion_cache_default.display();
|
||||
|
||||
println!(
|
||||
"\
|
||||
eval — dataset conversion, ingestion, and retrieval evaluation CLI
|
||||
@@ -610,7 +634,7 @@ OPTIONS:
|
||||
--corpus-limit <int> Cap the slice corpus size (positives + negatives). Defaults to ~10× --limit, capped at 1000.
|
||||
--raw <path> Path to the raw dataset (defaults per dataset).
|
||||
--converted <path> Path to write/read the converted dataset (defaults per dataset).
|
||||
--report-dir <path> Directory to write evaluation reports (default: eval/reports).
|
||||
--report-dir <path> Directory to write evaluation reports (default: {report_default_display}).
|
||||
--k <int> Precision@k cutoff (default: 5).
|
||||
--limit <int> Limit the number of questions evaluated (default: 200, 0 = all).
|
||||
--sample <int> Number of mismatches to surface in the Markdown summary (default: 5).
|
||||
@@ -632,9 +656,9 @@ OPTIONS:
|
||||
--embedding <name> Embedding backend: 'fastembed' (default) or 'hashed'.
|
||||
--embedding-model <code>
|
||||
FastEmbed model code (defaults to crate preset when omitted).
|
||||
--cache-dir <path> Directory for embedding caches (default: eval/cache).
|
||||
--cache-dir <path> Directory for embedding caches (default: {cache_default_display}).
|
||||
--ingestion-cache-dir <path>
|
||||
Directory for ingestion corpora caches (default: eval/cache/ingested).
|
||||
Directory for ingestion corpora caches (default: {ingestion_cache_default_display}).
|
||||
--ingestion-batch-size <int>
|
||||
Number of paragraphs to ingest concurrently (default: 5).
|
||||
--ingestion-max-retries <int>
|
||||
|
||||
@@ -1,187 +1,30 @@
|
||||
use anyhow::{Context, Result};
|
||||
use common::storage::db::SurrealDbClient;
|
||||
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes};
|
||||
use serde::Deserialize;
|
||||
use tracing::info;
|
||||
|
||||
// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings
|
||||
// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings.
|
||||
pub async fn change_embedding_length_in_hnsw_indexes(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<()> {
|
||||
tracing::info!("Changing embedding length in HNSW indexes");
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
|
||||
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim};
|
||||
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim};
|
||||
COMMIT TRANSACTION;",
|
||||
dim = dimension
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.await
|
||||
.context("changing HNSW indexes")?;
|
||||
tracing::info!("HNSW indexes successfully changed");
|
||||
// No-op for now; runtime indexes are created after ingestion with the correct dimension.
|
||||
let _ = (db, dimension);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper functions for index management during namespace reseed
|
||||
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
|
||||
tracing::info!("Removing ALL indexes before namespace reseed (aggressive approach)");
|
||||
|
||||
// Remove ALL indexes from ALL tables to ensure no cache access
|
||||
db.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
-- HNSW indexes
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
|
||||
|
||||
-- FTS indexes on text_content (remove ALL of them)
|
||||
REMOVE INDEX IF EXISTS text_content_fts_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_text_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_category_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_context_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_url_idx ON TABLE text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON TABLE text_content;
|
||||
|
||||
-- FTS indexes on knowledge_entity
|
||||
REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity;
|
||||
REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity;
|
||||
|
||||
-- FTS indexes on text_chunk
|
||||
REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk;
|
||||
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.context("removing all indexes before namespace reseed")?;
|
||||
|
||||
tracing::info!("All indexes removed before namespace reseed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_tokenizer(db: &SurrealDbClient) -> Result<()> {
|
||||
tracing::info!("Creating FTS analyzers for namespace reseed");
|
||||
let res = db
|
||||
.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.context("creating FTS analyzers for namespace reseed")?;
|
||||
|
||||
res.check().context("failed to create the tokenizer")?;
|
||||
let _ = db;
|
||||
info!("Removing ALL indexes before namespace reseed (no-op placeholder)");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
tracing::info!("Recreating ALL indexes after namespace reseed (SEQUENTIAL approach)");
|
||||
let total_start = std::time::Instant::now();
|
||||
|
||||
create_tokenizer(db)
|
||||
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
|
||||
ensure_runtime_indexes(db, dimension)
|
||||
.await
|
||||
.context("creating FTS analyzer")?;
|
||||
|
||||
// For now we dont remove these plain indexes, we could if they prove negatively impacting performance
|
||||
// create_regular_indexes_for_snapshot(db)
|
||||
// .await
|
||||
// .context("creating regular indexes for namespace reseed")?;
|
||||
|
||||
let fts_start = std::time::Instant::now();
|
||||
create_fts_indexes_for_snapshot(db)
|
||||
.await
|
||||
.context("creating FTS indexes for namespace reseed")?;
|
||||
tracing::info!(duration = ?fts_start.elapsed(), "FTS indexes created");
|
||||
|
||||
let hnsw_start = std::time::Instant::now();
|
||||
create_hnsw_indexes_for_snapshot(db, dimension)
|
||||
.await
|
||||
.context("creating HNSW indexes for namespace reseed")?;
|
||||
tracing::info!(duration = ?hnsw_start.elapsed(), "HNSW indexes created");
|
||||
|
||||
tracing::info!(duration = ?total_start.elapsed(), "All index groups recreated successfully in sequence");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // For now we dont do this. We could
|
||||
async fn create_regular_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> {
|
||||
tracing::info!("Creating regular indexes for namespace reseed (parallel group 1)");
|
||||
let res = db
|
||||
.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
DEFINE INDEX text_content_user_id_idx ON text_content FIELDS user_id;
|
||||
DEFINE INDEX text_content_created_at_idx ON text_content FIELDS created_at;
|
||||
DEFINE INDEX text_content_category_idx ON text_content FIELDS category;
|
||||
DEFINE INDEX text_chunk_source_id_idx ON text_chunk FIELDS source_id;
|
||||
DEFINE INDEX text_chunk_user_id_idx ON text_chunk FIELDS user_id;
|
||||
DEFINE INDEX knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||
DEFINE INDEX knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
||||
DEFINE INDEX knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
||||
DEFINE INDEX knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.context("creating regular indexes for namespace reseed")?;
|
||||
|
||||
res.check().context("one of the regular indexes failed")?;
|
||||
|
||||
tracing::info!("Regular indexes for namespace reseed created");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_fts_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> {
|
||||
tracing::info!("Creating FTS indexes for namespace reseed (group 2)");
|
||||
let res = db.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
DEFINE INDEX text_content_fts_idx ON TABLE text_content FIELDS text;
|
||||
DEFINE INDEX knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
DEFINE INDEX knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
DEFINE INDEX text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.context("sending FTS index creation query")?;
|
||||
|
||||
// This actually surfaces statement-level errors
|
||||
res.check()
|
||||
.context("one or more FTS index statements failed")?;
|
||||
|
||||
tracing::info!("FTS indexes for namespace reseed created");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_hnsw_indexes_for_snapshot(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
tracing::info!("Creating HNSW indexes for namespace reseed (group 3)");
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim};
|
||||
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim};
|
||||
COMMIT TRANSACTION;",
|
||||
dim = dimension
|
||||
);
|
||||
|
||||
let res = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.context("creating HNSW indexes for namespace reseed")?;
|
||||
|
||||
res.check()
|
||||
.context("one or more HNSW index statements failed")?;
|
||||
|
||||
tracing::info!("HNSW indexes for namespace reseed created");
|
||||
Ok(())
|
||||
.context("creating runtime indexes")
|
||||
}
|
||||
|
||||
pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &str) -> Result<()> {
|
||||
@@ -207,7 +50,6 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -12,7 +12,7 @@ use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{system_settings::SystemSettings, user::User},
|
||||
types::{system_settings::SystemSettings, user::User, StoredObject},
|
||||
},
|
||||
};
|
||||
use retrieval_pipeline::RetrievalTuning;
|
||||
@@ -172,18 +172,26 @@ pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> R
|
||||
|
||||
info!("Warming HNSW caches with sample queries");
|
||||
|
||||
// Warm up chunk index
|
||||
// Warm up chunk embedding index - just query the embedding table to load HNSW index
|
||||
let _ = db
|
||||
.client
|
||||
.query("SELECT * FROM text_chunk WHERE embedding <|1,1|> $embedding LIMIT 5")
|
||||
.query(
|
||||
"SELECT chunk_id \
|
||||
FROM text_chunk_embedding \
|
||||
WHERE embedding <|1,1|> $embedding LIMIT 5",
|
||||
)
|
||||
.bind(("embedding", dummy_embedding.clone()))
|
||||
.await
|
||||
.context("warming text chunk HNSW cache")?;
|
||||
|
||||
// Warm up entity index
|
||||
// Warm up entity embedding index
|
||||
let _ = db
|
||||
.client
|
||||
.query("SELECT * FROM knowledge_entity WHERE embedding <|1,1|> $embedding LIMIT 5")
|
||||
.query(
|
||||
"SELECT entity_id \
|
||||
FROM knowledge_entity_embedding \
|
||||
WHERE embedding <|1,1|> $embedding LIMIT 5",
|
||||
)
|
||||
.bind(("embedding", dummy_embedding))
|
||||
.await
|
||||
.context("warming knowledge entity HNSW cache")?;
|
||||
@@ -206,7 +214,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
||||
timezone: "UTC".to_string(),
|
||||
};
|
||||
|
||||
if let Some(existing) = db.get_item::<User>(&user.id).await? {
|
||||
if let Some(existing) = db.get_item::<User>(&user.get_id()).await? {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
@@ -321,11 +329,11 @@ pub(crate) async fn can_reuse_namespace(
|
||||
}
|
||||
};
|
||||
|
||||
if state.slice_case_count < slice_case_count {
|
||||
if state.slice_case_count != slice_case_count {
|
||||
info!(
|
||||
requested_cases = slice_case_count,
|
||||
stored_cases = state.slice_case_count,
|
||||
"Skipping live namespace reuse; ledger grew beyond cached state"
|
||||
"Skipping live namespace reuse; cached state does not match requested window"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
@@ -420,12 +428,12 @@ pub(crate) async fn enforce_system_settings(
|
||||
) -> Result<SystemSettings> {
|
||||
let mut updated_settings = settings.clone();
|
||||
let mut needs_settings_update = false;
|
||||
let mut embedding_dimension_changed = false;
|
||||
// let mut embedding_dimension_changed = false;
|
||||
|
||||
if provider_dimension != settings.embedding_dimensions as usize {
|
||||
updated_settings.embedding_dimensions = provider_dimension as u32;
|
||||
needs_settings_update = true;
|
||||
embedding_dimension_changed = true;
|
||||
// embedding_dimension_changed = true;
|
||||
}
|
||||
if let Some(query_override) = config.query_model.as_deref() {
|
||||
if settings.query_model != query_override {
|
||||
@@ -442,16 +450,18 @@ pub(crate) async fn enforce_system_settings(
|
||||
.await
|
||||
.context("updating system settings overrides")?;
|
||||
}
|
||||
if embedding_dimension_changed {
|
||||
change_embedding_length_in_hnsw_indexes(db, provider_dimension)
|
||||
.await
|
||||
.context("redefining HNSW indexes for new embedding dimension")?;
|
||||
}
|
||||
// We dont need to do this, we've changed from default settings already
|
||||
// if embedding_dimension_changed {
|
||||
// change_embedding_length_in_hnsw_indexes(db, provider_dimension)
|
||||
// .await
|
||||
// .context("redefining HNSW indexes for new embedding dimension")?;
|
||||
// }
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
pub(crate) async fn load_or_init_system_settings(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(SystemSettings, bool)> {
|
||||
match SystemSettings::get_current(db).await {
|
||||
Ok(settings) => Ok((settings, false)),
|
||||
@@ -460,7 +470,6 @@ pub(crate) async fn load_or_init_system_settings(
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.context("applying database migrations after missing system settings")?;
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
let settings = SystemSettings::get_current(db)
|
||||
.await
|
||||
.context("loading system settings after migrations")?;
|
||||
@@ -473,8 +482,8 @@ pub(crate) async fn load_or_init_system_settings(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ingest::store::CorpusParagraph;
|
||||
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion};
|
||||
use crate::ingest::store::{CorpusParagraph, EmbeddedKnowledgeEntity, EmbeddedTextChunk};
|
||||
use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion, MANIFEST_VERSION};
|
||||
use chrono::Utc;
|
||||
use common::storage::types::text_content::TextContent;
|
||||
|
||||
@@ -491,9 +500,9 @@ mod tests {
|
||||
None,
|
||||
"user".to_string(),
|
||||
),
|
||||
entities: Vec::new(),
|
||||
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
|
||||
relationships: Vec::new(),
|
||||
chunks: Vec::new(),
|
||||
chunks: Vec::<EmbeddedTextChunk>::new(),
|
||||
},
|
||||
CorpusParagraph {
|
||||
paragraph_id: "p2".to_string(),
|
||||
@@ -506,9 +515,9 @@ mod tests {
|
||||
None,
|
||||
"user".to_string(),
|
||||
),
|
||||
entities: Vec::new(),
|
||||
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
|
||||
relationships: Vec::new(),
|
||||
chunks: Vec::new(),
|
||||
chunks: Vec::<EmbeddedTextChunk>::new(),
|
||||
},
|
||||
];
|
||||
let questions = vec![
|
||||
@@ -541,7 +550,7 @@ mod tests {
|
||||
},
|
||||
];
|
||||
CorpusManifest {
|
||||
version: 1,
|
||||
version: MANIFEST_VERSION,
|
||||
metadata: CorpusMetadata {
|
||||
dataset_id: "ds".to_string(),
|
||||
dataset_label: "Dataset".to_string(),
|
||||
|
||||
@@ -5,9 +5,12 @@ use std::{
|
||||
};
|
||||
|
||||
use async_openai::Client;
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{system_settings::SystemSettings, user::User},
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{system_settings::SystemSettings, user::User},
|
||||
},
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
pipeline::{PipelineStageTimings, RetrievalConfig},
|
||||
@@ -18,7 +21,6 @@ use crate::{
|
||||
args::Config,
|
||||
cache::EmbeddingCache,
|
||||
datasets::ConvertedDataset,
|
||||
embedding::EmbeddingProvider,
|
||||
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
|
||||
ingest, slice, snapshot,
|
||||
};
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::time::Instant;
|
||||
use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{ingest, slice, snapshot};
|
||||
use crate::{eval::can_reuse_namespace, ingest, slice, snapshot};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
@@ -26,19 +26,78 @@ pub(crate) async fn prepare_corpus(
|
||||
let cache_settings = ingest::CorpusCacheConfig::from(config);
|
||||
let embedding_provider = ctx.embedding_provider().clone();
|
||||
let openai_client = ctx.openai_client();
|
||||
let slice = ctx.slice();
|
||||
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
|
||||
.context("selecting slice window for corpus preparation")?;
|
||||
|
||||
let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider());
|
||||
let expected_fingerprint = ingest::compute_ingestion_fingerprint(
|
||||
ctx.dataset(),
|
||||
slice,
|
||||
config.converted_dataset_path.as_path(),
|
||||
)?;
|
||||
let base_dir = ingest::cached_corpus_dir(
|
||||
&cache_settings,
|
||||
ctx.dataset().metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
);
|
||||
|
||||
if !config.reseed_slice {
|
||||
let requested_cases = window.cases.len();
|
||||
if can_reuse_namespace(
|
||||
ctx.db(),
|
||||
&descriptor,
|
||||
&ctx.namespace,
|
||||
&ctx.database,
|
||||
ctx.dataset().metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
if let Some(manifest) = ingest::load_cached_manifest(&base_dir)? {
|
||||
info!(
|
||||
cache = %base_dir.display(),
|
||||
namespace = ctx.namespace.as_str(),
|
||||
database = ctx.database.as_str(),
|
||||
"Namespace already seeded; reusing cached corpus manifest"
|
||||
);
|
||||
let corpus_handle = ingest::corpus_handle_from_manifest(manifest, base_dir);
|
||||
ctx.corpus_handle = Some(corpus_handle);
|
||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||
ctx.ingestion_duration_ms = 0;
|
||||
ctx.descriptor = Some(descriptor);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
return machine
|
||||
.prepare_corpus()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_corpus", guard));
|
||||
} else {
|
||||
info!(
|
||||
cache = %base_dir.display(),
|
||||
"Namespace reusable but cached manifest missing; regenerating corpus"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let eval_user_id = "eval-user".to_string();
|
||||
let ingestion_timer = Instant::now();
|
||||
let corpus_handle = {
|
||||
let slice = ctx.slice();
|
||||
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
|
||||
.context("selecting slice window for corpus preparation")?;
|
||||
ingest::ensure_corpus(
|
||||
ctx.dataset(),
|
||||
slice,
|
||||
&window,
|
||||
&cache_settings,
|
||||
&embedding_provider,
|
||||
embedding_provider.clone().into(),
|
||||
openai_client,
|
||||
&eval_user_id,
|
||||
config.converted_dataset_path.as_path(),
|
||||
@@ -64,11 +123,7 @@ pub(crate) async fn prepare_corpus(
|
||||
ctx.corpus_handle = Some(corpus_handle);
|
||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||
ctx.ingestion_duration_ms = ingestion_duration_ms;
|
||||
ctx.descriptor = Some(snapshot::Descriptor::new(
|
||||
config,
|
||||
ctx.slice(),
|
||||
ctx.embedding_provider(),
|
||||
));
|
||||
ctx.descriptor = Some(descriptor);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
|
||||
@@ -6,12 +6,12 @@ use tracing::info;
|
||||
use crate::{
|
||||
args::EmbeddingBackend,
|
||||
cache::EmbeddingCache,
|
||||
embedding,
|
||||
eval::{
|
||||
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
|
||||
},
|
||||
openai,
|
||||
};
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
@@ -35,15 +35,22 @@ pub(crate) async fn prepare_db(
|
||||
let config = ctx.config();
|
||||
|
||||
let db = connect_eval_db(config, &namespace, &database).await?;
|
||||
let (mut settings, settings_missing) = load_or_init_system_settings(&db).await?;
|
||||
|
||||
let embedding_provider =
|
||||
embedding::build_provider(config, settings.embedding_dimensions as usize)
|
||||
.await
|
||||
.context("building embedding provider")?;
|
||||
let (raw_openai_client, openai_base_url) =
|
||||
openai::build_client_from_env().context("building OpenAI client")?;
|
||||
let openai_client = Arc::new(raw_openai_client);
|
||||
|
||||
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
|
||||
let embedding_provider = match config.embedding_backend {
|
||||
crate::args::EmbeddingBackend::FastEmbed => {
|
||||
EmbeddingProvider::new_fastembed(config.embedding_model.clone())
|
||||
.await
|
||||
.context("creating FastEmbed provider")?
|
||||
}
|
||||
crate::args::EmbeddingBackend::Hashed => {
|
||||
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
|
||||
}
|
||||
};
|
||||
let provider_dimension = embedding_provider.dimension();
|
||||
if provider_dimension == 0 {
|
||||
return Err(anyhow!(
|
||||
@@ -62,6 +69,9 @@ pub(crate) async fn prepare_db(
|
||||
);
|
||||
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
|
||||
|
||||
let (mut settings, settings_missing) =
|
||||
load_or_init_system_settings(&db, provider_dimension).await?;
|
||||
|
||||
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
|
||||
if let Some(model_code) = embedding_provider.model_code() {
|
||||
let sanitized = sanitize_model_code(&model_code);
|
||||
|
||||
@@ -41,6 +41,22 @@ pub(crate) async fn prepare_namespace(
|
||||
let database = ctx.database.clone();
|
||||
let embedding_provider = ctx.embedding_provider().clone();
|
||||
|
||||
let corpus_handle = ctx.corpus_handle();
|
||||
let base_manifest = &corpus_handle.manifest;
|
||||
let manifest_for_seed =
|
||||
if ctx.window_offset == 0 && ctx.window_length >= base_manifest.questions.len() {
|
||||
base_manifest.clone()
|
||||
} else {
|
||||
ingest::window_manifest(
|
||||
base_manifest,
|
||||
ctx.window_offset,
|
||||
ctx.window_length,
|
||||
ctx.config().negative_multiplier,
|
||||
)
|
||||
.context("selecting manifest window for seeding")?
|
||||
};
|
||||
let requested_cases = manifest_for_seed.questions.len();
|
||||
|
||||
let mut namespace_reused = false;
|
||||
if !config.reseed_slice {
|
||||
namespace_reused = {
|
||||
@@ -53,7 +69,7 @@ pub(crate) async fn prepare_namespace(
|
||||
dataset.metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
slice.manifest.case_count,
|
||||
requested_cases,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
@@ -79,25 +95,39 @@ pub(crate) async fn prepare_namespace(
|
||||
slice = slice.manifest.slice_id.as_str(),
|
||||
window_offset = ctx.window_offset,
|
||||
window_length = ctx.window_length,
|
||||
positives = slice.manifest.positive_paragraphs,
|
||||
negatives = slice.manifest.negative_paragraphs,
|
||||
total = slice.manifest.total_paragraphs,
|
||||
positives = manifest_for_seed
|
||||
.questions
|
||||
.iter()
|
||||
.map(|q| q.paragraph_id.as_str())
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.len(),
|
||||
negatives = manifest_for_seed.paragraphs.len().saturating_sub(
|
||||
manifest_for_seed
|
||||
.questions
|
||||
.iter()
|
||||
.map(|q| q.paragraph_id.as_str())
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.len(),
|
||||
),
|
||||
total = manifest_for_seed.paragraphs.len(),
|
||||
"Seeding ingestion corpus into SurrealDB"
|
||||
);
|
||||
}
|
||||
let indexes_disabled = remove_all_indexes(ctx.db()).await.is_ok();
|
||||
|
||||
let seed_start = Instant::now();
|
||||
ingest::seed_manifest_into_db(ctx.db(), &ctx.corpus_handle().manifest)
|
||||
ingest::seed_manifest_into_db(ctx.db(), &manifest_for_seed)
|
||||
.await
|
||||
.context("seeding ingestion corpus from manifest")?;
|
||||
namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128);
|
||||
|
||||
// Recreate indexes AFTER data is loaded (correct bulk loading pattern)
|
||||
if indexes_disabled {
|
||||
info!("Recreating indexes after namespace reset");
|
||||
if let Err(err) = recreate_indexes(ctx.db(), embedding_provider.dimension()).await {
|
||||
warn!(error = %err, "failed to restore indexes after namespace reset");
|
||||
} else {
|
||||
warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?;
|
||||
}
|
||||
info!("Recreating indexes after seeding data");
|
||||
recreate_indexes(ctx.db(), embedding_provider.dimension())
|
||||
.await
|
||||
.context("recreating indexes with correct dimension")?;
|
||||
warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?;
|
||||
}
|
||||
{
|
||||
let slice = ctx.slice();
|
||||
@@ -108,7 +138,7 @@ pub(crate) async fn prepare_namespace(
|
||||
expected_fingerprint.as_str(),
|
||||
&namespace,
|
||||
&database,
|
||||
slice.manifest.case_count,
|
||||
requested_cases,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -128,11 +158,10 @@ pub(crate) async fn prepare_namespace(
|
||||
let user = ensure_eval_user(ctx.db()).await?;
|
||||
ctx.eval_user = Some(user);
|
||||
|
||||
let corpus_handle = ctx.corpus_handle();
|
||||
let total_manifest_questions = corpus_handle.manifest.questions.len();
|
||||
let cases = cases_from_manifest(&corpus_handle.manifest);
|
||||
let include_impossible = corpus_handle.manifest.metadata.include_unanswerable;
|
||||
let require_verified_chunks = corpus_handle.manifest.metadata.require_verified_chunks;
|
||||
let total_manifest_questions = manifest_for_seed.questions.len();
|
||||
let cases = cases_from_manifest(&manifest_for_seed);
|
||||
let include_impossible = manifest_for_seed.metadata.include_unanswerable;
|
||||
let require_verified_chunks = manifest_for_seed.metadata.require_verified_chunks;
|
||||
let filtered = total_manifest_questions.saturating_sub(cases.len());
|
||||
if filtered > 0 {
|
||||
info!(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::{collections::HashSet, sync::Arc, time::Instant};
|
||||
|
||||
use anyhow::Context;
|
||||
use common::storage::types::StoredObject;
|
||||
use futures::stream::{self, StreamExt};
|
||||
use tracing::{debug, info};
|
||||
|
||||
@@ -174,6 +175,7 @@ pub(crate) async fn run_queries(
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics(
|
||||
&db,
|
||||
&openai_client,
|
||||
Some(&embedding_provider),
|
||||
query_embedding,
|
||||
&question,
|
||||
&user_id,
|
||||
@@ -187,6 +189,7 @@ pub(crate) async fn run_queries(
|
||||
let outcome = pipeline::run_pipeline_with_embedding_with_metrics(
|
||||
&db,
|
||||
&openai_client,
|
||||
Some(&embedding_provider),
|
||||
query_embedding,
|
||||
&question,
|
||||
&user_id,
|
||||
@@ -228,9 +231,10 @@ pub(crate) async fn run_queries(
|
||||
}
|
||||
let chunk_id_for_entity = if chunk_id_required {
|
||||
expected_chunk_ids_set.contains(candidate.source_id.as_str())
|
||||
|| candidate.chunks.iter().any(|chunk| {
|
||||
expected_chunk_ids_set.contains(chunk.chunk.id.as_str())
|
||||
})
|
||||
|| candidate
|
||||
.chunks
|
||||
.iter()
|
||||
.any(|chunk| expected_chunk_ids_set.contains(&chunk.chunk.get_id()))
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use common::storage::types::StoredObject;
|
||||
use retrieval_pipeline::{
|
||||
PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput,
|
||||
};
|
||||
@@ -164,7 +165,7 @@ impl EvaluationCandidate {
|
||||
fn from_entity(entity: RetrievedEntity) -> Self {
|
||||
let entity_category = Some(format!("{:?}", entity.entity.entity_type));
|
||||
Self {
|
||||
entity_id: entity.entity.id.clone(),
|
||||
entity_id: entity.entity.get_id().to_string(),
|
||||
source_id: entity.entity.source_id.clone(),
|
||||
entity_name: entity.entity.name.clone(),
|
||||
entity_description: Some(entity.entity.description.clone()),
|
||||
@@ -177,7 +178,7 @@ impl EvaluationCandidate {
|
||||
fn from_chunk(chunk: RetrievedChunk) -> Self {
|
||||
let snippet = chunk_snippet(&chunk.chunk.chunk);
|
||||
Self {
|
||||
entity_id: chunk.chunk.id.clone(),
|
||||
entity_id: chunk.chunk.get_id().to_string(),
|
||||
source_id: chunk.chunk.source_id.clone(),
|
||||
entity_name: chunk.chunk.source_id.clone(),
|
||||
entity_description: Some(snippet),
|
||||
@@ -301,7 +302,9 @@ pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageL
|
||||
graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.graph_expansion_ms()
|
||||
})),
|
||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms())),
|
||||
chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| {
|
||||
entry.chunk_attach_ms()
|
||||
})),
|
||||
rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms())),
|
||||
assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms())),
|
||||
}
|
||||
@@ -332,11 +335,11 @@ pub fn build_case_diagnostics(
|
||||
let mut chunk_entries = Vec::new();
|
||||
for chunk in &candidate.chunks {
|
||||
let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower);
|
||||
let expected_chunk = expected_set.contains(chunk.chunk.id.as_str());
|
||||
seen_chunks.insert(chunk.chunk.id.clone());
|
||||
attached_chunk_ids.push(chunk.chunk.id.clone());
|
||||
let expected_chunk = expected_set.contains(chunk.chunk.get_id());
|
||||
seen_chunks.insert(chunk.chunk.get_id().to_string());
|
||||
attached_chunk_ids.push(chunk.chunk.get_id().to_string());
|
||||
chunk_entries.push(ChunkDiagnosticsEntry {
|
||||
chunk_id: chunk.chunk.id.clone(),
|
||||
chunk_id: chunk.chunk.get_id().to_string(),
|
||||
score: chunk.score,
|
||||
contains_answer,
|
||||
expected_chunk,
|
||||
|
||||
@@ -59,6 +59,25 @@ impl CorpusEmbeddingProvider for EmbeddingProvider {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CorpusEmbeddingProvider for common::utils::embedding::EmbeddingProvider {
|
||||
fn backend_label(&self) -> &str {
|
||||
common::utils::embedding::EmbeddingProvider::backend_label(self)
|
||||
}
|
||||
|
||||
fn model_code(&self) -> Option<String> {
|
||||
common::utils::embedding::EmbeddingProvider::model_code(self)
|
||||
}
|
||||
|
||||
fn dimension(&self) -> usize {
|
||||
common::utils::embedding::EmbeddingProvider::dimension(self)
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
common::utils::embedding::EmbeddingProvider::embed_batch(self, texts).await
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Config> for CorpusCacheConfig {
|
||||
fn from(config: &Config) -> Self {
|
||||
CorpusCacheConfig::new(
|
||||
|
||||
@@ -2,9 +2,13 @@ mod config;
|
||||
mod orchestrator;
|
||||
pub(crate) mod store;
|
||||
|
||||
pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider};
|
||||
pub use orchestrator::ensure_corpus;
|
||||
pub use store::{
|
||||
seed_manifest_into_db, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
|
||||
ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||
pub use config::CorpusCacheConfig;
|
||||
pub use orchestrator::{
|
||||
cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus,
|
||||
load_cached_manifest,
|
||||
};
|
||||
pub use store::{
|
||||
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
|
||||
CorpusQuestion, EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard,
|
||||
ParagraphShardStore, MANIFEST_VERSION,
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
io::Read,
|
||||
path::Path,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
@@ -13,10 +13,7 @@ use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
store::{DynStore, StorageManager},
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
|
||||
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk,
|
||||
},
|
||||
types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject},
|
||||
},
|
||||
utils::config::{AppConfig, StorageKind},
|
||||
};
|
||||
@@ -29,12 +26,14 @@ use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion},
|
||||
db_helpers::change_embedding_length_in_hnsw_indexes,
|
||||
slices::{self, ResolvedSlice, SliceParagraphKind},
|
||||
};
|
||||
|
||||
use crate::ingest::{
|
||||
CorpusCacheConfig, CorpusEmbeddingProvider, CorpusHandle, CorpusManifest, CorpusMetadata,
|
||||
CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||
CorpusCacheConfig, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
|
||||
EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, ParagraphShardStore,
|
||||
MANIFEST_VERSION,
|
||||
};
|
||||
|
||||
const INGESTION_SPEC_VERSION: u32 = 1;
|
||||
@@ -108,12 +107,12 @@ struct IngestionStats {
|
||||
negative_ingested: usize,
|
||||
}
|
||||
|
||||
pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
pub async fn ensure_corpus(
|
||||
dataset: &ConvertedDataset,
|
||||
slice: &ResolvedSlice<'_>,
|
||||
window: &slices::SliceWindow<'_>,
|
||||
cache: &CorpusCacheConfig,
|
||||
embedding: &E,
|
||||
embedding: Arc<common::utils::embedding::EmbeddingProvider>,
|
||||
openai: Arc<OpenAIClient>,
|
||||
user_id: &str,
|
||||
converted_path: &Path,
|
||||
@@ -122,10 +121,11 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
.with_context(|| format!("computing checksum for {}", converted_path.display()))?;
|
||||
let ingestion_fingerprint = build_ingestion_fingerprint(dataset, slice, &checksum);
|
||||
|
||||
let base_dir = cache
|
||||
.ingestion_cache_dir
|
||||
.join(dataset.metadata.id.as_str())
|
||||
.join(slice.manifest.slice_id.as_str());
|
||||
let base_dir = cached_corpus_dir(
|
||||
cache,
|
||||
dataset.metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
);
|
||||
if cache.force_refresh && !cache.refresh_embeddings_only {
|
||||
let _ = fs::remove_dir_all(&base_dir);
|
||||
}
|
||||
@@ -144,11 +144,19 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
));
|
||||
}
|
||||
|
||||
let desired_negatives =
|
||||
((positive_set.len() as f32) * slice.manifest.negative_multiplier).ceil() as usize;
|
||||
let mut plan = Vec::new();
|
||||
let mut negatives_added = 0usize;
|
||||
for (idx, entry) in slice.manifest.paragraphs.iter().enumerate() {
|
||||
let include = match &entry.kind {
|
||||
SliceParagraphKind::Positive { .. } => positive_set.contains(entry.id.as_str()),
|
||||
SliceParagraphKind::Negative => true,
|
||||
SliceParagraphKind::Negative => {
|
||||
negatives_added < desired_negatives && {
|
||||
negatives_added += 1;
|
||||
true
|
||||
}
|
||||
}
|
||||
};
|
||||
if include {
|
||||
let paragraph = slice
|
||||
@@ -224,7 +232,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
let new_shards = ingest_paragraph_batch(
|
||||
dataset,
|
||||
&ingest_requests,
|
||||
embedding,
|
||||
embedding.clone(),
|
||||
openai.clone(),
|
||||
user_id,
|
||||
&ingestion_fingerprint,
|
||||
@@ -251,8 +259,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
.as_mut()
|
||||
.context("shard record missing after ingestion run")?;
|
||||
if cache.refresh_embeddings_only || shard_record.needs_reembed {
|
||||
reembed_entities(&mut shard_record.shard.entities, embedding).await?;
|
||||
reembed_chunks(&mut shard_record.shard.chunks, embedding).await?;
|
||||
// Embeddings are now generated by the pipeline using FastEmbed - no need to re-embed
|
||||
shard_record.shard.ingestion_fingerprint = ingestion_fingerprint.clone();
|
||||
shard_record.shard.ingested_at = Utc::now();
|
||||
shard_record.shard.embedding_backend = embedding_backend_label.clone();
|
||||
@@ -320,7 +327,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
corpus_questions.push(CorpusQuestion {
|
||||
question_id: case.question.id.clone(),
|
||||
paragraph_id: case.paragraph.id.clone(),
|
||||
text_content_id: record.shard.text_content.id.clone(),
|
||||
text_content_id: record.shard.text_content.get_id().to_string(),
|
||||
question_text: case.question.question.clone(),
|
||||
answers: case.question.answers.clone(),
|
||||
is_impossible: case.question.is_impossible,
|
||||
@@ -361,7 +368,7 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
let reused_ingestion = ingested_count == 0 && !cache.force_refresh;
|
||||
let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only;
|
||||
|
||||
Ok(CorpusHandle {
|
||||
let handle = CorpusHandle {
|
||||
manifest,
|
||||
path: base_dir,
|
||||
reused_ingestion,
|
||||
@@ -370,64 +377,17 @@ pub async fn ensure_corpus<E: CorpusEmbeddingProvider>(
|
||||
positive_ingested: stats.positive_ingested,
|
||||
negative_reused: stats.negative_reused,
|
||||
negative_ingested: stats.negative_ingested,
|
||||
})
|
||||
};
|
||||
|
||||
persist_manifest(&handle).context("persisting corpus manifest")?;
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
async fn reembed_entities<E: CorpusEmbeddingProvider>(
|
||||
entities: &mut [KnowledgeEntity],
|
||||
embedding: &E,
|
||||
) -> Result<()> {
|
||||
if entities.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let payloads: Vec<String> = entities.iter().map(entity_embedding_text).collect();
|
||||
let vectors = embedding.embed_batch(payloads).await?;
|
||||
if vectors.len() != entities.len() {
|
||||
return Err(anyhow!(
|
||||
"entity embedding batch mismatch (expected {}, got {})",
|
||||
entities.len(),
|
||||
vectors.len()
|
||||
));
|
||||
}
|
||||
for (entity, vector) in entities.iter_mut().zip(vectors.into_iter()) {
|
||||
entity.embedding = vector;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn reembed_chunks<E: CorpusEmbeddingProvider>(
|
||||
chunks: &mut [TextChunk],
|
||||
embedding: &E,
|
||||
) -> Result<()> {
|
||||
if chunks.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let payloads: Vec<String> = chunks.iter().map(|chunk| chunk.chunk.clone()).collect();
|
||||
let vectors = embedding.embed_batch(payloads).await?;
|
||||
if vectors.len() != chunks.len() {
|
||||
return Err(anyhow!(
|
||||
"chunk embedding batch mismatch (expected {}, got {})",
|
||||
chunks.len(),
|
||||
vectors.len()
|
||||
));
|
||||
}
|
||||
for (chunk, vector) in chunks.iter_mut().zip(vectors.into_iter()) {
|
||||
chunk.embedding = vector;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn entity_embedding_text(entity: &KnowledgeEntity) -> String {
|
||||
format!(
|
||||
"name: {}\ndescription: {}\ntype: {:?}",
|
||||
entity.name, entity.description, entity.entity_type
|
||||
)
|
||||
}
|
||||
|
||||
async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
async fn ingest_paragraph_batch(
|
||||
dataset: &ConvertedDataset,
|
||||
targets: &[IngestRequest<'_>],
|
||||
embedding: &E,
|
||||
embedding: Arc<common::utils::embedding::EmbeddingProvider>,
|
||||
openai: Arc<OpenAIClient>,
|
||||
user_id: &str,
|
||||
ingestion_fingerprint: &str,
|
||||
@@ -444,12 +404,16 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
let db = Arc::new(
|
||||
SurrealDbClient::memory(&namespace, "corpus")
|
||||
.await
|
||||
.context("creating ingestion SurrealDB instance")?,
|
||||
.context("creating in-memory surrealdb for ingestion")?,
|
||||
);
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.context("applying migrations for ingestion")?;
|
||||
|
||||
change_embedding_length_in_hnsw_indexes(&db, embedding_dimension)
|
||||
.await
|
||||
.context("failed setting new hnsw length")?;
|
||||
|
||||
let mut app_config = AppConfig::default();
|
||||
app_config.storage = StorageKind::Memory;
|
||||
let backend: DynStore = Arc::new(InMemory::new());
|
||||
@@ -461,6 +425,7 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
app_config,
|
||||
None::<Arc<retrieval_pipeline::reranking::RerankerPool>>,
|
||||
storage,
|
||||
embedding.clone(),
|
||||
)
|
||||
.await?;
|
||||
let pipeline = Arc::new(pipeline);
|
||||
@@ -483,7 +448,6 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
pipeline_clone.clone(),
|
||||
request,
|
||||
category_clone.clone(),
|
||||
embedding,
|
||||
user_id,
|
||||
ingestion_fingerprint,
|
||||
backend_clone.clone(),
|
||||
@@ -501,11 +465,10 @@ async fn ingest_paragraph_batch<E: CorpusEmbeddingProvider>(
|
||||
Ok(shards)
|
||||
}
|
||||
|
||||
async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
|
||||
async fn ingest_single_paragraph(
|
||||
pipeline: Arc<IngestionPipeline>,
|
||||
request: IngestRequest<'_>,
|
||||
category: String,
|
||||
embedding: &E,
|
||||
user_id: &str,
|
||||
ingestion_fingerprint: &str,
|
||||
embedding_backend: String,
|
||||
@@ -524,17 +487,32 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
|
||||
};
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
match pipeline.produce_artifacts(&task).await {
|
||||
Ok(mut artifacts) => {
|
||||
reembed_entities(&mut artifacts.entities, embedding).await?;
|
||||
reembed_chunks(&mut artifacts.chunks, embedding).await?;
|
||||
Ok(artifacts) => {
|
||||
let entities: Vec<EmbeddedKnowledgeEntity> = artifacts
|
||||
.entities
|
||||
.into_iter()
|
||||
.map(|e| EmbeddedKnowledgeEntity {
|
||||
entity: e.entity,
|
||||
embedding: e.embedding,
|
||||
})
|
||||
.collect();
|
||||
let chunks: Vec<EmbeddedTextChunk> = artifacts
|
||||
.chunks
|
||||
.into_iter()
|
||||
.map(|c| EmbeddedTextChunk {
|
||||
chunk: c.chunk,
|
||||
embedding: c.embedding,
|
||||
})
|
||||
.collect();
|
||||
// No need to reembed - pipeline now uses FastEmbed internally
|
||||
let mut shard = ParagraphShard::new(
|
||||
paragraph,
|
||||
request.shard_path,
|
||||
ingestion_fingerprint,
|
||||
artifacts.text_content,
|
||||
artifacts.entities,
|
||||
entities,
|
||||
artifacts.relationships,
|
||||
artifacts.chunks,
|
||||
chunks,
|
||||
&embedding_backend,
|
||||
embedding_model.clone(),
|
||||
embedding_dimension,
|
||||
@@ -572,7 +550,11 @@ async fn ingest_single_paragraph<E: CorpusEmbeddingProvider>(
|
||||
.context(format!("running ingestion for paragraph {}", paragraph.id)))
|
||||
}
|
||||
|
||||
fn build_ingestion_fingerprint(
|
||||
pub fn cached_corpus_dir(cache: &CorpusCacheConfig, dataset_id: &str, slice_id: &str) -> PathBuf {
|
||||
cache.ingestion_cache_dir.join(dataset_id).join(slice_id)
|
||||
}
|
||||
|
||||
pub fn build_ingestion_fingerprint(
|
||||
dataset: &ConvertedDataset,
|
||||
slice: &ResolvedSlice<'_>,
|
||||
checksum: &str,
|
||||
@@ -592,6 +574,59 @@ fn build_ingestion_fingerprint(
|
||||
)
|
||||
}
|
||||
|
||||
pub fn compute_ingestion_fingerprint(
|
||||
dataset: &ConvertedDataset,
|
||||
slice: &ResolvedSlice<'_>,
|
||||
converted_path: &Path,
|
||||
) -> Result<String> {
|
||||
let checksum = compute_file_checksum(converted_path)?;
|
||||
Ok(build_ingestion_fingerprint(dataset, slice, &checksum))
|
||||
}
|
||||
|
||||
pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
|
||||
let path = base_dir.join("manifest.json");
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let mut file = fs::File::open(&path)
|
||||
.with_context(|| format!("opening cached manifest {}", path.display()))?;
|
||||
let mut buf = Vec::new();
|
||||
file.read_to_end(&mut buf)
|
||||
.with_context(|| format!("reading cached manifest {}", path.display()))?;
|
||||
let manifest: CorpusManifest = serde_json::from_slice(&buf)
|
||||
.with_context(|| format!("deserialising cached manifest {}", path.display()))?;
|
||||
Ok(Some(manifest))
|
||||
}
|
||||
|
||||
fn persist_manifest(handle: &CorpusHandle) -> Result<()> {
|
||||
let path = handle.path.join("manifest.json");
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating manifest directory {}", parent.display()))?;
|
||||
}
|
||||
let tmp_path = path.with_extension("json.tmp");
|
||||
let blob =
|
||||
serde_json::to_vec_pretty(&handle.manifest).context("serialising corpus manifest")?;
|
||||
fs::write(&tmp_path, &blob)
|
||||
.with_context(|| format!("writing temporary manifest {}", tmp_path.display()))?;
|
||||
fs::rename(&tmp_path, &path)
|
||||
.with_context(|| format!("replacing manifest {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn corpus_handle_from_manifest(manifest: CorpusManifest, base_dir: PathBuf) -> CorpusHandle {
|
||||
CorpusHandle {
|
||||
manifest,
|
||||
path: base_dir,
|
||||
reused_ingestion: true,
|
||||
reused_embeddings: true,
|
||||
positive_reused: 0,
|
||||
positive_ingested: 0,
|
||||
negative_reused: 0,
|
||||
negative_ingested: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_file_checksum(path: &Path) -> Result<String> {
|
||||
let mut file = fs::File::open(path)
|
||||
.with_context(|| format!("opening file {} for checksum", path.display()))?;
|
||||
|
||||
@@ -1,23 +1,126 @@
|
||||
use std::{collections::HashMap, fs, io::BufReader, path::PathBuf};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
io::BufReader,
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use common::storage::types::StoredObject;
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk, text_content::TextContent,
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
knowledge_relationship::{KnowledgeRelationship, RelationshipMetadata},
|
||||
text_chunk::TextChunk,
|
||||
text_chunk_embedding::TextChunkEmbedding,
|
||||
text_content::TextContent,
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use surrealdb::sql::Thing;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
||||
|
||||
pub const MANIFEST_VERSION: u32 = 1;
|
||||
pub const PARAGRAPH_SHARD_VERSION: u32 = 1;
|
||||
pub const MANIFEST_VERSION: u32 = 2;
|
||||
pub const PARAGRAPH_SHARD_VERSION: u32 = 2;
|
||||
const MANIFEST_BATCH_SIZE: usize = 100;
|
||||
const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches
|
||||
const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively
|
||||
const MAX_BATCHES_PER_REQUEST: usize = 24;
|
||||
const REQUEST_MAX_BYTES: usize = 800_000; // total payload cap per Surreal query request
|
||||
|
||||
fn current_manifest_version() -> u32 {
|
||||
MANIFEST_VERSION
|
||||
}
|
||||
|
||||
fn current_paragraph_shard_version() -> u32 {
|
||||
PARAGRAPH_SHARD_VERSION
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct EmbeddedKnowledgeEntity {
|
||||
pub entity: KnowledgeEntity,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct EmbeddedTextChunk {
|
||||
pub chunk: TextChunk,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
struct LegacyKnowledgeEntity {
|
||||
#[serde(flatten)]
|
||||
pub entity: KnowledgeEntity,
|
||||
#[serde(default)]
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
struct LegacyTextChunk {
|
||||
#[serde(flatten)]
|
||||
pub chunk: TextChunk,
|
||||
#[serde(default)]
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
fn deserialize_embedded_entities<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<Vec<EmbeddedKnowledgeEntity>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum EntityInput {
|
||||
Embedded(Vec<EmbeddedKnowledgeEntity>),
|
||||
Legacy(Vec<LegacyKnowledgeEntity>),
|
||||
}
|
||||
|
||||
match EntityInput::deserialize(deserializer)? {
|
||||
EntityInput::Embedded(items) => Ok(items),
|
||||
EntityInput::Legacy(items) => Ok(items
|
||||
.into_iter()
|
||||
.map(|legacy| EmbeddedKnowledgeEntity {
|
||||
entity: legacy.entity,
|
||||
embedding: legacy.embedding,
|
||||
})
|
||||
.collect()),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_embedded_chunks<'de, D>(deserializer: D) -> Result<Vec<EmbeddedTextChunk>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ChunkInput {
|
||||
Embedded(Vec<EmbeddedTextChunk>),
|
||||
Legacy(Vec<LegacyTextChunk>),
|
||||
}
|
||||
|
||||
match ChunkInput::deserialize(deserializer)? {
|
||||
ChunkInput::Embedded(items) => Ok(items),
|
||||
ChunkInput::Legacy(items) => Ok(items
|
||||
.into_iter()
|
||||
.map(|legacy| EmbeddedTextChunk {
|
||||
chunk: legacy.chunk,
|
||||
embedding: legacy.embedding,
|
||||
})
|
||||
.collect()),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusManifest {
|
||||
#[serde(default = "current_manifest_version")]
|
||||
pub version: u32,
|
||||
pub metadata: CorpusMetadata,
|
||||
pub paragraphs: Vec<CorpusParagraph>,
|
||||
@@ -47,9 +150,11 @@ pub struct CorpusParagraph {
|
||||
pub paragraph_id: String,
|
||||
pub title: String,
|
||||
pub text_content: TextContent,
|
||||
pub entities: Vec<KnowledgeEntity>,
|
||||
#[serde(deserialize_with = "deserialize_embedded_entities")]
|
||||
pub entities: Vec<EmbeddedKnowledgeEntity>,
|
||||
pub relationships: Vec<KnowledgeRelationship>,
|
||||
pub chunks: Vec<TextChunk>,
|
||||
#[serde(deserialize_with = "deserialize_embedded_chunks")]
|
||||
pub chunks: Vec<EmbeddedTextChunk>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
@@ -74,8 +179,189 @@ pub struct CorpusHandle {
|
||||
pub negative_ingested: usize,
|
||||
}
|
||||
|
||||
pub fn window_manifest(
|
||||
manifest: &CorpusManifest,
|
||||
offset: usize,
|
||||
length: usize,
|
||||
negative_multiplier: f32,
|
||||
) -> Result<CorpusManifest> {
|
||||
let total = manifest.questions.len();
|
||||
if total == 0 {
|
||||
return Err(anyhow!(
|
||||
"manifest contains no questions; cannot select a window"
|
||||
));
|
||||
}
|
||||
if offset >= total {
|
||||
return Err(anyhow!(
|
||||
"window offset {} exceeds manifest questions ({})",
|
||||
offset,
|
||||
total
|
||||
));
|
||||
}
|
||||
let end = (offset + length).min(total);
|
||||
let questions = manifest.questions[offset..end].to_vec();
|
||||
|
||||
let selected_positive_ids: HashSet<_> =
|
||||
questions.iter().map(|q| q.paragraph_id.clone()).collect();
|
||||
let positives_all: HashSet<_> = manifest
|
||||
.questions
|
||||
.iter()
|
||||
.map(|q| q.paragraph_id.as_str())
|
||||
.collect();
|
||||
let available_negatives = manifest
|
||||
.paragraphs
|
||||
.len()
|
||||
.saturating_sub(positives_all.len());
|
||||
let desired_negatives =
|
||||
((selected_positive_ids.len() as f32) * negative_multiplier).ceil() as usize;
|
||||
let desired_negatives = desired_negatives.min(available_negatives);
|
||||
|
||||
let mut paragraphs = Vec::new();
|
||||
let mut negative_count = 0usize;
|
||||
for paragraph in &manifest.paragraphs {
|
||||
if selected_positive_ids.contains(¶graph.paragraph_id) {
|
||||
paragraphs.push(paragraph.clone());
|
||||
} else if negative_count < desired_negatives {
|
||||
paragraphs.push(paragraph.clone());
|
||||
negative_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let mut narrowed = manifest.clone();
|
||||
narrowed.questions = questions;
|
||||
narrowed.paragraphs = paragraphs;
|
||||
narrowed.metadata.paragraph_count = narrowed.paragraphs.len();
|
||||
narrowed.metadata.question_count = narrowed.questions.len();
|
||||
|
||||
Ok(narrowed)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct RelationInsert {
|
||||
#[serde(rename = "in")]
|
||||
pub in_: Thing,
|
||||
#[serde(rename = "out")]
|
||||
pub out: Thing,
|
||||
pub id: String,
|
||||
pub metadata: RelationshipMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SizedBatch<T> {
|
||||
approx_bytes: usize,
|
||||
items: Vec<T>,
|
||||
}
|
||||
|
||||
struct ManifestBatches {
|
||||
text_contents: Vec<SizedBatch<TextContent>>,
|
||||
entities: Vec<SizedBatch<KnowledgeEntity>>,
|
||||
entity_embeddings: Vec<SizedBatch<KnowledgeEntityEmbedding>>,
|
||||
relationships: Vec<SizedBatch<RelationInsert>>,
|
||||
chunks: Vec<SizedBatch<TextChunk>>,
|
||||
chunk_embeddings: Vec<SizedBatch<TextChunkEmbedding>>,
|
||||
}
|
||||
|
||||
fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches> {
|
||||
let mut text_contents = Vec::new();
|
||||
let mut entities = Vec::new();
|
||||
let mut entity_embeddings = Vec::new();
|
||||
let mut relationships = Vec::new();
|
||||
let mut chunks = Vec::new();
|
||||
let mut chunk_embeddings = Vec::new();
|
||||
|
||||
let mut seen_text_content = HashSet::new();
|
||||
let mut seen_entities = HashSet::new();
|
||||
let mut seen_relationships = HashSet::new();
|
||||
let mut seen_chunks = HashSet::new();
|
||||
|
||||
for paragraph in &manifest.paragraphs {
|
||||
if seen_text_content.insert(paragraph.text_content.id.clone()) {
|
||||
text_contents.push(paragraph.text_content.clone());
|
||||
}
|
||||
|
||||
for embedded_entity in ¶graph.entities {
|
||||
if seen_entities.insert(embedded_entity.entity.id.clone()) {
|
||||
let entity = embedded_entity.entity.clone();
|
||||
entities.push(entity.clone());
|
||||
entity_embeddings.push(KnowledgeEntityEmbedding::new(
|
||||
&entity.id,
|
||||
embedded_entity.embedding.clone(),
|
||||
entity.user_id.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
for relationship in ¶graph.relationships {
|
||||
if seen_relationships.insert(relationship.id.clone()) {
|
||||
let table = KnowledgeEntity::table_name();
|
||||
let in_id = relationship
|
||||
.in_
|
||||
.strip_prefix(&format!("{table}:"))
|
||||
.unwrap_or(&relationship.in_);
|
||||
let out_id = relationship
|
||||
.out
|
||||
.strip_prefix(&format!("{table}:"))
|
||||
.unwrap_or(&relationship.out);
|
||||
let in_thing = Thing::from((table, in_id));
|
||||
let out_thing = Thing::from((table, out_id));
|
||||
relationships.push(RelationInsert {
|
||||
in_: in_thing,
|
||||
out: out_thing,
|
||||
id: relationship.id.clone(),
|
||||
metadata: relationship.metadata.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for embedded_chunk in ¶graph.chunks {
|
||||
if seen_chunks.insert(embedded_chunk.chunk.id.clone()) {
|
||||
let chunk = embedded_chunk.chunk.clone();
|
||||
chunks.push(chunk.clone());
|
||||
chunk_embeddings.push(TextChunkEmbedding::new(
|
||||
&chunk.id,
|
||||
chunk.source_id.clone(),
|
||||
embedded_chunk.embedding.clone(),
|
||||
chunk.user_id.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ManifestBatches {
|
||||
text_contents: chunk_items(
|
||||
&text_contents,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
TEXT_CONTENT_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking text_content payloads")?,
|
||||
entities: chunk_items(&entities, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
||||
.context("chunking knowledge_entity payloads")?,
|
||||
entity_embeddings: chunk_items(
|
||||
&entity_embeddings,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking knowledge_entity_embedding payloads")?,
|
||||
relationships: chunk_items(
|
||||
&relationships,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking relationship payloads")?,
|
||||
chunks: chunk_items(&chunks, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
||||
.context("chunking text_chunk payloads")?,
|
||||
chunk_embeddings: chunk_items(
|
||||
&chunk_embeddings,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking text_chunk_embedding payloads")?,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ParagraphShard {
|
||||
#[serde(default = "current_paragraph_shard_version")]
|
||||
pub version: u32,
|
||||
pub paragraph_id: String,
|
||||
pub shard_path: String,
|
||||
@@ -83,9 +369,11 @@ pub struct ParagraphShard {
|
||||
pub ingested_at: DateTime<Utc>,
|
||||
pub title: String,
|
||||
pub text_content: TextContent,
|
||||
pub entities: Vec<KnowledgeEntity>,
|
||||
#[serde(deserialize_with = "deserialize_embedded_entities")]
|
||||
pub entities: Vec<EmbeddedKnowledgeEntity>,
|
||||
pub relationships: Vec<KnowledgeRelationship>,
|
||||
pub chunks: Vec<TextChunk>,
|
||||
#[serde(deserialize_with = "deserialize_embedded_chunks")]
|
||||
pub chunks: Vec<EmbeddedTextChunk>,
|
||||
#[serde(default)]
|
||||
pub question_bindings: HashMap<String, Vec<String>>,
|
||||
#[serde(default)]
|
||||
@@ -126,30 +414,34 @@ impl ParagraphShardStore {
|
||||
let reader = BufReader::new(file);
|
||||
let mut shard: ParagraphShard = serde_json::from_reader(reader)
|
||||
.with_context(|| format!("parsing shard {}", path.display()))?;
|
||||
|
||||
if shard.ingestion_fingerprint != fingerprint {
|
||||
return Ok(None);
|
||||
}
|
||||
if shard.version != PARAGRAPH_SHARD_VERSION {
|
||||
warn!(
|
||||
path = %path.display(),
|
||||
version = shard.version,
|
||||
expected = PARAGRAPH_SHARD_VERSION,
|
||||
"Skipping shard due to version mismatch"
|
||||
"Upgrading shard to current version"
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
if shard.ingestion_fingerprint != fingerprint {
|
||||
return Ok(None);
|
||||
shard.version = PARAGRAPH_SHARD_VERSION;
|
||||
}
|
||||
shard.shard_path = relative.to_string();
|
||||
Ok(Some(shard))
|
||||
}
|
||||
|
||||
pub fn persist(&self, shard: &ParagraphShard) -> Result<()> {
|
||||
let mut shard = shard.clone();
|
||||
shard.version = PARAGRAPH_SHARD_VERSION;
|
||||
|
||||
let path = self.resolve(&shard.shard_path);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating shard dir {}", parent.display()))?;
|
||||
}
|
||||
let tmp_path = path.with_extension("json.tmp");
|
||||
let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?;
|
||||
let body = serde_json::to_vec_pretty(&shard).context("serialising paragraph shard")?;
|
||||
fs::write(&tmp_path, &body)
|
||||
.with_context(|| format!("writing shard tmp {}", tmp_path.display()))?;
|
||||
fs::rename(&tmp_path, &path)
|
||||
@@ -164,9 +456,9 @@ impl ParagraphShard {
|
||||
shard_path: String,
|
||||
ingestion_fingerprint: &str,
|
||||
text_content: TextContent,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
entities: Vec<EmbeddedKnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
chunks: Vec<TextChunk>,
|
||||
chunks: Vec<EmbeddedTextChunk>,
|
||||
embedding_backend: &str,
|
||||
embedding_model: Option<String>,
|
||||
embedding_dimension: usize,
|
||||
@@ -216,7 +508,7 @@ impl ParagraphShard {
|
||||
|
||||
fn validate_answers(
|
||||
content: &TextContent,
|
||||
chunks: &[TextChunk],
|
||||
chunks: &[EmbeddedTextChunk],
|
||||
question: &ConvertedQuestion,
|
||||
) -> Result<Vec<String>> {
|
||||
if question.is_impossible || question.answers.is_empty() {
|
||||
@@ -236,12 +528,12 @@ fn validate_answers(
|
||||
found_any = true;
|
||||
}
|
||||
for chunk in chunks {
|
||||
let chunk_text = chunk.chunk.to_ascii_lowercase();
|
||||
let chunk_text = chunk.chunk.chunk.to_ascii_lowercase();
|
||||
let chunk_norm = normalize_answer_text(&chunk_text);
|
||||
if chunk_text.contains(&needle)
|
||||
|| (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm))
|
||||
{
|
||||
matches.insert(chunk.id.clone());
|
||||
matches.insert(chunk.chunk.get_id().to_string());
|
||||
found_any = true;
|
||||
}
|
||||
}
|
||||
@@ -272,28 +564,492 @@ fn normalize_answer_text(text: &str) -> String {
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
||||
for paragraph in &manifest.paragraphs {
|
||||
db.upsert_item(paragraph.text_content.clone())
|
||||
fn chunk_items<T: Clone + Serialize>(
|
||||
items: &[T],
|
||||
max_items: usize,
|
||||
max_bytes: usize,
|
||||
) -> Result<Vec<SizedBatch<T>>> {
|
||||
if items.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut batches = Vec::new();
|
||||
let mut current = Vec::new();
|
||||
let mut current_bytes = 0usize;
|
||||
|
||||
for item in items {
|
||||
let size = serde_json::to_vec(item)
|
||||
.map(|buf| buf.len())
|
||||
.context("serialising batch item for sizing")?;
|
||||
|
||||
let would_overflow_items = !current.is_empty() && current.len() >= max_items;
|
||||
let would_overflow_bytes = !current.is_empty() && current_bytes + size > max_bytes;
|
||||
|
||||
if would_overflow_items || would_overflow_bytes {
|
||||
batches.push(SizedBatch {
|
||||
approx_bytes: current_bytes.max(1),
|
||||
items: std::mem::take(&mut current),
|
||||
});
|
||||
current_bytes = 0;
|
||||
}
|
||||
|
||||
current_bytes += size;
|
||||
current.push(item.clone());
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
batches.push(SizedBatch {
|
||||
approx_bytes: current_bytes.max(1),
|
||||
items: current,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(batches)
|
||||
}
|
||||
|
||||
async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
|
||||
db: &SurrealDbClient,
|
||||
statement: impl AsRef<str>,
|
||||
prefix: &str,
|
||||
batches: &[SizedBatch<T>],
|
||||
) -> Result<()> {
|
||||
if batches.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut start = 0;
|
||||
while start < batches.len() {
|
||||
let mut group_bytes = 0usize;
|
||||
let mut group_end = start;
|
||||
let mut group_count = 0usize;
|
||||
|
||||
while group_end < batches.len() {
|
||||
let batch_bytes = batches[group_end].approx_bytes.max(1);
|
||||
if group_count > 0
|
||||
&& (group_bytes + batch_bytes > REQUEST_MAX_BYTES
|
||||
|| group_count >= MAX_BATCHES_PER_REQUEST)
|
||||
{
|
||||
break;
|
||||
}
|
||||
group_bytes += batch_bytes;
|
||||
group_end += 1;
|
||||
group_count += 1;
|
||||
}
|
||||
|
||||
let slice = &batches[start..group_end];
|
||||
let mut query = db.client.query("BEGIN TRANSACTION;");
|
||||
let mut bind_index = 0usize;
|
||||
for batch in slice {
|
||||
let name = format!("{prefix}{bind_index}");
|
||||
bind_index += 1;
|
||||
query = query
|
||||
.query(format!("{} ${};", statement.as_ref(), name))
|
||||
.bind((name, batch.items.clone()));
|
||||
}
|
||||
let response = query
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.await
|
||||
.context("storing text_content from manifest")?;
|
||||
for entity in ¶graph.entities {
|
||||
db.upsert_item(entity.clone())
|
||||
.await
|
||||
.context("storing knowledge_entity from manifest")?;
|
||||
}
|
||||
for relationship in ¶graph.relationships {
|
||||
relationship
|
||||
.store_relationship(db)
|
||||
.await
|
||||
.context("storing knowledge_relationship from manifest")?;
|
||||
}
|
||||
for chunk in ¶graph.chunks {
|
||||
db.upsert_item(chunk.clone())
|
||||
.await
|
||||
.context("storing text_chunk from manifest")?;
|
||||
.context("executing batched insert transaction")?;
|
||||
if let Err(err) = response.check() {
|
||||
return Err(anyhow!(
|
||||
"batched insert failed for statement '{}': {err:?}",
|
||||
statement.as_ref()
|
||||
));
|
||||
}
|
||||
|
||||
start = group_end;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
||||
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
|
||||
|
||||
let result = (|| async {
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextContent::table_name()),
|
||||
"tc",
|
||||
&batches.text_contents,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", KnowledgeEntity::table_name()),
|
||||
"ke",
|
||||
&batches.entities,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextChunk::table_name()),
|
||||
"ch",
|
||||
&batches.chunks,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
"INSERT RELATION INTO relates_to",
|
||||
"rel",
|
||||
&batches.relationships,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", KnowledgeEntityEmbedding::table_name()),
|
||||
"kee",
|
||||
&batches.entity_embeddings,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextChunkEmbedding::table_name()),
|
||||
"tce",
|
||||
&batches.chunk_embeddings,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})()
|
||||
.await;
|
||||
|
||||
if result.is_err() {
|
||||
// Best-effort cleanup to avoid leaving partial manifest data behind.
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
"BEGIN TRANSACTION;
|
||||
DELETE text_chunk_embedding;
|
||||
DELETE knowledge_entity_embedding;
|
||||
DELETE relates_to;
|
||||
DELETE text_chunk;
|
||||
DELETE knowledge_entity;
|
||||
DELETE text_content;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::db_helpers::change_embedding_length_in_hnsw_indexes;
|
||||
use chrono::Utc;
|
||||
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn build_manifest() -> CorpusManifest {
|
||||
let user_id = "user-1".to_string();
|
||||
let source_id = "source-1".to_string();
|
||||
let now = Utc::now();
|
||||
let text_content_id = Uuid::new_v4().to_string();
|
||||
|
||||
let text_content = TextContent {
|
||||
id: text_content_id.clone(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
text: "Hello world".to_string(),
|
||||
file_info: None,
|
||||
url_info: None,
|
||||
context: None,
|
||||
category: "test".to_string(),
|
||||
user_id: user_id.clone(),
|
||||
};
|
||||
|
||||
let entity = KnowledgeEntity {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
source_id: source_id.clone(),
|
||||
name: "Entity".to_string(),
|
||||
description: "A test entity".to_string(),
|
||||
entity_type: KnowledgeEntityType::Document,
|
||||
metadata: None,
|
||||
user_id: user_id.clone(),
|
||||
};
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
format!("knowledge_entity:{}", entity.id),
|
||||
format!("knowledge_entity:{}", entity.id),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"related".to_string(),
|
||||
);
|
||||
|
||||
let chunk = TextChunk {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
source_id: source_id.clone(),
|
||||
chunk: "chunk text".to_string(),
|
||||
user_id: user_id.clone(),
|
||||
};
|
||||
|
||||
let paragraph_one = CorpusParagraph {
|
||||
paragraph_id: "p1".to_string(),
|
||||
title: "Paragraph 1".to_string(),
|
||||
text_content: text_content.clone(),
|
||||
entities: vec![EmbeddedKnowledgeEntity {
|
||||
entity: entity.clone(),
|
||||
embedding: vec![0.1, 0.2, 0.3],
|
||||
}],
|
||||
relationships: vec![relationship],
|
||||
chunks: vec![EmbeddedTextChunk {
|
||||
chunk: chunk.clone(),
|
||||
embedding: vec![0.3, 0.2, 0.1],
|
||||
}],
|
||||
};
|
||||
|
||||
// Duplicate content/entities should be de-duplicated by the loader.
|
||||
let paragraph_two = CorpusParagraph {
|
||||
paragraph_id: "p2".to_string(),
|
||||
title: "Paragraph 2".to_string(),
|
||||
text_content: text_content.clone(),
|
||||
entities: vec![EmbeddedKnowledgeEntity {
|
||||
entity: entity.clone(),
|
||||
embedding: vec![0.1, 0.2, 0.3],
|
||||
}],
|
||||
relationships: Vec::new(),
|
||||
chunks: vec![EmbeddedTextChunk {
|
||||
chunk: chunk.clone(),
|
||||
embedding: vec![0.3, 0.2, 0.1],
|
||||
}],
|
||||
};
|
||||
|
||||
let question = CorpusQuestion {
|
||||
question_id: "q1".to_string(),
|
||||
paragraph_id: paragraph_one.paragraph_id.clone(),
|
||||
text_content_id: text_content_id,
|
||||
question_text: "What is this?".to_string(),
|
||||
answers: vec!["Hello".to_string()],
|
||||
is_impossible: false,
|
||||
matching_chunk_ids: vec![chunk.id.clone()],
|
||||
};
|
||||
|
||||
CorpusManifest {
|
||||
version: current_manifest_version(),
|
||||
metadata: CorpusMetadata {
|
||||
dataset_id: "dataset".to_string(),
|
||||
dataset_label: "Dataset".to_string(),
|
||||
slice_id: "slice".to_string(),
|
||||
include_unanswerable: false,
|
||||
require_verified_chunks: false,
|
||||
ingestion_fingerprint: "fp".to_string(),
|
||||
embedding_backend: "test".to_string(),
|
||||
embedding_model: Some("model".to_string()),
|
||||
embedding_dimension: 3,
|
||||
converted_checksum: "checksum".to_string(),
|
||||
generated_at: now,
|
||||
paragraph_count: 2,
|
||||
question_count: 1,
|
||||
},
|
||||
paragraphs: vec![paragraph_one, paragraph_two],
|
||||
questions: vec![question],
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn seeds_manifest_with_transactional_batches() {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("memory db");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("apply migrations for memory db");
|
||||
change_embedding_length_in_hnsw_indexes(&db, 3)
|
||||
.await
|
||||
.expect("set embedding index dimension for test");
|
||||
|
||||
let manifest = build_manifest();
|
||||
seed_manifest_into_db(&db, &manifest)
|
||||
.await
|
||||
.expect("manifest seed should succeed");
|
||||
|
||||
let text_contents: Vec<TextContent> = db
|
||||
.client
|
||||
.query(format!("SELECT * FROM {};", TextContent::table_name()))
|
||||
.await
|
||||
.expect("select text_content")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
assert_eq!(text_contents.len(), 1);
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query(format!("SELECT * FROM {};", KnowledgeEntity::table_name()))
|
||||
.await
|
||||
.expect("select knowledge_entity")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
assert_eq!(entities.len(), 1);
|
||||
|
||||
let chunks: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(format!("SELECT * FROM {};", TextChunk::table_name()))
|
||||
.await
|
||||
.expect("select text_chunk")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
assert_eq!(chunks.len(), 1);
|
||||
|
||||
let relationships: Vec<KnowledgeRelationship> = db
|
||||
.client
|
||||
.query("SELECT * FROM relates_to;")
|
||||
.await
|
||||
.expect("select relates_to")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
assert_eq!(relationships.len(), 1);
|
||||
|
||||
let entity_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {};",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("select knowledge_entity_embedding")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
assert_eq!(entity_embeddings.len(), 1);
|
||||
|
||||
let chunk_embeddings: Vec<TextChunkEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {};",
|
||||
TextChunkEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("select text_chunk_embedding")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
assert_eq!(chunk_embeddings.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rolls_back_when_embeddings_mismatch_index_dimension() {
|
||||
let namespace = "test_ns_rollback";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("memory db");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("apply migrations for memory db");
|
||||
|
||||
let manifest = build_manifest();
|
||||
let result = seed_manifest_into_db(&db, &manifest).await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"expected embedding dimension mismatch to fail"
|
||||
);
|
||||
|
||||
let text_contents: Vec<TextContent> = db
|
||||
.client
|
||||
.query(format!("SELECT * FROM {};", TextContent::table_name()))
|
||||
.await
|
||||
.expect("select text_content")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
let entities: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query(format!("SELECT * FROM {};", KnowledgeEntity::table_name()))
|
||||
.await
|
||||
.expect("select knowledge_entity")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
let chunks: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(format!("SELECT * FROM {};", TextChunk::table_name()))
|
||||
.await
|
||||
.expect("select text_chunk")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
let relationships: Vec<KnowledgeRelationship> = db
|
||||
.client
|
||||
.query("SELECT * FROM relates_to;")
|
||||
.await
|
||||
.expect("select relates_to")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
let entity_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {};",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("select knowledge_entity_embedding")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
let chunk_embeddings: Vec<TextChunkEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {};",
|
||||
TextChunkEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("select text_chunk_embedding")
|
||||
.take(0)
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
text_contents.is_empty()
|
||||
&& entities.is_empty()
|
||||
&& chunks.is_empty()
|
||||
&& relationships.is_empty()
|
||||
&& entity_embeddings.is_empty()
|
||||
&& chunk_embeddings.is_empty(),
|
||||
"no rows should be inserted when transaction fails"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn window_manifest_trims_questions_and_negatives() {
|
||||
let manifest = build_manifest();
|
||||
// Add extra negatives to simulate multiplier ~4x
|
||||
let mut manifest = manifest;
|
||||
let mut extra_paragraphs = Vec::new();
|
||||
for _ in 0..8 {
|
||||
let mut p = manifest.paragraphs[0].clone();
|
||||
p.paragraph_id = Uuid::new_v4().to_string();
|
||||
p.entities.clear();
|
||||
p.relationships.clear();
|
||||
p.chunks.clear();
|
||||
extra_paragraphs.push(p);
|
||||
}
|
||||
manifest.paragraphs.extend(extra_paragraphs);
|
||||
manifest.metadata.paragraph_count = manifest.paragraphs.len();
|
||||
|
||||
let windowed = window_manifest(&manifest, 0, 1, 4.0).expect("window manifest");
|
||||
assert_eq!(windowed.questions.len(), 1);
|
||||
// Expect roughly 4x negatives (bounded by available paragraphs)
|
||||
assert!(
|
||||
windowed.paragraphs.len() <= manifest.paragraphs.len(),
|
||||
"windowed paragraphs should never exceed original"
|
||||
);
|
||||
let positive_set: std::collections::HashSet<_> = windowed
|
||||
.questions
|
||||
.iter()
|
||||
.map(|q| q.paragraph_id.as_str())
|
||||
.collect();
|
||||
let positives = windowed
|
||||
.paragraphs
|
||||
.iter()
|
||||
.filter(|p| positive_set.contains(p.paragraph_id.as_str()))
|
||||
.count();
|
||||
let negatives = windowed.paragraphs.len().saturating_sub(positives);
|
||||
assert_eq!(positives, 1);
|
||||
assert!(negatives >= 1, "should include some negatives");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,13 +121,14 @@ fn build_chunk_lookup(manifest: &ingest::CorpusManifest) -> HashMap<String, Chun
|
||||
for paragraph in &manifest.paragraphs {
|
||||
for chunk in ¶graph.chunks {
|
||||
let snippet = chunk
|
||||
.chunk
|
||||
.chunk
|
||||
.chars()
|
||||
.take(160)
|
||||
.collect::<String>()
|
||||
.replace('\n', " ");
|
||||
lookup.insert(
|
||||
chunk.id.clone(),
|
||||
chunk.chunk.id.clone(),
|
||||
ChunkEntry {
|
||||
paragraph_title: paragraph.title.clone(),
|
||||
snippet,
|
||||
|
||||
@@ -6,7 +6,8 @@ use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::fs;
|
||||
|
||||
use crate::{args::Config, embedding::EmbeddingProvider, slice};
|
||||
use crate::{args::Config, slice};
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SnapshotMetadata {
|
||||
|
||||
Reference in New Issue
Block a user