refactor: implemented state machine for ingestion pipeline, improved performance

changelog

additional moving around

moved files around a bit
This commit is contained in:
Per Stark
2025-10-19 16:08:46 +02:00
parent 83d39afad4
commit 07b3e1a0e8
20 changed files with 1762 additions and 802 deletions

View File

@@ -1,5 +1,8 @@
# Changelog
## Unreleased
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
- Added knowledge entity search results to the global search
- Backend fixes for improved performance when ingesting and retrieval
## Version 0.2.4 (2025-10-15)
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.

2
Cargo.lock generated
View File

@@ -2919,6 +2919,7 @@ name = "ingestion-pipeline"
version = "0.1.0"
dependencies = [
"async-openai",
"async-trait",
"axum",
"axum_typed_multipart",
"base64 0.22.1",
@@ -2933,6 +2934,7 @@ dependencies = [
"reqwest",
"serde",
"serde_json",
"state-machines",
"surrealdb",
"tempfile",
"text-splitter",

View File

@@ -31,5 +31,7 @@ lopdf = "0.32"
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }
async-trait = { workspace = true }
state-machines = { workspace = true }
[features]
docker = []

View File

@@ -1,118 +0,0 @@
use std::sync::Arc;
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
};
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity};
use crate::{
types::llm_enrichment_result::LLMEnrichmentResult,
utils::llm_instructions::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
};
pub struct IngestionEnricher {
db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
}
impl IngestionEnricher {
pub const fn new(
db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Self {
Self {
db_client,
openai_client,
}
}
pub async fn analyze_content(
&self,
category: &str,
context: Option<&str>,
text: &str,
user_id: &str,
) -> Result<LLMEnrichmentResult, AppError> {
let similar_entities = self
.find_similar_entities(category, context, text, user_id)
.await?;
let llm_request = self
.prepare_llm_request(category, context, text, &similar_entities)
.await?;
self.perform_analysis(llm_request).await
}
async fn find_similar_entities(
&self,
category: &str,
context: Option<&str>,
text: &str,
user_id: &str,
) -> Result<Vec<RetrievedEntity>, AppError> {
let input_text =
format!("content: {text}, category: {category}, user_context: {context:?}");
retrieve_entities(&self.db_client, &self.openai_client, &input_text, user_id).await
}
async fn prepare_llm_request(
&self,
category: &str,
context: Option<&str>,
text: &str,
similar_entities: &[RetrievedEntity],
) -> Result<CreateChatCompletionRequest, AppError> {
let settings = SystemSettings::get_current(&self.db_client).await?;
let entities_json = retrieved_entities_to_json(similar_entities);
let user_message = format!(
"Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}"
);
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Structured analysis of the submitted content".into()),
name: "content_analysis".into(),
schema: Some(get_ingress_analysis_schema()),
strict: Some(true),
},
};
let request = CreateChatCompletionRequestArgs::default()
.model(&settings.processing_model)
.messages([
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()?;
Ok(request)
}
async fn perform_analysis(
&self,
request: CreateChatCompletionRequest,
) -> Result<LLMEnrichmentResult, AppError> {
let response = self.openai_client.chat().create(request).await?;
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing(
"No content found in LLM response".into(),
))?;
serde_json::from_str::<LLMEnrichmentResult>(content).map_err(|e| {
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
})
}
}

View File

@@ -1,6 +1,4 @@
pub mod enricher;
pub mod pipeline;
pub mod types;
pub mod utils;
use chrono::Utc;

View File

@@ -1,299 +0,0 @@
use std::{sync::Arc, time::Instant};
use text_splitter::TextSplitter;
use tokio::time::{sleep, Duration};
use tracing::{debug, info, info_span, warn};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
ingestion_task::{IngestionTask, TaskErrorInfo},
knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
text_content::TextContent,
},
},
utils::{config::AppConfig, embedding::generate_embedding},
};
use crate::{
enricher::IngestionEnricher,
types::{llm_enrichment_result::LLMEnrichmentResult, to_text_content},
};
pub struct IngestionPipeline {
db: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
config: AppConfig,
}
impl IngestionPipeline {
pub async fn new(
db: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
config: AppConfig,
) -> Result<Self, AppError> {
Ok(Self {
db,
openai_client,
config,
})
}
pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> {
let task_id = task.id.clone();
let attempt = task.attempts;
let worker_label = task
.worker_id
.clone()
.unwrap_or_else(|| "unknown-worker".to_string());
let span = info_span!(
"ingestion_task",
%task_id,
attempt,
worker_id = %worker_label,
state = %task.state.as_str()
);
let _enter = span.enter();
let processing_task = task.mark_processing(&self.db).await?;
let text_content = to_text_content(
processing_task.content.clone(),
&self.db,
&self.config,
&self.openai_client,
)
.await?;
let text_len = text_content.text.chars().count();
let preview: String = text_content.text.chars().take(120).collect();
let preview_clean = preview.replace("\n", " ");
let preview_len = preview_clean.chars().count();
let truncated = text_len > preview_len;
let context_len = text_content
.context
.as_ref()
.map(|c| c.chars().count())
.unwrap_or(0);
info!(
%task_id,
attempt,
user_id = %text_content.user_id,
category = %text_content.category,
text_chars = text_len,
context_chars = context_len,
attachments = text_content.file_info.is_some(),
"ingestion task input ready"
);
debug!(
%task_id,
attempt,
preview = %preview_clean,
preview_truncated = truncated,
"ingestion task input preview"
);
match self.process(&text_content).await {
Ok(()) => {
processing_task.mark_succeeded(&self.db).await?;
info!(%task_id, attempt, "ingestion task succeeded");
Ok(())
}
Err(err) => {
let reason = err.to_string();
let error_info = TaskErrorInfo {
code: None,
message: reason.clone(),
};
if processing_task.can_retry() {
let delay = Self::retry_delay(processing_task.attempts);
processing_task
.mark_failed(error_info, delay, &self.db)
.await?;
warn!(
%task_id,
attempt = processing_task.attempts,
retry_in_secs = delay.as_secs(),
"ingestion task failed; scheduled retry"
);
} else {
processing_task
.mark_dead_letter(error_info, &self.db)
.await?;
warn!(
%task_id,
attempt = processing_task.attempts,
"ingestion task failed; moved to dead letter queue"
);
}
Err(AppError::Processing(reason))
}
}
}
fn retry_delay(attempt: u32) -> Duration {
const BASE_SECONDS: u64 = 30;
const MAX_SECONDS: u64 = 15 * 60;
let capped_attempt = attempt.saturating_sub(1).min(5);
let multiplier = 2_u64.pow(capped_attempt);
let delay = BASE_SECONDS * multiplier;
Duration::from_secs(delay.min(MAX_SECONDS))
}
pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
let now = Instant::now();
// Perform analyis, this step also includes retrieval
let analysis = self.perform_semantic_analysis(content).await?;
let end = now.elapsed();
info!(
"{:?} time elapsed during creation of entities and relationships",
end
);
// Convert analysis to application objects
let (entities, relationships) = analysis
.to_database_entities(&content.id, &content.user_id, &self.openai_client, &self.db)
.await?;
// Store everything
tokio::try_join!(
self.store_graph_entities(entities, relationships),
self.store_vector_chunks(content),
)?;
// Store original content
self.db.store_item(content.to_owned()).await?;
self.db.rebuild_indexes().await?;
Ok(())
}
async fn perform_semantic_analysis(
&self,
content: &TextContent,
) -> Result<LLMEnrichmentResult, AppError> {
let analyser = IngestionEnricher::new(self.db.clone(), self.openai_client.clone());
analyser
.analyze_content(
&content.category,
content.context.as_deref(),
&content.text,
&content.user_id,
)
.await
}
async fn store_graph_entities(
&self,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), AppError> {
let entities = Arc::new(entities);
let relationships = Arc::new(relationships);
let entity_count = entities.len();
let relationship_count = relationships.len();
const STORE_GRAPH_MUTATION: &str = r"
BEGIN TRANSACTION;
LET $entities = $entities;
LET $relationships = $relationships;
FOR $entity IN $entities {
CREATE type::thing('knowledge_entity', $entity.id) CONTENT $entity;
};
FOR $relationship IN $relationships {
LET $in_node = type::thing('knowledge_entity', $relationship.in);
LET $out_node = type::thing('knowledge_entity', $relationship.out);
RELATE $in_node->relates_to->$out_node CONTENT {
id: type::thing('relates_to', $relationship.id),
metadata: $relationship.metadata
};
};
COMMIT TRANSACTION;
";
const MAX_ATTEMPTS: usize = 3;
const INITIAL_BACKOFF_MS: u64 = 50;
const MAX_BACKOFF_MS: u64 = 800;
let mut backoff_ms = INITIAL_BACKOFF_MS;
let mut success = false;
for attempt in 0..MAX_ATTEMPTS {
let result = self
.db
.client
.query(STORE_GRAPH_MUTATION)
.bind(("entities", entities.clone()))
.bind(("relationships", relationships.clone()))
.await;
match result {
Ok(_) => {
success = true;
break;
}
Err(err) => {
if Self::is_retryable_conflict(&err) && attempt + 1 < MAX_ATTEMPTS {
warn!(
attempt = attempt + 1,
"Transient SurrealDB conflict while storing graph data; retrying"
);
sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
continue;
}
return Err(AppError::from(err));
}
}
}
if !success {
return Err(AppError::InternalError(
"Failed to store graph entities after retries".to_string(),
));
}
info!(
"Stored {} entities and {} relationships",
entity_count, relationship_count
);
Ok(())
}
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), AppError> {
let splitter = TextSplitter::new(500..2000);
let chunks = splitter.chunks(&content.text);
// Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk, &self.db).await?;
let text_chunk = TextChunk::new(
content.id.to_string(),
chunk.to_string(),
embedding,
content.user_id.to_string(),
);
self.db.store_item(text_chunk).await?;
}
Ok(())
}
fn is_retryable_conflict(error: &surrealdb::Error) -> bool {
error
.to_string()
.contains("Failed to commit transaction due to a read or write conflict")
}
}

