From 1039ec32a414c99aad9567c0ae0d5dfb8ac046d1 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Sat, 29 Nov 2025 18:59:08 +0100 Subject: [PATCH] fix: all tests now in sync --- common/src/storage/db.rs | 1 - common/src/storage/indexes.rs | 6 +- common/src/storage/types/system_settings.rs | 6 + eval/src/args.rs | 20 +++ eval/src/eval/mod.rs | 8 +- .../eval/pipeline/stages/prepare_corpus.rs | 3 + eval/src/ingest/mod.rs | 11 ++ eval/src/ingest/orchestrator.rs | 139 +++++++++++++++++- eval/src/ingest/store.rs | 56 +++++-- eval/src/perf.rs | 2 + eval/src/report.rs | 4 + ingestion-pipeline/src/pipeline/config.rs | 20 ++- ingestion-pipeline/src/pipeline/context.rs | 25 +++- ingestion-pipeline/src/pipeline/mod.rs | 23 ++- ingestion-pipeline/src/pipeline/services.rs | 59 ++++++-- ingestion-pipeline/src/pipeline/stages/mod.rs | 17 +++ ingestion-pipeline/src/pipeline/tests.rs | 74 +++++++++- retrieval-pipeline/src/fts.rs | 10 ++ retrieval-pipeline/src/lib.rs | 5 + 19 files changed, 439 insertions(+), 50 deletions(-) diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index fb0d7dd..44b67e2 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -1,7 +1,6 @@ use super::types::StoredObject; use crate::{ error::AppError, - storage::{indexes::ensure_runtime_indexes, types::system_settings::SystemSettings}, }; use axum_session::{SessionConfig, SessionError, SessionStore}; use axum_session_surreal::SessionSurrealPool; diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs index ce71a6c..c4a60b0 100644 --- a/common/src/storage/indexes.rs +++ b/common/src/storage/indexes.rs @@ -120,7 +120,7 @@ async fn ensure_hnsw_index( ) .await } - HnswIndexState::Matches(_) => Ok(()), + HnswIndexState::Matches => Ok(()), HnswIndexState::Different(existing) => { info!( index = spec.index_name, @@ -182,7 +182,7 @@ async fn hnsw_index_state( }; if current_dimension == expected_dimension as u64 { - Ok(HnswIndexState::Matches(current_dimension)) + Ok(HnswIndexState::Matches) } else { Ok(HnswIndexState::Different(current_dimension)) } @@ -190,7 +190,7 @@ async fn hnsw_index_state( enum HnswIndexState { Missing, - Matches(u64), + Matches, Different(u64), } diff --git a/common/src/storage/types/system_settings.rs b/common/src/storage/types/system_settings.rs index 392cfdb..b48a588 100644 --- a/common/src/storage/types/system_settings.rs +++ b/common/src/storage/types/system_settings.rs @@ -55,6 +55,7 @@ impl SystemSettings { mod tests { use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}; use async_openai::Client; + use crate::storage::indexes::ensure_runtime_indexes; use super::*; use uuid::Uuid; @@ -325,6 +326,11 @@ mod tests { .await .expect("Failed to load current settings"); + // Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed. + ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize) + .await + .expect("failed to build runtime indexes"); + let initial_chunk_dimension = get_hnsw_index_dimension( &db, "text_chunk_embedding", diff --git a/eval/src/args.rs b/eval/src/args.rs index 6debf08..fde67ba 100644 --- a/eval/src/args.rs +++ b/eval/src/args.rs @@ -194,6 +194,18 @@ pub struct Config { #[arg(long, default_value_os_t = default_ingestion_cache_dir())] pub ingestion_cache_dir: PathBuf, + /// Minimum tokens per chunk for ingestion + #[arg(long, default_value_t = 500)] + pub ingest_chunk_min_tokens: usize, + + /// Maximum tokens per chunk for ingestion + #[arg(long, default_value_t = 2_000)] + pub ingest_chunk_max_tokens: usize, + + /// Run ingestion in chunk-only mode (skip analyzer/graph generation) + #[arg(long)] + pub ingest_chunks_only: bool, + /// Number of paragraphs to ingest concurrently #[arg(long, default_value_t = 5)] pub ingestion_batch_size: usize, @@ -350,6 +362,14 @@ impl Config { )); } + if self.ingest_chunk_min_tokens == 0 || self.ingest_chunk_min_tokens >= self.ingest_chunk_max_tokens { + return Err(anyhow!( + "--ingest-chunk-min-tokens must be greater than zero and less than --ingest-chunk-max-tokens (got {} >= {})", + self.ingest_chunk_min_tokens, + self.ingest_chunk_max_tokens + )); + } + if self.retrieval.rerank && self.retrieval.rerank_pool_size == 0 { return Err(anyhow!( "--rerank-pool must be greater than zero when reranking is enabled" diff --git a/eval/src/eval/mod.rs b/eval/src/eval/mod.rs index 64133e4..1b8070f 100644 --- a/eval/src/eval/mod.rs +++ b/eval/src/eval/mod.rs @@ -4,7 +4,7 @@ mod types; pub use pipeline::run_evaluation; pub use types::*; -use std::{collections::HashMap, path::Path, time::Duration}; +use std::{collections::HashMap, path::Path}; use anyhow::{anyhow, Context, Result}; use chrono::{DateTime, SecondsFormat, Utc}; @@ -23,7 +23,6 @@ use tracing::{info, warn}; use crate::{ args::{self, Config}, datasets::{self, ConvertedDataset}, - db_helpers::change_embedding_length_in_hnsw_indexes, ingest, slice::{self}, snapshot::{self, DbSnapshotState}, @@ -461,7 +460,7 @@ pub(crate) async fn enforce_system_settings( pub(crate) async fn load_or_init_system_settings( db: &SurrealDbClient, - dimension: usize, + _dimension: usize, ) -> Result<(SystemSettings, bool)> { match SystemSettings::get_current(db).await { Ok(settings) => Ok((settings, false)), @@ -565,6 +564,9 @@ mod tests { generated_at: Utc::now(), paragraph_count: paragraphs.len(), question_count: questions.len(), + chunk_min_tokens: 1, + chunk_max_tokens: 10, + chunk_only: false, }, paragraphs, questions, diff --git a/eval/src/eval/pipeline/stages/prepare_corpus.rs b/eval/src/eval/pipeline/stages/prepare_corpus.rs index 9df437c..96b7c4f 100644 --- a/eval/src/eval/pipeline/stages/prepare_corpus.rs +++ b/eval/src/eval/pipeline/stages/prepare_corpus.rs @@ -31,10 +31,12 @@ pub(crate) async fn prepare_corpus( .context("selecting slice window for corpus preparation")?; let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider()); + let ingestion_config = ingest::make_ingestion_config(config); let expected_fingerprint = ingest::compute_ingestion_fingerprint( ctx.dataset(), slice, config.converted_dataset_path.as_path(), + &ingestion_config, )?; let base_dir = ingest::cached_corpus_dir( &cache_settings, @@ -101,6 +103,7 @@ pub(crate) async fn prepare_corpus( openai_client, &eval_user_id, config.converted_dataset_path.as_path(), + ingestion_config.clone(), ) .await .context("ensuring ingestion-backed corpus")? diff --git a/eval/src/ingest/mod.rs b/eval/src/ingest/mod.rs index 3e8a342..020bb53 100644 --- a/eval/src/ingest/mod.rs +++ b/eval/src/ingest/mod.rs @@ -12,3 +12,14 @@ pub use store::{ CorpusQuestion, EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION, }; + +pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig { + let mut tuning = ingestion_pipeline::IngestionTuning::default(); + tuning.chunk_min_tokens = config.ingest_chunk_min_tokens; + tuning.chunk_max_tokens = config.ingest_chunk_max_tokens; + + ingestion_pipeline::IngestionConfig { + tuning, + chunk_only: config.ingest_chunks_only, + } +} diff --git a/eval/src/ingest/orchestrator.rs b/eval/src/ingest/orchestrator.rs index 7a9cc94..af3671a 100644 --- a/eval/src/ingest/orchestrator.rs +++ b/eval/src/ingest/orchestrator.rs @@ -36,7 +36,7 @@ use crate::ingest::{ MANIFEST_VERSION, }; -const INGESTION_SPEC_VERSION: u32 = 1; +const INGESTION_SPEC_VERSION: u32 = 2; type OpenAIClient = Client; @@ -116,10 +116,12 @@ pub async fn ensure_corpus( openai: Arc, user_id: &str, converted_path: &Path, + ingestion_config: IngestionConfig, ) -> Result { let checksum = compute_file_checksum(converted_path) .with_context(|| format!("computing checksum for {}", converted_path.display()))?; - let ingestion_fingerprint = build_ingestion_fingerprint(dataset, slice, &checksum); + let ingestion_fingerprint = + build_ingestion_fingerprint(dataset, slice, &checksum, &ingestion_config); let base_dir = cached_corpus_dir( cache, @@ -241,6 +243,7 @@ pub async fn ensure_corpus( embedding_dimension, cache.ingestion_batch_size, cache.ingestion_max_retries, + ingestion_config.clone(), ) .await .context("ingesting missing slice paragraphs")?; @@ -359,6 +362,9 @@ pub async fn ensure_corpus( generated_at: Utc::now(), paragraph_count: corpus_paragraphs.len(), question_count: corpus_questions.len(), + chunk_min_tokens: ingestion_config.tuning.chunk_min_tokens, + chunk_max_tokens: ingestion_config.tuning.chunk_max_tokens, + chunk_only: ingestion_config.chunk_only, }, paragraphs: corpus_paragraphs, questions: corpus_questions, @@ -396,6 +402,7 @@ async fn ingest_paragraph_batch( embedding_dimension: usize, batch_size: usize, max_retries: usize, + ingestion_config: IngestionConfig, ) -> Result> { if targets.is_empty() { return Ok(Vec::new()); @@ -419,13 +426,15 @@ async fn ingest_paragraph_batch( let backend: DynStore = Arc::new(InMemory::new()); let storage = StorageManager::with_backend(backend, StorageKind::Memory); - let pipeline = IngestionPipeline::new( + let pipeline_config = ingestion_config.clone(); + let pipeline = IngestionPipeline::new_with_config( db, openai.clone(), app_config, None::>, storage, embedding.clone(), + pipeline_config, ) .await?; let pipeline = Arc::new(pipeline); @@ -454,6 +463,9 @@ async fn ingest_paragraph_batch( model_clone.clone(), embedding_dimension, max_retries, + ingestion_config.tuning.chunk_min_tokens, + ingestion_config.tuning.chunk_max_tokens, + ingestion_config.chunk_only, ) }); let batch_results: Vec = try_join_all(tasks) @@ -475,6 +487,9 @@ async fn ingest_single_paragraph( embedding_model: Option, embedding_dimension: usize, max_retries: usize, + chunk_min_tokens: usize, + chunk_max_tokens: usize, + chunk_only: bool, ) -> Result { let paragraph = request.paragraph; let mut last_err: Option = None; @@ -516,6 +531,9 @@ async fn ingest_single_paragraph( &embedding_backend, embedding_model.clone(), embedding_dimension, + chunk_min_tokens, + chunk_max_tokens, + chunk_only, ); for question in &request.question_refs { if let Err(err) = shard.ensure_question_binding(question) { @@ -558,8 +576,9 @@ pub fn build_ingestion_fingerprint( dataset: &ConvertedDataset, slice: &ResolvedSlice<'_>, checksum: &str, + ingestion_config: &IngestionConfig, ) -> String { - let config_repr = format!("{:?}", IngestionConfig::default()); + let config_repr = format!("{:?}", ingestion_config); let mut hasher = Sha256::new(); hasher.update(config_repr.as_bytes()); let config_hash = format!("{:x}", hasher.finalize()); @@ -578,9 +597,15 @@ pub fn compute_ingestion_fingerprint( dataset: &ConvertedDataset, slice: &ResolvedSlice<'_>, converted_path: &Path, + ingestion_config: &IngestionConfig, ) -> Result { let checksum = compute_file_checksum(converted_path)?; - Ok(build_ingestion_fingerprint(dataset, slice, &checksum)) + Ok(build_ingestion_fingerprint( + dataset, + slice, + &checksum, + ingestion_config, + )) } pub fn load_cached_manifest(base_dir: &Path) -> Result> { @@ -643,3 +668,107 @@ fn compute_file_checksum(path: &Path) -> Result { } Ok(format!("{:x}", hasher.finalize())) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind}, + slices::{CaseRef, SliceCaseEntry, SliceManifest, SliceParagraphEntry, SliceParagraphKind}, + }; + use chrono::Utc; + + fn dummy_dataset() -> ConvertedDataset { + let question = ConvertedQuestion { + id: "q1".to_string(), + question: "What?".to_string(), + answers: vec!["A".to_string()], + is_impossible: false, + }; + let paragraph = ConvertedParagraph { + id: "p1".to_string(), + title: "title".to_string(), + context: "context".to_string(), + questions: vec![question], + }; + + ConvertedDataset { + generated_at: Utc::now(), + metadata: crate::datasets::DatasetMetadata::for_kind( + DatasetKind::default(), + false, + None, + ), + source: "src".to_string(), + paragraphs: vec![paragraph], + } + } + + fn dummy_slice<'a>(dataset: &'a ConvertedDataset) -> ResolvedSlice<'a> { + let paragraph = &dataset.paragraphs[0]; + let question = ¶graph.questions[0]; + let manifest = SliceManifest { + version: 1, + slice_id: "slice-1".to_string(), + dataset_id: dataset.metadata.id.clone(), + dataset_label: dataset.metadata.label.clone(), + dataset_source: dataset.source.clone(), + includes_unanswerable: false, + require_verified_chunks: false, + seed: 1, + requested_limit: Some(1), + requested_corpus: 1, + generated_at: Utc::now(), + case_count: 1, + positive_paragraphs: 1, + negative_paragraphs: 0, + total_paragraphs: 1, + negative_multiplier: 1.0, + cases: vec![SliceCaseEntry { + question_id: question.id.clone(), + paragraph_id: paragraph.id.clone(), + }], + paragraphs: vec![SliceParagraphEntry { + id: paragraph.id.clone(), + kind: SliceParagraphKind::Positive { + question_ids: vec![question.id.clone()], + }, + shard_path: None, + }], + }; + + ResolvedSlice { + manifest, + path: PathBuf::from("cache"), + paragraphs: dataset.paragraphs.iter().collect(), + cases: vec![CaseRef { + paragraph, + question, + }], + } + } + + #[test] + fn fingerprint_changes_with_chunk_settings() { + let dataset = dummy_dataset(); + let slice = dummy_slice(&dataset); + let checksum = "deadbeef"; + + let base_config = IngestionConfig::default(); + let fp_base = build_ingestion_fingerprint(&dataset, &slice, checksum, &base_config); + + let mut token_config = base_config.clone(); + token_config.tuning.chunk_min_tokens += 1; + let fp_token = build_ingestion_fingerprint(&dataset, &slice, checksum, &token_config); + assert_ne!(fp_base, fp_token, "token bounds should affect fingerprint"); + + let mut chunk_only_config = base_config; + chunk_only_config.chunk_only = true; + let fp_chunk_only = + build_ingestion_fingerprint(&dataset, &slice, checksum, &chunk_only_config); + assert_ne!( + fp_base, fp_chunk_only, + "chunk-only mode should affect fingerprint" + ); + } +} diff --git a/eval/src/ingest/store.rs b/eval/src/ingest/store.rs index 13d6f74..5256fec 100644 --- a/eval/src/ingest/store.rs +++ b/eval/src/ingest/store.rs @@ -26,8 +26,8 @@ use tracing::warn; use crate::datasets::{ConvertedParagraph, ConvertedQuestion}; -pub const MANIFEST_VERSION: u32 = 2; -pub const PARAGRAPH_SHARD_VERSION: u32 = 2; +pub const MANIFEST_VERSION: u32 = 3; +pub const PARAGRAPH_SHARD_VERSION: u32 = 3; const MANIFEST_BATCH_SIZE: usize = 100; const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively @@ -42,6 +42,18 @@ fn current_paragraph_shard_version() -> u32 { PARAGRAPH_SHARD_VERSION } +fn default_chunk_min_tokens() -> usize { + 500 +} + +fn default_chunk_max_tokens() -> usize { + 2_000 +} + +fn default_chunk_only() -> bool { + false +} + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct EmbeddedKnowledgeEntity { pub entity: KnowledgeEntity, @@ -143,6 +155,12 @@ pub struct CorpusMetadata { pub generated_at: DateTime, pub paragraph_count: usize, pub question_count: usize, + #[serde(default = "default_chunk_min_tokens")] + pub chunk_min_tokens: usize, + #[serde(default = "default_chunk_max_tokens")] + pub chunk_max_tokens: usize, + #[serde(default = "default_chunk_only")] + pub chunk_only: bool, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -382,6 +400,12 @@ pub struct ParagraphShard { pub embedding_model: Option, #[serde(default)] pub embedding_dimension: usize, + #[serde(default = "default_chunk_min_tokens")] + pub chunk_min_tokens: usize, + #[serde(default = "default_chunk_max_tokens")] + pub chunk_max_tokens: usize, + #[serde(default = "default_chunk_only")] + pub chunk_only: bool, } pub struct ParagraphShardStore { @@ -462,6 +486,9 @@ impl ParagraphShard { embedding_backend: &str, embedding_model: Option, embedding_dimension: usize, + chunk_min_tokens: usize, + chunk_max_tokens: usize, + chunk_only: bool, ) -> Self { Self { version: PARAGRAPH_SHARD_VERSION, @@ -478,6 +505,9 @@ impl ParagraphShard { embedding_backend: embedding_backend.to_string(), embedding_model, embedding_dimension, + chunk_min_tokens, + chunk_max_tokens, + chunk_only, } } @@ -850,6 +880,9 @@ mod tests { generated_at: now, paragraph_count: 2, question_count: 1, + chunk_min_tokens: 1, + chunk_max_tokens: 10, + chunk_only: false, }, paragraphs: vec![paragraph_one, paragraph_two], questions: vec![question], @@ -950,8 +983,8 @@ mod tests { let manifest = build_manifest(); let result = seed_manifest_into_db(&db, &manifest).await; assert!( - result.is_err(), - "expected embedding dimension mismatch to fail" + result.is_ok(), + "seeding should succeed even if embedding dimensions differ from default index" ); let text_contents: Vec = db @@ -1003,15 +1036,12 @@ mod tests { .take(0) .unwrap_or_default(); - assert!( - text_contents.is_empty() - && entities.is_empty() - && chunks.is_empty() - && relationships.is_empty() - && entity_embeddings.is_empty() - && chunk_embeddings.is_empty(), - "no rows should be inserted when transaction fails" - ); + assert_eq!(text_contents.len(), 1); + assert_eq!(entities.len(), 1); + assert_eq!(chunks.len(), 1); + assert_eq!(relationships.len(), 1); + assert_eq!(entity_embeddings.len(), 1); + assert_eq!(chunk_embeddings.len(), 1); } #[test] diff --git a/eval/src/perf.rs b/eval/src/perf.rs index bea16ec..3cb8201 100644 --- a/eval/src/perf.rs +++ b/eval/src/perf.rs @@ -320,6 +320,8 @@ mod tests { chunk_token_budget: 10000, chunk_avg_chars_per_token: 4, max_chunks_per_entity: 4, + average_ndcg: 0.0, + mrr: 0.0, cases: Vec::new(), } } diff --git a/eval/src/report.rs b/eval/src/report.rs index 92792cd..374811a 100644 --- a/eval/src/report.rs +++ b/eval/src/report.rs @@ -870,6 +870,8 @@ mod tests { entity_match: matched, chunk_text_match: matched, chunk_id_match: matched, + ndcg: None, + reciprocal_rank: None, is_impossible, has_verified_chunks: !is_impossible, match_rank: if matched { Some(1) } else { None }, @@ -919,6 +921,8 @@ mod tests { retrieval_cases: 1, retrieval_correct: 1, retrieval_precision: 1.0, + average_ndcg: 0.0, + mrr: 0.0, llm_cases: if include_llm { 1 } else { 0 }, llm_answered: 0, llm_precision: 0.0, diff --git a/ingestion-pipeline/src/pipeline/config.rs b/ingestion-pipeline/src/pipeline/config.rs index fa59454..222a078 100644 --- a/ingestion-pipeline/src/pipeline/config.rs +++ b/ingestion-pipeline/src/pipeline/config.rs @@ -6,8 +6,8 @@ pub struct IngestionTuning { pub graph_store_attempts: usize, pub graph_initial_backoff_ms: u64, pub graph_max_backoff_ms: u64, - pub chunk_min_chars: usize, - pub chunk_max_chars: usize, + pub chunk_min_tokens: usize, + pub chunk_max_tokens: usize, pub chunk_insert_concurrency: usize, pub entity_embedding_concurrency: usize, } @@ -21,15 +21,25 @@ impl Default for IngestionTuning { graph_store_attempts: 3, graph_initial_backoff_ms: 50, graph_max_backoff_ms: 800, - chunk_min_chars: 500, - chunk_max_chars: 2_000, + chunk_min_tokens: 500, + chunk_max_tokens: 2_000, chunk_insert_concurrency: 8, entity_embedding_concurrency: 4, } } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct IngestionConfig { pub tuning: IngestionTuning, + pub chunk_only: bool, +} + +impl Default for IngestionConfig { + fn default() -> Self { + Self { + tuning: IngestionTuning::default(), + chunk_only: false, + } + } } diff --git a/ingestion-pipeline/src/pipeline/context.rs b/ingestion-pipeline/src/pipeline/context.rs index 825fea3..aa3a524 100644 --- a/ingestion-pipeline/src/pipeline/context.rs +++ b/ingestion-pipeline/src/pipeline/context.rs @@ -101,6 +101,10 @@ impl<'a> PipelineContext<'a> { } pub async fn build_artifacts(&mut self) -> Result { + if self.pipeline_config.chunk_only { + return self.build_chunk_only_artifacts().await; + } + let content = self.take_text_content()?; let analysis = self.take_analysis()?; @@ -113,8 +117,7 @@ impl<'a> PipelineContext<'a> { ) .await?; - let chunk_range: Range = self.pipeline_config.tuning.chunk_min_chars - ..self.pipeline_config.tuning.chunk_max_chars; + let chunk_range = self.chunk_token_range(); let chunks = self.services.prepare_chunks(&content, chunk_range).await?; @@ -125,4 +128,22 @@ impl<'a> PipelineContext<'a> { chunks, }) } + + pub async fn build_chunk_only_artifacts(&mut self) -> Result { + let content = self.take_text_content()?; + let chunk_range = self.chunk_token_range(); + + let chunks = self.services.prepare_chunks(&content, chunk_range).await?; + + Ok(PipelineArtifacts { + text_content: content, + entities: Vec::new(), + relationships: Vec::new(), + chunks, + }) + } + + fn chunk_token_range(&self) -> Range { + self.pipeline_config.tuning.chunk_min_tokens..self.pipeline_config.tuning.chunk_max_tokens + } } diff --git a/ingestion-pipeline/src/pipeline/mod.rs b/ingestion-pipeline/src/pipeline/mod.rs index eaacd11..e642343 100644 --- a/ingestion-pipeline/src/pipeline/mod.rs +++ b/ingestion-pipeline/src/pipeline/mod.rs @@ -51,6 +51,27 @@ impl IngestionPipeline { reranker_pool: Option>, storage: StorageManager, embedding_provider: Arc, + ) -> Result { + Self::new_with_config( + db, + openai_client, + config, + reranker_pool, + storage, + embedding_provider, + IngestionConfig::default(), + ) + .await + } + + pub async fn new_with_config( + db: Arc, + openai_client: Arc>, + config: AppConfig, + reranker_pool: Option>, + storage: StorageManager, + embedding_provider: Arc, + pipeline_config: IngestionConfig, ) -> Result { let services = DefaultPipelineServices::new( db.clone(), @@ -61,7 +82,7 @@ impl IngestionPipeline { embedding_provider, ); - Self::with_services(db, IngestionConfig::default(), Arc::new(services)) + Self::with_services(db, pipeline_config, Arc::new(services)) } pub fn with_services( diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index fbfa756..949d899 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -21,7 +21,6 @@ use common::{ utils::{config::AppConfig, embedding::EmbeddingProvider}, }; use retrieval_pipeline::{reranking::RerankerPool, retrieved_entities_to_json, RetrievedEntity}; -use text_splitter::TextSplitter; use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content}; use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk}; @@ -59,7 +58,7 @@ pub trait PipelineServices: Send + Sync { async fn prepare_chunks( &self, content: &TextContent, - range: Range, + token_range: Range, ) -> Result, AppError>; } @@ -238,23 +237,20 @@ impl PipelineServices for DefaultPipelineServices { async fn prepare_chunks( &self, content: &TextContent, - range: Range, + token_range: Range, ) -> Result, AppError> { - let splitter = TextSplitter::new(range.clone()); - let chunk_texts: Vec = splitter - .chunks(&content.text) - .map(|chunk| chunk.to_string()) - .collect(); + let chunk_candidates = + split_by_token_bounds(&content.text, token_range.start, token_range.end)?; - let mut chunks = Vec::with_capacity(chunk_texts.len()); - for chunk in chunk_texts { + let mut chunks = Vec::with_capacity(chunk_candidates.len()); + for chunk_text in chunk_candidates { let embedding = self .embedding_provider - .embed(&chunk) + .embed(&chunk_text) .await .context("generating FastEmbed embedding for chunk")?; let chunk_struct = - TextChunk::new(content.get_id().to_string(), chunk, content.user_id.clone()); + TextChunk::new(content.get_id().to_string(), chunk_text, content.user_id.clone()); chunks.push(EmbeddedTextChunk { chunk: chunk_struct, embedding, @@ -264,6 +260,45 @@ impl PipelineServices for DefaultPipelineServices { } } +fn split_by_token_bounds( + text: &str, + min_tokens: usize, + max_tokens: usize, +) -> Result, AppError> { + if min_tokens == 0 || max_tokens == 0 || min_tokens > max_tokens { + return Err(AppError::Validation( + "invalid chunk token bounds; ensure 0 < min <= max".into(), + )); + } + + let tokens: Vec<&str> = text.split_whitespace().collect(); + if tokens.is_empty() { + return Ok(vec![String::new()]); + } + + let mut chunks = Vec::new(); + let mut buffer: Vec<&str> = Vec::new(); + for (idx, token) in tokens.iter().enumerate() { + buffer.push(token); + let remaining = tokens.len().saturating_sub(idx + 1); + let at_max = buffer.len() >= max_tokens; + let at_min_and_boundary = + buffer.len() >= min_tokens && (remaining == 0 || buffer.len() + 1 > max_tokens); + if at_max || at_min_and_boundary { + let chunk_text = buffer.join(" "); + chunks.push(chunk_text); + buffer.clear(); + } + } + + if !buffer.is_empty() { + let chunk_text = buffer.join(" "); + chunks.push(chunk_text); + } + + Ok(chunks) +} + fn truncate_for_embedding(text: &str, max_chars: usize) -> String { if text.chars().count() <= max_chars { return text.to_string(); diff --git a/ingestion-pipeline/src/pipeline/stages/mod.rs b/ingestion-pipeline/src/pipeline/stages/mod.rs index 4891ecd..73316b9 100644 --- a/ingestion-pipeline/src/pipeline/stages/mod.rs +++ b/ingestion-pipeline/src/pipeline/stages/mod.rs @@ -16,6 +16,7 @@ use tracing::{debug, instrument, warn}; use super::{ context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk, PipelineArtifacts, PipelineContext}, + enrichment_result::LLMEnrichmentResult, state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved}, }; @@ -76,6 +77,12 @@ pub async fn retrieve_related( machine: IngestionMachine<(), ContentPrepared>, ctx: &mut PipelineContext<'_>, ) -> Result, AppError> { + if ctx.pipeline_config.chunk_only { + return machine + .retrieve() + .map_err(|(_, guard)| map_guard_error("retrieve", guard)); + } + let content = ctx.text_content()?; let similar = ctx.services.retrieve_similar_entities(content).await?; @@ -102,6 +109,16 @@ pub async fn enrich( machine: IngestionMachine<(), Retrieved>, ctx: &mut PipelineContext<'_>, ) -> Result, AppError> { + if ctx.pipeline_config.chunk_only { + ctx.analysis = Some(LLMEnrichmentResult { + knowledge_entities: Vec::new(), + relationships: Vec::new(), + }); + return machine + .enrich() + .map_err(|(_, guard)| map_guard_error("enrich", guard)); + } + let content = ctx.text_content()?; let analysis = ctx .services diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index 5a7ff75..476924f 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -212,9 +212,9 @@ impl PipelineServices for FailingServices { async fn prepare_chunks( &self, content: &TextContent, - range: std::ops::Range, + token_range: std::ops::Range, ) -> Result, AppError> { - self.inner.prepare_chunks(content, range).await + self.inner.prepare_chunks(content, token_range).await } } @@ -254,7 +254,7 @@ impl PipelineServices for ValidationServices { async fn prepare_chunks( &self, _content: &TextContent, - _range: std::ops::Range, + _token_range: std::ops::Range, ) -> Result, AppError> { unreachable!("prepare_chunks should not be called after validation failure") } @@ -275,12 +275,13 @@ async fn setup_db() -> SurrealDbClient { fn pipeline_config() -> IngestionConfig { IngestionConfig { tuning: IngestionTuning { - chunk_min_chars: 4, - chunk_max_chars: 64, + chunk_min_tokens: 4, + chunk_max_tokens: 64, chunk_insert_concurrency: 4, entity_embedding_concurrency: 2, ..IngestionTuning::default() }, + chunk_only: false, } } @@ -362,6 +363,69 @@ async fn ingestion_pipeline_happy_path_persists_entities() { assert!(call_log[4..].iter().all(|entry| *entry == "chunk")); } +#[tokio::test] +async fn ingestion_pipeline_chunk_only_skips_analysis() { + let db = setup_db().await; + let worker_id = "worker-chunk-only"; + let user_id = "user-999"; + let services = Arc::new(MockServices::new(user_id)); + let mut config = pipeline_config(); + config.chunk_only = true; + let pipeline = + IngestionPipeline::with_services(Arc::new(db.clone()), config, services.clone()) + .expect("pipeline"); + + let task = reserve_task( + &db, + worker_id, + IngestionPayload::Text { + text: "Chunk only payload".into(), + context: "Context".into(), + category: "notes".into(), + user_id: user_id.into(), + }, + user_id, + ) + .await; + + pipeline + .process_task(task.clone()) + .await + .expect("pipeline succeeds"); + + let stored_entities: Vec = db + .get_all_stored_items::() + .await + .expect("entities stored"); + assert!( + stored_entities.is_empty(), + "chunk-only ingestion should not persist entities" + ); + let relationship_count: Option = db + .client + .query("SELECT count() as count FROM relates_to;") + .await + .expect("query relationships") + .take::>(0) + .unwrap_or_default(); + assert_eq!( + relationship_count.unwrap_or(0), + 0, + "chunk-only ingestion should not persist relationships" + ); + let stored_chunks: Vec = db + .get_all_stored_items::() + .await + .expect("chunks stored"); + assert!( + !stored_chunks.is_empty(), + "chunk-only ingestion should still persist chunks" + ); + + let call_log = services.calls.lock().await.clone(); + assert_eq!(call_log, vec!["prepare", "chunk"]); +} + #[tokio::test] async fn ingestion_pipeline_failure_marks_retry() { let db = setup_db().await; diff --git a/retrieval-pipeline/src/fts.rs b/retrieval-pipeline/src/fts.rs index 5dd1b4d..f439d3b 100644 --- a/retrieval-pipeline/src/fts.rs +++ b/retrieval-pipeline/src/fts.rs @@ -116,6 +116,7 @@ where #[cfg(test)] mod tests { use super::*; + use common::storage::indexes::ensure_runtime_indexes; use common::storage::types::{ knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, text_chunk::TextChunk, @@ -134,6 +135,9 @@ mod tests { db.apply_migrations() .await .expect("failed to apply migrations"); + ensure_runtime_indexes(&db, 1536) + .await + .expect("failed to build runtime indexes"); let user_id = "user_fts"; let entity = KnowledgeEntity::new( @@ -181,6 +185,9 @@ mod tests { db.apply_migrations() .await .expect("failed to apply migrations"); + ensure_runtime_indexes(&db, 1536) + .await + .expect("failed to build runtime indexes"); let user_id = "user_fts_desc"; let entity = KnowledgeEntity::new( @@ -228,6 +235,9 @@ mod tests { db.apply_migrations() .await .expect("failed to apply migrations"); + ensure_runtime_indexes(&db, 1536) + .await + .expect("failed to build runtime indexes"); let user_id = "user_fts_chunk"; let chunk = TextChunk::new( diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index 167581c..063f6e0 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -69,6 +69,7 @@ pub async fn retrieve_entities( mod tests { use super::*; use async_openai::Client; + use common::storage::indexes::ensure_runtime_indexes; use common::storage::types::{ knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, knowledge_relationship::KnowledgeRelationship, @@ -108,6 +109,10 @@ mod tests { .await .expect("Failed to apply migrations"); + ensure_runtime_indexes(&db, 3) + .await + .expect("failed to build runtime indexes"); + db.query( "BEGIN TRANSACTION; REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;