mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-25 10:18:38 +02:00
refactor: implemented state machine for ingestion pipeline, improved performance
changelog additional moving around moved files around a bit
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
# Changelog
|
||||
## Unreleased
|
||||
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
||||
- Added knowledge entity search results to the global search
|
||||
- Backend fixes for improved performance when ingesting and retrieval
|
||||
|
||||
## Version 0.2.4 (2025-10-15)
|
||||
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2919,6 +2919,7 @@ name = "ingestion-pipeline"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"axum_typed_multipart",
|
||||
"base64 0.22.1",
|
||||
@@ -2933,6 +2934,7 @@ dependencies = [
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"state-machines",
|
||||
"surrealdb",
|
||||
"tempfile",
|
||||
"text-splitter",
|
||||
|
||||
@@ -31,5 +31,7 @@ lopdf = "0.32"
|
||||
|
||||
common = { path = "../common" }
|
||||
composite-retrieval = { path = "../composite-retrieval" }
|
||||
async-trait = { workspace = true }
|
||||
state-machines = { workspace = true }
|
||||
[features]
|
||||
docker = []
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_openai::types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
||||
ResponseFormatJsonSchema,
|
||||
};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity};
|
||||
|
||||
use crate::{
|
||||
types::llm_enrichment_result::LLMEnrichmentResult,
|
||||
utils::llm_instructions::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
||||
};
|
||||
|
||||
pub struct IngestionEnricher {
|
||||
db_client: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
}
|
||||
|
||||
impl IngestionEnricher {
|
||||
pub const fn new(
|
||||
db_client: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn analyze_content(
|
||||
&self,
|
||||
category: &str,
|
||||
context: Option<&str>,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<LLMEnrichmentResult, AppError> {
|
||||
let similar_entities = self
|
||||
.find_similar_entities(category, context, text, user_id)
|
||||
.await?;
|
||||
let llm_request = self
|
||||
.prepare_llm_request(category, context, text, &similar_entities)
|
||||
.await?;
|
||||
self.perform_analysis(llm_request).await
|
||||
}
|
||||
|
||||
async fn find_similar_entities(
|
||||
&self,
|
||||
category: &str,
|
||||
context: Option<&str>,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let input_text =
|
||||
format!("content: {text}, category: {category}, user_context: {context:?}");
|
||||
|
||||
retrieve_entities(&self.db_client, &self.openai_client, &input_text, user_id).await
|
||||
}
|
||||
|
||||
async fn prepare_llm_request(
|
||||
&self,
|
||||
category: &str,
|
||||
context: Option<&str>,
|
||||
text: &str,
|
||||
similar_entities: &[RetrievedEntity],
|
||||
) -> Result<CreateChatCompletionRequest, AppError> {
|
||||
let settings = SystemSettings::get_current(&self.db_client).await?;
|
||||
|
||||
let entities_json = retrieved_entities_to_json(similar_entities);
|
||||
|
||||
let user_message = format!(
|
||||
"Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}"
|
||||
);
|
||||
|
||||
let response_format = ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: Some("Structured analysis of the submitted content".into()),
|
||||
name: "content_analysis".into(),
|
||||
schema: Some(get_ingress_analysis_schema()),
|
||||
strict: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
let request = CreateChatCompletionRequestArgs::default()
|
||||
.model(&settings.processing_model)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
])
|
||||
.response_format(response_format)
|
||||
.build()?;
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
async fn perform_analysis(
|
||||
&self,
|
||||
request: CreateChatCompletionRequest,
|
||||
) -> Result<LLMEnrichmentResult, AppError> {
|
||||
let response = self.openai_client.chat().create(request).await?;
|
||||
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".into(),
|
||||
))?;
|
||||
|
||||
serde_json::from_str::<LLMEnrichmentResult>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
pub mod enricher;
|
||||
pub mod pipeline;
|
||||
pub mod types;
|
||||
pub mod utils;
|
||||
|
||||
use chrono::Utc;
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
use text_splitter::TextSplitter;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::{debug, info, info_span, warn};
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
ingestion_task::{IngestionTask, TaskErrorInfo},
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
text_content::TextContent,
|
||||
},
|
||||
},
|
||||
utils::{config::AppConfig, embedding::generate_embedding},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
enricher::IngestionEnricher,
|
||||
types::{llm_enrichment_result::LLMEnrichmentResult, to_text_content},
|
||||
};
|
||||
|
||||
pub struct IngestionPipeline {
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
config: AppConfig,
|
||||
}
|
||||
|
||||
impl IngestionPipeline {
|
||||
pub async fn new(
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
config: AppConfig,
|
||||
) -> Result<Self, AppError> {
|
||||
Ok(Self {
|
||||
db,
|
||||
openai_client,
|
||||
config,
|
||||
})
|
||||
}
|
||||
pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> {
|
||||
let task_id = task.id.clone();
|
||||
let attempt = task.attempts;
|
||||
let worker_label = task
|
||||
.worker_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "unknown-worker".to_string());
|
||||
let span = info_span!(
|
||||
"ingestion_task",
|
||||
%task_id,
|
||||
attempt,
|
||||
worker_id = %worker_label,
|
||||
state = %task.state.as_str()
|
||||
);
|
||||
let _enter = span.enter();
|
||||
let processing_task = task.mark_processing(&self.db).await?;
|
||||
|
||||
let text_content = to_text_content(
|
||||
processing_task.content.clone(),
|
||||
&self.db,
|
||||
&self.config,
|
||||
&self.openai_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let text_len = text_content.text.chars().count();
|
||||
let preview: String = text_content.text.chars().take(120).collect();
|
||||
let preview_clean = preview.replace("\n", " ");
|
||||
let preview_len = preview_clean.chars().count();
|
||||
let truncated = text_len > preview_len;
|
||||
let context_len = text_content
|
||||
.context
|
||||
.as_ref()
|
||||
.map(|c| c.chars().count())
|
||||
.unwrap_or(0);
|
||||
info!(
|
||||
%task_id,
|
||||
attempt,
|
||||
user_id = %text_content.user_id,
|
||||
category = %text_content.category,
|
||||
text_chars = text_len,
|
||||
context_chars = context_len,
|
||||
attachments = text_content.file_info.is_some(),
|
||||
"ingestion task input ready"
|
||||
);
|
||||
debug!(
|
||||
%task_id,
|
||||
attempt,
|
||||
preview = %preview_clean,
|
||||
preview_truncated = truncated,
|
||||
"ingestion task input preview"
|
||||
);
|
||||
|
||||
match self.process(&text_content).await {
|
||||
Ok(()) => {
|
||||
processing_task.mark_succeeded(&self.db).await?;
|
||||
info!(%task_id, attempt, "ingestion task succeeded");
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let reason = err.to_string();
|
||||
let error_info = TaskErrorInfo {
|
||||
code: None,
|
||||
message: reason.clone(),
|
||||
};
|
||||
|
||||
if processing_task.can_retry() {
|
||||
let delay = Self::retry_delay(processing_task.attempts);
|
||||
processing_task
|
||||
.mark_failed(error_info, delay, &self.db)
|
||||
.await?;
|
||||
warn!(
|
||||
%task_id,
|
||||
attempt = processing_task.attempts,
|
||||
retry_in_secs = delay.as_secs(),
|
||||
"ingestion task failed; scheduled retry"
|
||||
);
|
||||
} else {
|
||||
processing_task
|
||||
.mark_dead_letter(error_info, &self.db)
|
||||
.await?;
|
||||
warn!(
|
||||
%task_id,
|
||||
attempt = processing_task.attempts,
|
||||
"ingestion task failed; moved to dead letter queue"
|
||||
);
|
||||
}
|
||||
|
||||
Err(AppError::Processing(reason))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_delay(attempt: u32) -> Duration {
|
||||
const BASE_SECONDS: u64 = 30;
|
||||
const MAX_SECONDS: u64 = 15 * 60;
|
||||
|
||||
let capped_attempt = attempt.saturating_sub(1).min(5);
|
||||
let multiplier = 2_u64.pow(capped_attempt);
|
||||
let delay = BASE_SECONDS * multiplier;
|
||||
|
||||
Duration::from_secs(delay.min(MAX_SECONDS))
|
||||
}
|
||||
|
||||
pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
|
||||
let now = Instant::now();
|
||||
|
||||
// Perform analyis, this step also includes retrieval
|
||||
let analysis = self.perform_semantic_analysis(content).await?;
|
||||
|
||||
let end = now.elapsed();
|
||||
info!(
|
||||
"{:?} time elapsed during creation of entities and relationships",
|
||||
end
|
||||
);
|
||||
|
||||
// Convert analysis to application objects
|
||||
let (entities, relationships) = analysis
|
||||
.to_database_entities(&content.id, &content.user_id, &self.openai_client, &self.db)
|
||||
.await?;
|
||||
|
||||
// Store everything
|
||||
tokio::try_join!(
|
||||
self.store_graph_entities(entities, relationships),
|
||||
self.store_vector_chunks(content),
|
||||
)?;
|
||||
|
||||
// Store original content
|
||||
self.db.store_item(content.to_owned()).await?;
|
||||
|
||||
self.db.rebuild_indexes().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn perform_semantic_analysis(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
) -> Result<LLMEnrichmentResult, AppError> {
|
||||
let analyser = IngestionEnricher::new(self.db.clone(), self.openai_client.clone());
|
||||
analyser
|
||||
.analyze_content(
|
||||
&content.category,
|
||||
content.context.as_deref(),
|
||||
&content.text,
|
||||
&content.user_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_graph_entities(
|
||||
&self,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
) -> Result<(), AppError> {
|
||||
let entities = Arc::new(entities);
|
||||
let relationships = Arc::new(relationships);
|
||||
let entity_count = entities.len();
|
||||
let relationship_count = relationships.len();
|
||||
|
||||
const STORE_GRAPH_MUTATION: &str = r"
|
||||
BEGIN TRANSACTION;
|
||||
LET $entities = $entities;
|
||||
LET $relationships = $relationships;
|
||||
|
||||
FOR $entity IN $entities {
|
||||
CREATE type::thing('knowledge_entity', $entity.id) CONTENT $entity;
|
||||
};
|
||||
|
||||
FOR $relationship IN $relationships {
|
||||
LET $in_node = type::thing('knowledge_entity', $relationship.in);
|
||||
LET $out_node = type::thing('knowledge_entity', $relationship.out);
|
||||
RELATE $in_node->relates_to->$out_node CONTENT {
|
||||
id: type::thing('relates_to', $relationship.id),
|
||||
metadata: $relationship.metadata
|
||||
};
|
||||
};
|
||||
|
||||
COMMIT TRANSACTION;
|
||||
";
|
||||
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
const INITIAL_BACKOFF_MS: u64 = 50;
|
||||
const MAX_BACKOFF_MS: u64 = 800;
|
||||
|
||||
let mut backoff_ms = INITIAL_BACKOFF_MS;
|
||||
let mut success = false;
|
||||
|
||||
for attempt in 0..MAX_ATTEMPTS {
|
||||
let result = self
|
||||
.db
|
||||
.client
|
||||
.query(STORE_GRAPH_MUTATION)
|
||||
.bind(("entities", entities.clone()))
|
||||
.bind(("relationships", relationships.clone()))
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
success = true;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
if Self::is_retryable_conflict(&err) && attempt + 1 < MAX_ATTEMPTS {
|
||||
warn!(
|
||||
attempt = attempt + 1,
|
||||
"Transient SurrealDB conflict while storing graph data; retrying"
|
||||
);
|
||||
sleep(Duration::from_millis(backoff_ms)).await;
|
||||
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AppError::from(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !success {
|
||||
return Err(AppError::InternalError(
|
||||
"Failed to store graph entities after retries".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
info!(
|
||||
"Stored {} entities and {} relationships",
|
||||
entity_count, relationship_count
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), AppError> {
|
||||
let splitter = TextSplitter::new(500..2000);
|
||||
let chunks = splitter.chunks(&content.text);
|
||||
|
||||
// Could potentially process chunks in parallel with a bounded concurrent limit
|
||||
for chunk in chunks {
|
||||
let embedding = generate_embedding(&self.openai_client, chunk, &self.db).await?;
|
||||
let text_chunk = TextChunk::new(
|
||||
content.id.to_string(),
|
||||
chunk.to_string(),
|
||||
embedding,
|
||||
content.user_id.to_string(),
|
||||
);
|
||||
self.db.store_item(text_chunk).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_retryable_conflict(error: &surrealdb::Error) -> bool {
|
||||
error
|
||||
.to_string()
|
||||
.contains("Failed to commit transaction due to a read or write conflict")
|
||||
}
|
||||
}
|
||||
35
ingestion-pipeline/src/pipeline/config.rs
Normal file
35
ingestion-pipeline/src/pipeline/config.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IngestionTuning {
|
||||
pub retry_base_delay_secs: u64,
|
||||
pub retry_max_delay_secs: u64,
|
||||
pub retry_backoff_cap_exponent: u32,
|
||||
pub graph_store_attempts: usize,
|
||||
pub graph_initial_backoff_ms: u64,
|
||||
pub graph_max_backoff_ms: u64,
|
||||
pub chunk_min_chars: usize,
|
||||
pub chunk_max_chars: usize,
|
||||
pub chunk_insert_concurrency: usize,
|
||||
pub entity_embedding_concurrency: usize,
|
||||
}
|
||||
|
||||
impl Default for IngestionTuning {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
retry_base_delay_secs: 30,
|
||||
retry_max_delay_secs: 15 * 60,
|
||||
retry_backoff_cap_exponent: 5,
|
||||
graph_store_attempts: 3,
|
||||
graph_initial_backoff_ms: 50,
|
||||
graph_max_backoff_ms: 800,
|
||||
chunk_min_chars: 500,
|
||||
chunk_max_chars: 2_000,
|
||||
chunk_insert_concurrency: 8,
|
||||
entity_embedding_concurrency: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct IngestionConfig {
|
||||
pub tuning: IngestionTuning,
|
||||
}
|
||||
76
ingestion-pipeline/src/pipeline/context.rs
Normal file
76
ingestion-pipeline/src/pipeline/context.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{ingestion_task::IngestionTask, text_content::TextContent},
|
||||
},
|
||||
};
|
||||
use composite_retrieval::RetrievedEntity;
|
||||
use tracing::error;
|
||||
|
||||
use super::enrichment_result::LLMEnrichmentResult;
|
||||
|
||||
use super::{config::IngestionConfig, services::PipelineServices};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
pub task: &'a IngestionTask,
|
||||
pub task_id: String,
|
||||
pub attempt: u32,
|
||||
pub db: &'a SurrealDbClient,
|
||||
pub pipeline_config: &'a IngestionConfig,
|
||||
pub services: &'a dyn PipelineServices,
|
||||
pub text_content: Option<TextContent>,
|
||||
pub similar_entities: Vec<RetrievedEntity>,
|
||||
pub analysis: Option<LLMEnrichmentResult>,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(
|
||||
task: &'a IngestionTask,
|
||||
db: &'a SurrealDbClient,
|
||||
pipeline_config: &'a IngestionConfig,
|
||||
services: &'a dyn PipelineServices,
|
||||
) -> Self {
|
||||
let task_id = task.id.clone();
|
||||
let attempt = task.attempts;
|
||||
Self {
|
||||
task,
|
||||
task_id,
|
||||
attempt,
|
||||
db,
|
||||
pipeline_config,
|
||||
services,
|
||||
text_content: None,
|
||||
similar_entities: Vec::new(),
|
||||
analysis: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn text_content(&self) -> Result<&TextContent, AppError> {
|
||||
self.text_content
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::InternalError("text content expected to be available".into()))
|
||||
}
|
||||
|
||||
pub fn take_text_content(&mut self) -> Result<TextContent, AppError> {
|
||||
self.text_content.take().ok_or_else(|| {
|
||||
AppError::InternalError("text content expected to be available for persistence".into())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn take_analysis(&mut self) -> Result<LLMEnrichmentResult, AppError> {
|
||||
self.analysis.take().ok_or_else(|| {
|
||||
AppError::InternalError("analysis expected to be available for persistence".into())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn abort(&mut self, err: AppError) -> AppError {
|
||||
error!(
|
||||
task_id = %self.task_id,
|
||||
attempt = self.attempt,
|
||||
error = %err,
|
||||
"ingestion pipeline aborted"
|
||||
);
|
||||
err
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::Utc;
|
||||
use futures::stream::{self, StreamExt, TryStreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::task;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
@@ -15,28 +15,25 @@ use common::{
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use futures::future::try_join_all;
|
||||
|
||||
use crate::utils::GraphMapper;
|
||||
use crate::utils::graph_mapper::GraphMapper;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMKnowledgeEntity {
|
||||
pub key: String, // Temporary identifier
|
||||
pub key: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub entity_type: String, // Should match KnowledgeEntityType variants
|
||||
pub entity_type: String,
|
||||
}
|
||||
|
||||
/// Represents a single relationship from the LLM.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMRelationship {
|
||||
#[serde(rename = "type")]
|
||||
pub type_: String, // e.g., RelatedTo, RelevantTo
|
||||
pub source: String, // Key of the source entity
|
||||
pub target: String, // Key of the target entity
|
||||
pub type_: String,
|
||||
pub source: String,
|
||||
pub target: String,
|
||||
}
|
||||
|
||||
/// Represents the entire graph analysis result from the LLM.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMEnrichmentResult {
|
||||
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
|
||||
@@ -44,27 +41,16 @@ pub struct LLMEnrichmentResult {
|
||||
}
|
||||
|
||||
impl LLMEnrichmentResult {
|
||||
/// Converts the LLM graph analysis result into database entities and relationships.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - A UUID representing the source identifier.
|
||||
/// * `openai_client` - `OpenAI` client for LLM calls.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
|
||||
pub async fn to_database_entities(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
db_client: &SurrealDbClient,
|
||||
entity_concurrency: usize,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||
// Create mapper and pre-assign IDs
|
||||
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
|
||||
let mapper = Arc::new(self.create_mapper()?);
|
||||
|
||||
// Process entities
|
||||
let entities = self
|
||||
.process_entities(
|
||||
source_id,
|
||||
@@ -72,10 +58,10 @@ impl LLMEnrichmentResult {
|
||||
Arc::clone(&mapper),
|
||||
openai_client,
|
||||
db_client,
|
||||
entity_concurrency,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Process relationships
|
||||
let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?;
|
||||
|
||||
Ok((entities, relationships))
|
||||
@@ -84,7 +70,6 @@ impl LLMEnrichmentResult {
|
||||
fn create_mapper(&self) -> Result<GraphMapper, AppError> {
|
||||
let mut mapper = GraphMapper::new();
|
||||
|
||||
// Pre-assign all IDs
|
||||
for entity in &self.knowledge_entities {
|
||||
mapper.assign_id(&entity.key);
|
||||
}
|
||||
@@ -96,57 +81,46 @@ impl LLMEnrichmentResult {
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
mapper: Arc<GraphMapper>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
db_client: &SurrealDbClient,
|
||||
entity_concurrency: usize,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let futures: Vec<_> = self
|
||||
.knowledge_entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
let mapper = Arc::clone(&mapper);
|
||||
let openai_client = openai_client.clone();
|
||||
let source_id = source_id.to_string();
|
||||
let user_id = user_id.to_string();
|
||||
let entity = entity.clone();
|
||||
let db_client = db_client.clone();
|
||||
stream::iter(self.knowledge_entities.iter().cloned().map(|entity| {
|
||||
let mapper = Arc::clone(&mapper);
|
||||
let openai_client = openai_client.clone();
|
||||
let source_id = source_id.to_string();
|
||||
let user_id = user_id.to_string();
|
||||
let db_client = db_client.clone();
|
||||
|
||||
task::spawn(async move {
|
||||
create_single_entity(
|
||||
&entity,
|
||||
&source_id,
|
||||
&user_id,
|
||||
mapper,
|
||||
&openai_client,
|
||||
&db_client.clone(),
|
||||
)
|
||||
.await
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = try_join_all(futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(results)
|
||||
async move {
|
||||
create_single_entity(
|
||||
&entity,
|
||||
&source_id,
|
||||
&user_id,
|
||||
mapper,
|
||||
&openai_client,
|
||||
&db_client,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}))
|
||||
.buffer_unordered(entity_concurrency.max(1))
|
||||
.try_collect()
|
||||
.await
|
||||
}
|
||||
|
||||
fn process_relationships(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
mapper: Arc<GraphMapper>,
|
||||
) -> Result<Vec<KnowledgeRelationship>, AppError> {
|
||||
let mapper_guard = mapper
|
||||
.lock()
|
||||
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
|
||||
self.relationships
|
||||
.iter()
|
||||
.map(|rel| {
|
||||
let source_db_id = mapper_guard.get_or_parse_id(&rel.source)?;
|
||||
let target_db_id = mapper_guard.get_or_parse_id(&rel.target)?;
|
||||
let source_db_id = mapper.get_or_parse_id(&rel.source)?;
|
||||
let target_db_id = mapper.get_or_parse_id(&rel.target)?;
|
||||
|
||||
Ok(KnowledgeRelationship::new(
|
||||
source_db_id.to_string(),
|
||||
@@ -159,20 +133,16 @@ impl LLMEnrichmentResult {
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_single_entity(
|
||||
llm_entity: &LLMKnowledgeEntity,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
mapper: Arc<GraphMapper>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<KnowledgeEntity, AppError> {
|
||||
let assigned_id = {
|
||||
let mapper = mapper
|
||||
.lock()
|
||||
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
|
||||
mapper.get_id(&llm_entity.key)?.to_string()
|
||||
};
|
||||
let assigned_id = mapper.get_id(&llm_entity.key)?.to_string();
|
||||
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {}",
|
||||
221
ingestion-pipeline/src/pipeline/mod.rs
Normal file
221
ingestion-pipeline/src/pipeline/mod.rs
Normal file
@@ -0,0 +1,221 @@
|
||||
mod config;
|
||||
mod context;
|
||||
mod enrichment_result;
|
||||
mod preparation;
|
||||
mod services;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
pub use config::{IngestionConfig, IngestionTuning};
|
||||
pub use services::{DefaultPipelineServices, PipelineServices};
|
||||
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_openai::Client;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload,
|
||||
ingestion_task::{IngestionTask, TaskErrorInfo},
|
||||
},
|
||||
},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use self::{
|
||||
context::PipelineContext,
|
||||
stages::{enrich, persist, prepare_content, retrieve_related},
|
||||
state::ready,
|
||||
};
|
||||
|
||||
pub struct IngestionPipeline {
|
||||
db: Arc<SurrealDbClient>,
|
||||
pipeline_config: IngestionConfig,
|
||||
services: Arc<dyn PipelineServices>,
|
||||
}
|
||||
|
||||
impl IngestionPipeline {
|
||||
pub async fn new(
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
config: AppConfig,
|
||||
) -> Result<Self, AppError> {
|
||||
let services =
|
||||
DefaultPipelineServices::new(db.clone(), openai_client.clone(), config.clone());
|
||||
|
||||
Self::with_services(db, IngestionConfig::default(), Arc::new(services))
|
||||
}
|
||||
|
||||
pub fn with_services(
|
||||
db: Arc<SurrealDbClient>,
|
||||
pipeline_config: IngestionConfig,
|
||||
services: Arc<dyn PipelineServices>,
|
||||
) -> Result<Self, AppError> {
|
||||
Ok(Self {
|
||||
db,
|
||||
pipeline_config,
|
||||
services,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
task_id = %task.id,
|
||||
attempt = task.attempts,
|
||||
worker_id = task.worker_id.as_deref().unwrap_or("unknown-worker"),
|
||||
user_id = %task.user_id
|
||||
)
|
||||
)]
|
||||
pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> {
|
||||
let mut processing_task = task.mark_processing(&self.db).await?;
|
||||
let payload = std::mem::replace(
|
||||
&mut processing_task.content,
|
||||
IngestionPayload::Text {
|
||||
text: String::new(),
|
||||
context: String::new(),
|
||||
category: String::new(),
|
||||
user_id: processing_task.user_id.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
match self
|
||||
.drive_pipeline(&processing_task, payload)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
debug!(
|
||||
task_id = %processing_task.id,
|
||||
attempt = processing_task.attempts,
|
||||
error = %err,
|
||||
"ingestion pipeline failed"
|
||||
);
|
||||
err
|
||||
}) {
|
||||
Ok(()) => {
|
||||
processing_task.mark_succeeded(&self.db).await?;
|
||||
tracing::info!(
|
||||
task_id = %processing_task.id,
|
||||
attempt = processing_task.attempts,
|
||||
"ingestion task succeeded"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let reason = err.to_string();
|
||||
let retryable = !matches!(err, AppError::Validation(_));
|
||||
let error_info = TaskErrorInfo {
|
||||
code: None,
|
||||
message: reason.clone(),
|
||||
};
|
||||
|
||||
if retryable && processing_task.can_retry() {
|
||||
let delay = self.retry_delay(processing_task.attempts);
|
||||
processing_task
|
||||
.mark_failed(error_info, delay, &self.db)
|
||||
.await?;
|
||||
warn!(
|
||||
task_id = %processing_task.id,
|
||||
attempt = processing_task.attempts,
|
||||
retry_in_secs = delay.as_secs(),
|
||||
"ingestion task failed; scheduled retry"
|
||||
);
|
||||
} else {
|
||||
let failed_task = processing_task
|
||||
.mark_failed(error_info.clone(), Duration::from_secs(0), &self.db)
|
||||
.await?;
|
||||
failed_task.mark_dead_letter(error_info, &self.db).await?;
|
||||
warn!(
|
||||
task_id = %failed_task.id,
|
||||
attempt = failed_task.attempts,
|
||||
"ingestion task failed; moved to dead letter queue"
|
||||
);
|
||||
}
|
||||
|
||||
Err(AppError::Processing(reason))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_delay(&self, attempt: u32) -> Duration {
|
||||
let tuning = &self.pipeline_config.tuning;
|
||||
let capped_attempt = attempt
|
||||
.saturating_sub(1)
|
||||
.min(tuning.retry_backoff_cap_exponent);
|
||||
let multiplier = 2_u64.pow(capped_attempt);
|
||||
let delay = tuning.retry_base_delay_secs * multiplier;
|
||||
|
||||
Duration::from_secs(delay.min(tuning.retry_max_delay_secs))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(task_id = %task.id, attempt = task.attempts, user_id = %task.user_id)
|
||||
)]
|
||||
async fn drive_pipeline(
|
||||
&self,
|
||||
task: &IngestionTask,
|
||||
payload: IngestionPayload,
|
||||
) -> Result<(), AppError> {
|
||||
let mut ctx = PipelineContext::new(
|
||||
task,
|
||||
self.db.as_ref(),
|
||||
&self.pipeline_config,
|
||||
self.services.as_ref(),
|
||||
);
|
||||
|
||||
let machine = ready();
|
||||
|
||||
let pipeline_started = Instant::now();
|
||||
|
||||
let stage_start = Instant::now();
|
||||
let machine = prepare_content(machine, &mut ctx, payload)
|
||||
.await
|
||||
.map_err(|err| ctx.abort(err))?;
|
||||
let prepare_duration = stage_start.elapsed();
|
||||
|
||||
let stage_start = Instant::now();
|
||||
let machine = retrieve_related(machine, &mut ctx)
|
||||
.await
|
||||
.map_err(|err| ctx.abort(err))?;
|
||||
let retrieve_duration = stage_start.elapsed();
|
||||
|
||||
let stage_start = Instant::now();
|
||||
let machine = enrich(machine, &mut ctx)
|
||||
.await
|
||||
.map_err(|err| ctx.abort(err))?;
|
||||
let enrich_duration = stage_start.elapsed();
|
||||
|
||||
let stage_start = Instant::now();
|
||||
let _machine = persist(machine, &mut ctx)
|
||||
.await
|
||||
.map_err(|err| ctx.abort(err))?;
|
||||
let persist_duration = stage_start.elapsed();
|
||||
|
||||
let total_duration = pipeline_started.elapsed();
|
||||
let prepare_ms = prepare_duration.as_millis() as u64;
|
||||
let retrieve_ms = retrieve_duration.as_millis() as u64;
|
||||
let enrich_ms = enrich_duration.as_millis() as u64;
|
||||
let persist_ms = persist_duration.as_millis() as u64;
|
||||
info!(
|
||||
task_id = %ctx.task_id,
|
||||
attempt = ctx.attempt,
|
||||
total_ms = total_duration.as_millis() as u64,
|
||||
prepare_ms,
|
||||
retrieve_ms,
|
||||
enrich_ms,
|
||||
persist_ms,
|
||||
"ingestion pipeline finished"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
74
ingestion-pipeline/src/pipeline/preparation.rs
Normal file
74
ingestion-pipeline/src/pipeline/preparation.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload,
|
||||
text_content::{TextContent, UrlInfo},
|
||||
},
|
||||
},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
|
||||
use crate::utils::{
|
||||
file_text_extraction::extract_text_from_file, url_text_retrieval::extract_text_from_url,
|
||||
};
|
||||
|
||||
pub(crate) async fn to_text_content(
|
||||
ingestion_payload: IngestionPayload,
|
||||
db: &SurrealDbClient,
|
||||
config: &AppConfig,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<TextContent, AppError> {
|
||||
match ingestion_payload {
|
||||
IngestionPayload::Url {
|
||||
url,
|
||||
context,
|
||||
category,
|
||||
user_id,
|
||||
} => {
|
||||
let (article, file_info) = extract_text_from_url(&url, db, &user_id, config).await?;
|
||||
Ok(TextContent::new(
|
||||
article.text_content.into(),
|
||||
Some(context),
|
||||
category,
|
||||
None,
|
||||
Some(UrlInfo {
|
||||
url,
|
||||
title: article.title,
|
||||
image_id: file_info.id,
|
||||
}),
|
||||
user_id,
|
||||
))
|
||||
}
|
||||
IngestionPayload::Text {
|
||||
text,
|
||||
context,
|
||||
category,
|
||||
user_id,
|
||||
} => Ok(TextContent::new(
|
||||
text,
|
||||
Some(context),
|
||||
category,
|
||||
None,
|
||||
None,
|
||||
user_id,
|
||||
)),
|
||||
IngestionPayload::File {
|
||||
file_info,
|
||||
context,
|
||||
category,
|
||||
user_id,
|
||||
} => {
|
||||
let text = extract_text_from_file(&file_info, db, openai_client, config).await?;
|
||||
Ok(TextContent::new(
|
||||
text,
|
||||
Some(context),
|
||||
category,
|
||||
Some(file_info),
|
||||
None,
|
||||
user_id,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
213
ingestion-pipeline/src/pipeline/services.rs
Normal file
213
ingestion-pipeline/src/pipeline/services.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use async_openai::types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
||||
ResponseFormatJsonSchema,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
|
||||
text_chunk::TextChunk, text_content::TextContent,
|
||||
},
|
||||
},
|
||||
utils::{config::AppConfig, embedding::generate_embedding},
|
||||
};
|
||||
use composite_retrieval::{retrieve_entities, retrieved_entities_to_json, RetrievedEntity};
|
||||
use text_splitter::TextSplitter;
|
||||
|
||||
use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content};
|
||||
use crate::utils::llm_instructions::{
|
||||
get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait PipelineServices: Send + Sync {
|
||||
async fn prepare_text_content(
|
||||
&self,
|
||||
payload: IngestionPayload,
|
||||
) -> Result<TextContent, AppError>;
|
||||
|
||||
async fn retrieve_similar_entities(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError>;
|
||||
|
||||
async fn run_enrichment(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
similar_entities: &[RetrievedEntity],
|
||||
) -> Result<LLMEnrichmentResult, AppError>;
|
||||
|
||||
async fn convert_analysis(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
analysis: &LLMEnrichmentResult,
|
||||
entity_concurrency: usize,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>;
|
||||
|
||||
async fn prepare_chunks(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
range: Range<usize>,
|
||||
) -> Result<Vec<TextChunk>, AppError>;
|
||||
}
|
||||
|
||||
pub struct DefaultPipelineServices {
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
config: AppConfig,
|
||||
}
|
||||
|
||||
impl DefaultPipelineServices {
|
||||
pub fn new(
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
config: AppConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
db,
|
||||
openai_client,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
async fn prepare_llm_request(
|
||||
&self,
|
||||
category: &str,
|
||||
context: Option<&str>,
|
||||
text: &str,
|
||||
similar_entities: &[RetrievedEntity],
|
||||
) -> Result<CreateChatCompletionRequest, AppError> {
|
||||
let settings = SystemSettings::get_current(&self.db).await?;
|
||||
|
||||
let entities_json = retrieved_entities_to_json(similar_entities);
|
||||
|
||||
let user_message = format!(
|
||||
"Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}"
|
||||
);
|
||||
|
||||
let response_format = ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: Some("Structured analysis of the submitted content".into()),
|
||||
name: "content_analysis".into(),
|
||||
schema: Some(get_ingress_analysis_schema()),
|
||||
strict: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
let request = CreateChatCompletionRequestArgs::default()
|
||||
.model(&settings.processing_model)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
])
|
||||
.response_format(response_format)
|
||||
.build()?;
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
async fn perform_analysis(
|
||||
&self,
|
||||
request: CreateChatCompletionRequest,
|
||||
) -> Result<LLMEnrichmentResult, AppError> {
|
||||
let response = self.openai_client.chat().create(request).await?;
|
||||
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".into(),
|
||||
))?;
|
||||
|
||||
serde_json::from_str::<LLMEnrichmentResult>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineServices for DefaultPipelineServices {
|
||||
async fn prepare_text_content(
|
||||
&self,
|
||||
payload: IngestionPayload,
|
||||
) -> Result<TextContent, AppError> {
|
||||
to_text_content(payload, &self.db, &self.config, &self.openai_client).await
|
||||
}
|
||||
|
||||
async fn retrieve_similar_entities(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let input_text = format!(
|
||||
"content: {}, category: {}, user_context: {:?}",
|
||||
content.text, content.category, content.context
|
||||
);
|
||||
|
||||
retrieve_entities(&self.db, &self.openai_client, &input_text, &content.user_id).await
|
||||
}
|
||||
|
||||
async fn run_enrichment(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
similar_entities: &[RetrievedEntity],
|
||||
) -> Result<LLMEnrichmentResult, AppError> {
|
||||
let request = self
|
||||
.prepare_llm_request(
|
||||
&content.category,
|
||||
content.context.as_deref(),
|
||||
&content.text,
|
||||
similar_entities,
|
||||
)
|
||||
.await?;
|
||||
self.perform_analysis(request).await
|
||||
}
|
||||
|
||||
async fn convert_analysis(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
analysis: &LLMEnrichmentResult,
|
||||
entity_concurrency: usize,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||
analysis
|
||||
.to_database_entities(
|
||||
&content.id,
|
||||
&content.user_id,
|
||||
&self.openai_client,
|
||||
&self.db,
|
||||
entity_concurrency,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn prepare_chunks(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
range: Range<usize>,
|
||||
) -> Result<Vec<TextChunk>, AppError> {
|
||||
let splitter = TextSplitter::new(range.clone());
|
||||
let chunk_texts: Vec<String> = splitter
|
||||
.chunks(&content.text)
|
||||
.map(|chunk| chunk.to_string())
|
||||
.collect();
|
||||
|
||||
let mut chunks = Vec::with_capacity(chunk_texts.len());
|
||||
for chunk in chunk_texts {
|
||||
let embedding = generate_embedding(&self.openai_client, &chunk, &self.db).await?;
|
||||
chunks.push(TextChunk::new(
|
||||
content.id.clone(),
|
||||
chunk,
|
||||
embedding,
|
||||
content.user_id.clone(),
|
||||
));
|
||||
}
|
||||
Ok(chunks)
|
||||
}
|
||||
}
|
||||
338
ingestion-pipeline/src/pipeline/stages/mod.rs
Normal file
338
ingestion-pipeline/src/pipeline/stages/mod.rs
Normal file
@@ -0,0 +1,338 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
||||
text_content::TextContent,
|
||||
},
|
||||
},
|
||||
};
|
||||
use state_machines::core::GuardError;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use super::{
|
||||
context::PipelineContext,
|
||||
services::PipelineServices,
|
||||
state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved},
|
||||
};
|
||||
|
||||
#[instrument(
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id)
|
||||
)]
|
||||
pub async fn prepare_content(
|
||||
machine: IngestionMachine<(), Ready>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
payload: IngestionPayload,
|
||||
) -> Result<IngestionMachine<(), ContentPrepared>, AppError> {
|
||||
let text_content = ctx.services.prepare_text_content(payload).await?;
|
||||
|
||||
let text_len = text_content.text.chars().count();
|
||||
let preview: String = text_content.text.chars().take(120).collect();
|
||||
let preview_clean = preview.replace('\n', " ");
|
||||
let preview_len = preview_clean.chars().count();
|
||||
let truncated = text_len > preview_len;
|
||||
let context_len = text_content
|
||||
.context
|
||||
.as_ref()
|
||||
.map(|c| c.chars().count())
|
||||
.unwrap_or(0);
|
||||
|
||||
tracing::info!(
|
||||
task_id = %ctx.task_id,
|
||||
attempt = ctx.attempt,
|
||||
user_id = %text_content.user_id,
|
||||
category = %text_content.category,
|
||||
text_chars = text_len,
|
||||
context_chars = context_len,
|
||||
attachments = text_content.file_info.is_some(),
|
||||
"ingestion task input ready"
|
||||
);
|
||||
debug!(
|
||||
task_id = %ctx.task_id,
|
||||
attempt = ctx.attempt,
|
||||
preview = %preview_clean,
|
||||
preview_truncated = truncated,
|
||||
"ingestion task input preview"
|
||||
);
|
||||
|
||||
ctx.text_content = Some(text_content);
|
||||
|
||||
machine
|
||||
.prepare()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare", guard))
|
||||
}
|
||||
|
||||
#[instrument(
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id)
|
||||
)]
|
||||
pub async fn retrieve_related(
|
||||
machine: IngestionMachine<(), ContentPrepared>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<IngestionMachine<(), Retrieved>, AppError> {
|
||||
let content = ctx.text_content()?;
|
||||
let similar = ctx.services.retrieve_similar_entities(content).await?;
|
||||
|
||||
debug!(
|
||||
task_id = %ctx.task_id,
|
||||
attempt = ctx.attempt,
|
||||
similar_count = similar.len(),
|
||||
"ingestion retrieved similar entities"
|
||||
);
|
||||
|
||||
ctx.similar_entities = similar;
|
||||
|
||||
machine
|
||||
.retrieve()
|
||||
.map_err(|(_, guard)| map_guard_error("retrieve", guard))
|
||||
}
|
||||
|
||||
#[instrument(
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id)
|
||||
)]
|
||||
pub async fn enrich(
|
||||
machine: IngestionMachine<(), Retrieved>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<IngestionMachine<(), Enriched>, AppError> {
|
||||
let content = ctx.text_content()?;
|
||||
let analysis = ctx
|
||||
.services
|
||||
.run_enrichment(content, &ctx.similar_entities)
|
||||
.await?;
|
||||
|
||||
debug!(
|
||||
task_id = %ctx.task_id,
|
||||
attempt = ctx.attempt,
|
||||
entity_suggestions = analysis.knowledge_entities.len(),
|
||||
relationship_suggestions = analysis.relationships.len(),
|
||||
"ingestion enrichment completed"
|
||||
);
|
||||
|
||||
ctx.analysis = Some(analysis);
|
||||
|
||||
machine
|
||||
.enrich()
|
||||
.map_err(|(_, guard)| map_guard_error("enrich", guard))
|
||||
}
|
||||
|
||||
#[instrument(
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(task_id = %ctx.task_id, attempt = ctx.attempt, user_id = %ctx.task.user_id)
|
||||
)]
|
||||
pub async fn persist(
|
||||
machine: IngestionMachine<(), Enriched>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<IngestionMachine<(), Persisted>, AppError> {
|
||||
let content = ctx.take_text_content()?;
|
||||
let analysis = ctx.take_analysis()?;
|
||||
|
||||
let (entities, relationships) = ctx
|
||||
.services
|
||||
.convert_analysis(
|
||||
&content,
|
||||
&analysis,
|
||||
ctx.pipeline_config.tuning.entity_embedding_concurrency,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let entity_count = entities.len();
|
||||
let relationship_count = relationships.len();
|
||||
|
||||
let chunk_range =
|
||||
ctx.pipeline_config.tuning.chunk_min_chars..ctx.pipeline_config.tuning.chunk_max_chars;
|
||||
|
||||
let ((), chunk_count) = tokio::try_join!(
|
||||
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships),
|
||||
store_vector_chunks(
|
||||
ctx.db,
|
||||
ctx.services,
|
||||
ctx.task_id.as_str(),
|
||||
&content,
|
||||
chunk_range,
|
||||
&ctx.pipeline_config.tuning
|
||||
)
|
||||
)?;
|
||||
|
||||
ctx.db.store_item(content).await?;
|
||||
ctx.db.rebuild_indexes().await?;
|
||||
|
||||
debug!(
|
||||
task_id = %ctx.task_id,
|
||||
attempt = ctx.attempt,
|
||||
entity_count,
|
||||
relationship_count,
|
||||
chunk_count,
|
||||
"ingestion persistence flushed to database"
|
||||
);
|
||||
|
||||
machine
|
||||
.persist()
|
||||
.map_err(|(_, guard)| map_guard_error("persist", guard))
|
||||
}
|
||||
|
||||
fn map_guard_error(event: &str, guard: GuardError) -> AppError {
|
||||
AppError::InternalError(format!(
|
||||
"invalid ingestion pipeline transition during {event}: {guard:?}"
|
||||
))
|
||||
}
|
||||
|
||||
async fn store_graph_entities(
|
||||
db: &SurrealDbClient,
|
||||
tuning: &super::config::IngestionTuning,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
) -> Result<(), AppError> {
|
||||
const STORE_GRAPH_MUTATION: &str = r"
|
||||
BEGIN TRANSACTION;
|
||||
LET $entities = $entities;
|
||||
LET $relationships = $relationships;
|
||||
|
||||
FOR $entity IN $entities {
|
||||
CREATE type::thing('knowledge_entity', $entity.id) CONTENT $entity;
|
||||
};
|
||||
|
||||
FOR $relationship IN $relationships {
|
||||
LET $in_node = type::thing('knowledge_entity', $relationship.in);
|
||||
LET $out_node = type::thing('knowledge_entity', $relationship.out);
|
||||
RELATE $in_node->relates_to->$out_node CONTENT {
|
||||
id: type::thing('relates_to', $relationship.id),
|
||||
metadata: $relationship.metadata
|
||||
};
|
||||
};
|
||||
|
||||
COMMIT TRANSACTION;
|
||||
";
|
||||
|
||||
let entities = Arc::new(entities);
|
||||
let relationships = Arc::new(relationships);
|
||||
|
||||
let mut backoff_ms = tuning.graph_initial_backoff_ms;
|
||||
|
||||
for attempt in 0..tuning.graph_store_attempts {
|
||||
let result = db
|
||||
.client
|
||||
.query(STORE_GRAPH_MUTATION)
|
||||
.bind(("entities", entities.clone()))
|
||||
.bind(("relationships", relationships.clone()))
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => return Ok(()),
|
||||
Err(err) => {
|
||||
if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts {
|
||||
warn!(
|
||||
attempt = attempt + 1,
|
||||
"Transient SurrealDB conflict while storing graph data; retrying"
|
||||
);
|
||||
sleep(Duration::from_millis(backoff_ms)).await;
|
||||
backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms);
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AppError::from(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::InternalError(
|
||||
"Failed to store graph entities after retries".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn store_vector_chunks(
|
||||
db: &SurrealDbClient,
|
||||
services: &dyn PipelineServices,
|
||||
task_id: &str,
|
||||
content: &TextContent,
|
||||
chunk_range: std::ops::Range<usize>,
|
||||
tuning: &super::config::IngestionTuning,
|
||||
) -> Result<usize, AppError> {
|
||||
let prepared_chunks = services.prepare_chunks(content, chunk_range).await?;
|
||||
let chunk_count = prepared_chunks.len();
|
||||
|
||||
let batch_size = tuning.chunk_insert_concurrency.max(1);
|
||||
for chunk in &prepared_chunks {
|
||||
debug!(
|
||||
task_id = %task_id,
|
||||
chunk_id = %chunk.id,
|
||||
chunk_len = chunk.chunk.chars().count(),
|
||||
"chunk persisted"
|
||||
);
|
||||
}
|
||||
|
||||
for batch in prepared_chunks.chunks(batch_size) {
|
||||
store_chunk_batch(db, batch, tuning).await?;
|
||||
}
|
||||
|
||||
Ok(chunk_count)
|
||||
}
|
||||
|
||||
fn is_retryable_conflict(error: &surrealdb::Error) -> bool {
|
||||
error
|
||||
.to_string()
|
||||
.contains("Failed to commit transaction due to a read or write conflict")
|
||||
}
|
||||
|
||||
async fn store_chunk_batch(
|
||||
db: &SurrealDbClient,
|
||||
batch: &[TextChunk],
|
||||
tuning: &super::config::IngestionTuning,
|
||||
) -> Result<(), AppError> {
|
||||
if batch.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
const STORE_CHUNKS_MUTATION: &str = r"
|
||||
BEGIN TRANSACTION;
|
||||
LET $chunks = $chunks;
|
||||
|
||||
FOR $chunk IN $chunks {
|
||||
CREATE type::thing('text_chunk', $chunk.id) CONTENT $chunk;
|
||||
};
|
||||
|
||||
COMMIT TRANSACTION;
|
||||
";
|
||||
|
||||
let chunks = Arc::new(batch.to_vec());
|
||||
let mut backoff_ms = tuning.graph_initial_backoff_ms;
|
||||
|
||||
for attempt in 0..tuning.graph_store_attempts {
|
||||
let result = db
|
||||
.client
|
||||
.query(STORE_CHUNKS_MUTATION)
|
||||
.bind(("chunks", chunks.clone()))
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => return Ok(()),
|
||||
Err(err) => {
|
||||
if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts {
|
||||
warn!(
|
||||
attempt = attempt + 1,
|
||||
"Transient SurrealDB conflict while storing chunks; retrying"
|
||||
);
|
||||
sleep(Duration::from_millis(backoff_ms)).await;
|
||||
backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms);
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AppError::from(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::InternalError(
|
||||
"Failed to store text chunks after retries".to_string(),
|
||||
))
|
||||
}
|
||||
25
ingestion-pipeline/src/pipeline/state.rs
Normal file
25
ingestion-pipeline/src/pipeline/state.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use state_machines::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: IngestionMachine,
|
||||
state: IngestionState,
|
||||
initial: Ready,
|
||||
states: [Ready, ContentPrepared, Retrieved, Enriched, Persisted, Failed],
|
||||
events {
|
||||
prepare { transition: { from: Ready, to: ContentPrepared } }
|
||||
retrieve { transition: { from: ContentPrepared, to: Retrieved } }
|
||||
enrich { transition: { from: Retrieved, to: Enriched } }
|
||||
persist { transition: { from: Enriched, to: Persisted } }
|
||||
abort {
|
||||
transition: { from: Ready, to: Failed }
|
||||
transition: { from: ContentPrepared, to: Failed }
|
||||
transition: { from: Retrieved, to: Failed }
|
||||
transition: { from: Enriched, to: Failed }
|
||||
transition: { from: Persisted, to: Failed }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ready() -> IngestionMachine<(), Ready> {
|
||||
IngestionMachine::new(())
|
||||
}
|
||||
440
ingestion-pipeline/src/pipeline/tests.rs
Normal file
440
ingestion-pipeline/src/pipeline/tests.rs
Normal file
@@ -0,0 +1,440 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload,
|
||||
ingestion_task::{IngestionTask, TaskState},
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
text_content::TextContent,
|
||||
},
|
||||
},
|
||||
};
|
||||
use composite_retrieval::{RetrievedChunk, RetrievedEntity};
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
config::{IngestionConfig, IngestionTuning},
|
||||
enrichment_result::LLMEnrichmentResult,
|
||||
services::PipelineServices,
|
||||
IngestionPipeline,
|
||||
};
|
||||
|
||||
struct MockServices {
|
||||
text_content: TextContent,
|
||||
similar_entities: Vec<RetrievedEntity>,
|
||||
analysis: LLMEnrichmentResult,
|
||||
chunk_embedding: Vec<f32>,
|
||||
graph_entities: Vec<KnowledgeEntity>,
|
||||
graph_relationships: Vec<KnowledgeRelationship>,
|
||||
calls: Mutex<Vec<&'static str>>,
|
||||
}
|
||||
|
||||
impl MockServices {
|
||||
fn new(user_id: &str) -> Self {
|
||||
const TEST_EMBEDDING_DIM: usize = 1536;
|
||||
let text_content = TextContent::new(
|
||||
"Example document for ingestion pipeline.".into(),
|
||||
Some("light context".into()),
|
||||
"notes".into(),
|
||||
None,
|
||||
None,
|
||||
user_id.into(),
|
||||
);
|
||||
let retrieved_entity = KnowledgeEntity::new(
|
||||
text_content.id.clone(),
|
||||
"Existing Entity".into(),
|
||||
"Previously known context".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
vec![0.1; TEST_EMBEDDING_DIM],
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
let retrieved_chunk = TextChunk::new(
|
||||
retrieved_entity.source_id.clone(),
|
||||
"existing chunk".into(),
|
||||
vec![0.1; TEST_EMBEDDING_DIM],
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
let analysis = LLMEnrichmentResult {
|
||||
knowledge_entities: Vec::new(),
|
||||
relationships: Vec::new(),
|
||||
};
|
||||
|
||||
let graph_entity = KnowledgeEntity::new(
|
||||
text_content.id.clone(),
|
||||
"Generated Entity".into(),
|
||||
"Entity from enrichment".into(),
|
||||
KnowledgeEntityType::Idea,
|
||||
None,
|
||||
vec![0.2; TEST_EMBEDDING_DIM],
|
||||
user_id.into(),
|
||||
);
|
||||
let graph_relationship = KnowledgeRelationship::new(
|
||||
graph_entity.id.clone(),
|
||||
graph_entity.id.clone(),
|
||||
user_id.into(),
|
||||
text_content.id.clone(),
|
||||
"related_to".into(),
|
||||
);
|
||||
|
||||
Self {
|
||||
text_content,
|
||||
similar_entities: vec![RetrievedEntity {
|
||||
entity: retrieved_entity,
|
||||
score: 0.8,
|
||||
chunks: vec![RetrievedChunk {
|
||||
chunk: retrieved_chunk,
|
||||
score: 0.7,
|
||||
}],
|
||||
}],
|
||||
analysis,
|
||||
chunk_embedding: vec![0.3; TEST_EMBEDDING_DIM],
|
||||
graph_entities: vec![graph_entity],
|
||||
graph_relationships: vec![graph_relationship],
|
||||
calls: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn record(&self, stage: &'static str) {
|
||||
self.calls.lock().await.push(stage);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineServices for MockServices {
|
||||
async fn prepare_text_content(
|
||||
&self,
|
||||
_payload: IngestionPayload,
|
||||
) -> Result<TextContent, AppError> {
|
||||
self.record("prepare").await;
|
||||
Ok(self.text_content.clone())
|
||||
}
|
||||
|
||||
async fn retrieve_similar_entities(
|
||||
&self,
|
||||
_content: &TextContent,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
self.record("retrieve").await;
|
||||
Ok(self.similar_entities.clone())
|
||||
}
|
||||
|
||||
async fn run_enrichment(
|
||||
&self,
|
||||
_content: &TextContent,
|
||||
_similar_entities: &[RetrievedEntity],
|
||||
) -> Result<LLMEnrichmentResult, AppError> {
|
||||
self.record("enrich").await;
|
||||
Ok(self.analysis.clone())
|
||||
}
|
||||
|
||||
async fn convert_analysis(
|
||||
&self,
|
||||
_content: &TextContent,
|
||||
_analysis: &LLMEnrichmentResult,
|
||||
_entity_concurrency: usize,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||
self.record("convert").await;
|
||||
Ok((
|
||||
self.graph_entities.clone(),
|
||||
self.graph_relationships.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn prepare_chunks(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
_range: std::ops::Range<usize>,
|
||||
) -> Result<Vec<TextChunk>, AppError> {
|
||||
self.record("chunk").await;
|
||||
Ok(vec![TextChunk::new(
|
||||
content.id.clone(),
|
||||
"chunk from mock services".into(),
|
||||
self.chunk_embedding.clone(),
|
||||
content.user_id.clone(),
|
||||
)])
|
||||
}
|
||||
}
|
||||
|
||||
struct FailingServices {
|
||||
inner: MockServices,
|
||||
}
|
||||
|
||||
struct ValidationServices;
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineServices for FailingServices {
|
||||
async fn prepare_text_content(
|
||||
&self,
|
||||
payload: IngestionPayload,
|
||||
) -> Result<TextContent, AppError> {
|
||||
self.inner.prepare_text_content(payload).await
|
||||
}
|
||||
|
||||
async fn retrieve_similar_entities(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
self.inner.retrieve_similar_entities(content).await
|
||||
}
|
||||
|
||||
async fn run_enrichment(
|
||||
&self,
|
||||
_content: &TextContent,
|
||||
_similar_entities: &[RetrievedEntity],
|
||||
) -> Result<LLMEnrichmentResult, AppError> {
|
||||
Err(AppError::Processing("mock enrichment failure".to_string()))
|
||||
}
|
||||
|
||||
async fn convert_analysis(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
analysis: &LLMEnrichmentResult,
|
||||
entity_concurrency: usize,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||
self.inner
|
||||
.convert_analysis(content, analysis, entity_concurrency)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn prepare_chunks(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
range: std::ops::Range<usize>,
|
||||
) -> Result<Vec<TextChunk>, AppError> {
|
||||
self.inner.prepare_chunks(content, range).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PipelineServices for ValidationServices {
|
||||
async fn prepare_text_content(
|
||||
&self,
|
||||
_payload: IngestionPayload,
|
||||
) -> Result<TextContent, AppError> {
|
||||
Err(AppError::Validation("unsupported".to_string()))
|
||||
}
|
||||
|
||||
async fn retrieve_similar_entities(
|
||||
&self,
|
||||
_content: &TextContent,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
unreachable!("retrieve_similar_entities should not be called after validation failure")
|
||||
}
|
||||
|
||||
async fn run_enrichment(
|
||||
&self,
|
||||
_content: &TextContent,
|
||||
_similar_entities: &[RetrievedEntity],
|
||||
) -> Result<LLMEnrichmentResult, AppError> {
|
||||
unreachable!("run_enrichment should not be called after validation failure")
|
||||
}
|
||||
|
||||
async fn convert_analysis(
|
||||
&self,
|
||||
_content: &TextContent,
|
||||
_analysis: &LLMEnrichmentResult,
|
||||
_entity_concurrency: usize,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||
unreachable!("convert_analysis should not be called after validation failure")
|
||||
}
|
||||
|
||||
async fn prepare_chunks(
|
||||
&self,
|
||||
_content: &TextContent,
|
||||
_range: std::ops::Range<usize>,
|
||||
) -> Result<Vec<TextChunk>, AppError> {
|
||||
unreachable!("prepare_chunks should not be called after validation failure")
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup_db() -> SurrealDbClient {
|
||||
let namespace = "pipeline_test";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to create in-memory SurrealDB");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
db
|
||||
}
|
||||
|
||||
fn pipeline_config() -> IngestionConfig {
|
||||
IngestionConfig {
|
||||
tuning: IngestionTuning {
|
||||
chunk_min_chars: 4,
|
||||
chunk_max_chars: 64,
|
||||
chunk_insert_concurrency: 4,
|
||||
entity_embedding_concurrency: 2,
|
||||
..IngestionTuning::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn reserve_task(
|
||||
db: &SurrealDbClient,
|
||||
worker_id: &str,
|
||||
payload: IngestionPayload,
|
||||
user_id: &str,
|
||||
) -> IngestionTask {
|
||||
let task = IngestionTask::create_and_add_to_db(payload, user_id.into(), db)
|
||||
.await
|
||||
.expect("task created");
|
||||
let lease = task.lease_duration();
|
||||
IngestionTask::claim_next_ready(db, worker_id, Utc::now(), lease)
|
||||
.await
|
||||
.expect("claim succeeds")
|
||||
.expect("task claimed")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ingestion_pipeline_happy_path_persists_entities() {
|
||||
let db = setup_db().await;
|
||||
let worker_id = "worker-happy";
|
||||
let user_id = "user-123";
|
||||
let services = Arc::new(MockServices::new(user_id));
|
||||
let pipeline =
|
||||
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services.clone())
|
||||
.expect("pipeline");
|
||||
|
||||
let task = reserve_task(
|
||||
&db,
|
||||
worker_id,
|
||||
IngestionPayload::Text {
|
||||
text: "Example payload".into(),
|
||||
context: "Context".into(),
|
||||
category: "notes".into(),
|
||||
user_id: user_id.into(),
|
||||
},
|
||||
user_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
pipeline
|
||||
.process_task(task.clone())
|
||||
.await
|
||||
.expect("pipeline succeeds");
|
||||
|
||||
let stored_task: IngestionTask = db
|
||||
.get_item(&task.id)
|
||||
.await
|
||||
.expect("retrieve task")
|
||||
.expect("task present");
|
||||
assert_eq!(stored_task.state, TaskState::Succeeded);
|
||||
|
||||
let stored_entities: Vec<KnowledgeEntity> = db
|
||||
.get_all_stored_items::<KnowledgeEntity>()
|
||||
.await
|
||||
.expect("entities stored");
|
||||
assert!(!stored_entities.is_empty(), "entities should be stored");
|
||||
|
||||
let stored_chunks: Vec<TextChunk> = db
|
||||
.get_all_stored_items::<TextChunk>()
|
||||
.await
|
||||
.expect("chunks stored");
|
||||
assert!(
|
||||
!stored_chunks.is_empty(),
|
||||
"chunks should be stored for ingestion text"
|
||||
);
|
||||
|
||||
let call_log = services.calls.lock().await.clone();
|
||||
assert!(
|
||||
call_log.len() >= 5,
|
||||
"expected at least one chunk embedding call"
|
||||
);
|
||||
assert_eq!(
|
||||
&call_log[0..4],
|
||||
["prepare", "retrieve", "enrich", "convert"]
|
||||
);
|
||||
assert!(call_log[4..].iter().all(|entry| *entry == "chunk"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ingestion_pipeline_failure_marks_retry() {
|
||||
let db = setup_db().await;
|
||||
let worker_id = "worker-fail";
|
||||
let user_id = "user-456";
|
||||
let services = Arc::new(FailingServices {
|
||||
inner: MockServices::new(user_id),
|
||||
});
|
||||
let pipeline =
|
||||
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)
|
||||
.expect("pipeline");
|
||||
|
||||
let task = reserve_task(
|
||||
&db,
|
||||
worker_id,
|
||||
IngestionPayload::Text {
|
||||
text: "Example failure payload".into(),
|
||||
context: "Context".into(),
|
||||
category: "notes".into(),
|
||||
user_id: user_id.into(),
|
||||
},
|
||||
user_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = pipeline.process_task(task.clone()).await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"failure services should bubble error from pipeline"
|
||||
);
|
||||
|
||||
let stored_task: IngestionTask = db
|
||||
.get_item(&task.id)
|
||||
.await
|
||||
.expect("retrieve task")
|
||||
.expect("task present");
|
||||
assert_eq!(stored_task.state, TaskState::Failed);
|
||||
assert!(
|
||||
stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5),
|
||||
"failed task should schedule retry in the future"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ingestion_pipeline_validation_failure_dead_letters_task() {
|
||||
let db = setup_db().await;
|
||||
let worker_id = "worker-validation";
|
||||
let user_id = "user-789";
|
||||
let services = Arc::new(ValidationServices);
|
||||
let pipeline =
|
||||
IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services)
|
||||
.expect("pipeline");
|
||||
|
||||
let task = reserve_task(
|
||||
&db,
|
||||
worker_id,
|
||||
IngestionPayload::Text {
|
||||
text: "irrelevant".into(),
|
||||
context: "".into(),
|
||||
category: "notes".into(),
|
||||
user_id: user_id.into(),
|
||||
},
|
||||
user_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = pipeline.process_task(task.clone()).await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"validation failure should surface as error"
|
||||
);
|
||||
|
||||
let stored_task: IngestionTask = db
|
||||
.get_item(&task.id)
|
||||
.await
|
||||
.expect("retrieve task")
|
||||
.expect("task present");
|
||||
assert_eq!(stored_task.state, TaskState::DeadLetter);
|
||||
}
|
||||
@@ -1,259 +0,0 @@
|
||||
pub mod llm_enrichment_result;
|
||||
|
||||
use std::io::Write;
|
||||
use std::time::Instant;
|
||||
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::{FieldData, FieldMetadata};
|
||||
use chrono::Utc;
|
||||
use common::storage::db::SurrealDbClient;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
store,
|
||||
types::{
|
||||
file_info::FileInfo,
|
||||
ingestion_payload::IngestionPayload,
|
||||
text_content::{TextContent, UrlInfo},
|
||||
},
|
||||
},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
use dom_smoothie::{Article, Readability, TextMode};
|
||||
use headless_chrome::Browser;
|
||||
use std::io::{Seek, SeekFrom};
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::utils::{
|
||||
audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image,
|
||||
pdf_ingestion::extract_pdf_content,
|
||||
};
|
||||
|
||||
pub async fn to_text_content(
|
||||
ingestion_payload: IngestionPayload,
|
||||
db: &SurrealDbClient,
|
||||
config: &AppConfig,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<TextContent, AppError> {
|
||||
match ingestion_payload {
|
||||
IngestionPayload::Url {
|
||||
url,
|
||||
context,
|
||||
category,
|
||||
user_id,
|
||||
} => {
|
||||
let (article, file_info) = fetch_article_from_url(&url, db, &user_id, config).await?;
|
||||
Ok(TextContent::new(
|
||||
article.text_content.into(),
|
||||
Some(context),
|
||||
category,
|
||||
None,
|
||||
Some(UrlInfo {
|
||||
url,
|
||||
title: article.title,
|
||||
image_id: file_info.id,
|
||||
}),
|
||||
user_id,
|
||||
))
|
||||
}
|
||||
IngestionPayload::Text {
|
||||
text,
|
||||
context,
|
||||
category,
|
||||
user_id,
|
||||
} => Ok(TextContent::new(
|
||||
text,
|
||||
Some(context),
|
||||
category,
|
||||
None,
|
||||
None,
|
||||
user_id,
|
||||
)),
|
||||
IngestionPayload::File {
|
||||
file_info,
|
||||
context,
|
||||
category,
|
||||
user_id,
|
||||
} => {
|
||||
let text = extract_text_from_file(&file_info, db, openai_client, config).await?;
|
||||
Ok(TextContent::new(
|
||||
text,
|
||||
Some(context),
|
||||
category,
|
||||
Some(file_info),
|
||||
None,
|
||||
user_id,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches web content from a URL, extracts the main article text as Markdown,
|
||||
/// captures a screenshot, and stores the screenshot returning [`FileInfo`].
|
||||
///
|
||||
/// This function handles browser automation, content extraction via Readability,
|
||||
/// screenshot capture, temporary file handling, and persisting the screenshot
|
||||
/// details (including deduplication based on content hash via [`FileInfo::new`]).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `url` - The URL of the web page to fetch.
|
||||
/// * `db` - A reference to the database client (`SurrealDbClient`).
|
||||
/// * `user_id` - The ID of the user performing the action, used for associating the file.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `Result` containing:
|
||||
/// * Ok: A tuple `(Article, FileInfo)` where `Article` contains the parsed markdown
|
||||
/// content and metadata, and `FileInfo` contains the details of the stored screenshot.
|
||||
/// * Err: An `AppError` if any step fails (navigation, screenshot, file handling, DB operation).
|
||||
async fn fetch_article_from_url(
|
||||
url: &str,
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
config: &AppConfig,
|
||||
) -> Result<(Article, FileInfo), AppError> {
|
||||
info!("Fetching URL: {}", url);
|
||||
// Instantiate timer
|
||||
let now = Instant::now();
|
||||
// Setup browser, navigate and wait
|
||||
let browser = {
|
||||
#[cfg(feature = "docker")]
|
||||
{
|
||||
// Use this when compiling for docker
|
||||
let options = headless_chrome::LaunchOptionsBuilder::default()
|
||||
.sandbox(false)
|
||||
.build()
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
Browser::new(options)?
|
||||
}
|
||||
#[cfg(not(feature = "docker"))]
|
||||
{
|
||||
// Use this otherwise
|
||||
Browser::default()?
|
||||
}
|
||||
};
|
||||
let tab = browser.new_tab()?;
|
||||
let page = tab.navigate_to(url)?;
|
||||
let loaded_page = page.wait_until_navigated()?;
|
||||
// Get content
|
||||
let raw_content = loaded_page.get_content()?;
|
||||
// Get screenshot
|
||||
let screenshot = loaded_page.capture_screenshot(
|
||||
headless_chrome::protocol::cdp::Page::CaptureScreenshotFormatOption::Jpeg,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)?;
|
||||
|
||||
// Create temp file
|
||||
let mut tmp_file = NamedTempFile::new()?;
|
||||
let temp_path_str = format!("{:?}", tmp_file.path());
|
||||
|
||||
// Write screenshot TO the temp file
|
||||
tmp_file.write_all(&screenshot)?;
|
||||
|
||||
// Ensure the OS buffer is written to the file system _before_ we proceed.
|
||||
tmp_file.as_file().sync_all()?;
|
||||
|
||||
// Ensure the file handle's read cursor is at the beginning before hashing occurs.
|
||||
if let Err(e) = tmp_file.seek(SeekFrom::Start(0)) {
|
||||
error!("URL: {}. Failed to seek temp file {} to start: {:?}. Proceeding, but hashing might fail.", url, temp_path_str, e);
|
||||
}
|
||||
|
||||
// Prepare file metadata
|
||||
let parsed_url =
|
||||
url::Url::parse(url).map_err(|_| AppError::Processing("Invalid URL".to_string()))?;
|
||||
let domain = parsed_url
|
||||
.host_str()
|
||||
.unwrap_or("unknown")
|
||||
.replace(|c: char| !c.is_alphanumeric(), "_");
|
||||
let timestamp = Utc::now().format("%Y%m%d%H%M%S");
|
||||
let file_name = format!("{}_{}_{}.jpg", domain, "screenshot", timestamp);
|
||||
|
||||
// Construct FieldData and FieldMetadata
|
||||
let metadata = FieldMetadata {
|
||||
file_name: Some(file_name),
|
||||
content_type: Some("image/jpeg".to_string()),
|
||||
name: None,
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
let field_data = FieldData {
|
||||
contents: tmp_file,
|
||||
metadata,
|
||||
};
|
||||
|
||||
// Store screenshot
|
||||
let file_info = FileInfo::new(field_data, db, user_id, config).await?;
|
||||
|
||||
// Parse content...
|
||||
let config = dom_smoothie::Config {
|
||||
text_mode: TextMode::Markdown,
|
||||
..Default::default()
|
||||
};
|
||||
let mut readability = Readability::new(raw_content, None, Some(config))?;
|
||||
let article: Article = readability.parse()?;
|
||||
let end = now.elapsed();
|
||||
info!(
|
||||
"URL: {}. Total time: {:?}. Final File ID: {}",
|
||||
url, end, file_info.id
|
||||
);
|
||||
|
||||
Ok((article, file_info))
|
||||
}
|
||||
|
||||
/// Extracts text from a stored file by MIME type.
|
||||
async fn extract_text_from_file(
|
||||
file_info: &FileInfo,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
config: &AppConfig,
|
||||
) -> Result<String, AppError> {
|
||||
let base_path = store::resolve_base_dir(config);
|
||||
let absolute_path = base_path.join(&file_info.path);
|
||||
|
||||
match file_info.mime_type.as_str() {
|
||||
"text/plain" | "text/markdown" | "application/octet-stream" | "text/x-rust" => {
|
||||
let content = tokio::fs::read_to_string(&absolute_path).await?;
|
||||
Ok(content)
|
||||
}
|
||||
"application/pdf" => {
|
||||
extract_pdf_content(
|
||||
&absolute_path,
|
||||
db_client,
|
||||
openai_client,
|
||||
&config.pdf_ingest_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
"image/png" | "image/jpeg" => {
|
||||
let path_str = absolute_path
|
||||
.to_str()
|
||||
.ok_or_else(|| {
|
||||
AppError::Processing(format!(
|
||||
"Encountered a non-UTF8 path while reading image {}",
|
||||
file_info.id
|
||||
))
|
||||
})?
|
||||
.to_string();
|
||||
let content = extract_text_from_image(&path_str, db_client, openai_client).await?;
|
||||
Ok(content)
|
||||
}
|
||||
"audio/mpeg" | "audio/mp3" | "audio/wav" | "audio/x-wav" | "audio/webm" | "audio/mp4"
|
||||
| "audio/ogg" | "audio/flac" => {
|
||||
let path_str = absolute_path
|
||||
.to_str()
|
||||
.ok_or_else(|| {
|
||||
AppError::Processing(format!(
|
||||
"Encountered a non-UTF8 path while reading audio {}",
|
||||
file_info.id
|
||||
))
|
||||
})?
|
||||
.to_string();
|
||||
transcribe_audio_file(&path_str, db_client, openai_client).await
|
||||
}
|
||||
// Handle other MIME types as needed
|
||||
_ => Err(AppError::NotFound(file_info.mime_type.clone())),
|
||||
}
|
||||
}
|
||||
63
ingestion-pipeline/src/utils/file_text_extraction.rs
Normal file
63
ingestion-pipeline/src/utils/file_text_extraction.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, store, types::file_info::FileInfo},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
|
||||
use super::{
|
||||
audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image,
|
||||
pdf_ingestion::extract_pdf_content,
|
||||
};
|
||||
|
||||
pub async fn extract_text_from_file(
|
||||
file_info: &FileInfo,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
config: &AppConfig,
|
||||
) -> Result<String, AppError> {
|
||||
let base_path = store::resolve_base_dir(config);
|
||||
let absolute_path = base_path.join(&file_info.path);
|
||||
|
||||
match file_info.mime_type.as_str() {
|
||||
"text/plain" | "text/markdown" | "application/octet-stream" | "text/x-rust" => {
|
||||
let content = tokio::fs::read_to_string(&absolute_path).await?;
|
||||
Ok(content)
|
||||
}
|
||||
"application/pdf" => {
|
||||
extract_pdf_content(
|
||||
&absolute_path,
|
||||
db_client,
|
||||
openai_client,
|
||||
&config.pdf_ingest_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
"image/png" | "image/jpeg" => {
|
||||
let path_str = absolute_path
|
||||
.to_str()
|
||||
.ok_or_else(|| {
|
||||
AppError::Processing(format!(
|
||||
"Encountered a non-UTF8 path while reading image {}",
|
||||
file_info.id
|
||||
))
|
||||
})?
|
||||
.to_string();
|
||||
let content = extract_text_from_image(&path_str, db_client, openai_client).await?;
|
||||
Ok(content)
|
||||
}
|
||||
"audio/mpeg" | "audio/mp3" | "audio/wav" | "audio/x-wav" | "audio/webm" | "audio/mp4"
|
||||
| "audio/ogg" | "audio/flac" => {
|
||||
let path_str = absolute_path
|
||||
.to_str()
|
||||
.ok_or_else(|| {
|
||||
AppError::Processing(format!(
|
||||
"Encountered a non-UTF8 path while reading audio {}",
|
||||
file_info.id
|
||||
))
|
||||
})?
|
||||
.to_string();
|
||||
transcribe_audio_file(&path_str, db_client, openai_client).await
|
||||
}
|
||||
_ => Err(AppError::NotFound(file_info.mime_type.clone())),
|
||||
}
|
||||
}
|
||||
53
ingestion-pipeline/src/utils/graph_mapper.rs
Normal file
53
ingestion-pipeline/src/utils/graph_mapper.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use common::error::AppError;
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
|
||||
#[derive(Clone)]
|
||||
pub struct GraphMapper {
|
||||
pub key_to_id: HashMap<String, Uuid>,
|
||||
}
|
||||
|
||||
impl Default for GraphMapper {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
key_to_id: HashMap::new(),
|
||||
}
|
||||
}
|
||||
/// Tries to get an ID by first parsing the key as a UUID,
|
||||
/// and if that fails, looking it up in the internal map.
|
||||
pub fn get_or_parse_id(&self, key: &str) -> Result<Uuid, AppError> {
|
||||
// First, try to parse the key as a UUID.
|
||||
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
|
||||
return Ok(parsed_uuid);
|
||||
}
|
||||
|
||||
// If parsing fails, look it up in the map.
|
||||
self.key_to_id.get(key).copied().ok_or_else(|| {
|
||||
AppError::GraphMapper(format!(
|
||||
"Key '{key}' is not a valid UUID and was not found in the map."
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Assigns a new UUID for a given key. (No changes needed here)
|
||||
pub fn assign_id(&mut self, key: &str) -> Uuid {
|
||||
let id = Uuid::new_v4();
|
||||
self.key_to_id.insert(key.to_string(), id);
|
||||
id
|
||||
}
|
||||
|
||||
/// Retrieves the UUID for a given key, returning a Result for consistency.
|
||||
pub fn get_id(&self, key: &str) -> Result<Uuid, AppError> {
|
||||
self.key_to_id
|
||||
.get(key)
|
||||
.copied()
|
||||
.ok_or_else(|| AppError::GraphMapper(format!("Key '{key}' not found in map.")))
|
||||
}
|
||||
}
|
||||
@@ -1,58 +1,7 @@
|
||||
pub mod audio_transcription;
|
||||
pub mod file_text_extraction;
|
||||
pub mod graph_mapper;
|
||||
pub mod image_parsing;
|
||||
pub mod llm_instructions;
|
||||
pub mod pdf_ingestion;
|
||||
|
||||
use common::error::AppError;
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
|
||||
#[derive(Clone)]
|
||||
pub struct GraphMapper {
|
||||
pub key_to_id: HashMap<String, Uuid>,
|
||||
}
|
||||
|
||||
impl Default for GraphMapper {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
key_to_id: HashMap::new(),
|
||||
}
|
||||
}
|
||||
/// Tries to get an ID by first parsing the key as a UUID,
|
||||
/// and if that fails, looking it up in the internal map.
|
||||
pub fn get_or_parse_id(&self, key: &str) -> Result<Uuid, AppError> {
|
||||
// First, try to parse the key as a UUID.
|
||||
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
|
||||
return Ok(parsed_uuid);
|
||||
}
|
||||
|
||||
// If parsing fails, look it up in the map.
|
||||
self.key_to_id.get(key).copied().ok_or_else(|| {
|
||||
AppError::GraphMapper(format!(
|
||||
"Key '{key}' is not a valid UUID and was not found in the map."
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Assigns a new UUID for a given key. (No changes needed here)
|
||||
pub fn assign_id(&mut self, key: &str) -> Uuid {
|
||||
let id = Uuid::new_v4();
|
||||
self.key_to_id.insert(key.to_string(), id);
|
||||
id
|
||||
}
|
||||
|
||||
/// Retrieves the UUID for a given key, returning a Result for consistency.
|
||||
pub fn get_id(&self, key: &str) -> Result<Uuid, AppError> {
|
||||
self.key_to_id
|
||||
.get(key)
|
||||
.copied()
|
||||
.ok_or_else(|| AppError::GraphMapper(format!("Key '{key}' not found in map.")))
|
||||
}
|
||||
}
|
||||
pub mod url_text_retrieval;
|
||||
|
||||
174
ingestion-pipeline/src/utils/url_text_retrieval.rs
Normal file
174
ingestion-pipeline/src/utils/url_text_retrieval.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::{FieldData, FieldMetadata};
|
||||
use chrono::Utc;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::file_info::FileInfo},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
use dom_smoothie::{Article, Readability, TextMode};
|
||||
use headless_chrome::Browser;
|
||||
use std::{
|
||||
io::{Seek, SeekFrom, Write},
|
||||
net::IpAddr,
|
||||
time::Instant,
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::{error, info, warn};
|
||||
pub async fn extract_text_from_url(
|
||||
url: &str,
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
config: &AppConfig,
|
||||
) -> Result<(Article, FileInfo), AppError> {
|
||||
info!("Fetching URL: {}", url);
|
||||
let now = Instant::now();
|
||||
|
||||
let browser = {
|
||||
#[cfg(feature = "docker")]
|
||||
{
|
||||
let options = headless_chrome::LaunchOptionsBuilder::default()
|
||||
.sandbox(false)
|
||||
.build()
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
Browser::new(options)?
|
||||
}
|
||||
#[cfg(not(feature = "docker"))]
|
||||
{
|
||||
Browser::default()?
|
||||
}
|
||||
};
|
||||
|
||||
let tab = browser.new_tab()?;
|
||||
let page = tab.navigate_to(url)?;
|
||||
let loaded_page = page.wait_until_navigated()?;
|
||||
let raw_content = loaded_page.get_content()?;
|
||||
let screenshot = loaded_page.capture_screenshot(
|
||||
headless_chrome::protocol::cdp::Page::CaptureScreenshotFormatOption::Jpeg,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)?;
|
||||
|
||||
let mut tmp_file = NamedTempFile::new()?;
|
||||
let temp_path_str = format!("{:?}", tmp_file.path());
|
||||
|
||||
tmp_file.write_all(&screenshot)?;
|
||||
tmp_file.as_file().sync_all()?;
|
||||
|
||||
if let Err(e) = tmp_file.seek(SeekFrom::Start(0)) {
|
||||
error!(
|
||||
"URL: {}. Failed to seek temp file {} to start: {:?}. Proceeding, but hashing might fail.",
|
||||
url, temp_path_str, e
|
||||
);
|
||||
}
|
||||
|
||||
let parsed_url =
|
||||
url::Url::parse(url).map_err(|_| AppError::Validation("Invalid URL".to_string()))?;
|
||||
|
||||
let domain = ensure_ingestion_url_allowed(&parsed_url)?;
|
||||
let timestamp = Utc::now().format("%Y%m%d%H%M%S");
|
||||
let file_name = format!("{}_{}_{}.jpg", domain, "screenshot", timestamp);
|
||||
|
||||
let metadata = FieldMetadata {
|
||||
file_name: Some(file_name),
|
||||
content_type: Some("image/jpeg".to_string()),
|
||||
name: None,
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
let field_data = FieldData {
|
||||
contents: tmp_file,
|
||||
metadata,
|
||||
};
|
||||
|
||||
let file_info = FileInfo::new(field_data, db, user_id, config).await?;
|
||||
|
||||
let config = dom_smoothie::Config {
|
||||
text_mode: TextMode::Markdown,
|
||||
..Default::default()
|
||||
};
|
||||
let mut readability = Readability::new(raw_content, None, Some(config))?;
|
||||
let article: Article = readability.parse()?;
|
||||
let end = now.elapsed();
|
||||
info!(
|
||||
"URL: {}. Total time: {:?}. Final File ID: {}",
|
||||
url, end, file_info.id
|
||||
);
|
||||
|
||||
Ok((article, file_info))
|
||||
}
|
||||
|
||||
fn ensure_ingestion_url_allowed(url: &url::Url) -> Result<String, AppError> {
|
||||
match url.scheme() {
|
||||
"http" | "https" => {}
|
||||
scheme => {
|
||||
warn!(%url, %scheme, "Rejected ingestion URL due to unsupported scheme");
|
||||
return Err(AppError::Validation(
|
||||
"Unsupported URL scheme for ingestion".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let host = match url.host_str() {
|
||||
Some(host) => host,
|
||||
None => {
|
||||
warn!(%url, "Rejected ingestion URL missing host");
|
||||
return Err(AppError::Validation(
|
||||
"URL is missing a host component".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if host.eq_ignore_ascii_case("localhost") {
|
||||
warn!(%url, host, "Rejected ingestion URL to localhost");
|
||||
return Err(AppError::Validation(
|
||||
"Ingestion URL host is not allowed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
let is_disallowed = match ip {
|
||||
IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(),
|
||||
IpAddr::V6(v6) => v6.is_unique_local() || v6.is_unicast_link_local(),
|
||||
};
|
||||
|
||||
if ip.is_loopback() || ip.is_unspecified() || ip.is_multicast() || is_disallowed {
|
||||
warn!(%url, host, %ip, "Rejected ingestion URL pointing to restricted network range");
|
||||
return Err(AppError::Validation(
|
||||
"Ingestion URL host is not allowed".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(host.replace(|c: char| !c.is_alphanumeric(), "_"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn rejects_unsupported_scheme() {
|
||||
let url = url::Url::parse("ftp://example.com").expect("url");
|
||||
assert!(ensure_ingestion_url_allowed(&url).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_localhost() {
|
||||
let url = url::Url::parse("http://localhost/resource").expect("url");
|
||||
assert!(ensure_ingestion_url_allowed(&url).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_private_ipv4() {
|
||||
let url = url::Url::parse("http://192.168.1.10/index.html").expect("url");
|
||||
assert!(ensure_ingestion_url_allowed(&url).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allows_public_domain_and_sanitizes() {
|
||||
let url = url::Url::parse("https://sub.example.com/path").expect("url");
|
||||
let sanitized = ensure_ingestion_url_allowed(&url).expect("allowed");
|
||||
assert_eq!(sanitized, "sub_example_com");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user