View File

@@ -0,0 +1,35 @@
#[derive(Debug, Clone)]
pub struct IngestionTuning {
pub retry_base_delay_secs: u64,
pub retry_max_delay_secs: u64,
pub retry_backoff_cap_exponent: u32,
pub graph_store_attempts: usize,
pub graph_initial_backoff_ms: u64,
pub graph_max_backoff_ms: u64,
pub chunk_min_chars: usize,
pub chunk_max_chars: usize,
pub chunk_insert_concurrency: usize,
pub entity_embedding_concurrency: usize,
}
impl Default for IngestionTuning {
fn default() -> Self {
Self {
retry_base_delay_secs: 30,
retry_max_delay_secs: 15 * 60,
retry_backoff_cap_exponent: 5,
graph_store_attempts: 3,
graph_initial_backoff_ms: 50,
graph_max_backoff_ms: 800,
chunk_min_chars: 500,
chunk_max_chars: 2_000,
chunk_insert_concurrency: 8,
entity_embedding_concurrency: 4,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct IngestionConfig {
pub tuning: IngestionTuning,
}

View File

@@ -0,0 +1,76 @@
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{ingestion_task::IngestionTask, text_content::TextContent},
},
};
use composite_retrieval::RetrievedEntity;
use tracing::error;
use super::enrichment_result::LLMEnrichmentResult;
use super::{config::IngestionConfig, services::PipelineServices};
pub struct PipelineContext<'a> {
pub task: &'a IngestionTask,
pub task_id: String,
pub attempt: u32,
pub db: &'a SurrealDbClient,
pub pipeline_config: &'a IngestionConfig,
pub services: &'a dyn PipelineServices,
pub text_content: Option<TextContent>,
pub similar_entities: Vec<RetrievedEntity>,
pub analysis: Option<LLMEnrichmentResult>,
}
impl<'a> PipelineContext<'a> {
pub fn new(
task: &'a IngestionTask,
db: &'a SurrealDbClient,
pipeline_config: &'a IngestionConfig,
services: &'a dyn PipelineServices,
) -> Self {
let task_id = task.id.clone();
let attempt = task.attempts;
Self {
task,
task_id,
attempt,
db,
pipeline_config,
services,
text_content: None,
similar_entities: Vec::new(),
analysis: None,
}
}
pub fn text_content(&self) -> Result<&TextContent, AppError> {
self.text_content
.as_ref()
.ok_or_else(|| AppError::InternalError("text content expected to be available".into()))
}
pub fn take_text_content(&mut self) -> Result<TextContent, AppError> {
self.text_content.take().ok_or_else(|| {
AppError::InternalError("text content expected to be available for persistence".into())
})
}
pub fn take_analysis(&mut self) -> Result<LLMEnrichmentResult, AppError> {
self.analysis.take().ok_or_else(|| {
AppError::InternalError("analysis expected to be available for persistence".into())
})
}
pub fn abort(&mut self, err: AppError) -> AppError {
error!(
task_id = %self.task_id,
attempt = self.attempt,
error = %err,
"ingestion pipeline aborted"
);
err
}
}

View File

