diff --git a/CHANGELOG.md b/CHANGELOG.md index 399317e..c93fb48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog ## Unreleased +- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships +- Added knowledge entity search results to the global search +- Backend fixes for improved performance when ingesting and retrieval ## Version 0.2.4 (2025-10-15) - Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal. diff --git a/Cargo.lock b/Cargo.lock index 7a08320..c085446 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2919,6 +2919,7 @@ name = "ingestion-pipeline" version = "0.1.0" dependencies = [ "async-openai", + "async-trait", "axum", "axum_typed_multipart", "base64 0.22.1", @@ -2933,6 +2934,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "state-machines", "surrealdb", "tempfile", "text-splitter", diff --git a/ingestion-pipeline/Cargo.toml b/ingestion-pipeline/Cargo.toml index 68cbd28..9a200f2 100644 --- a/ingestion-pipeline/Cargo.toml +++ b/ingestion-pipeline/Cargo.toml @@ -31,5 +31,7 @@ lopdf = "0.32" common = { path = "../common" } composite-retrieval = { path = "../composite-retrieval" } +async-trait = { workspace = true } +state-machines = { workspace = true } [features] docker = [] diff --git a/ingestion-pipeline/src/enricher.rs b/ingestion-pipeline/src/enricher.rs deleted file mode 100644 index 6a71cd4..0000000 --- a/ingestion-pipeline/src/enricher.rs +++ /dev/null @@ -1,118 +0,0 @@ -use std::sync::Arc; - -use async_openai::types::{ - ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, - CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat, - ResponseFormatJsonSchema, -}; -use common::{ - error::AppError, - storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, -}; -use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity}; - -use crate::{ - types::llm_enrichment_result::LLMEnrichmentResult, - utils::llm_instructions::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE}, -}; - -pub struct IngestionEnricher { - db_client: Arc, - openai_client: Arc>, -} - -impl IngestionEnricher { - pub const fn new( - db_client: Arc, - openai_client: Arc>, - ) -> Self { - Self { - db_client, - openai_client, - } - } - - pub async fn analyze_content( - &self, - category: &str, - context: Option<&str>, - text: &str, - user_id: &str, - ) -> Result { - let similar_entities = self - .find_similar_entities(category, context, text, user_id) - .await?; - let llm_request = self - .prepare_llm_request(category, context, text, &similar_entities) - .await?; - self.perform_analysis(llm_request).await - } - - async fn find_similar_entities( - &self, - category: &str, - context: Option<&str>, - text: &str, - user_id: &str, - ) -> Result, AppError> { - let input_text = - format!("content: {text}, category: {category}, user_context: {context:?}"); - - retrieve_entities(&self.db_client, &self.openai_client, &input_text, user_id).await - } - - async fn prepare_llm_request( - &self, - category: &str, - context: Option<&str>, - text: &str, - similar_entities: &[RetrievedEntity], - ) -> Result { - let settings = SystemSettings::get_current(&self.db_client).await?; - - let entities_json = retrieved_entities_to_json(similar_entities); - - let user_message = format!( - "Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}" - ); - - let response_format = ResponseFormat::JsonSchema { - json_schema: ResponseFormatJsonSchema { - description: Some("Structured analysis of the submitted content".into()), - name: "content_analysis".into(), - schema: Some(get_ingress_analysis_schema()), - strict: Some(true), - }, - }; - - let request = CreateChatCompletionRequestArgs::default() - .model(&settings.processing_model) - .messages([ - ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(), - ChatCompletionRequestUserMessage::from(user_message).into(), - ]) - .response_format(response_format) - .build()?; - - Ok(request) - } - - async fn perform_analysis( - &self, - request: CreateChatCompletionRequest, - ) -> Result { - let response = self.openai_client.chat().create(request).await?; - - let content = response - .choices - .first() - .and_then(|choice| choice.message.content.as_ref()) - .ok_or(AppError::LLMParsing( - "No content found in LLM response".into(), - ))?; - - serde_json::from_str::(content).map_err(|e| { - AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")) - }) - } -} diff --git a/ingestion-pipeline/src/lib.rs b/ingestion-pipeline/src/lib.rs index 0ff384a..4eb7e9f 100644 --- a/ingestion-pipeline/src/lib.rs +++ b/ingestion-pipeline/src/lib.rs @@ -1,6 +1,4 @@ -pub mod enricher; pub mod pipeline; -pub mod types; pub mod utils; use chrono::Utc; diff --git a/ingestion-pipeline/src/pipeline.rs b/ingestion-pipeline/src/pipeline.rs deleted file mode 100644 index 497a040..0000000 --- a/ingestion-pipeline/src/pipeline.rs +++ /dev/null @@ -1,299 +0,0 @@ -use std::{sync::Arc, time::Instant}; - -use text_splitter::TextSplitter; -use tokio::time::{sleep, Duration}; -use tracing::{debug, info, info_span, warn}; - -use common::{ - error::AppError, - storage::{ - db::SurrealDbClient, - types::{ - ingestion_task::{IngestionTask, TaskErrorInfo}, - knowledge_entity::KnowledgeEntity, - knowledge_relationship::KnowledgeRelationship, - text_chunk::TextChunk, - text_content::TextContent, - }, - }, - utils::{config::AppConfig, embedding::generate_embedding}, -}; - -use crate::{ - enricher::IngestionEnricher, - types::{llm_enrichment_result::LLMEnrichmentResult, to_text_content}, -}; - -pub struct IngestionPipeline { - db: Arc, - openai_client: Arc>, - config: AppConfig, -} - -impl IngestionPipeline { - pub async fn new( - db: Arc, - openai_client: Arc>, - config: AppConfig, - ) -> Result { - Ok(Self { - db, - openai_client, - config, - }) - } - pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> { - let task_id = task.id.clone(); - let attempt = task.attempts; - let worker_label = task - .worker_id - .clone() - .unwrap_or_else(|| "unknown-worker".to_string()); - let span = info_span!( - "ingestion_task", - %task_id, - attempt, - worker_id = %worker_label, - state = %task.state.as_str() - ); - let _enter = span.enter(); - let processing_task = task.mark_processing(&self.db).await?; - - let text_content = to_text_content( - processing_task.content.clone(), - &self.db, - &self.config, - &self.openai_client, - ) - .await?; - - let text_len = text_content.text.chars().count(); - let preview: String = text_content.text.chars().take(120).collect(); - let preview_clean = preview.replace("\n", " "); - let preview_len = preview_clean.chars().count(); - let truncated = text_len > preview_len; - let context_len = text_content - .context - .as_ref() - .map(|c| c.chars().count()) - .unwrap_or(0); - info!( - %task_id, - attempt, - user_id = %text_content.user_id, - category = %text_content.category, - text_chars = text_len, - context_chars = context_len, - attachments = text_content.file_info.is_some(), - "ingestion task input ready" - ); - debug!( - %task_id, - attempt, - preview = %preview_clean, - preview_truncated = truncated, - "ingestion task input preview" - ); - - match self.process(&text_content).await { - Ok(()) => { - processing_task.mark_succeeded(&self.db).await?; - info!(%task_id, attempt, "ingestion task succeeded"); - Ok(()) - } - Err(err) => { - let reason = err.to_string(); - let error_info = TaskErrorInfo { - code: None, - message: reason.clone(), - }; - - if processing_task.can_retry() { - let delay = Self::retry_delay(processing_task.attempts); - processing_task - .mark_failed(error_info, delay, &self.db) - .await?; - warn!( - %task_id, - attempt = processing_task.attempts, - retry_in_secs = delay.as_secs(), - "ingestion task failed; scheduled retry" - ); - } else { - processing_task - .mark_dead_letter(error_info, &self.db) - .await?; - warn!( - %task_id, - attempt = processing_task.attempts, - "ingestion task failed; moved to dead letter queue" - ); - } - - Err(AppError::Processing(reason)) - } - } - } - - fn retry_delay(attempt: u32) -> Duration { - const BASE_SECONDS: u64 = 30; - const MAX_SECONDS: u64 = 15 * 60; - - let capped_attempt = attempt.saturating_sub(1).min(5); - let multiplier = 2_u64.pow(capped_attempt); - let delay = BASE_SECONDS * multiplier; - - Duration::from_secs(delay.min(MAX_SECONDS)) - } - - pub async fn process(&self, content: &TextContent) -> Result<(), AppError> { - let now = Instant::now(); - - // Perform analyis, this step also includes retrieval - let analysis = self.perform_semantic_analysis(content).await?; - - let end = now.elapsed(); - info!( - "{:?} time elapsed during creation of entities and relationships", - end - ); - - // Convert analysis to application objects - let (entities, relationships) = analysis - .to_database_entities(&content.id, &content.user_id, &self.openai_client, &self.db) - .await?; - - // Store everything - tokio::try_join!( - self.store_graph_entities(entities, relationships), - self.store_vector_chunks(content), - )?; - - // Store original content - self.db.store_item(content.to_owned()).await?; - - self.db.rebuild_indexes().await?; - Ok(()) - } - - async fn perform_semantic_analysis( - &self, - content: &TextContent, - ) -> Result { - let analyser = IngestionEnricher::new(self.db.clone(), self.openai_client.clone()); - analyser - .analyze_content( - &content.category, - content.context.as_deref(), - &content.text, - &content.user_id, - ) - .await - } - - async fn store_graph_entities( - &self, - entities: Vec, - relationships: Vec, - ) -> Result<(), AppError> { - let entities = Arc::new(entities); - let relationships = Arc::new(relationships); - let entity_count = entities.len(); - let relationship_count = relationships.len(); - - const STORE_GRAPH_MUTATION: &str = r" - BEGIN TRANSACTION; - LET $entities = $entities; - LET $relationships = $relationships; - - FOR $entity IN $entities { - CREATE type::thing('knowledge_entity', $entity.id) CONTENT $entity; - }; - - FOR $relationship IN $relationships { - LET $in_node = type::thing('knowledge_entity', $relationship.in); - LET $out_node = type::thing('knowledge_entity', $relationship.out); - RELATE $in_node->relates_to->$out_node CONTENT { - id: type::thing('relates_to', $relationship.id), - metadata: $relationship.metadata - }; - }; - - COMMIT TRANSACTION; - "; - - const MAX_ATTEMPTS: usize = 3; - const INITIAL_BACKOFF_MS: u64 = 50; - const MAX_BACKOFF_MS: u64 = 800; - - let mut backoff_ms = INITIAL_BACKOFF_MS; - let mut success = false; - - for attempt in 0..MAX_ATTEMPTS { - let result = self - .db - .client - .query(STORE_GRAPH_MUTATION) - .bind(("entities", entities.clone())) - .bind(("relationships", relationships.clone())) - .await; - - match result { - Ok(_) => { - success = true; - break; - } - Err(err) => { - if Self::is_retryable_conflict(&err) && attempt + 1 < MAX_ATTEMPTS { - warn!( - attempt = attempt + 1, - "Transient SurrealDB conflict while storing graph data; retrying" - ); - sleep(Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS); - continue; - } - - return Err(AppError::from(err)); - } - } - } - - if !success { - return Err(AppError::InternalError( - "Failed to store graph entities after retries".to_string(), - )); - } - - info!( - "Stored {} entities and {} relationships", - entity_count, relationship_count - ); - Ok(()) - } - - async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), AppError> { - let splitter = TextSplitter::new(500..2000); - let chunks = splitter.chunks(&content.text); - - // Could potentially process chunks in parallel with a bounded concurrent limit - for chunk in chunks { - let embedding = generate_embedding(&self.openai_client, chunk, &self.db).await?; - let text_chunk = TextChunk::new( - content.id.to_string(), - chunk.to_string(), - embedding, - content.user_id.to_string(), - ); - self.db.store_item(text_chunk).await?; - } - - Ok(()) - } - - fn is_retryable_conflict(error: &surrealdb::Error) -> bool { - error - .to_string() - .contains("Failed to commit transaction due to a read or write conflict") - } -} diff --git a/ingestion-pipeline/src/pipeline/config.rs b/ingestion-pipeline/src/pipeline/config.rs new file mode 100644 index 0000000..fa59454 --- /dev/null +++ b/ingestion-pipeline/src/pipeline/config.rs @@ -0,0 +1,35 @@ +#[derive(Debug, Clone)] +pub struct IngestionTuning { + pub retry_base_delay_secs: u64, + pub retry_max_delay_secs: u64, + pub retry_backoff_cap_exponent: u32, + pub graph_store_attempts: usize, + pub graph_initial_backoff_ms: u64, + pub graph_max_backoff_ms: u64, + pub chunk_min_chars: usize, + pub chunk_max_chars: usize, + pub chunk_insert_concurrency: usize, + pub entity_embedding_concurrency: usize, +} + +impl Default for IngestionTuning { + fn default() -> Self { + Self { + retry_base_delay_secs: 30, + retry_max_delay_secs: 15 * 60, + retry_backoff_cap_exponent: 5, + graph_store_attempts: 3, + graph_initial_backoff_ms: 50, + graph_max_backoff_ms: 800, + chunk_min_chars: 500, + chunk_max_chars: 2_000, + chunk_insert_concurrency: 8, + entity_embedding_concurrency: 4, + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct IngestionConfig { + pub tuning: IngestionTuning, +} diff --git a/ingestion-pipeline/src/pipeline/context.rs b/ingestion-pipeline/src/pipeline/context.rs new file mode 100644 index 0000000..74959b1 --- /dev/null +++ b/ingestion-pipeline/src/pipeline/context.rs @@ -0,0 +1,76 @@ +use common::{ + error::AppError, + storage::{ + db::SurrealDbClient, + types::{ingestion_task::IngestionTask, text_content::TextContent}, + }, +}; +use composite_retrieval::RetrievedEntity; +use tracing::error; + +use super::enrichment_result::LLMEnrichmentResult; + +use super::{config::IngestionConfig, services::PipelineServices}; + +pub struct PipelineContext<'a> { + pub task: &'a IngestionTask, + pub task_id: String, + pub attempt: u32, + pub db: &'a SurrealDbClient, + pub pipeline_config: &'a IngestionConfig, + pub services: &'a dyn PipelineServices, + pub text_content: Option, + pub similar_entities: Vec, + pub analysis: Option, +} + +impl<'a> PipelineContext<'a> { + pub fn new( + task: &'a IngestionTask, + db: &'a SurrealDbClient, + pipeline_config: &'a IngestionConfig, + services: &'a dyn PipelineServices, + ) -> Self { + let task_id = task.id.clone(); + let attempt = task.attempts; + Self { + task, + task_id, + attempt, + db, + pipeline_config, + services, + text_content: None, + similar_entities: Vec::new(), + analysis: None, + } + } + + pub fn text_content(&self) -> Result<&TextContent, AppError> { + self.text_content + .as_ref() + .ok_or_else(|| AppError::InternalError("text content expected to be available".into())) + } + + pub fn take_text_content(&mut self) -> Result { + self.text_content.take().ok_or_else(|| { + AppError::InternalError("text content expected to be available for persistence".into()) + }) + } + + pub fn take_analysis(&mut self) -> Result { + self.analysis.take().ok_or_else(|| { + AppError::InternalError("analysis expected to be available for persistence".into()) + }) + } + + pub fn abort(&mut self, err: AppError) -> AppError { + error!( + task_id = %self.task_id, + attempt = self.attempt, + error = %err, + "ingestion pipeline aborted" + ); + err + } +} diff --git a/ingestion-pipeline/src/types/llm_enrichment_result.rs b/ingestion-pipeline/src/pipeline/enrichment_result.rs similarity index 55% rename from ingestion-pipeline/src/types/llm_enrichment_result.rs rename to ingestion-pipeline/src/pipeline/enrichment_result.rs index 94b2fb4..e73b28a 100644 --- a/ingestion-pipeline/src/types/llm_enrichment_result.rs +++ b/ingestion-pipeline/src/pipeline/enrichment_result.rs @@ -1,8 +1,8 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use chrono::Utc; +use futures::stream::{self, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; -use tokio::task; use common::{ error::AppError, @@ -15,28 +15,25 @@ use common::{ }, utils::embedding::generate_embedding, }; -use futures::future::try_join_all; -use crate::utils::GraphMapper; +use crate::utils::graph_mapper::GraphMapper; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct LLMKnowledgeEntity { - pub key: String, // Temporary identifier + pub key: String, pub name: String, pub description: String, - pub entity_type: String, // Should match KnowledgeEntityType variants + pub entity_type: String, } -/// Represents a single relationship from the LLM. #[derive(Debug, Serialize, Deserialize, Clone)] pub struct LLMRelationship { #[serde(rename = "type")] - pub type_: String, // e.g., RelatedTo, RelevantTo - pub source: String, // Key of the source entity - pub target: String, // Key of the target entity + pub type_: String, + pub source: String, + pub target: String, } -/// Represents the entire graph analysis result from the LLM. #[derive(Debug, Serialize, Deserialize, Clone)] pub struct LLMEnrichmentResult { pub knowledge_entities: Vec, @@ -44,27 +41,16 @@ pub struct LLMEnrichmentResult { } impl LLMEnrichmentResult { - /// Converts the LLM graph analysis result into database entities and relationships. - /// - /// # Arguments - /// - /// * `source_id` - A UUID representing the source identifier. - /// * `openai_client` - `OpenAI` client for LLM calls. - /// - /// # Returns - /// - /// * `Result<(Vec, Vec), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`. pub async fn to_database_entities( &self, source_id: &str, user_id: &str, openai_client: &async_openai::Client, db_client: &SurrealDbClient, + entity_concurrency: usize, ) -> Result<(Vec, Vec), AppError> { - // Create mapper and pre-assign IDs - let mapper = Arc::new(Mutex::new(self.create_mapper()?)); + let mapper = Arc::new(self.create_mapper()?); - // Process entities let entities = self .process_entities( source_id, @@ -72,10 +58,10 @@ impl LLMEnrichmentResult { Arc::clone(&mapper), openai_client, db_client, + entity_concurrency, ) .await?; - // Process relationships let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?; Ok((entities, relationships)) @@ -84,7 +70,6 @@ impl LLMEnrichmentResult { fn create_mapper(&self) -> Result { let mut mapper = GraphMapper::new(); - // Pre-assign all IDs for entity in &self.knowledge_entities { mapper.assign_id(&entity.key); } @@ -96,57 +81,46 @@ impl LLMEnrichmentResult { &self, source_id: &str, user_id: &str, - mapper: Arc>, + mapper: Arc, openai_client: &async_openai::Client, db_client: &SurrealDbClient, + entity_concurrency: usize, ) -> Result, AppError> { - let futures: Vec<_> = self - .knowledge_entities - .iter() - .map(|entity| { - let mapper = Arc::clone(&mapper); - let openai_client = openai_client.clone(); - let source_id = source_id.to_string(); - let user_id = user_id.to_string(); - let entity = entity.clone(); - let db_client = db_client.clone(); + stream::iter(self.knowledge_entities.iter().cloned().map(|entity| { + let mapper = Arc::clone(&mapper); + let openai_client = openai_client.clone(); + let source_id = source_id.to_string(); + let user_id = user_id.to_string(); + let db_client = db_client.clone(); - task::spawn(async move { - create_single_entity( - &entity, - &source_id, - &user_id, - mapper, - &openai_client, - &db_client.clone(), - ) - .await - }) - }) - .collect(); - - let results = try_join_all(futures) - .await? - .into_iter() - .collect::, _>>()?; - - Ok(results) + async move { + create_single_entity( + &entity, + &source_id, + &user_id, + mapper, + &openai_client, + &db_client, + ) + .await + } + })) + .buffer_unordered(entity_concurrency.max(1)) + .try_collect() + .await } fn process_relationships( &self, source_id: &str, user_id: &str, - mapper: Arc>, + mapper: Arc, ) -> Result, AppError> { - let mapper_guard = mapper - .lock() - .map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?; self.relationships .iter() .map(|rel| { - let source_db_id = mapper_guard.get_or_parse_id(&rel.source)?; - let target_db_id = mapper_guard.get_or_parse_id(&rel.target)?; + let source_db_id = mapper.get_or_parse_id(&rel.source)?; + let target_db_id = mapper.get_or_parse_id(&rel.target)?; Ok(KnowledgeRelationship::new( source_db_id.to_string(), @@ -159,20 +133,16 @@ impl LLMEnrichmentResult { .collect() } } + async fn create_single_entity( llm_entity: &LLMKnowledgeEntity, source_id: &str, user_id: &str, - mapper: Arc>, + mapper: Arc, openai_client: &async_openai::Client, db_client: &SurrealDbClient, ) -> Result { - let assigned_id = { - let mapper = mapper - .lock() - .map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?; - mapper.get_id(&llm_entity.key)?.to_string() - }; + let assigned_id = mapper.get_id(&llm_entity.key)?.to_string(); let embedding_input = format!( "name: {}, description: {}, type: {}", diff --git a/ingestion-pipeline/src/pipeline/mod.rs b/ingestion-pipeline/src/pipeline/mod.rs new file mode 100644 index 0000000..8f721c3 --- /dev/null +++ b/ingestion-pipeline/src/pipeline/mod.rs @@ -0,0 +1,221 @@ +mod config; +mod context; +mod enrichment_result; +mod preparation; +mod services; +mod stages; +mod state; + +pub use config::{IngestionConfig, IngestionTuning}; +pub use services::{DefaultPipelineServices, PipelineServices}; + +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; + +use async_openai::Client; +use common::{ + error::AppError, + storage::{ + db::SurrealDbClient, + types::{ + ingestion_payload::IngestionPayload, + ingestion_task::{IngestionTask, TaskErrorInfo}, + }, + }, + utils::config::AppConfig, +}; +use tracing::{debug, info, warn}; + +use self::{ + context::PipelineContext, + stages::{enrich, persist, prepare_content, retrieve_related}, + state::ready, +}; + +pub struct IngestionPipeline { + db: Arc, + pipeline_config: IngestionConfig, + services: Arc, +} + +impl IngestionPipeline { + pub async fn new( + db: Arc, + openai_client: Arc>, + config: AppConfig, + ) -> Result { + let services = + DefaultPipelineServices::new(db.clone(), openai_client.clone(), config.clone()); + + Self::with_services(db, IngestionConfig::default(), Arc::new(services)) + } + + pub fn with_services( + db: Arc, + pipeline_config: IngestionConfig, + services: Arc, + ) -> Result { + Ok(Self { + db, + pipeline_config, + services, + }) + } + + #[tracing::instrument( + skip_all, + fields( + task_id = %task.id, + attempt = task.attempts, + worker_id = task.worker_id.as_deref().unwrap_or("unknown-worker"), + user_id = %task.user_id + ) + )] + pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> { + let mut processing_task = task.mark_processing(&self.db).await?; + let payload = std::mem::replace( + &mut processing_task.content, + IngestionPayload::Text { + text: String::new(), + context: String::new(), + category: String::new(), + user_id: processing_task.user_id.clone(), + }, + ); + + match self + .drive_pipeline(&processing_task, payload) + .await + .map_err(|err| { + debug!( + task_id = %processing_task.id, + attempt = processing_task.attempts, + error = %err, + "ingestion pipeline failed" + ); + err + }) { + Ok(()) => { + processing_task.mark_succeeded(&self.db).await?; + tracing::info!( + task_id = %processing_task.id, + attempt = processing_task.attempts, + "ingestion task succeeded" + ); + Ok(()) + } + Err(err) => { + let reason = err.to_string(); + let retryable = !matches!(err, AppError::Validation(_)); + let error_info = TaskErrorInfo { + code: None, + message: reason.clone(), + }; + + if retryable && processing_task.can_retry() { + let delay = self.retry_delay(processing_task.attempts); + processing_task + .mark_failed(error_info, delay, &self.db) + .await?; + warn!( + task_id = %processing_task.id, + attempt = processing_task.attempts, + retry_in_secs = delay.as_secs(), + "ingestion task failed; scheduled retry" + ); + } else { + let failed_task = processing_task + .mark_failed(error_info.clone(), Duration::from_secs(0), &self.db) + .await?; + failed_task.mark_dead_letter(error_info, &self.db).await?; + warn!( + task_id = %failed_task.id, + attempt = failed_task.attempts, + "ingestion task failed; moved to dead letter queue" + ); + } + + Err(AppError::Processing(reason)) + } + } + } + + fn retry_delay(&self, attempt: u32) -> Duration { + let tuning = &self.pipeline_config.tuning; + let capped_attempt = attempt + .saturating_sub(1) + .min(tuning.retry_backoff_cap_exponent); + let multiplier = 2_u64.pow(capped_attempt); + let delay = tuning.retry_base_delay_secs * multiplier; + + Duration::from_secs(delay.min(tuning.retry_max_delay_secs)) + } + + #[tracing::instrument( + skip_all, + fields(task_id = %task.id, attempt = task.attempts, user_id = %task.user_id) + )] + async fn drive_pipeline( + &self, + task: &IngestionTask, + payload: IngestionPayload, + ) -> Result<(), AppError> { + let mut ctx = PipelineContext::new( + task, + self.db.as_ref(), + &self.pipeline_config, + self.services.as_ref(), + ); + + let machine = ready(); + + let pipeline_started = Instant::now(); + + let stage_start = Instant::now(); + let machine = prepare_content(machine, &mut ctx, payload) + .await + .map_err(|err| ctx.abort(err))?; + let prepare_duration = stage_start.elapsed(); + + let stage_start = Instant::now(); + let machine = retrieve_related(machine, &mut ctx) + .await + .map_err(|err| ctx.abort(err))?; + let retrieve_duration = stage_start.elapsed(); + + let stage_start = Instant::now(); + let machine = enrich(machine, &mut ctx) + .await + .map_err(|err| ctx.abort(err))?; + let enrich_duration = stage_start.elapsed(); + + let stage_start = Instant::now(); + let _machine = persist(machine, &mut ctx) + .await + .map_err(|err| ctx.abort(err))?; + let persist_duration = stage_start.elapsed(); + + let total_duration = pipeline_started.elapsed(); + let prepare_ms = prepare_duration.as_millis() as u64; + let retrieve_ms = retrieve_duration.as_millis() as u64; + let enrich_ms = enrich_duration.as_millis() as u64; + let persist_ms = persist_duration.as_millis() as u64; + info!( + task_id = %ctx.task_id, + attempt = ctx.attempt, + total_ms = total_duration.as_millis() as u64, + prepare_ms, + retrieve_ms, + enrich_ms, + persist_ms, + "ingestion pipeline finished" + ); + + Ok(()) + } +} + +#[cfg(test)] +mod tests; diff --git a/ingestion-pipeline/src/pipeline/preparation.rs b/ingestion-pipeline/src/pipeline/preparation.rs new file mode 100644 index 0000000..2582bab --- /dev/null +++ b/ingestion-pipeline/src/pipeline/preparation.rs @@ -0,0 +1,74 @@ +use common::{ + error::AppError, + storage::{ + db::SurrealDbClient, + types::{ + ingestion_payload::IngestionPayload, + text_content::{TextContent, UrlInfo}, + }, + }, + utils::config::AppConfig, +}; + +use crate::utils::{ + file_text_extraction::extract_text_from_file, url_text_retrieval::extract_text_from_url, +}; + +pub(crate) async fn to_text_content( + ingestion_payload: IngestionPayload, + db: &SurrealDbClient, + config: &AppConfig, + openai_client: &async_openai::Client, +) -> Result { + match ingestion_payload { + IngestionPayload::Url { + url, + context, + category, + user_id, + } => { + let (article, file_info) = extract_text_from_url(&url, db, &user_id, config).await?; + Ok(TextContent::new( + article.text_content.into(), + Some(context), + category, + None, + Some(UrlInfo { + url, + title: article.title, + image_id: file_info.id, + }), + user_id, + )) + } + IngestionPayload::Text { + text, + context, + category, + user_id, + } => Ok(TextContent::new( + text, + Some(context), + category, + None, + None, + user_id, + )), + IngestionPayload::File { + file_info, + context, + category, + user_id, + } => { + let text = extract_text_from_file(&file_info, db, openai_client, config).await?; + Ok(TextContent::new( + text, + Some(context), + category, + Some(file_info), + None, + user_id, + )) + } + } +} diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs new file mode 100644 index 0000000..24f222c --- /dev/null +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -0,0 +1,213 @@ +use std::{ops::Range, sync::Arc}; + +use async_openai::types::{ + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, + CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat, + ResponseFormatJsonSchema, +}; +use async_trait::async_trait; +use common::{ + error::AppError, + storage::{ + db::SurrealDbClient, + types::{ + ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity, + knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings, + text_chunk::TextChunk, text_content::TextContent, + }, + }, + utils::{config::AppConfig, embedding::generate_embedding}, +}; +use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity}; +use text_splitter::TextSplitter; + +use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content}; +use crate::utils::llm_instructions::{ + get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE, +}; + +#[async_trait] +pub trait PipelineServices: Send + Sync { + async fn prepare_text_content( + &self, + payload: IngestionPayload, + ) -> Result; + + async fn retrieve_similar_entities( + &self, + content: &TextContent, + ) -> Result, AppError>; + + async fn run_enrichment( + &self, + content: &TextContent, + similar_entities: &[RetrievedEntity], + ) -> Result; + + async fn convert_analysis( + &self, + content: &TextContent, + analysis: &LLMEnrichmentResult, + entity_concurrency: usize, + ) -> Result<(Vec, Vec), AppError>; + + async fn prepare_chunks( + &self, + content: &TextContent, + range: Range, + ) -> Result, AppError>; +} + +pub struct DefaultPipelineServices { + db: Arc, + openai_client: Arc>, + config: AppConfig, +} + +impl DefaultPipelineServices { + pub fn new( + db: Arc, + openai_client: Arc>, + config: AppConfig, + ) -> Self { + Self { + db, + openai_client, + config, + } + } + + async fn prepare_llm_request( + &self, + category: &str, + context: Option<&str>, + text: &str, + similar_entities: &[RetrievedEntity], + ) -> Result { + let settings = SystemSettings::get_current(&self.db).await?; + + let entities_json = retrieved_entities_to_json(similar_entities); + + let user_message = format!( + "Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}" + ); + + let response_format = ResponseFormat::JsonSchema { + json_schema: ResponseFormatJsonSchema { + description: Some("Structured analysis of the submitted content".into()), + name: "content_analysis".into(), + schema: Some(get_ingress_analysis_schema()), + strict: Some(true), + }, + }; + + let request = CreateChatCompletionRequestArgs::default() + .model(&settings.processing_model) + .messages([ + ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(), + ChatCompletionRequestUserMessage::from(user_message).into(), + ]) + .response_format(response_format) + .build()?; + + Ok(request) + } + + async fn perform_analysis( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + let response = self.openai_client.chat().create(request).await?; + + let content = response + .choices + .first() + .and_then(|choice| choice.message.content.as_ref()) + .ok_or(AppError::LLMParsing( + "No content found in LLM response".into(), + ))?; + + serde_json::from_str::(content).map_err(|e| { + AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")) + }) + } +} + +#[async_trait] +impl PipelineServices for DefaultPipelineServices { + async fn prepare_text_content( + &self, + payload: IngestionPayload, + ) -> Result { + to_text_content(payload, &self.db, &self.config, &self.openai_client).await + } + + async fn retrieve_similar_entities( + &self, + content: &TextContent, + ) -> Result, AppError> { + let input_text = format!( + "content: {}, category: {}, user_context: {:?}", + content.text, content.category, content.context + ); + + retrieve_entities(&self.db, &self.openai_client, &input_text, &content.user_id).await + } + + async fn run_enrichment( + &self, + content: &TextContent, + similar_entities: &[RetrievedEntity], + ) -> Result { + let request = self + .prepare_llm_request( + &content.category, + content.context.as_deref(), + &content.text, + similar_entities, + ) + .await?; + self.perform_analysis(request).await + } + + async fn convert_analysis( + &self, + content: &TextContent, + analysis: &LLMEnrichmentResult, + entity_concurrency: usize, + ) -> Result<(Vec, Vec), AppError> { + analysis + .to_database_entities( + &content.id, + &content.user_id, + &self.openai_client, + &self.db, + entity_concurrency, + ) + .await + } + + async fn prepare_chunks( + &self, + content: &TextContent, + range: Range, + ) -> Result, AppError> { + let splitter = TextSplitter::new(range.clone()); + let chunk_texts: Vec = splitter + .chunks(&content.text) + .map(|chunk| chunk.to_string()) + .collect(); + + let mut chunks = Vec::with_capacity(chunk_texts.len()); + for chunk in chunk_texts { + let embedding = generate_embedding(&self.openai_client, &chunk, &self.db).await?; + chunks.push(TextChunk::new( + content.id.clone(), + chunk, + embedding, + content.user_id.clone(), + )); + } + Ok(chunks) + } +} diff --git a/ingestion-pipeline/src/pipeline/stages/mod.rs b/ingestion-pipeline/src/pipeline/stages/mod.rs new file mode 100644 index 0000000..8143dc8 --- /dev/null +++ b/ingestion-pipeline/src/pipeline/stages/mod.rs @@ -0,0 +1,338 @@ +use std::sync::Arc; + +use common::{ + error::AppError, + storage::{ + db::SurrealDbClient, + types::{ + ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity, + knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, + text_content::TextContent, + }, + }, +}; +use state_machines::core::GuardError; +use tokio::time::{sleep, Duration}; +use tracing::{debug, instrument, warn}; + +use super::{ + context::PipelineContext, + services::PipelineServices, + state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved}, +}; + +#[instrument( + level = "trace", + skip_all, + fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id) +)] +pub async fn prepare_content( + machine: IngestionMachine<(), Ready>, + ctx: &mut PipelineContext<'_>, + payload: IngestionPayload, +) -> Result, AppError> { + let text_content = ctx.services.prepare_text_content(payload).await?; + + let text_len = text_content.text.chars().count(); + let preview: String = text_content.text.chars().take(120).collect(); + let preview_clean = preview.replace('\n', " "); + let preview_len = preview_clean.chars().count(); + let truncated = text_len > preview_len; + let context_len = text_content + .context + .as_ref() + .map(|c| c.chars().count()) + .unwrap_or(0); + + tracing::info!( + task_id = %ctx.task_id, + attempt = ctx.attempt, + user_id = %text_content.user_id, + category = %text_content.category, + text_chars = text_len, + context_chars = context_len, + attachments = text_content.file_info.is_some(), + "ingestion task input ready" + ); + debug!( + task_id = %ctx.task_id, + attempt = ctx.attempt, + preview = %preview_clean, + preview_truncated = truncated, + "ingestion task input preview" + ); + + ctx.text_content = Some(text_content); + + machine + .prepare() + .map_err(|(_, guard)| map_guard_error("prepare", guard)) +} + +#[instrument( + level = "trace", + skip_all, + fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id) +)] +pub async fn retrieve_related( + machine: IngestionMachine<(), ContentPrepared>, + ctx: &mut PipelineContext<'_>, +) -> Result, AppError> { + let content = ctx.text_content()?; + let similar = ctx.services.retrieve_similar_entities(content).await?; + + debug!( + task_id = %ctx.task_id, + attempt = ctx.attempt, + similar_count = similar.len(), + "ingestion retrieved similar entities" + ); + + ctx.similar_entities = similar; + + machine + .retrieve() + .map_err(|(_, guard)| map_guard_error("retrieve", guard)) +} + +#[instrument( + level = "trace", + skip_all, + fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id) +)] +pub async fn enrich( + machine: IngestionMachine<(), Retrieved>, + ctx: &mut PipelineContext<'_>, +) -> Result, AppError> { + let content = ctx.text_content()?; + let analysis = ctx + .services + .run_enrichment(content, &ctx.similar_entities) + .await?; + + debug!( + task_id = %ctx.task_id, + attempt = ctx.attempt, + entity_suggestions = analysis.knowledge_entities.len(), + relationship_suggestions = analysis.relationships.len(), + "ingestion enrichment completed" + ); + + ctx.analysis = Some(analysis); + + machine + .enrich() + .map_err(|(_, guard)| map_guard_error("enrich", guard)) +} + +#[instrument( + level = "trace", + skip_all, + fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id) +)] +pub async fn persist( + machine: IngestionMachine<(), Enriched>, + ctx: &mut PipelineContext<'_>, +) -> Result, AppError> { + let content = ctx.take_text_content()?; + let analysis = ctx.take_analysis()?; + + let (entities, relationships) = ctx + .services + .convert_analysis( + &content, + &analysis, + ctx.pipeline_config.tuning.entity_embedding_concurrency, + ) + .await?; + + let entity_count = entities.len(); + let relationship_count = relationships.len(); + + let chunk_range = + ctx.pipeline_config.tuning.chunk_min_chars..ctx.pipeline_config.tuning.chunk_max_chars; + + let ((), chunk_count) = tokio::try_join!( + store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships), + store_vector_chunks( + ctx.db, + ctx.services, + ctx.task_id.as_str(), + &content, + chunk_range, + &ctx.pipeline_config.tuning + ) + )?; + + ctx.db.store_item(content).await?; + ctx.db.rebuild_indexes().await?; + + debug!( + task_id = %ctx.task_id, + attempt = ctx.attempt, + entity_count, + relationship_count, + chunk_count, + "ingestion persistence flushed to database" + ); + + machine + .persist() + .map_err(|(_, guard)| map_guard_error("persist", guard)) +} + +fn map_guard_error(event: &str, guard: GuardError) -> AppError { + AppError::InternalError(format!( + "invalid ingestion pipeline transition during {event}: {guard:?}" + )) +} + +async fn store_graph_entities( + db: &SurrealDbClient, + tuning: &super::config::IngestionTuning, + entities: Vec, + relationships: Vec, +) -> Result<(), AppError> { + const STORE_GRAPH_MUTATION: &str = r" + BEGIN TRANSACTION; + LET $entities = $entities; + LET $relationships = $relationships; + + FOR $entity IN $entities { + CREATE type::thing('knowledge_entity', $entity.id) CONTENT $entity; + }; + + FOR $relationship IN $relationships { + LET $in_node = type::thing('knowledge_entity', $relationship.in); + LET $out_node = type::thing('knowledge_entity', $relationship.out); + RELATE $in_node->relates_to->$out_node CONTENT { + id: type::thing('relates_to', $relationship.id), + metadata: $relationship.metadata + }; + }; + + COMMIT TRANSACTION; + "; + + let entities = Arc::new(entities); + let relationships = Arc::new(relationships); + + let mut backoff_ms = tuning.graph_initial_backoff_ms; + + for attempt in 0..tuning.graph_store_attempts { + let result = db + .client + .query(STORE_GRAPH_MUTATION) + .bind(("entities", entities.clone())) + .bind(("relationships", relationships.clone())) + .await; + + match result { + Ok(_) => return Ok(()), + Err(err) => { + if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts { + warn!( + attempt = attempt + 1, + "Transient SurrealDB conflict while storing graph data; retrying" + ); + sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms); + continue; + } + + return Err(AppError::from(err)); + } + } + } + + Err(AppError::InternalError( + "Failed to store graph entities after retries".to_string(), + )) +} + +async fn store_vector_chunks( + db: &SurrealDbClient, + services: &dyn PipelineServices, + task_id: &str, + content: &TextContent, + chunk_range: std::ops::Range, + tuning: &super::config::IngestionTuning, +) -> Result { + let prepared_chunks = services.prepare_chunks(content, chunk_range).await?; + let chunk_count = prepared_chunks.len(); + + let batch_size = tuning.chunk_insert_concurrency.max(1); + for chunk in &prepared_chunks { + debug!( + task_id = %task_id, + chunk_id = %chunk.id, + chunk_len = chunk.chunk.chars().count(), + "chunk persisted" + ); + } + + for batch in prepared_chunks.chunks(batch_size) { + store_chunk_batch(db, batch, tuning).await?; + } + + Ok(chunk_count) +} + +fn is_retryable_conflict(error: &surrealdb::Error) -> bool { + error + .to_string() + .contains("Failed to commit transaction due to a read or write conflict") +} + +async fn store_chunk_batch( + db: &SurrealDbClient, + batch: &[TextChunk], + tuning: &super::config::IngestionTuning, +) -> Result<(), AppError> { + if batch.is_empty() { + return Ok(()); + } + + const STORE_CHUNKS_MUTATION: &str = r" + BEGIN TRANSACTION; + LET $chunks = $chunks; + + FOR $chunk IN $chunks { + CREATE type::thing('text_chunk', $chunk.id) CONTENT $chunk; + }; + + COMMIT TRANSACTION; + "; + + let chunks = Arc::new(batch.to_vec()); + let mut backoff_ms = tuning.graph_initial_backoff_ms; + + for attempt in 0..tuning.graph_store_attempts { + let result = db + .client + .query(STORE_CHUNKS_MUTATION) + .bind(("chunks", chunks.clone())) + .await; + + match result { + Ok(_) => return Ok(()), + Err(err) => { + if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts { + warn!( + attempt = attempt + 1, + "Transient SurrealDB conflict while storing chunks; retrying" + ); + sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms); + continue; + } + + return Err(AppError::from(err)); + } + } + } + + Err(AppError::InternalError( + "Failed to store text chunks after retries".to_string(), + )) +} diff --git a/ingestion-pipeline/src/pipeline/state.rs b/ingestion-pipeline/src/pipeline/state.rs new file mode 100644 index 0000000..c8ea04b --- /dev/null +++ b/ingestion-pipeline/src/pipeline/state.rs @@ -0,0 +1,25 @@ +use state_machines::state_machine; + +state_machine! { + name: IngestionMachine, + state: IngestionState, + initial: Ready, + states: [Ready, ContentPrepared, Retrieved, Enriched, Persisted, Failed], + events { + prepare { transition: { from: Ready, to: ContentPrepared } } + retrieve { transition: { from: ContentPrepared, to: Retrieved } } + enrich { transition: { from: Retrieved, to: Enriched } } + persist { transition: { from: Enriched, to: Persisted } } + abort { + transition: { from: Ready, to: Failed } + transition: { from: ContentPrepared, to: Failed } + transition: { from: Retrieved, to: Failed } + transition: { from: Enriched, to: Failed } + transition: { from: Persisted, to: Failed } + } + } +} + +pub fn ready() -> IngestionMachine<(), Ready> { + IngestionMachine::new(()) +} diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs new file mode 100644 index 0000000..41b2c67 --- /dev/null +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -0,0 +1,440 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use chrono::{Duration as ChronoDuration, Utc}; +use common::{ + error::AppError, + storage::{ + db::SurrealDbClient, + types::{ + ingestion_payload::IngestionPayload, + ingestion_task::{IngestionTask, TaskState}, + knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, + knowledge_relationship::KnowledgeRelationship, + text_chunk::TextChunk, + text_content::TextContent, + }, + }, +}; +use composite_retrieval::{RetrievedChunk, RetrievedEntity}; +use tokio::sync::Mutex; +use uuid::Uuid; + +use super::{ + config::{IngestionConfig, IngestionTuning}, + enrichment_result::LLMEnrichmentResult, + services::PipelineServices, + IngestionPipeline, +}; + +struct MockServices { + text_content: TextContent, + similar_entities: Vec, + analysis: LLMEnrichmentResult, + chunk_embedding: Vec, + graph_entities: Vec, + graph_relationships: Vec, + calls: Mutex>, +} + +impl MockServices { + fn new(user_id: &str) -> Self { + const TEST_EMBEDDING_DIM: usize = 1536; + let text_content = TextContent::new( + "Example document for ingestion pipeline.".into(), + Some("light context".into()), + "notes".into(), + None, + None, + user_id.into(), + ); + let retrieved_entity = KnowledgeEntity::new( + text_content.id.clone(), + "Existing Entity".into(), + "Previously known context".into(), + KnowledgeEntityType::Document, + None, + vec![0.1; TEST_EMBEDDING_DIM], + user_id.into(), + ); + + let retrieved_chunk = TextChunk::new( + retrieved_entity.source_id.clone(), + "existing chunk".into(), + vec![0.1; TEST_EMBEDDING_DIM], + user_id.into(), + ); + + let analysis = LLMEnrichmentResult { + knowledge_entities: Vec::new(), + relationships: Vec::new(), + }; + + let graph_entity = KnowledgeEntity::new( + text_content.id.clone(), + "Generated Entity".into(), + "Entity from enrichment".into(), + KnowledgeEntityType::Idea, + None, + vec![0.2; TEST_EMBEDDING_DIM], + user_id.into(), + ); + let graph_relationship = KnowledgeRelationship::new( + graph_entity.id.clone(), + graph_entity.id.clone(), + user_id.into(), + text_content.id.clone(), + "related_to".into(), + ); + + Self { + text_content, + similar_entities: vec![RetrievedEntity { + entity: retrieved_entity, + score: 0.8, + chunks: vec![RetrievedChunk { + chunk: retrieved_chunk, + score: 0.7, + }], + }], + analysis, + chunk_embedding: vec![0.3; TEST_EMBEDDING_DIM], + graph_entities: vec![graph_entity], + graph_relationships: vec![graph_relationship], + calls: Mutex::new(Vec::new()), + } + } + + async fn record(&self, stage: &'static str) { + self.calls.lock().await.push(stage); + } +} + +#[async_trait] +impl PipelineServices for MockServices { + async fn prepare_text_content( + &self, + _payload: IngestionPayload, + ) -> Result { + self.record("prepare").await; + Ok(self.text_content.clone()) + } + + async fn retrieve_similar_entities( + &self, + _content: &TextContent, + ) -> Result, AppError> { + self.record("retrieve").await; + Ok(self.similar_entities.clone()) + } + + async fn run_enrichment( + &self, + _content: &TextContent, + _similar_entities: &[RetrievedEntity], + ) -> Result { + self.record("enrich").await; + Ok(self.analysis.clone()) + } + + async fn convert_analysis( + &self, + _content: &TextContent, + _analysis: &LLMEnrichmentResult, + _entity_concurrency: usize, + ) -> Result<(Vec, Vec), AppError> { + self.record("convert").await; + Ok(( + self.graph_entities.clone(), + self.graph_relationships.clone(), + )) + } + + async fn prepare_chunks( + &self, + content: &TextContent, + _range: std::ops::Range, + ) -> Result, AppError> { + self.record("chunk").await; + Ok(vec![TextChunk::new( + content.id.clone(), + "chunk from mock services".into(), + self.chunk_embedding.clone(), + content.user_id.clone(), + )]) + } +} + +struct FailingServices { + inner: MockServices, +} + +struct ValidationServices; + +#[async_trait] +impl PipelineServices for FailingServices { + async fn prepare_text_content( + &self, + payload: IngestionPayload, + ) -> Result { + self.inner.prepare_text_content(payload).await + } + + async fn retrieve_similar_entities( + &self, + content: &TextContent, + ) -> Result, AppError> { + self.inner.retrieve_similar_entities(content).await + } + + async fn run_enrichment( + &self, + _content: &TextContent, + _similar_entities: &[RetrievedEntity], + ) -> Result { + Err(AppError::Processing("mock enrichment failure".to_string())) + } + + async fn convert_analysis( + &self, + content: &TextContent, + analysis: &LLMEnrichmentResult, + entity_concurrency: usize, + ) -> Result<(Vec, Vec), AppError> { + self.inner + .convert_analysis(content, analysis, entity_concurrency) + .await + } + + async fn prepare_chunks( + &self, + content: &TextContent, + range: std::ops::Range, + ) -> Result, AppError> { + self.inner.prepare_chunks(content, range).await + } +} + +#[async_trait] +impl PipelineServices for ValidationServices { + async fn prepare_text_content( + &self, + _payload: IngestionPayload, + ) -> Result { + Err(AppError::Validation("unsupported".to_string())) + } + + async fn retrieve_similar_entities( + &self, + _content: &TextContent, + ) -> Result, AppError> { + unreachable!("retrieve_similar_entities should not be called after validation failure") + } + + async fn run_enrichment( + &self, + _content: &TextContent, + _similar_entities: &[RetrievedEntity], + ) -> Result { + unreachable!("run_enrichment should not be called after validation failure") + } + + async fn convert_analysis( + &self, + _content: &TextContent, + _analysis: &LLMEnrichmentResult, + _entity_concurrency: usize, + ) -> Result<(Vec, Vec), AppError> { + unreachable!("convert_analysis should not be called after validation failure") + } + + async fn prepare_chunks( + &self, + _content: &TextContent, + _range: std::ops::Range, + ) -> Result, AppError> { + unreachable!("prepare_chunks should not be called after validation failure") + } +} + +async fn setup_db() -> SurrealDbClient { + let namespace = "pipeline_test"; + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, &database) + .await + .expect("Failed to create in-memory SurrealDB"); + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + db +} + +fn pipeline_config() -> IngestionConfig { + IngestionConfig { + tuning: IngestionTuning { + chunk_min_chars: 4, + chunk_max_chars: 64, + chunk_insert_concurrency: 4, + entity_embedding_concurrency: 2, + ..IngestionTuning::default() + }, + } +} + +async fn reserve_task( + db: &SurrealDbClient, + worker_id: &str, + payload: IngestionPayload, + user_id: &str, +) -> IngestionTask { + let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db) + .await + .expect("task created"); + let lease = task.lease_duration(); + IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease) + .await + .expect("claim succeeds") + .expect("task claimed") +} + +#[tokio::test] +async fn ingestion_pipeline_happy_path_persists_entities() { + let db = setup_db().await; + let worker_id = "worker-happy"; + let user_id = "user-123"; + let services = Arc::new(MockServices::new(user_id)); + let pipeline = + IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services.clone()) + .expect("pipeline"); + + let task = reserve_task( + &db, + worker_id, + IngestionPayload::Text { + text: "Example payload".into(), + context: "Context".into(), + category: "notes".into(), + user_id: user_id.into(), + }, + user_id, + ) + .await; + + pipeline + .process_task(task.clone()) + .await + .expect("pipeline succeeds"); + + let stored_task: IngestionTask = db + .get_item(&task.id) + .await + .expect("retrieve task") + .expect("task present"); + assert_eq!(stored_task.state, TaskState::Succeeded); + + let stored_entities: Vec = db + .get_all_stored_items::() + .await + .expect("entities stored"); + assert!(!stored_entities.is_empty(), "entities should be stored"); + + let stored_chunks: Vec = db + .get_all_stored_items::() + .await + .expect("chunks stored"); + assert!( + !stored_chunks.is_empty(), + "chunks should be stored for ingestion text" + ); + + let call_log = services.calls.lock().await.clone(); + assert!( + call_log.len() >= 5, + "expected at least one chunk embedding call" + ); + assert_eq!( + &call_log[0..4], + ["prepare", "retrieve", "enrich", "convert"] + ); + assert!(call_log[4..].iter().all(|entry| *entry == "chunk")); +} + +#[tokio::test] +async fn ingestion_pipeline_failure_marks_retry() { + let db = setup_db().await; + let worker_id = "worker-fail"; + let user_id = "user-456"; + let services = Arc::new(FailingServices { + inner: MockServices::new(user_id), + }); + let pipeline = + IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services) + .expect("pipeline"); + + let task = reserve_task( + &db, + worker_id, + IngestionPayload::Text { + text: "Example failure payload".into(), + context: "Context".into(), + category: "notes".into(), + user_id: user_id.into(), + }, + user_id, + ) + .await; + + let result = pipeline.process_task(task.clone()).await; + assert!( + result.is_err(), + "failure services should bubble error from pipeline" + ); + + let stored_task: IngestionTask = db + .get_item(&task.id) + .await + .expect("retrieve task") + .expect("task present"); + assert_eq!(stored_task.state, TaskState::Failed); + assert!( + stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5), + "failed task should schedule retry in the future" + ); +} + +#[tokio::test] +async fn ingestion_pipeline_validation_failure_dead_letters_task() { + let db = setup_db().await; + let worker_id = "worker-validation"; + let user_id = "user-789"; + let services = Arc::new(ValidationServices); + let pipeline = + IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services) + .expect("pipeline"); + + let task = reserve_task( + &db, + worker_id, + IngestionPayload::Text { + text: "irrelevant".into(), + context: "".into(), + category: "notes".into(), + user_id: user_id.into(), + }, + user_id, + ) + .await; + + let result = pipeline.process_task(task.clone()).await; + assert!( + result.is_err(), + "validation failure should surface as error" + ); + + let stored_task: IngestionTask = db + .get_item(&task.id) + .await + .expect("retrieve task") + .expect("task present"); + assert_eq!(stored_task.state, TaskState::DeadLetter); +} diff --git a/ingestion-pipeline/src/types/mod.rs b/ingestion-pipeline/src/types/mod.rs deleted file mode 100644 index 6df7322..0000000 --- a/ingestion-pipeline/src/types/mod.rs +++ /dev/null @@ -1,259 +0,0 @@ -pub mod llm_enrichment_result; - -use std::io::Write; -use std::time::Instant; - -use axum::http::HeaderMap; -use axum_typed_multipart::{FieldData, FieldMetadata}; -use chrono::Utc; -use common::storage::db::SurrealDbClient; -use common::{ - error::AppError, - storage::{ - store, - types::{ - file_info::FileInfo, - ingestion_payload::IngestionPayload, - text_content::{TextContent, UrlInfo}, - }, - }, - utils::config::AppConfig, -}; -use dom_smoothie::{Article, Readability, TextMode}; -use headless_chrome::Browser; -use std::io::{Seek, SeekFrom}; -use tempfile::NamedTempFile; -use tracing::{error, info}; - -use crate::utils::{ - audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image, - pdf_ingestion::extract_pdf_content, -}; - -pub async fn to_text_content( - ingestion_payload: IngestionPayload, - db: &SurrealDbClient, - config: &AppConfig, - openai_client: &async_openai::Client, -) -> Result { - match ingestion_payload { - IngestionPayload::Url { - url, - context, - category, - user_id, - } => { - let (article, file_info) = fetch_article_from_url(&url, db, &user_id, config).await?; - Ok(TextContent::new( - article.text_content.into(), - Some(context), - category, - None, - Some(UrlInfo { - url, - title: article.title, - image_id: file_info.id, - }), - user_id, - )) - } - IngestionPayload::Text { - text, - context, - category, - user_id, - } => Ok(TextContent::new( - text, - Some(context), - category, - None, - None, - user_id, - )), - IngestionPayload::File { - file_info, - context, - category, - user_id, - } => { - let text = extract_text_from_file(&file_info, db, openai_client, config).await?; - Ok(TextContent::new( - text, - Some(context), - category, - Some(file_info), - None, - user_id, - )) - } - } -} - -/// Fetches web content from a URL, extracts the main article text as Markdown, -/// captures a screenshot, and stores the screenshot returning [`FileInfo`]. -/// -/// This function handles browser automation, content extraction via Readability, -/// screenshot capture, temporary file handling, and persisting the screenshot -/// details (including deduplication based on content hash via [`FileInfo::new`]). -/// -/// # Arguments -/// -/// * `url` - The URL of the web page to fetch. -/// * `db` - A reference to the database client (`SurrealDbClient`). -/// * `user_id` - The ID of the user performing the action, used for associating the file. -/// -/// # Returns -/// -/// A `Result` containing: -/// * Ok: A tuple `(Article, FileInfo)` where `Article` contains the parsed markdown -/// content and metadata, and `FileInfo` contains the details of the stored screenshot. -/// * Err: An `AppError` if any step fails (navigation, screenshot, file handling, DB operation). -async fn fetch_article_from_url( - url: &str, - db: &SurrealDbClient, - user_id: &str, - config: &AppConfig, -) -> Result<(Article, FileInfo), AppError> { - info!("Fetching URL: {}", url); - // Instantiate timer - let now = Instant::now(); - // Setup browser, navigate and wait - let browser = { - #[cfg(feature = "docker")] - { - // Use this when compiling for docker - let options = headless_chrome::LaunchOptionsBuilder::default() - .sandbox(false) - .build() - .map_err(|e| AppError::InternalError(e.to_string()))?; - Browser::new(options)? - } - #[cfg(not(feature = "docker"))] - { - // Use this otherwise - Browser::default()? - } - }; - let tab = browser.new_tab()?; - let page = tab.navigate_to(url)?; - let loaded_page = page.wait_until_navigated()?; - // Get content - let raw_content = loaded_page.get_content()?; - // Get screenshot - let screenshot = loaded_page.capture_screenshot( - headless_chrome::protocol::cdp::Page::CaptureScreenshotFormatOption::Jpeg, - None, - None, - true, - )?; - - // Create temp file - let mut tmp_file = NamedTempFile::new()?; - let temp_path_str = format!("{:?}", tmp_file.path()); - - // Write screenshot TO the temp file - tmp_file.write_all(&screenshot)?; - - // Ensure the OS buffer is written to the file system _before_ we proceed. - tmp_file.as_file().sync_all()?; - - // Ensure the file handle's read cursor is at the beginning before hashing occurs. - if let Err(e) = tmp_file.seek(SeekFrom::Start(0)) { - error!("URL: {}. Failed to seek temp file {} to start: {:?}. Proceeding, but hashing might fail.", url, temp_path_str, e); - } - - // Prepare file metadata - let parsed_url = - url::Url::parse(url).map_err(|_| AppError::Processing("Invalid URL".to_string()))?; - let domain = parsed_url - .host_str() - .unwrap_or("unknown") - .replace(|c: char| !c.is_alphanumeric(), "_"); - let timestamp = Utc::now().format("%Y%m%d%H%M%S"); - let file_name = format!("{}_{}_{}.jpg", domain, "screenshot", timestamp); - - // Construct FieldData and FieldMetadata - let metadata = FieldMetadata { - file_name: Some(file_name), - content_type: Some("image/jpeg".to_string()), - name: None, - headers: HeaderMap::new(), - }; - let field_data = FieldData { - contents: tmp_file, - metadata, - }; - - // Store screenshot - let file_info = FileInfo::new(field_data, db, user_id, config).await?; - - // Parse content... - let config = dom_smoothie::Config { - text_mode: TextMode::Markdown, - ..Default::default() - }; - let mut readability = Readability::new(raw_content, None, Some(config))?; - let article: Article = readability.parse()?; - let end = now.elapsed(); - info!( - "URL: {}. Total time: {:?}. Final File ID: {}", - url, end, file_info.id - ); - - Ok((article, file_info)) -} - -/// Extracts text from a stored file by MIME type. -async fn extract_text_from_file( - file_info: &FileInfo, - db_client: &SurrealDbClient, - openai_client: &async_openai::Client, - config: &AppConfig, -) -> Result { - let base_path = store::resolve_base_dir(config); - let absolute_path = base_path.join(&file_info.path); - - match file_info.mime_type.as_str() { - "text/plain" | "text/markdown" | "application/octet-stream" | "text/x-rust" => { - let content = tokio::fs::read_to_string(&absolute_path).await?; - Ok(content) - } - "application/pdf" => { - extract_pdf_content( - &absolute_path, - db_client, - openai_client, - &config.pdf_ingest_mode, - ) - .await - } - "image/png" | "image/jpeg" => { - let path_str = absolute_path - .to_str() - .ok_or_else(|| { - AppError::Processing(format!( - "Encountered a non-UTF8 path while reading image {}", - file_info.id - )) - })? - .to_string(); - let content = extract_text_from_image(&path_str, db_client, openai_client).await?; - Ok(content) - } - "audio/mpeg" | "audio/mp3" | "audio/wav" | "audio/x-wav" | "audio/webm" | "audio/mp4" - | "audio/ogg" | "audio/flac" => { - let path_str = absolute_path - .to_str() - .ok_or_else(|| { - AppError::Processing(format!( - "Encountered a non-UTF8 path while reading audio {}", - file_info.id - )) - })? - .to_string(); - transcribe_audio_file(&path_str, db_client, openai_client).await - } - // Handle other MIME types as needed - _ => Err(AppError::NotFound(file_info.mime_type.clone())), - } -} diff --git a/ingestion-pipeline/src/utils/file_text_extraction.rs b/ingestion-pipeline/src/utils/file_text_extraction.rs new file mode 100644 index 0000000..21d6479 --- /dev/null +++ b/ingestion-pipeline/src/utils/file_text_extraction.rs @@ -0,0 +1,63 @@ +use common::{ + error::AppError, + storage::{db::SurrealDbClient, store, types::file_info::FileInfo}, + utils::config::AppConfig, +}; + +use super::{ + audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image, + pdf_ingestion::extract_pdf_content, +}; + +pub async fn extract_text_from_file( + file_info: &FileInfo, + db_client: &SurrealDbClient, + openai_client: &async_openai::Client, + config: &AppConfig, +) -> Result { + let base_path = store::resolve_base_dir(config); + let absolute_path = base_path.join(&file_info.path); + + match file_info.mime_type.as_str() { + "text/plain" | "text/markdown" | "application/octet-stream" | "text/x-rust" => { + let content = tokio::fs::read_to_string(&absolute_path).await?; + Ok(content) + } + "application/pdf" => { + extract_pdf_content( + &absolute_path, + db_client, + openai_client, + &config.pdf_ingest_mode, + ) + .await + } + "image/png" | "image/jpeg" => { + let path_str = absolute_path + .to_str() + .ok_or_else(|| { + AppError::Processing(format!( + "Encountered a non-UTF8 path while reading image {}", + file_info.id + )) + })? + .to_string(); + let content = extract_text_from_image(&path_str, db_client, openai_client).await?; + Ok(content) + } + "audio/mpeg" | "audio/mp3" | "audio/wav" | "audio/x-wav" | "audio/webm" | "audio/mp4" + | "audio/ogg" | "audio/flac" => { + let path_str = absolute_path + .to_str() + .ok_or_else(|| { + AppError::Processing(format!( + "Encountered a non-UTF8 path while reading audio {}", + file_info.id + )) + })? + .to_string(); + transcribe_audio_file(&path_str, db_client, openai_client).await + } + _ => Err(AppError::NotFound(file_info.mime_type.clone())), + } +} diff --git a/ingestion-pipeline/src/utils/graph_mapper.rs b/ingestion-pipeline/src/utils/graph_mapper.rs new file mode 100644 index 0000000..5b6a429 --- /dev/null +++ b/ingestion-pipeline/src/utils/graph_mapper.rs @@ -0,0 +1,53 @@ +use common::error::AppError; +use std::collections::HashMap; +use uuid::Uuid; + +/// Intermediate struct to hold mapping between LLM keys and generated IDs. +#[derive(Clone)] +pub struct GraphMapper { + pub key_to_id: HashMap, +} + +impl Default for GraphMapper { + fn default() -> Self { + Self::new() + } +} + +impl GraphMapper { + pub fn new() -> Self { + Self { + key_to_id: HashMap::new(), + } + } + /// Tries to get an ID by first parsing the key as a UUID, + /// and if that fails, looking it up in the internal map. + pub fn get_or_parse_id(&self, key: &str) -> Result { + // First, try to parse the key as a UUID. + if let Ok(parsed_uuid) = Uuid::parse_str(key) { + return Ok(parsed_uuid); + } + + // If parsing fails, look it up in the map. + self.key_to_id.get(key).copied().ok_or_else(|| { + AppError::GraphMapper(format!( + "Key '{key}' is not a valid UUID and was not found in the map." + )) + }) + } + + /// Assigns a new UUID for a given key. (No changes needed here) + pub fn assign_id(&mut self, key: &str) -> Uuid { + let id = Uuid::new_v4(); + self.key_to_id.insert(key.to_string(), id); + id + } + + /// Retrieves the UUID for a given key, returning a Result for consistency. + pub fn get_id(&self, key: &str) -> Result { + self.key_to_id + .get(key) + .copied() + .ok_or_else(|| AppError::GraphMapper(format!("Key '{key}' not found in map."))) + } +} diff --git a/ingestion-pipeline/src/utils/mod.rs b/ingestion-pipeline/src/utils/mod.rs index 0a0b5c2..d334d12 100644 --- a/ingestion-pipeline/src/utils/mod.rs +++ b/ingestion-pipeline/src/utils/mod.rs @@ -1,58 +1,7 @@ pub mod audio_transcription; +pub mod file_text_extraction; +pub mod graph_mapper; pub mod image_parsing; pub mod llm_instructions; pub mod pdf_ingestion; - -use common::error::AppError; -use std::collections::HashMap; -use uuid::Uuid; - -/// Intermediate struct to hold mapping between LLM keys and generated IDs. -#[derive(Clone)] -pub struct GraphMapper { - pub key_to_id: HashMap, -} - -impl Default for GraphMapper { - fn default() -> Self { - Self::new() - } -} - -impl GraphMapper { - pub fn new() -> Self { - Self { - key_to_id: HashMap::new(), - } - } - /// Tries to get an ID by first parsing the key as a UUID, - /// and if that fails, looking it up in the internal map. - pub fn get_or_parse_id(&self, key: &str) -> Result { - // First, try to parse the key as a UUID. - if let Ok(parsed_uuid) = Uuid::parse_str(key) { - return Ok(parsed_uuid); - } - - // If parsing fails, look it up in the map. - self.key_to_id.get(key).copied().ok_or_else(|| { - AppError::GraphMapper(format!( - "Key '{key}' is not a valid UUID and was not found in the map." - )) - }) - } - - /// Assigns a new UUID for a given key. (No changes needed here) - pub fn assign_id(&mut self, key: &str) -> Uuid { - let id = Uuid::new_v4(); - self.key_to_id.insert(key.to_string(), id); - id - } - - /// Retrieves the UUID for a given key, returning a Result for consistency. - pub fn get_id(&self, key: &str) -> Result { - self.key_to_id - .get(key) - .copied() - .ok_or_else(|| AppError::GraphMapper(format!("Key '{key}' not found in map."))) - } -} +pub mod url_text_retrieval; diff --git a/ingestion-pipeline/src/utils/url_text_retrieval.rs b/ingestion-pipeline/src/utils/url_text_retrieval.rs new file mode 100644 index 0000000..d8cca7e --- /dev/null +++ b/ingestion-pipeline/src/utils/url_text_retrieval.rs @@ -0,0 +1,174 @@ +use axum::http::HeaderMap; +use axum_typed_multipart::{FieldData, FieldMetadata}; +use chrono::Utc; +use common::{ + error::AppError, + storage::{db::SurrealDbClient, types::file_info::FileInfo}, + utils::config::AppConfig, +}; +use dom_smoothie::{Article, Readability, TextMode}; +use headless_chrome::Browser; +use std::{ + io::{Seek, SeekFrom, Write}, + net::IpAddr, + time::Instant, +}; +use tempfile::NamedTempFile; +use tracing::{error, info, warn}; +pub async fn extract_text_from_url( + url: &str, + db: &SurrealDbClient, + user_id: &str, + config: &AppConfig, +) -> Result<(Article, FileInfo), AppError> { + info!("Fetching URL: {}", url); + let now = Instant::now(); + + let browser = { + #[cfg(feature = "docker")] + { + let options = headless_chrome::LaunchOptionsBuilder::default() + .sandbox(false) + .build() + .map_err(|e| AppError::InternalError(e.to_string()))?; + Browser::new(options)? + } + #[cfg(not(feature = "docker"))] + { + Browser::default()? + } + }; + + let tab = browser.new_tab()?; + let page = tab.navigate_to(url)?; + let loaded_page = page.wait_until_navigated()?; + let raw_content = loaded_page.get_content()?; + let screenshot = loaded_page.capture_screenshot( + headless_chrome::protocol::cdp::Page::CaptureScreenshotFormatOption::Jpeg, + None, + None, + true, + )?; + + let mut tmp_file = NamedTempFile::new()?; + let temp_path_str = format!("{:?}", tmp_file.path()); + + tmp_file.write_all(&screenshot)?; + tmp_file.as_file().sync_all()?; + + if let Err(e) = tmp_file.seek(SeekFrom::Start(0)) { + error!( + "URL: {}. Failed to seek temp file {} to start: {:?}. Proceeding, but hashing might fail.", + url, temp_path_str, e + ); + } + + let parsed_url = + url::Url::parse(url).map_err(|_| AppError::Validation("Invalid URL".to_string()))?; + + let domain = ensure_ingestion_url_allowed(&parsed_url)?; + let timestamp = Utc::now().format("%Y%m%d%H%M%S"); + let file_name = format!("{}_{}_{}.jpg", domain, "screenshot", timestamp); + + let metadata = FieldMetadata { + file_name: Some(file_name), + content_type: Some("image/jpeg".to_string()), + name: None, + headers: HeaderMap::new(), + }; + let field_data = FieldData { + contents: tmp_file, + metadata, + }; + + let file_info = FileInfo::new(field_data, db, user_id, config).await?; + + let config = dom_smoothie::Config { + text_mode: TextMode::Markdown, + ..Default::default() + }; + let mut readability = Readability::new(raw_content, None, Some(config))?; + let article: Article = readability.parse()?; + let end = now.elapsed(); + info!( + "URL: {}. Total time: {:?}. Final File ID: {}", + url, end, file_info.id + ); + + Ok((article, file_info)) +} + +fn ensure_ingestion_url_allowed(url: &url::Url) -> Result { + match url.scheme() { + "http" | "https" => {} + scheme => { + warn!(%url, %scheme, "Rejected ingestion URL due to unsupported scheme"); + return Err(AppError::Validation( + "Unsupported URL scheme for ingestion".to_string(), + )); + } + } + + let host = match url.host_str() { + Some(host) => host, + None => { + warn!(%url, "Rejected ingestion URL missing host"); + return Err(AppError::Validation( + "URL is missing a host component".to_string(), + )); + } + }; + + if host.eq_ignore_ascii_case("localhost") { + warn!(%url, host, "Rejected ingestion URL to localhost"); + return Err(AppError::Validation( + "Ingestion URL host is not allowed".to_string(), + )); + } + + if let Ok(ip) = host.parse::() { + let is_disallowed = match ip { + IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(), + IpAddr::V6(v6) => v6.is_unique_local() || v6.is_unicast_link_local(), + }; + + if ip.is_loopback() || ip.is_unspecified() || ip.is_multicast() || is_disallowed { + warn!(%url, host, %ip, "Rejected ingestion URL pointing to restricted network range"); + return Err(AppError::Validation( + "Ingestion URL host is not allowed".to_string(), + )); + } + } + + Ok(host.replace(|c: char| !c.is_alphanumeric(), "_")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_unsupported_scheme() { + let url = url::Url::parse("ftp://example.com").expect("url"); + assert!(ensure_ingestion_url_allowed(&url).is_err()); + } + + #[test] + fn rejects_localhost() { + let url = url::Url::parse("http://localhost/resource").expect("url"); + assert!(ensure_ingestion_url_allowed(&url).is_err()); + } + + #[test] + fn rejects_private_ipv4() { + let url = url::Url::parse("http://192.168.1.10/index.html").expect("url"); + assert!(ensure_ingestion_url_allowed(&url).is_err()); + } + + #[test] + fn allows_public_domain_and_sanitizes() { + let url = url::Url::parse("https://sub.example.com/path").expect("url"); + let sanitized = ensure_ingestion_url_allowed(&url).expect("allowed"); + assert_eq!(sanitized, "sub_example_com"); + } +}