perf: batch entity embeddings during ingest and expand retry tests.

Entity enrichment now uses embed_batch like chunks; the unused entity_embedding_concurrency knob is removed and ingest retry paths gain test coverage.
This commit is contained in:
Per Stark
2026-06-12 18:40:36 +02:00
parent 1013035731
commit adc04d8c6d
8 changed files with 267 additions and 96 deletions
+2
View File
@@ -2,10 +2,12 @@
## Unreleased
- Performance: ingestion skips per-task index rebuild; worker runs scheduled `REBUILD INDEX` (default every 24h via `index_rebuild_interval_secs`, `0` disables)
- Performance: ingestion persists all artifacts in a single SurrealDB transaction per task (atomic replace by task id)
- Performance: entity embeddings during ingestion use batched `embed_batch`, matching chunk embedding
- Fix: ingestion reclaims tasks after a successful persist without re-running the pipeline when `mark_succeeded` failed
- Fix: content deletion clears graph relationships via shared `TextContent::clear_ingested_children`
- Fix: regression re suggestion of relationships
- Internal: eval corpus DB seed uses `persist_artifacts` instead of a separate batched insert path
- Internal: removed unused `entity_embedding_concurrency` ingest tuning knob
## 1.0.3 (2026-06-12)
- Search: filter results by type — knowledge entities, ingested content, or both
@@ -9,7 +9,6 @@ pub struct IngestionTuning {
pub chunk_min_tokens: usize,
pub chunk_max_tokens: usize,
pub chunk_overlap_tokens: usize,
pub entity_embedding_concurrency: usize,
/// Maximum characters of content body used to build the similarity-search query
/// during retrieval. Longer bodies are truncated to keep embedding inputs bounded.
pub embedding_query_char_limit: usize,
@@ -27,7 +26,6 @@ impl Default for IngestionTuning {
chunk_min_tokens: 256,
chunk_max_tokens: 512,
chunk_overlap_tokens: 50,
entity_embedding_concurrency: 4,
embedding_query_char_limit: 12_000,
}
}
+1 -5
View File
@@ -111,11 +111,7 @@ impl<'a> PipelineContext<'a> {
let (entities, relationships) = self
.services
.convert_analysis(
&content,
&analysis,
self.pipeline_config.tuning.entity_embedding_concurrency,
)
.convert_analysis(&content, &analysis)
.await?;
let chunk_range = self.chunk_token_range();
@@ -1,7 +1,4 @@
use std::sync::Arc;
use chrono::Utc;
use futures::stream::{self, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use common::{
@@ -43,22 +40,15 @@ impl LLMEnrichmentResult {
&self,
source_id: &str,
user_id: &str,
entity_concurrency: usize,
embedding_provider: &EmbeddingProvider,
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
let mapper = Arc::new(self.create_mapper());
let mapper = self.create_mapper();
let entities = self
.process_entities(
source_id,
user_id,
Arc::clone(&mapper),
entity_concurrency,
embedding_provider,
)
.process_entities(source_id, user_id, &mapper, embedding_provider)
.await?;
let relationships = self.process_relationships(source_id, user_id, mapper.as_ref())?;
let relationships = self.process_relationships(source_id, user_id, &mapper)?;
Ok((entities, relationships))
}
@@ -77,36 +67,64 @@ impl LLMEnrichmentResult {
&self,
source_id: &str,
user_id: &str,
mapper: Arc<GraphMapper>,
entity_concurrency: usize,
mapper: &GraphMapper,
embedding_provider: &EmbeddingProvider,
) -> Result<Vec<EmbeddedKnowledgeEntity>, AppError> {
let tasks: Vec<_> = self
.knowledge_entities
.iter()
.map(|entity| {
let llm_entity = entity.clone();
let mapper = Arc::clone(&mapper);
let source_id = source_id.to_string();
let user_id = user_id.to_string();
if self.knowledge_entities.is_empty() {
return Ok(Vec::new());
}
async move {
create_single_entity(
llm_entity,
&source_id,
&user_id,
mapper,
embedding_provider,
)
.await
}
})
.collect();
let now = Utc::now();
let mut prepared = Vec::with_capacity(self.knowledge_entities.len());
let mut embedding_inputs = Vec::with_capacity(self.knowledge_entities.len());
stream::iter(tasks)
.buffer_unordered(entity_concurrency.max(1))
.try_collect()
for llm_entity in &self.knowledge_entities {
let assigned_id = mapper.get_id(&llm_entity.key)?.to_string();
let entity_type = KnowledgeEntityType::from(llm_entity.entity_type.clone());
embedding_inputs.push(KnowledgeEntity::embedding_input_text(
&llm_entity.name,
&llm_entity.description,
entity_type,
));
prepared.push((llm_entity, assigned_id, entity_type));
}
// Embed all entities from this document in one batch: a single lock acquisition and one
// blocking hop, letting the backend batch the inference internally.
let embeddings = embedding_provider
.embed_batch(&embedding_inputs)
.await
.map_err(|e| AppError::InternalError(format!("entity embedding batch failed: {e}")))?;
if embeddings.len() != prepared.len() {
return Err(AppError::InternalError(format!(
"embedding batch returned {} vectors for {} entities",
embeddings.len(),
prepared.len()
)));
}
let mut entities = Vec::with_capacity(prepared.len());
for ((llm_entity, assigned_id, entity_type), embedding) in
prepared.into_iter().zip(embeddings)
{
entities.push(EmbeddedKnowledgeEntity {
entity: KnowledgeEntity {
id: assigned_id,
created_at: now,
updated_at: now,
name: llm_entity.name.clone(),
description: llm_entity.description.clone(),
entity_type,
source_id: source_id.to_string(),
metadata: None,
user_id: user_id.to_string(),
},
embedding,
});
}
Ok(entities)
}
fn process_relationships(
@@ -133,44 +151,11 @@ impl LLMEnrichmentResult {
}
}
async fn create_single_entity(
llm_entity: LLMKnowledgeEntity,
source_id: &str,
user_id: &str,
mapper: Arc<GraphMapper>,
embedding_provider: &EmbeddingProvider,
) -> Result<EmbeddedKnowledgeEntity, AppError> {
let assigned_id = mapper.get_id(&llm_entity.key)?.to_string();
let entity_type = KnowledgeEntityType::from(llm_entity.entity_type);
let embedding_input = KnowledgeEntity::embedding_input_text(
&llm_entity.name,
&llm_entity.description,
entity_type,
);
let embedding = embedding_provider.embed(&embedding_input).await?;
let now = Utc::now();
let entity = KnowledgeEntity {
id: assigned_id,
created_at: now,
updated_at: now,
name: llm_entity.name,
description: llm_entity.description,
entity_type,
source_id: source_id.to_string(),
metadata: None,
user_id: user_id.into(),
};
Ok(EmbeddedKnowledgeEntity { entity, embedding })
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
use super::*;
use common::utils::embedding::EmbeddingProvider;
use uuid::Uuid;
fn entity(key: &str) -> LLMKnowledgeEntity {
@@ -247,6 +232,32 @@ mod tests {
);
}
#[tokio::test]
async fn process_entities_batches_embeddings_and_preserves_order() -> anyhow::Result<()> {
let result = LLMEnrichmentResult {
knowledge_entities: vec![entity("k1"), entity("k2"), entity("k3")],
relationships: Vec::new(),
};
let mapper = result.create_mapper();
let provider = EmbeddingProvider::new_hashed(8)?;
let entities = result
.process_entities("source-1", "user-1", &mapper, &provider)
.await?;
assert_eq!(entities.len(), 3);
let first = entities.first().expect("first entity");
let second = entities.get(1).expect("second entity");
let third = entities.get(2).expect("third entity");
assert_eq!(first.entity.name, "name-k1");
assert_eq!(second.entity.name, "name-k2");
assert_eq!(third.entity.name, "name-k3");
assert!(entities.iter().all(|item| item.embedding.len() == 8));
assert_ne!(first.embedding, second.embedding);
Ok(())
}
#[test]
fn process_relationships_errors_on_unknown_endpoint() {
let result = LLMEnrichmentResult {
+82
View File
@@ -347,3 +347,85 @@ mod test_support;
#[cfg(test)]
mod tests;
#[cfg(test)]
mod finalize_tests {
use std::{sync::Arc, time::Duration};
use common::storage::types::{
ingestion_payload::IngestionPayload,
ingestion_task::{IngestionTask, TaskState},
};
use tokio::time::sleep;
use super::{
config::IngestionTuning,
test_support::setup_db,
tests::{pipeline_config, reserve_task, MockServices},
IngestionPipeline, PipelineServices,
};
#[tokio::test]
async fn finalize_succeeded_retries_mark_succeeded() -> anyhow::Result<()> {
use anyhow::Context;
let db = setup_db().await?;
let worker_id = "worker-finalize-retry";
let user_id = "user-finalize-retry";
let services: Arc<dyn PipelineServices> = Arc::new(MockServices::new(user_id));
let mut config = pipeline_config();
config.tuning = IngestionTuning {
persist_attempts: 3,
persist_initial_backoff_ms: 10,
persist_max_backoff_ms: 10,
..IngestionTuning::default()
};
let pipeline =
IngestionPipeline::with_services(Arc::new(db.clone()), config, services)?;
let task = reserve_task(
&db,
worker_id,
IngestionPayload::Text {
text: "Finalize retry payload".into(),
context: "Context".into(),
category: "notes".into(),
user_id: user_id.into(),
},
user_id,
)
.await?;
let processing = task.mark_processing(&db).await?;
db.client
.query(
"UPDATE type::thing('ingestion_task', $id) SET worker_id = $wrong_worker;",
)
.bind(("id", processing.id.clone()))
.bind(("wrong_worker", "wrong-worker"))
.await?;
let task_id = processing.id.clone();
let db_fix = db.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(5)).await;
let _ = db_fix
.client
.query(
"UPDATE type::thing('ingestion_task', $id) SET worker_id = $worker_id;",
)
.bind(("id", task_id))
.bind(("worker_id", worker_id))
.await;
});
pipeline.finalize_succeeded(&processing).await?;
let stored: IngestionTask = db
.get_item(&processing.id)
.await?
.context("task stored")?;
assert_eq!(stored.state, TaskState::Succeeded);
Ok(())
}
}
@@ -124,6 +124,13 @@ async fn execute_persist_transaction(
db: &SurrealDbClient,
payload: &PersistPayload,
) -> Result<(), AppError> {
#[cfg(test)]
if test_persist_should_fail() {
return Err(AppError::InternalError(
"Failed to commit transaction due to a read or write conflict".into(),
));
}
let mut query = String::from("BEGIN TRANSACTION;\n");
query.push_str(TextContent::CLEAR_INGESTED_CHILD_ROWS_SURQL);
query.push_str(
@@ -236,6 +243,24 @@ fn is_retryable_conflict(error: &AppError) -> bool {
.contains("Failed to commit transaction due to a read or write conflict")
}
#[cfg(test)]
static TEST_PERSIST_FAILURES: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
#[cfg(test)]
fn set_test_persist_failures(count: usize) {
TEST_PERSIST_FAILURES.store(count, std::sync::atomic::Ordering::SeqCst);
}
#[cfg(test)]
fn test_persist_should_fail() -> bool {
let remaining = TEST_PERSIST_FAILURES.load(std::sync::atomic::Ordering::SeqCst);
if remaining == 0 {
return false;
}
TEST_PERSIST_FAILURES.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
true
}
#[cfg(test)]
mod tests {
use common::storage::types::text_content::TextContent;
@@ -340,4 +365,44 @@ mod tests {
Ok(())
}
#[test]
fn is_retryable_conflict_matches_surreal_transaction_conflict() {
let err = AppError::InternalError(
"Failed to commit transaction due to a read or write conflict".into(),
);
assert!(is_retryable_conflict(&err));
}
#[test]
fn is_retryable_conflict_rejects_unrelated_errors() {
let err = AppError::Validation("invalid payload".into());
assert!(!is_retryable_conflict(&err));
}
#[tokio::test]
async fn persist_artifacts_retries_transient_conflicts() -> anyhow::Result<()> {
set_test_persist_failures(2);
let db = setup_db().await?;
let source_id = uuid::Uuid::new_v4().to_string();
let user_id = "persist-retry";
let mut tuning = test_support::tuning();
tuning.persist_attempts = 3;
tuning.persist_initial_backoff_ms = 1;
tuning.persist_max_backoff_ms = 1;
let counts = persist_artifacts(
&db,
&tuning,
TEST_EMBEDDING_DIM,
sample_artifacts(&source_id, user_id),
)
.await?;
assert_eq!(counts.chunk_count, 1);
assert_eq!(count_chunks_for_source(&db, &source_id).await?, 1);
Ok(())
}
}
+1 -8
View File
@@ -51,7 +51,6 @@ pub trait PipelineServices: Send + Sync {
&self,
content: &TextContent,
analysis: &LLMEnrichmentResult,
entity_concurrency: usize,
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>;
async fn prepare_chunks(
@@ -230,15 +229,9 @@ impl PipelineServices for DefaultPipelineServices {
&self,
content: &TextContent,
analysis: &LLMEnrichmentResult,
entity_concurrency: usize,
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
analysis
.to_database_entities(
content.id(),
&content.user_id,
entity_concurrency,
&self.embedding_provider,
)
.to_database_entities(content.id(), &content.user_id, &self.embedding_provider)
.await
}
+33 -9
View File
@@ -31,7 +31,7 @@ use super::{
IngestionPipeline,
};
struct MockServices {
pub(crate) struct MockServices {
text_content: TextContent,
similar_entities: Vec<RetrievedEntity>,
analysis: LLMEnrichmentResult,
@@ -42,7 +42,7 @@ struct MockServices {
}
impl MockServices {
fn new(user_id: &str) -> Self {
pub(crate) fn new(user_id: &str) -> Self {
const TEST_EMBEDDING_DIM: usize = 1536;
let text_content = TextContent::new(
"Example document for ingestion pipeline.".into(),
@@ -145,7 +145,6 @@ impl PipelineServices for MockServices {
&self,
content: &TextContent,
_analysis: &LLMEnrichmentResult,
_entity_concurrency: usize,
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
self.record("convert").await;
let entities = self
@@ -221,10 +220,9 @@ impl PipelineServices for FailingServices {
&self,
content: &TextContent,
analysis: &LLMEnrichmentResult,
entity_concurrency: usize,
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
self.inner
.convert_analysis(content, analysis, entity_concurrency)
.convert_analysis(content, analysis)
.await
}
@@ -268,7 +266,6 @@ impl PipelineServices for ValidationServices {
&self,
_content: &TextContent,
_analysis: &LLMEnrichmentResult,
_entity_concurrency: usize,
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
unreachable!("convert_analysis should not be called after validation failure")
}
@@ -283,19 +280,18 @@ impl PipelineServices for ValidationServices {
}
}
fn pipeline_config() -> IngestionConfig {
pub(crate) fn pipeline_config() -> IngestionConfig {
IngestionConfig {
tuning: IngestionTuning {
chunk_min_tokens: 4,
chunk_max_tokens: 64,
entity_embedding_concurrency: 2,
..IngestionTuning::default()
},
chunk_only: false,
}
}
async fn reserve_task(
pub(crate) async fn reserve_task(
db: &SurrealDbClient,
worker_id: &str,
payload: IngestionPayload,
@@ -459,6 +455,34 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() -> anyhow::Result<()> {
Ok(())
}
#[tokio::test]
async fn produce_artifacts_returns_enriched_snapshot_without_persisting() -> anyhow::Result<()> {
let db = setup_db().await?;
let user_id = "user-produce";
let services = Arc::new(MockServices::new(user_id));
let pipeline =
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)?;
let payload = IngestionPayload::Text {
text: "Produce artifacts payload".into(),
context: "Context".into(),
category: "notes".into(),
user_id: user_id.into(),
};
let task = IngestionTask::new(payload, user_id.to_string());
let artifacts = pipeline.produce_artifacts(&task).await?;
assert_eq!(artifacts.text_content.user_id, user_id);
assert_eq!(artifacts.chunks.len(), 1);
assert_eq!(artifacts.entities.len(), 1);
assert_eq!(artifacts.relationships.len(), 1);
assert_eq!(count_chunks_for_source(&db, &task.id).await?, 0);
assert_eq!(count_entities_for_source(&db, &task.id).await?, 0);
Ok(())
}
#[tokio::test]
async fn ingestion_pipeline_failure_marks_retry() -> anyhow::Result<()> {
let db = setup_db().await?;