@@ -1,8 +1,8 @@
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use chrono::Utc;
use futures::stream::{self, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use tokio::task;
use common::{
error::AppError,
@@ -15,28 +15,25 @@ use common::{
},
utils::embedding::generate_embedding,
};
use futures::future::try_join_all;
use crate::utils::GraphMapper;
use crate::utils::graph_mapper::GraphMapper;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMKnowledgeEntity {
pub key: String, // Temporary identifier
pub key: String,
pub name: String,
pub description: String,
pub entity_type: String, // Should match KnowledgeEntityType variants
pub entity_type: String,
}
/// Represents a single relationship from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMRelationship {
#[serde(rename = "type")]
pub type_: String, // e.g., RelatedTo, RelevantTo
pub source: String, // Key of the source entity
pub target: String, // Key of the target entity
pub type_: String,
pub source: String,
pub target: String,
}
/// Represents the entire graph analysis result from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMEnrichmentResult {
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
@@ -44,27 +41,16 @@ pub struct LLMEnrichmentResult {
}
impl LLMEnrichmentResult {
/// Converts the LLM graph analysis result into database entities and relationships.
///
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
/// * `openai_client` - `OpenAI` client for LLM calls.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities(
&self,
source_id: &str,
user_id: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
db_client: &SurrealDbClient,
entity_concurrency: usize,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
// Create mapper and pre-assign IDs
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
let mapper = Arc::new(self.create_mapper()?);
// Process entities
let entities = self
.process_entities(
source_id,
@@ -72,10 +58,10 @@ impl LLMEnrichmentResult {
Arc::clone(&mapper),
openai_client,
db_client,
entity_concurrency,
)
.await?;
// Process relationships
let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?;
Ok((entities, relationships))
@@ -84,7 +70,6 @@ impl LLMEnrichmentResult {
fn create_mapper(&self) -> Result<GraphMapper, AppError> {
let mut mapper = GraphMapper::new();
// Pre-assign all IDs
for entity in &self.knowledge_entities {
mapper.assign_id(&entity.key);
}
@@ -96,57 +81,46 @@ impl LLMEnrichmentResult {
&self,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
mapper: Arc<GraphMapper>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
db_client: &SurrealDbClient,
entity_concurrency: usize,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let futures: Vec<_> = self
.knowledge_entities
.iter()
.map(|entity| {
let mapper = Arc::clone(&mapper);
let openai_client = openai_client.clone();
let source_id = source_id.to_string();
let user_id = user_id.to_string();
let entity = entity.clone();
let db_client = db_client.clone();
stream::iter(self.knowledge_entities.iter().cloned().map(|entity| {
let mapper = Arc::clone(&mapper);
let openai_client = openai_client.clone();
let source_id = source_id.to_string();
let user_id = user_id.to_string();
let db_client = db_client.clone();
task::spawn(async move {
create_single_entity(
&entity,
&source_id,
&user_id,
mapper,
&openai_client,
&db_client.clone(),
)
.await
})
})
.collect();
let results = try_join_all(futures)
.await?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(results)
async move {
create_single_entity(
&entity,
&source_id,
&user_id,
mapper,
&openai_client,
&db_client,
)
.await
}
}))
.buffer_unordered(entity_concurrency.max(1))
.try_collect()
.await
}
fn process_relationships(
&self,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
mapper: Arc<GraphMapper>,
) -> Result<Vec<KnowledgeRelationship>, AppError> {
let mapper_guard = mapper
.lock()
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
self.relationships
.iter()
.map(|rel| {
let source_db_id = mapper_guard.get_or_parse_id(&rel.source)?;
let target_db_id = mapper_guard.get_or_parse_id(&rel.target)?;
let source_db_id = mapper.get_or_parse_id(&rel.source)?;
let target_db_id = mapper.get_or_parse_id(&rel.target)?;
Ok(KnowledgeRelationship::new(
source_db_id.to_string(),
@@ -159,20 +133,16 @@ impl LLMEnrichmentResult {
.collect()
}
}
async fn create_single_entity(
llm_entity: &LLMKnowledgeEntity,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
mapper: Arc<GraphMapper>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
db_client: &SurrealDbClient,
) -> Result<KnowledgeEntity, AppError> {
let assigned_id = {
let mapper = mapper
.lock()
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
mapper.get_id(&llm_entity.key)?.to_string()
};
let assigned_id = mapper.get_id(&llm_entity.key)?.to_string();
let embedding_input = format!(
"name: {}, description: {}, type: {}",

View File

@@ -0,0 +1,221 @@
mod config;
mod context;
mod enrichment_result;
mod preparation;
mod services;
mod stages;
mod state;
pub use config::{IngestionConfig, IngestionTuning};
pub use services::{DefaultPipelineServices, PipelineServices};
use std::{
sync::Arc,
time::{Duration, Instant},
};
use async_openai::Client;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
ingestion_payload::IngestionPayload,
ingestion_task::{IngestionTask, TaskErrorInfo},
},
},
utils::config::AppConfig,
};
use tracing::{debug, info, warn};
use self::{
context::PipelineContext,
stages::{enrich, persist, prepare_content, retrieve_related},
state::ready,
};
pub struct IngestionPipeline {
db: Arc<SurrealDbClient>,
pipeline_config: IngestionConfig,
services: Arc<dyn PipelineServices>,
}
impl IngestionPipeline {
pub async fn new(
db: Arc<SurrealDbClient>,
openai_client: Arc<Client<async_openai::config::OpenAIConfig>>,
config: AppConfig,
) -> Result<Self, AppError> {
let services =
DefaultPipelineServices::new(db.clone(), openai_client.clone(), config.clone());
Self::with_services(db, IngestionConfig::default(), Arc::new(services))
}
pub fn with_services(
db: Arc<SurrealDbClient>,
pipeline_config: IngestionConfig,
services: Arc<dyn PipelineServices>,
) -> Result<Self, AppError> {
Ok(Self {
db,
pipeline_config,
services,
})
}
#[tracing::instrument(
skip_all,
fields(
task_id = %task.id,
attempt = task.attempts,
worker_id = task.worker_id.as_deref().unwrap_or("unknown-worker"),
user_id = %task.user_id
)
)]
pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> {
let mut processing_task = task.mark_processing(&self.db).await?;
let payload = std::mem::replace(
&mut processing_task.content,
IngestionPayload::Text {
text: String::new(),
context: String::new(),
category: String::new(),
user_id: processing_task.user_id.clone(),
},
);
match self
.drive_pipeline(&processing_task, payload)
.await
.map_err(|err| {
debug!(
task_id = %processing_task.id,
attempt = processing_task.attempts,
error = %err,
"ingestion pipeline failed"
);
err
}) {
Ok(()) => {
processing_task.mark_succeeded(&self.db).await?;
tracing::info!(
task_id = %processing_task.id,
attempt = processing_task.attempts,
"ingestion task succeeded"
);
Ok(())
}
Err(err) => {
let reason = err.to_string();
let retryable = !matches!(err, AppError::Validation(_));
let error_info = TaskErrorInfo {
code: None,
message: reason.clone(),
};
if retryable && processing_task.can_retry() {
let delay = self.retry_delay(processing_task.attempts);
processing_task
.mark_failed(error_info, delay, &self.db)
.await?;
warn!(
task_id = %processing_task.id,
attempt = processing_task.attempts,
retry_in_secs = delay.as_secs(),
"ingestion task failed; scheduled retry"
);
} else {
let failed_task = processing_task
.mark_failed(error_info.clone(), Duration::from_secs(0), &self.db)
.await?;
failed_task.mark_dead_letter(error_info, &self.db).await?;
warn!(
task_id = %failed_task.id,
attempt = failed_task.attempts,
"ingestion task failed; moved to dead letter queue"
);
}
Err(AppError::Processing(reason))
}
}
}
fn retry_delay(&self, attempt: u32) -> Duration {
let tuning = &self.pipeline_config.tuning;
let capped_attempt = attempt
.saturating_sub(1)
.min(tuning.retry_backoff_cap_exponent);
let multiplier = 2_u64.pow(capped_attempt);
let delay = tuning.retry_base_delay_secs * multiplier;
Duration::from_secs(delay.min(tuning.retry_max_delay_secs))
}
#[tracing::instrument(
skip_all,
fields(task_id = %task.id, attempt = task.attempts, user_id = %task.user_id)
)]
async fn drive_pipeline(
&self,
task: &IngestionTask,
payload: IngestionPayload,
) -> Result<(), AppError> {
let mut ctx = PipelineContext::new(
task,
self.db.as_ref(),
&self.pipeline_config,
self.services.as_ref(),
);
let machine = ready();
let pipeline_started = Instant::now();
let stage_start = Instant::now();
let machine = prepare_content(machine, &mut ctx, payload)
.await
.map_err(|err| ctx.abort(err))?;
let prepare_duration = stage_start.elapsed();
let stage_start = Instant::now();
let machine = retrieve_related(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
let retrieve_duration = stage_start.elapsed();
let stage_start = Instant::now();
let machine = enrich(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
let enrich_duration = stage_start.elapsed();
let stage_start = Instant::now();
let _machine = persist(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
let persist_duration = stage_start.elapsed();
let total_duration = pipeline_started.elapsed();
let prepare_ms = prepare_duration.as_millis() as u64;
let retrieve_ms = retrieve_duration.as_millis() as u64;
let enrich_ms = enrich_duration.as_millis() as u64;
let persist_ms = persist_duration.as_millis() as u64;
info!(
task_id = %ctx.task_id,
attempt = ctx.attempt,
total_ms = total_duration.as_millis() as u64,
prepare_ms,
retrieve_ms,
enrich_ms,
persist_ms,
"ingestion pipeline finished"
);
Ok(())
}
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,74 @@
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
ingestion_payload::IngestionPayload,
text_content::{TextContent, UrlInfo},
},
},
utils::config::AppConfig,
};
use crate::utils::{
file_text_extraction::extract_text_from_file, url_text_retrieval::extract_text_from_url,
};
pub(crate) async fn to_text_content(
ingestion_payload: IngestionPayload,
db: &SurrealDbClient,
config: &AppConfig,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<TextContent, AppError> {
match ingestion_payload {
IngestionPayload::Url {
url,
context,
category,
user_id,
} => {
let (article, file_info) = extract_text_from_url(&url, db, &user_id, config).await?;
Ok(TextContent::new(
article.text_content.into(),
Some(context),
category,
None,
Some(UrlInfo {
url,
title: article.title,
image_id: file_info.id,
}),
user_id,
))
}
IngestionPayload::Text {
text,
context,
category,
user_id,
} => Ok(TextContent::new(
text,
Some(context),
category,
None,
None,
user_id,
)),
IngestionPayload::File {
file_info,
context,
category,
user_id,
} => {
let text = extract_text_from_file(&file_info, db, openai_client, config).await?;
Ok(TextContent::new(
text,
Some(context),
category,
Some(file_info),
None,
user_id,
))
}
}
}

View File

@@ -0,0 +1,213 @@
use std::{ops::Range, sync::Arc};
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
};
use async_trait::async_trait;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
text_chunk::TextChunk, text_content::TextContent,
},
},
utils::{config::AppConfig, embedding::generate_embedding},
};
use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity};
use text_splitter::TextSplitter;
use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content};
use crate::utils::llm_instructions::{
get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE,
};
#[async_trait]
pub trait PipelineServices: Send + Sync {
async fn prepare_text_content(
&self,
payload: IngestionPayload,
) -> Result<TextContent, AppError>;
async fn retrieve_similar_entities(
&self,
content: &TextContent,
) -> Result<Vec<RetrievedEntity>, AppError>;
async fn run_enrichment(
&self,
content: &TextContent,
similar_entities: &[RetrievedEntity],
) -> Result<LLMEnrichmentResult, AppError>;
async fn convert_analysis(
&self,
content: &TextContent,
analysis: &LLMEnrichmentResult,
entity_concurrency: usize,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>;
async fn prepare_chunks(
&self,
content: &TextContent,
range: Range<usize>,
) -> Result<Vec<TextChunk>, AppError>;
}
pub struct DefaultPipelineServices {
db: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
config: AppConfig,
}
impl DefaultPipelineServices {
pub fn new(
db: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
config: AppConfig,
) -> Self {
Self {
db,
openai_client,
config,
}
}
async fn prepare_llm_request(
&self,
category: &str,
context: Option<&str>,
text: &str,
similar_entities: &[RetrievedEntity],
) -> Result<CreateChatCompletionRequest, AppError> {
let settings = SystemSettings::get_current(&self.db).await?;
let entities_json = retrieved_entities_to_json(similar_entities);
let user_message = format!(
"Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}"
);
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Structured analysis of the submitted content".into()),
name: "content_analysis".into(),
schema: Some(get_ingress_analysis_schema()),
strict: Some(true),
},
};
let request = CreateChatCompletionRequestArgs::default()
.model(&settings.processing_model)
.messages([
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()?;
Ok(request)
}
async fn perform_analysis(
&self,
request: CreateChatCompletionRequest,
) -> Result<LLMEnrichmentResult, AppError> {
let response = self.openai_client.chat().create(request).await?;
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing(
"No content found in LLM response".into(),
))?;
serde_json::from_str::<LLMEnrichmentResult>(content).map_err(|e| {
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
})
}
}
#[async_trait]
impl PipelineServices for DefaultPipelineServices {
async fn prepare_text_content(
&self,
payload: IngestionPayload,
) -> Result<TextContent, AppError> {
to_text_content(payload, &self.db, &self.config, &self.openai_client).await
}
async fn retrieve_similar_entities(
&self,
content: &TextContent,
) -> Result<Vec<RetrievedEntity>, AppError> {
let input_text = format!(
"content: {}, category: {}, user_context: {:?}",
content.text, content.category, content.context
);
retrieve_entities(&self.db, &self.openai_client, &input_text, &content.user_id).await
}
async fn run_enrichment(
&self,
content: &TextContent,
similar_entities: &[RetrievedEntity],
) -> Result<LLMEnrichmentResult, AppError> {
let request = self
.prepare_llm_request(
&content.category,
content.context.as_deref(),
&content.text,
similar_entities,
)
.await?;
self.perform_analysis(request).await
}
async fn convert_analysis(
&self,
content: &TextContent,
analysis: &LLMEnrichmentResult,
entity_concurrency: usize,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
analysis
.to_database_entities(
&content.id,
&content.user_id,
&self.openai_client,
&self.db,
entity_concurrency,
)
.await
}
async fn prepare_chunks(
&self,
content: &TextContent,
range: Range<usize>,
) -> Result<Vec<TextChunk>, AppError> {
let splitter = TextSplitter::new(range.clone());
let chunk_texts: Vec<String> = splitter
.chunks(&content.text)
.map(|chunk| chunk.to_string())
.collect();
let mut chunks = Vec::with_capacity(chunk_texts.len());
for chunk in chunk_texts {
let embedding = generate_embedding(&self.openai_client, &chunk, &self.db).await?;
chunks.push(TextChunk::new(
content.id.clone(),
chunk,
embedding,
content.user_id.clone(),
));
}
Ok(chunks)
}
}

