diff --git a/CHANGELOG.md b/CHANGELOG.md index 5408f44..32a6e7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,12 @@ ## Unreleased - Performance: ingestion skips per-task index rebuild; worker runs scheduled `REBUILD INDEX` (default every 24h via `index_rebuild_interval_secs`, `0` disables) - Performance: ingestion persists all artifacts in a single SurrealDB transaction per task (atomic replace by task id) +- Performance: entity embeddings during ingestion use batched `embed_batch`, matching chunk embedding - Fix: ingestion reclaims tasks after a successful persist without re-running the pipeline when `mark_succeeded` failed - Fix: content deletion clears graph relationships via shared `TextContent::clear_ingested_children` - Fix: regression re suggestion of relationships - Internal: eval corpus DB seed uses `persist_artifacts` instead of a separate batched insert path +- Internal: removed unused `entity_embedding_concurrency` ingest tuning knob ## 1.0.3 (2026-06-12) - Search: filter results by type — knowledge entities, ingested content, or both diff --git a/ingestion-pipeline/src/pipeline/config.rs b/ingestion-pipeline/src/pipeline/config.rs index 1820c26..ceb00d4 100644 --- a/ingestion-pipeline/src/pipeline/config.rs +++ b/ingestion-pipeline/src/pipeline/config.rs @@ -9,7 +9,6 @@ pub struct IngestionTuning { pub chunk_min_tokens: usize, pub chunk_max_tokens: usize, pub chunk_overlap_tokens: usize, - pub entity_embedding_concurrency: usize, /// Maximum characters of content body used to build the similarity-search query /// during retrieval. Longer bodies are truncated to keep embedding inputs bounded. pub embedding_query_char_limit: usize, @@ -27,7 +26,6 @@ impl Default for IngestionTuning { chunk_min_tokens: 256, chunk_max_tokens: 512, chunk_overlap_tokens: 50, - entity_embedding_concurrency: 4, embedding_query_char_limit: 12_000, } } diff --git a/ingestion-pipeline/src/pipeline/context.rs b/ingestion-pipeline/src/pipeline/context.rs index e3f818f..9014f94 100644 --- a/ingestion-pipeline/src/pipeline/context.rs +++ b/ingestion-pipeline/src/pipeline/context.rs @@ -111,11 +111,7 @@ impl<'a> PipelineContext<'a> { let (entities, relationships) = self .services - .convert_analysis( - &content, - &analysis, - self.pipeline_config.tuning.entity_embedding_concurrency, - ) + .convert_analysis(&content, &analysis) .await?; let chunk_range = self.chunk_token_range(); diff --git a/ingestion-pipeline/src/pipeline/enrichment_result.rs b/ingestion-pipeline/src/pipeline/enrichment_result.rs index aa2ceae..900f66b 100644 --- a/ingestion-pipeline/src/pipeline/enrichment_result.rs +++ b/ingestion-pipeline/src/pipeline/enrichment_result.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - use chrono::Utc; -use futures::stream::{self, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use common::{ @@ -43,22 +40,15 @@ impl LLMEnrichmentResult { &self, source_id: &str, user_id: &str, - entity_concurrency: usize, embedding_provider: &EmbeddingProvider, ) -> Result<(Vec, Vec), AppError> { - let mapper = Arc::new(self.create_mapper()); + let mapper = self.create_mapper(); let entities = self - .process_entities( - source_id, - user_id, - Arc::clone(&mapper), - entity_concurrency, - embedding_provider, - ) + .process_entities(source_id, user_id, &mapper, embedding_provider) .await?; - let relationships = self.process_relationships(source_id, user_id, mapper.as_ref())?; + let relationships = self.process_relationships(source_id, user_id, &mapper)?; Ok((entities, relationships)) } @@ -77,36 +67,64 @@ impl LLMEnrichmentResult { &self, source_id: &str, user_id: &str, - mapper: Arc, - entity_concurrency: usize, + mapper: &GraphMapper, embedding_provider: &EmbeddingProvider, ) -> Result, AppError> { - let tasks: Vec<_> = self - .knowledge_entities - .iter() - .map(|entity| { - let llm_entity = entity.clone(); - let mapper = Arc::clone(&mapper); - let source_id = source_id.to_string(); - let user_id = user_id.to_string(); + if self.knowledge_entities.is_empty() { + return Ok(Vec::new()); + } - async move { - create_single_entity( - llm_entity, - &source_id, - &user_id, - mapper, - embedding_provider, - ) - .await - } - }) - .collect(); + let now = Utc::now(); + let mut prepared = Vec::with_capacity(self.knowledge_entities.len()); + let mut embedding_inputs = Vec::with_capacity(self.knowledge_entities.len()); - stream::iter(tasks) - .buffer_unordered(entity_concurrency.max(1)) - .try_collect() + for llm_entity in &self.knowledge_entities { + let assigned_id = mapper.get_id(&llm_entity.key)?.to_string(); + let entity_type = KnowledgeEntityType::from(llm_entity.entity_type.clone()); + embedding_inputs.push(KnowledgeEntity::embedding_input_text( + &llm_entity.name, + &llm_entity.description, + entity_type, + )); + prepared.push((llm_entity, assigned_id, entity_type)); + } + + // Embed all entities from this document in one batch: a single lock acquisition and one + // blocking hop, letting the backend batch the inference internally. + let embeddings = embedding_provider + .embed_batch(&embedding_inputs) .await + .map_err(|e| AppError::InternalError(format!("entity embedding batch failed: {e}")))?; + + if embeddings.len() != prepared.len() { + return Err(AppError::InternalError(format!( + "embedding batch returned {} vectors for {} entities", + embeddings.len(), + prepared.len() + ))); + } + + let mut entities = Vec::with_capacity(prepared.len()); + for ((llm_entity, assigned_id, entity_type), embedding) in + prepared.into_iter().zip(embeddings) + { + entities.push(EmbeddedKnowledgeEntity { + entity: KnowledgeEntity { + id: assigned_id, + created_at: now, + updated_at: now, + name: llm_entity.name.clone(), + description: llm_entity.description.clone(), + entity_type, + source_id: source_id.to_string(), + metadata: None, + user_id: user_id.to_string(), + }, + embedding, + }); + } + + Ok(entities) } fn process_relationships( @@ -133,44 +151,11 @@ impl LLMEnrichmentResult { } } -async fn create_single_entity( - llm_entity: LLMKnowledgeEntity, - source_id: &str, - user_id: &str, - mapper: Arc, - embedding_provider: &EmbeddingProvider, -) -> Result { - let assigned_id = mapper.get_id(&llm_entity.key)?.to_string(); - - let entity_type = KnowledgeEntityType::from(llm_entity.entity_type); - let embedding_input = KnowledgeEntity::embedding_input_text( - &llm_entity.name, - &llm_entity.description, - entity_type, - ); - - let embedding = embedding_provider.embed(&embedding_input).await?; - - let now = Utc::now(); - let entity = KnowledgeEntity { - id: assigned_id, - created_at: now, - updated_at: now, - name: llm_entity.name, - description: llm_entity.description, - entity_type, - source_id: source_id.to_string(), - metadata: None, - user_id: user_id.into(), - }; - - Ok(EmbeddedKnowledgeEntity { entity, embedding }) -} - #[cfg(test)] mod tests { #![allow(clippy::expect_used)] use super::*; + use common::utils::embedding::EmbeddingProvider; use uuid::Uuid; fn entity(key: &str) -> LLMKnowledgeEntity { @@ -247,6 +232,32 @@ mod tests { ); } + #[tokio::test] + async fn process_entities_batches_embeddings_and_preserves_order() -> anyhow::Result<()> { + let result = LLMEnrichmentResult { + knowledge_entities: vec![entity("k1"), entity("k2"), entity("k3")], + relationships: Vec::new(), + }; + let mapper = result.create_mapper(); + let provider = EmbeddingProvider::new_hashed(8)?; + + let entities = result + .process_entities("source-1", "user-1", &mapper, &provider) + .await?; + + assert_eq!(entities.len(), 3); + let first = entities.first().expect("first entity"); + let second = entities.get(1).expect("second entity"); + let third = entities.get(2).expect("third entity"); + assert_eq!(first.entity.name, "name-k1"); + assert_eq!(second.entity.name, "name-k2"); + assert_eq!(third.entity.name, "name-k3"); + assert!(entities.iter().all(|item| item.embedding.len() == 8)); + assert_ne!(first.embedding, second.embedding); + + Ok(()) + } + #[test] fn process_relationships_errors_on_unknown_endpoint() { let result = LLMEnrichmentResult { diff --git a/ingestion-pipeline/src/pipeline/mod.rs b/ingestion-pipeline/src/pipeline/mod.rs index 0c9a604..75d853a 100644 --- a/ingestion-pipeline/src/pipeline/mod.rs +++ b/ingestion-pipeline/src/pipeline/mod.rs @@ -347,3 +347,85 @@ mod test_support; #[cfg(test)] mod tests; + +#[cfg(test)] +mod finalize_tests { + use std::{sync::Arc, time::Duration}; + + use common::storage::types::{ + ingestion_payload::IngestionPayload, + ingestion_task::{IngestionTask, TaskState}, + }; + use tokio::time::sleep; + + use super::{ + config::IngestionTuning, + test_support::setup_db, + tests::{pipeline_config, reserve_task, MockServices}, + IngestionPipeline, PipelineServices, + }; + + #[tokio::test] + async fn finalize_succeeded_retries_mark_succeeded() -> anyhow::Result<()> { + use anyhow::Context; + let db = setup_db().await?; + let worker_id = "worker-finalize-retry"; + let user_id = "user-finalize-retry"; + let services: Arc = Arc::new(MockServices::new(user_id)); + let mut config = pipeline_config(); + config.tuning = IngestionTuning { + persist_attempts: 3, + persist_initial_backoff_ms: 10, + persist_max_backoff_ms: 10, + ..IngestionTuning::default() + }; + let pipeline = + IngestionPipeline::with_services(Arc::new(db.clone()), config, services)?; + + let task = reserve_task( + &db, + worker_id, + IngestionPayload::Text { + text: "Finalize retry payload".into(), + context: "Context".into(), + category: "notes".into(), + user_id: user_id.into(), + }, + user_id, + ) + .await?; + let processing = task.mark_processing(&db).await?; + + db.client + .query( + "UPDATE type::thing('ingestion_task', $id) SET worker_id = $wrong_worker;", + ) + .bind(("id", processing.id.clone())) + .bind(("wrong_worker", "wrong-worker")) + .await?; + + let task_id = processing.id.clone(); + let db_fix = db.clone(); + tokio::spawn(async move { + sleep(Duration::from_millis(5)).await; + let _ = db_fix + .client + .query( + "UPDATE type::thing('ingestion_task', $id) SET worker_id = $worker_id;", + ) + .bind(("id", task_id)) + .bind(("worker_id", worker_id)) + .await; + }); + + pipeline.finalize_succeeded(&processing).await?; + + let stored: IngestionTask = db + .get_item(&processing.id) + .await? + .context("task stored")?; + assert_eq!(stored.state, TaskState::Succeeded); + + Ok(()) + } +} diff --git a/ingestion-pipeline/src/pipeline/persistence.rs b/ingestion-pipeline/src/pipeline/persistence.rs index a4483f3..7eb9ab5 100644 --- a/ingestion-pipeline/src/pipeline/persistence.rs +++ b/ingestion-pipeline/src/pipeline/persistence.rs @@ -124,6 +124,13 @@ async fn execute_persist_transaction( db: &SurrealDbClient, payload: &PersistPayload, ) -> Result<(), AppError> { + #[cfg(test)] + if test_persist_should_fail() { + return Err(AppError::InternalError( + "Failed to commit transaction due to a read or write conflict".into(), + )); + } + let mut query = String::from("BEGIN TRANSACTION;\n"); query.push_str(TextContent::CLEAR_INGESTED_CHILD_ROWS_SURQL); query.push_str( @@ -236,6 +243,24 @@ fn is_retryable_conflict(error: &AppError) -> bool { .contains("Failed to commit transaction due to a read or write conflict") } +#[cfg(test)] +static TEST_PERSIST_FAILURES: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); + +#[cfg(test)] +fn set_test_persist_failures(count: usize) { + TEST_PERSIST_FAILURES.store(count, std::sync::atomic::Ordering::SeqCst); +} + +#[cfg(test)] +fn test_persist_should_fail() -> bool { + let remaining = TEST_PERSIST_FAILURES.load(std::sync::atomic::Ordering::SeqCst); + if remaining == 0 { + return false; + } + TEST_PERSIST_FAILURES.fetch_sub(1, std::sync::atomic::Ordering::SeqCst); + true +} + #[cfg(test)] mod tests { use common::storage::types::text_content::TextContent; @@ -340,4 +365,44 @@ mod tests { Ok(()) } + + #[test] + fn is_retryable_conflict_matches_surreal_transaction_conflict() { + let err = AppError::InternalError( + "Failed to commit transaction due to a read or write conflict".into(), + ); + assert!(is_retryable_conflict(&err)); + } + + #[test] + fn is_retryable_conflict_rejects_unrelated_errors() { + let err = AppError::Validation("invalid payload".into()); + assert!(!is_retryable_conflict(&err)); + } + + #[tokio::test] + async fn persist_artifacts_retries_transient_conflicts() -> anyhow::Result<()> { + set_test_persist_failures(2); + + let db = setup_db().await?; + let source_id = uuid::Uuid::new_v4().to_string(); + let user_id = "persist-retry"; + let mut tuning = test_support::tuning(); + tuning.persist_attempts = 3; + tuning.persist_initial_backoff_ms = 1; + tuning.persist_max_backoff_ms = 1; + + let counts = persist_artifacts( + &db, + &tuning, + TEST_EMBEDDING_DIM, + sample_artifacts(&source_id, user_id), + ) + .await?; + + assert_eq!(counts.chunk_count, 1); + assert_eq!(count_chunks_for_source(&db, &source_id).await?, 1); + + Ok(()) + } } diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index b0c084c..7eb4485 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -51,7 +51,6 @@ pub trait PipelineServices: Send + Sync { &self, content: &TextContent, analysis: &LLMEnrichmentResult, - entity_concurrency: usize, ) -> Result<(Vec, Vec), AppError>; async fn prepare_chunks( @@ -230,15 +229,9 @@ impl PipelineServices for DefaultPipelineServices { &self, content: &TextContent, analysis: &LLMEnrichmentResult, - entity_concurrency: usize, ) -> Result<(Vec, Vec), AppError> { analysis - .to_database_entities( - content.id(), - &content.user_id, - entity_concurrency, - &self.embedding_provider, - ) + .to_database_entities(content.id(), &content.user_id, &self.embedding_provider) .await } diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index 867e9d3..7483bb6 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -31,7 +31,7 @@ use super::{ IngestionPipeline, }; -struct MockServices { +pub(crate) struct MockServices { text_content: TextContent, similar_entities: Vec, analysis: LLMEnrichmentResult, @@ -42,7 +42,7 @@ struct MockServices { } impl MockServices { - fn new(user_id: &str) -> Self { + pub(crate) fn new(user_id: &str) -> Self { const TEST_EMBEDDING_DIM: usize = 1536; let text_content = TextContent::new( "Example document for ingestion pipeline.".into(), @@ -145,7 +145,6 @@ impl PipelineServices for MockServices { &self, content: &TextContent, _analysis: &LLMEnrichmentResult, - _entity_concurrency: usize, ) -> Result<(Vec, Vec), AppError> { self.record("convert").await; let entities = self @@ -221,10 +220,9 @@ impl PipelineServices for FailingServices { &self, content: &TextContent, analysis: &LLMEnrichmentResult, - entity_concurrency: usize, ) -> Result<(Vec, Vec), AppError> { self.inner - .convert_analysis(content, analysis, entity_concurrency) + .convert_analysis(content, analysis) .await } @@ -268,7 +266,6 @@ impl PipelineServices for ValidationServices { &self, _content: &TextContent, _analysis: &LLMEnrichmentResult, - _entity_concurrency: usize, ) -> Result<(Vec, Vec), AppError> { unreachable!("convert_analysis should not be called after validation failure") } @@ -283,19 +280,18 @@ impl PipelineServices for ValidationServices { } } -fn pipeline_config() -> IngestionConfig { +pub(crate) fn pipeline_config() -> IngestionConfig { IngestionConfig { tuning: IngestionTuning { chunk_min_tokens: 4, chunk_max_tokens: 64, - entity_embedding_concurrency: 2, ..IngestionTuning::default() }, chunk_only: false, } } -async fn reserve_task( +pub(crate) async fn reserve_task( db: &SurrealDbClient, worker_id: &str, payload: IngestionPayload, @@ -459,6 +455,34 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() -> anyhow::Result<()> { Ok(()) } +#[tokio::test] +async fn produce_artifacts_returns_enriched_snapshot_without_persisting() -> anyhow::Result<()> { + let db = setup_db().await?; + let user_id = "user-produce"; + let services = Arc::new(MockServices::new(user_id)); + let pipeline = + IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)?; + + let payload = IngestionPayload::Text { + text: "Produce artifacts payload".into(), + context: "Context".into(), + category: "notes".into(), + user_id: user_id.into(), + }; + let task = IngestionTask::new(payload, user_id.to_string()); + + let artifacts = pipeline.produce_artifacts(&task).await?; + + assert_eq!(artifacts.text_content.user_id, user_id); + assert_eq!(artifacts.chunks.len(), 1); + assert_eq!(artifacts.entities.len(), 1); + assert_eq!(artifacts.relationships.len(), 1); + assert_eq!(count_chunks_for_source(&db, &task.id).await?, 0); + assert_eq!(count_entities_for_source(&db, &task.id).await?, 0); + + Ok(()) +} + #[tokio::test] async fn ingestion_pipeline_failure_marks_retry() -> anyhow::Result<()> { let db = setup_db().await?;