View File

@@ -0,0 +1,338 @@
use std::sync::Arc;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
text_content::TextContent,
},
},
};
use state_machines::core::GuardError;
use tokio::time::{sleep, Duration};
use tracing::{debug, instrument, warn};
use super::{
context::PipelineContext,
services::PipelineServices,
state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved},
};
#[instrument(
level = "trace",
skip_all,
fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id)
)]
pub async fn prepare_content(
machine: IngestionMachine<(), Ready>,
ctx: &mut PipelineContext<'_>,
payload: IngestionPayload,
) -> Result<IngestionMachine<(), ContentPrepared>, AppError> {
let text_content = ctx.services.prepare_text_content(payload).await?;
let text_len = text_content.text.chars().count();
let preview: String = text_content.text.chars().take(120).collect();
let preview_clean = preview.replace('\n', " ");
let preview_len = preview_clean.chars().count();
let truncated = text_len > preview_len;
let context_len = text_content
.context
.as_ref()
.map(|c| c.chars().count())
.unwrap_or(0);
tracing::info!(
task_id = %ctx.task_id,
attempt = ctx.attempt,
user_id = %text_content.user_id,
category = %text_content.category,
text_chars = text_len,
context_chars = context_len,
attachments = text_content.file_info.is_some(),
"ingestion task input ready"
);
debug!(
task_id = %ctx.task_id,
attempt = ctx.attempt,
preview = %preview_clean,
preview_truncated = truncated,
"ingestion task input preview"
);
ctx.text_content = Some(text_content);
machine
.prepare()
.map_err(|(_, guard)| map_guard_error("prepare", guard))
}
#[instrument(
level = "trace",
skip_all,
fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id)
)]
pub async fn retrieve_related(
machine: IngestionMachine<(), ContentPrepared>,
ctx: &mut PipelineContext<'_>,
) -> Result<IngestionMachine<(), Retrieved>, AppError> {
let content = ctx.text_content()?;
let similar = ctx.services.retrieve_similar_entities(content).await?;
debug!(
task_id = %ctx.task_id,
attempt = ctx.attempt,
similar_count = similar.len(),
"ingestion retrieved similar entities"
);
ctx.similar_entities = similar;
machine
.retrieve()
.map_err(|(_, guard)| map_guard_error("retrieve", guard))
}
#[instrument(
level = "trace",
skip_all,
fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id)
)]
pub async fn enrich(
machine: IngestionMachine<(), Retrieved>,
ctx: &mut PipelineContext<'_>,
) -> Result<IngestionMachine<(), Enriched>, AppError> {
let content = ctx.text_content()?;
let analysis = ctx
.services
.run_enrichment(content, &ctx.similar_entities)
.await?;
debug!(
task_id = %ctx.task_id,
attempt = ctx.attempt,
entity_suggestions = analysis.knowledge_entities.len(),
relationship_suggestions = analysis.relationships.len(),
"ingestion enrichment completed"
);
ctx.analysis = Some(analysis);
machine
.enrich()
.map_err(|(_, guard)| map_guard_error("enrich", guard))
}
#[instrument(
level = "trace",
skip_all,
fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id)
)]
pub async fn persist(
machine: IngestionMachine<(), Enriched>,
ctx: &mut PipelineContext<'_>,
) -> Result<IngestionMachine<(), Persisted>, AppError> {
let content = ctx.take_text_content()?;
let analysis = ctx.take_analysis()?;
let (entities, relationships) = ctx
.services
.convert_analysis(
&content,
&analysis,
ctx.pipeline_config.tuning.entity_embedding_concurrency,
)
.await?;
let entity_count = entities.len();
let relationship_count = relationships.len();
let chunk_range =
ctx.pipeline_config.tuning.chunk_min_chars..ctx.pipeline_config.tuning.chunk_max_chars;
let ((), chunk_count) = tokio::try_join!(
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships),
store_vector_chunks(
ctx.db,
ctx.services,
ctx.task_id.as_str(),
&content,
chunk_range,
&ctx.pipeline_config.tuning
)
)?;
ctx.db.store_item(content).await?;
ctx.db.rebuild_indexes().await?;
debug!(
task_id = %ctx.task_id,
attempt = ctx.attempt,
entity_count,
relationship_count,
chunk_count,
"ingestion persistence flushed to database"
);
machine
.persist()
.map_err(|(_, guard)| map_guard_error("persist", guard))
}
fn map_guard_error(event: &str, guard: GuardError) -> AppError {
AppError::InternalError(format!(
"invalid ingestion pipeline transition during {event}: {guard:?}"
))
}
async fn store_graph_entities(
db: &SurrealDbClient,
tuning: &super::config::IngestionTuning,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), AppError> {
const STORE_GRAPH_MUTATION: &str = r"
BEGIN TRANSACTION;
LET $entities = $entities;
LET $relationships = $relationships;
FOR $entity IN $entities {
CREATE type::thing('knowledge_entity', $entity.id) CONTENT $entity;
};
FOR $relationship IN $relationships {
LET $in_node = type::thing('knowledge_entity', $relationship.in);
LET $out_node = type::thing('knowledge_entity', $relationship.out);
RELATE $in_node->relates_to->$out_node CONTENT {
id: type::thing('relates_to', $relationship.id),
metadata: $relationship.metadata
};
};
COMMIT TRANSACTION;
";
let entities = Arc::new(entities);
let relationships = Arc::new(relationships);
let mut backoff_ms = tuning.graph_initial_backoff_ms;
for attempt in 0..tuning.graph_store_attempts {
let result = db
.client
.query(STORE_GRAPH_MUTATION)
.bind(("entities", entities.clone()))
.bind(("relationships", relationships.clone()))
.await;
match result {
Ok(_) => return Ok(()),
Err(err) => {
if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts {
warn!(
attempt = attempt + 1,
"Transient SurrealDB conflict while storing graph data; retrying"
);
sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms);
continue;
}
return Err(AppError::from(err));
}
}
}
Err(AppError::InternalError(
"Failed to store graph entities after retries".to_string(),
))
}
async fn store_vector_chunks(
db: &SurrealDbClient,
services: &dyn PipelineServices,
task_id: &str,
content: &TextContent,
chunk_range: std::ops::Range<usize>,
tuning: &super::config::IngestionTuning,
) -> Result<usize, AppError> {
let prepared_chunks = services.prepare_chunks(content, chunk_range).await?;
let chunk_count = prepared_chunks.len();
let batch_size = tuning.chunk_insert_concurrency.max(1);
for chunk in &prepared_chunks {
debug!(
task_id = %task_id,
chunk_id = %chunk.id,
chunk_len = chunk.chunk.chars().count(),
"chunk persisted"
);
}
for batch in prepared_chunks.chunks(batch_size) {
store_chunk_batch(db, batch, tuning).await?;
}
Ok(chunk_count)
}
fn is_retryable_conflict(error: &surrealdb::Error) -> bool {
error
.to_string()
.contains("Failed to commit transaction due to a read or write conflict")
}
async fn store_chunk_batch(
db: &SurrealDbClient,
batch: &[TextChunk],
tuning: &super::config::IngestionTuning,
) -> Result<(), AppError> {
if batch.is_empty() {
return Ok(());
}
const STORE_CHUNKS_MUTATION: &str = r"
BEGIN TRANSACTION;
LET $chunks = $chunks;
FOR $chunk IN $chunks {
CREATE type::thing('text_chunk', $chunk.id) CONTENT $chunk;
};
COMMIT TRANSACTION;
";
let chunks = Arc::new(batch.to_vec());
let mut backoff_ms = tuning.graph_initial_backoff_ms;
for attempt in 0..tuning.graph_store_attempts {
let result = db
.client
.query(STORE_CHUNKS_MUTATION)
.bind(("chunks", chunks.clone()))
.await;
match result {
Ok(_) => return Ok(()),
Err(err) => {
if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts {
warn!(
attempt = attempt + 1,
"Transient SurrealDB conflict while storing chunks; retrying"
);
sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms);
continue;
}
return Err(AppError::from(err));
}
}
}
Err(AppError::InternalError(
"Failed to store text chunks after retries".to_string(),
))
}

View File

@@ -0,0 +1,25 @@
use state_machines::state_machine;
state_machine! {
name: IngestionMachine,
state: IngestionState,
initial: Ready,
states: [Ready, ContentPrepared, Retrieved, Enriched, Persisted, Failed],
events {
prepare { transition: { from: Ready, to: ContentPrepared } }
retrieve { transition: { from: ContentPrepared, to: Retrieved } }
enrich { transition: { from: Retrieved, to: Enriched } }
persist { transition: { from: Enriched, to: Persisted } }
abort {
transition: { from: Ready, to: Failed }
transition: { from: ContentPrepared, to: Failed }
transition: { from: Retrieved, to: Failed }
transition: { from: Enriched, to: Failed }
transition: { from: Persisted, to: Failed }
}
}
}
pub fn ready() -> IngestionMachine<(), Ready> {
IngestionMachine::new(())
}

View File

@@ -0,0 +1,440 @@
use std::sync::Arc;
use async_trait::async_trait;
use chrono::{Duration as ChronoDuration, Utc};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
ingestion_payload::IngestionPayload,
ingestion_task::{IngestionTask, TaskState},
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
text_content::TextContent,
},
},
};
use composite_retrieval::{RetrievedChunk, RetrievedEntity};
use tokio::sync::Mutex;
use uuid::Uuid;
use super::{
config::{IngestionConfig, IngestionTuning},
enrichment_result::LLMEnrichmentResult,
services::PipelineServices,
IngestionPipeline,
};
struct MockServices {
text_content: TextContent,
similar_entities: Vec<RetrievedEntity>,
analysis: LLMEnrichmentResult,
chunk_embedding: Vec<f32>,
graph_entities: Vec<KnowledgeEntity>,
graph_relationships: Vec<KnowledgeRelationship>,
calls: Mutex<Vec<&'static str>>,
}
impl MockServices {
fn new(user_id: &str) -> Self {
const TEST_EMBEDDING_DIM: usize = 1536;
let text_content = TextContent::new(
"Example document for ingestion pipeline.".into(),
Some("light context".into()),
"notes".into(),
None,
None,
user_id.into(),
);
let retrieved_entity = KnowledgeEntity::new(
text_content.id.clone(),
"Existing Entity".into(),
"Previously known context".into(),
KnowledgeEntityType::Document,
None,
vec![0.1; TEST_EMBEDDING_DIM],
user_id.into(),
);
let retrieved_chunk = TextChunk::new(
retrieved_entity.source_id.clone(),
"existing chunk".into(),
vec![0.1; TEST_EMBEDDING_DIM],
user_id.into(),
);
let analysis = LLMEnrichmentResult {
knowledge_entities: Vec::new(),
relationships: Vec::new(),
};
let graph_entity = KnowledgeEntity::new(
text_content.id.clone(),
"Generated Entity".into(),
"Entity from enrichment".into(),
KnowledgeEntityType::Idea,
None,
vec![0.2; TEST_EMBEDDING_DIM],
user_id.into(),
);
let graph_relationship = KnowledgeRelationship::new(
graph_entity.id.clone(),
graph_entity.id.clone(),
user_id.into(),
text_content.id.clone(),
"related_to".into(),
);
Self {
text_content,
similar_entities: vec![RetrievedEntity {
entity: retrieved_entity,
score: 0.8,
chunks: vec![RetrievedChunk {
chunk: retrieved_chunk,
score: 0.7,
}],
}],
analysis,
chunk_embedding: vec![0.3; TEST_EMBEDDING_DIM],
graph_entities: vec![graph_entity],
graph_relationships: vec![graph_relationship],
calls: Mutex::new(Vec::new()),
}
}
async fn record(&self, stage: &'static str) {
self.calls.lock().await.push(stage);
}
}
#[async_trait]
impl PipelineServices for MockServices {
async fn prepare_text_content(
&self,
_payload: IngestionPayload,
) -> Result<TextContent, AppError> {
self.record("prepare").await;
Ok(self.text_content.clone())
}
async fn retrieve_similar_entities(
&self,
_content: &TextContent,
) -> Result<Vec<RetrievedEntity>, AppError> {
self.record("retrieve").await;
Ok(self.similar_entities.clone())
}
async fn run_enrichment(
&self,
_content: &TextContent,
_similar_entities: &[RetrievedEntity],
) -> Result<LLMEnrichmentResult, AppError> {
self.record("enrich").await;
Ok(self.analysis.clone())
}
async fn convert_analysis(
&self,
_content: &TextContent,
_analysis: &LLMEnrichmentResult,
_entity_concurrency: usize,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
self.record("convert").await;
Ok((
self.graph_entities.clone(),
self.graph_relationships.clone(),
))
}
async fn prepare_chunks(
&self,
content: &TextContent,
_range: std::ops::Range<usize>,
) -> Result<Vec<TextChunk>, AppError> {
self.record("chunk").await;
Ok(vec![TextChunk::new(
content.id.clone(),
"chunk from mock services".into(),
self.chunk_embedding.clone(),
content.user_id.clone(),
)])
}
}
struct FailingServices {
inner: MockServices,
}
struct ValidationServices;
#[async_trait]
impl PipelineServices for FailingServices {
async fn prepare_text_content(
&self,
payload: IngestionPayload,
) -> Result<TextContent, AppError> {
self.inner.prepare_text_content(payload).await
}
async fn retrieve_similar_entities(
&self,
content: &TextContent,
) -> Result<Vec<RetrievedEntity>, AppError> {
self.inner.retrieve_similar_entities(content).await
}
async fn run_enrichment(
&self,
_content: &TextContent,
_similar_entities: &[RetrievedEntity],
) -> Result<LLMEnrichmentResult, AppError> {
Err(AppError::Processing("mock enrichment failure".to_string()))
}
async fn convert_analysis(
&self,
content: &TextContent,
analysis: &LLMEnrichmentResult,
entity_concurrency: usize,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
self.inner
.convert_analysis(content, analysis, entity_concurrency)
.await
}
async fn prepare_chunks(
&self,
content: &TextContent,
range: std::ops::Range<usize>,
) -> Result<Vec<TextChunk>, AppError> {
self.inner.prepare_chunks(content, range).await
}
}
#[async_trait]
impl PipelineServices for ValidationServices {
async fn prepare_text_content(
&self,
_payload: IngestionPayload,
) -> Result<TextContent, AppError> {
Err(AppError::Validation("unsupported".to_string()))
}
async fn retrieve_similar_entities(
&self,
_content: &TextContent,
) -> Result<Vec<RetrievedEntity>, AppError> {
unreachable!("retrieve_similar_entities should not be called after validation failure")
}
async fn run_enrichment(
&self,
_content: &TextContent,
_similar_entities: &[RetrievedEntity],
) -> Result<LLMEnrichmentResult, AppError> {
unreachable!("run_enrichment should not be called after validation failure")
}
async fn convert_analysis(
&self,
_content: &TextContent,
_analysis: &LLMEnrichmentResult,
_entity_concurrency: usize,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
unreachable!("convert_analysis should not be called after validation failure")
}
async fn prepare_chunks(
&self,
_content: &TextContent,
_range: std::ops::Range<usize>,
) -> Result<Vec<TextChunk>, AppError> {
unreachable!("prepare_chunks should not be called after validation failure")
}
}
async fn setup_db() -> SurrealDbClient {
let namespace = "pipeline_test";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("Failed to create in-memory SurrealDB");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
}
fn pipeline_config() -> IngestionConfig {
IngestionConfig {
tuning: IngestionTuning {
chunk_min_chars: 4,
chunk_max_chars: 64,
chunk_insert_concurrency: 4,
entity_embedding_concurrency: 2,
..IngestionTuning::default()
},
}
}
async fn reserve_task(
db: &SurrealDbClient,
worker_id: &str,
payload: IngestionPayload,
user_id: &str,
) -> IngestionTask {
let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db)
.await
.expect("task created");
let lease = task.lease_duration();
IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease)
.await
.expect("claim succeeds")
.expect("task claimed")
}
#[tokio::test]
async fn ingestion_pipeline_happy_path_persists_entities() {
let db = setup_db().await;
let worker_id = "worker-happy";
let user_id = "user-123";
let services = Arc::new(MockServices::new(user_id));
let pipeline =
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services.clone())
.expect("pipeline");
let task = reserve_task(
&db,
worker_id,
IngestionPayload::Text {
text: "Example payload".into(),
context: "Context".into(),
category: "notes".into(),
user_id: user_id.into(),
},
user_id,
)
.await;
pipeline
.process_task(task.clone())
.await
.expect("pipeline succeeds");
let stored_task: IngestionTask = db
.get_item(&task.id)
.await
.expect("retrieve task")
.expect("task present");
assert_eq!(stored_task.state, TaskState::Succeeded);
let stored_entities: Vec<KnowledgeEntity> = db
.get_all_stored_items::<KnowledgeEntity>()
.await
.expect("entities stored");
assert!(!stored_entities.is_empty(), "entities should be stored");
let stored_chunks: Vec<TextChunk> = db
.get_all_stored_items::<TextChunk>()
.await
.expect("chunks stored");
assert!(
!stored_chunks.is_empty(),
"chunks should be stored for ingestion text"
);
let call_log = services.calls.lock().await.clone();
assert!(
call_log.len() >= 5,
"expected at least one chunk embedding call"
);
assert_eq!(
&call_log[0..4],
["prepare", "retrieve", "enrich", "convert"]
);
assert!(call_log[4..].iter().all(|entry| *entry == "chunk"));
}
#[tokio::test]
async fn ingestion_pipeline_failure_marks_retry() {
let db = setup_db().await;
let worker_id = "worker-fail";
let user_id = "user-456";
let services = Arc::new(FailingServices {
inner: MockServices::new(user_id),
});
let pipeline =
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)
.expect("pipeline");
let task = reserve_task(
&db,
worker_id,
IngestionPayload::Text {
text: "Example failure payload".into(),
context: "Context".into(),
category: "notes".into(),
user_id: user_id.into(),
},
user_id,
)
.await;
let result = pipeline.process_task(task.clone()).await;
assert!(
result.is_err(),
"failure services should bubble error from pipeline"
);
let stored_task: IngestionTask = db
.get_item(&task.id)
.await
.expect("retrieve task")
.expect("task present");
assert_eq!(stored_task.state, TaskState::Failed);
assert!(
stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5),
"failed task should schedule retry in the future"
);
}
#[tokio::test]
async fn ingestion_pipeline_validation_failure_dead_letters_task() {
let db = setup_db().await;
let worker_id = "worker-validation";
let user_id = "user-789";
let services = Arc::new(ValidationServices);
let pipeline =
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)
.expect("pipeline");
let task = reserve_task(
&db,
worker_id,
IngestionPayload::Text {
text: "irrelevant".into(),
context: "".into(),
category: "notes".into(),
user_id: user_id.into(),
},
user_id,
)
.await;
let result = pipeline.process_task(task.clone()).await;
assert!(
result.is_err(),
"validation failure should surface as error"
);
let stored_task: IngestionTask = db
.get_item(&task.id)
.await
.expect("retrieve task")
.expect("task present");
assert_eq!(stored_task.state, TaskState::DeadLetter);
}

View File

@@ -1,259 +0,0 @@
pub mod llm_enrichment_result;
use std::io::Write;
use std::time::Instant;
use axum::http::HeaderMap;
use axum_typed_multipart::{FieldData, FieldMetadata};
use chrono::Utc;
use common::storage::db::SurrealDbClient;
use common::{
error::AppError,
storage::{
store,
types::{
file_info::FileInfo,
ingestion_payload::IngestionPayload,
text_content::{TextContent, UrlInfo},
},
},
utils::config::AppConfig,
};
use dom_smoothie::{Article, Readability, TextMode};
use headless_chrome::Browser;
use std::io::{Seek, SeekFrom};
use tempfile::NamedTempFile;
use tracing::{error, info};
use crate::utils::{
audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image,
pdf_ingestion::extract_pdf_content,
};
pub async fn to_text_content(
ingestion_payload: IngestionPayload,
db: &SurrealDbClient,
config: &AppConfig,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<TextContent, AppError> {
match ingestion_payload {
IngestionPayload::Url {
url,
context,
category,
user_id,
} => {
let (article, file_info) = fetch_article_from_url(&url, db, &user_id, config).await?;
Ok(TextContent::new(
article.text_content.into(),
Some(context),
category,
None,
Some(UrlInfo {
url,
title: article.title,
image_id: file_info.id,
}),
user_id,
))
}
IngestionPayload::Text {
text,
context,
category,
user_id,
} => Ok(TextContent::new(
text,
Some(context),
category,
None,
None,
user_id,
)),
IngestionPayload::File {
file_info,
context,
category,
user_id,
} => {
let text = extract_text_from_file(&file_info, db, openai_client, config).await?;
Ok(TextContent::new(
text,
Some(context),
category,
Some(file_info),
None,
user_id,
))
}
}
}
/// Fetches web content from a URL, extracts the main article text as Markdown,
/// captures a screenshot, and stores the screenshot returning [`FileInfo`].
///
/// This function handles browser automation, content extraction via Readability,
/// screenshot capture, temporary file handling, and persisting the screenshot
/// details (including deduplication based on content hash via [`FileInfo::new`]).
///
/// # Arguments
///
/// * `url` - The URL of the web page to fetch.
/// * `db` - A reference to the database client (`SurrealDbClient`).
/// * `user_id` - The ID of the user performing the action, used for associating the file.
///
/// # Returns
///
/// A `Result` containing:
/// * Ok: A tuple `(Article, FileInfo)` where `Article` contains the parsed markdown
/// content and metadata, and `FileInfo` contains the details of the stored screenshot.
/// * Err: An `AppError` if any step fails (navigation, screenshot, file handling, DB operation).
async fn fetch_article_from_url(
url: &str,
db: &SurrealDbClient,
user_id: &str,
config: &AppConfig,
) -> Result<(Article, FileInfo), AppError> {
info!("Fetching URL: {}", url);
// Instantiate timer
let now = Instant::now();
// Setup browser, navigate and wait
let browser = {
#[cfg(feature = "docker")]
{
// Use this when compiling for docker
let options = headless_chrome::LaunchOptionsBuilder::default()
.sandbox(false)
.build()
.map_err(|e| AppError::InternalError(e.to_string()))?;
Browser::new(options)?
}
#[cfg(not(feature = "docker"))]
{
// Use this otherwise
Browser::default()?
}
};
let tab = browser.new_tab()?;
let page = tab.navigate_to(url)?;
let loaded_page = page.wait_until_navigated()?;
// Get content
let raw_content = loaded_page.get_content()?;
// Get screenshot
let screenshot = loaded_page.capture_screenshot(
headless_chrome::protocol::cdp::Page::CaptureScreenshotFormatOption::Jpeg,
None,
None,
true,
)?;
// Create temp file
let mut tmp_file = NamedTempFile::new()?;
let temp_path_str = format!("{:?}", tmp_file.path());
// Write screenshot TO the temp file
tmp_file.write_all(&screenshot)?;
// Ensure the OS buffer is written to the file system _before_ we proceed.
tmp_file.as_file().sync_all()?;
// Ensure the file handle's read cursor is at the beginning before hashing occurs.
if let Err(e) = tmp_file.seek(SeekFrom::Start(0)) {
error!("URL: {}. Failed to seek temp file {} to start: {:?}. Proceeding, but hashing might fail.", url, temp_path_str, e);
}
// Prepare file metadata
let parsed_url =
url::Url::parse(url).map_err(|_| AppError::Processing("Invalid URL".to_string()))?;
let domain = parsed_url
.host_str()
.unwrap_or("unknown")
.replace(|c: char| !c.is_alphanumeric(), "_");
let timestamp = Utc::now().format("%Y%m%d%H%M%S");
let file_name = format!("{}_{}_{}.jpg", domain, "screenshot", timestamp);
// Construct FieldData and FieldMetadata
let metadata = FieldMetadata {
file_name: Some(file_name),
content_type: Some("image/jpeg".to_string()),
name: None,
headers: HeaderMap::new(),
};
let field_data = FieldData {
contents: tmp_file,
metadata,
};
// Store screenshot
let file_info = FileInfo::new(field_data, db, user_id, config).await?;
// Parse content...
let config = dom_smoothie::Config {
text_mode: TextMode::Markdown,
..Default::default()
};
let mut readability = Readability::new(raw_content, None, Some(config))?;
let article: Article = readability.parse()?;
let end = now.elapsed();
info!(
"URL: {}. Total time: {:?}. Final File ID: {}",
url, end, file_info.id
);
Ok((article, file_info))
}
/// Extracts text from a stored file by MIME type.
async fn extract_text_from_file(
file_info: &FileInfo,
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
config: &AppConfig,
) -> Result<String, AppError> {
let base_path = store::resolve_base_dir(config);
let absolute_path = base_path.join(&file_info.path);
match file_info.mime_type.as_str() {
"text/plain" | "text/markdown" | "application/octet-stream" | "text/x-rust" => {
let content = tokio::fs::read_to_string(&absolute_path).await?;
Ok(content)
}
"application/pdf" => {
extract_pdf_content(
&absolute_path,
db_client,
openai_client,
&config.pdf_ingest_mode,
)
.await
}
"image/png" | "image/jpeg" => {
let path_str = absolute_path
.to_str()
.ok_or_else(|| {
AppError::Processing(format!(
"Encountered a non-UTF8 path while reading image {}",
file_info.id
))
})?
.to_string();
let content = extract_text_from_image(&path_str, db_client, openai_client).await?;
Ok(content)
}
"audio/mpeg" | "audio/mp3" | "audio/wav" | "audio/x-wav" | "audio/webm" | "audio/mp4"
| "audio/ogg" | "audio/flac" => {
let path_str = absolute_path
.to_str()
.ok_or_else(|| {
AppError::Processing(format!(
"Encountered a non-UTF8 path while reading audio {}",
file_info.id
))
})?
.to_string();
transcribe_audio_file(&path_str, db_client, openai_client).await
}
// Handle other MIME types as needed
_ => Err(AppError::NotFound(file_info.mime_type.clone())),
}
}

View File

@@ -0,0 +1,63 @@
use common::{
error::AppError,
storage::{db::SurrealDbClient, store, types::file_info::FileInfo},
utils::config::AppConfig,
};
use super::{
audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image,
pdf_ingestion::extract_pdf_content,
};
pub async fn extract_text_from_file(
file_info: &FileInfo,
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
config: &AppConfig,
) -> Result<String, AppError> {
let base_path = store::resolve_base_dir(config);
let absolute_path = base_path.join(&file_info.path);
match file_info.mime_type.as_str() {
"text/plain" | "text/markdown" | "application/octet-stream" | "text/x-rust" => {
let content = tokio::fs::read_to_string(&absolute_path).await?;
Ok(content)
}
"application/pdf" => {
extract_pdf_content(
&absolute_path,
db_client,
openai_client,
&config.pdf_ingest_mode,
)
.await
}
"image/png" | "image/jpeg" => {
let path_str = absolute_path
.to_str()
.ok_or_else(|| {
AppError::Processing(format!(
"Encountered a non-UTF8 path while reading image {}",
file_info.id
))
})?
.to_string();
let content = extract_text_from_image(&path_str, db_client, openai_client).await?;
Ok(content)
}
"audio/mpeg" | "audio/mp3" | "audio/wav" | "audio/x-wav" | "audio/webm" | "audio/mp4"
| "audio/ogg" | "audio/flac" => {
let path_str = absolute_path
.to_str()
.ok_or_else(|| {
AppError::Processing(format!(
"Encountered a non-UTF8 path while reading audio {}",
file_info.id
))
})?
.to_string();
transcribe_audio_file(&path_str, db_client, openai_client).await
}
_ => Err(AppError::NotFound(file_info.mime_type.clone())),
}
}

View File

@@ -0,0 +1,53 @@
use common::error::AppError;
use std::collections::HashMap;
use uuid::Uuid;
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
#[derive(Clone)]
pub struct GraphMapper {
pub key_to_id: HashMap<String, Uuid>,
}
impl Default for GraphMapper {
fn default() -> Self {
Self::new()
}
}
impl GraphMapper {
pub fn new() -> Self {
Self {
key_to_id: HashMap::new(),
}
}
/// Tries to get an ID by first parsing the key as a UUID,
/// and if that fails, looking it up in the internal map.
pub fn get_or_parse_id(&self, key: &str) -> Result<Uuid, AppError> {
// First, try to parse the key as a UUID.
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
return Ok(parsed_uuid);
}
// If parsing fails, look it up in the map.
self.key_to_id.get(key).copied().ok_or_else(|| {
AppError::GraphMapper(format!(
"Key '{key}' is not a valid UUID and was not found in the map."
))
})
}
/// Assigns a new UUID for a given key. (No changes needed here)
pub fn assign_id(&mut self, key: &str) -> Uuid {
let id = Uuid::new_v4();
self.key_to_id.insert(key.to_string(), id);
id
}
/// Retrieves the UUID for a given key, returning a Result for consistency.
pub fn get_id(&self, key: &str) -> Result<Uuid, AppError> {
self.key_to_id
.get(key)
.copied()
.ok_or_else(|| AppError::GraphMapper(format!("Key '{key}' not found in map.")))
}
}

View File

@@ -1,58 +1,7 @@
pub mod audio_transcription;
pub mod file_text_extraction;
pub mod graph_mapper;
pub mod image_parsing;
pub mod llm_instructions;
pub mod pdf_ingestion;
use common::error::AppError;
use std::collections::HashMap;
use uuid::Uuid;
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
#[derive(Clone)]
pub struct GraphMapper {
pub key_to_id: HashMap<String, Uuid>,
}
impl Default for GraphMapper {
fn default() -> Self {
Self::new()
}
}
impl GraphMapper {
pub fn new() -> Self {
Self {
key_to_id: HashMap::new(),
}
}
/// Tries to get an ID by first parsing the key as a UUID,
/// and if that fails, looking it up in the internal map.
pub fn get_or_parse_id(&self, key: &str) -> Result<Uuid, AppError> {
// First, try to parse the key as a UUID.
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
return Ok(parsed_uuid);
}
// If parsing fails, look it up in the map.
self.key_to_id.get(key).copied().ok_or_else(|| {
AppError::GraphMapper(format!(
"Key '{key}' is not a valid UUID and was not found in the map."
))
})
}
/// Assigns a new UUID for a given key. (No changes needed here)
pub fn assign_id(&mut self, key: &str) -> Uuid {
let id = Uuid::new_v4();
self.key_to_id.insert(key.to_string(), id);
id
}
/// Retrieves the UUID for a given key, returning a Result for consistency.
pub fn get_id(&self, key: &str) -> Result<Uuid, AppError> {
self.key_to_id
.get(key)
.copied()
.ok_or_else(|| AppError::GraphMapper(format!("Key '{key}' not found in map.")))
}
}
pub mod url_text_retrieval;

View File

@@ -0,0 +1,174 @@
use axum::http::HeaderMap;
use axum_typed_multipart::{FieldData, FieldMetadata};
use chrono::Utc;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::file_info::FileInfo},
utils::config::AppConfig,
};
use dom_smoothie::{Article, Readability, TextMode};
use headless_chrome::Browser;
use std::{
io::{Seek, SeekFrom, Write},
net::IpAddr,
time::Instant,
};
use tempfile::NamedTempFile;
use tracing::{error, info, warn};
pub async fn extract_text_from_url(
url: &str,
db: &SurrealDbClient,
user_id: &str,
config: &AppConfig,
) -> Result<(Article, FileInfo), AppError> {
info!("Fetching URL: {}", url);
let now = Instant::now();
let browser = {
#[cfg(feature = "docker")]
{
let options = headless_chrome::LaunchOptionsBuilder::default()
.sandbox(false)
.build()
.map_err(|e| AppError::InternalError(e.to_string()))?;
Browser::new(options)?
}
#[cfg(not(feature = "docker"))]
{
Browser::default()?
}
};
let tab = browser.new_tab()?;
let page = tab.navigate_to(url)?;
let loaded_page = page.wait_until_navigated()?;
let raw_content = loaded_page.get_content()?;
let screenshot = loaded_page.capture_screenshot(
headless_chrome::protocol::cdp::Page::CaptureScreenshotFormatOption::Jpeg,
None,
None,
true,
)?;
let mut tmp_file = NamedTempFile::new()?;
let temp_path_str = format!("{:?}", tmp_file.path());
tmp_file.write_all(&screenshot)?;
tmp_file.as_file().sync_all()?;
if let Err(e) = tmp_file.seek(SeekFrom::Start(0)) {
error!(
"URL: {}. Failed to seek temp file {} to start: {:?}. Proceeding, but hashing might fail.",
url, temp_path_str, e
);
}
let parsed_url =
url::Url::parse(url).map_err(|_| AppError::Validation("Invalid URL".to_string()))?;
let domain = ensure_ingestion_url_allowed(&parsed_url)?;
let timestamp = Utc::now().format("%Y%m%d%H%M%S");
let file_name = format!("{}_{}_{}.jpg", domain, "screenshot", timestamp);
let metadata = FieldMetadata {
file_name: Some(file_name),
content_type: Some("image/jpeg".to_string()),
name: None,
headers: HeaderMap::new(),
};
let field_data = FieldData {
contents: tmp_file,
metadata,
};
let file_info = FileInfo::new(field_data, db, user_id, config).await?;
let config = dom_smoothie::Config {
text_mode: TextMode::Markdown,
..Default::default()
};
let mut readability = Readability::new(raw_content, None, Some(config))?;
let article: Article = readability.parse()?;
let end = now.elapsed();
info!(
"URL: {}. Total time: {:?}. Final File ID: {}",
url, end, file_info.id
);
Ok((article, file_info))
}
fn ensure_ingestion_url_allowed(url: &url::Url) -> Result<String, AppError> {
match url.scheme() {
"http" | "https" => {}
scheme => {
warn!(%url, %scheme, "Rejected ingestion URL due to unsupported scheme");
return Err(AppError::Validation(
"Unsupported URL scheme for ingestion".to_string(),
));
}
}
let host = match url.host_str() {
Some(host) => host,
None => {
warn!(%url, "Rejected ingestion URL missing host");
return Err(AppError::Validation(
"URL is missing a host component".to_string(),
));
}
};
if host.eq_ignore_ascii_case("localhost") {
warn!(%url, host, "Rejected ingestion URL to localhost");
return Err(AppError::Validation(
"Ingestion URL host is not allowed".to_string(),
));
}
if let Ok(ip) = host.parse::<IpAddr>() {
let is_disallowed = match ip {
IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(),
IpAddr::V6(v6) => v6.is_unique_local() || v6.is_unicast_link_local(),
};
if ip.is_loopback() || ip.is_unspecified() || ip.is_multicast() || is_disallowed {
warn!(%url, host, %ip, "Rejected ingestion URL pointing to restricted network range");
return Err(AppError::Validation(
"Ingestion URL host is not allowed".to_string(),
));
}
}
Ok(host.replace(|c: char| !c.is_alphanumeric(), "_"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_unsupported_scheme() {
let url = url::Url::parse("ftp://example.com").expect("url");
assert!(ensure_ingestion_url_allowed(&url).is_err());
}
#[test]
fn rejects_localhost() {
let url = url::Url::parse("http://localhost/resource").expect("url");
assert!(ensure_ingestion_url_allowed(&url).is_err());
}
#[test]
fn rejects_private_ipv4() {
let url = url::Url::parse("http://192.168.1.10/index.html").expect("url");
assert!(ensure_ingestion_url_allowed(&url).is_err());
}
#[test]
fn allows_public_domain_and_sanitizes() {
let url = url::Url::parse("https://sub.example.com/path").expect("url");
let sanitized = ensure_ingestion_url_allowed(&url).expect("allowed");
assert_eq!(sanitized, "sub_example_com");
}
}