use std::{ ops::Range, sync::{Arc, OnceLock}, }; use async_openai::types::{ ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat, ResponseFormatJsonSchema, }; use async_trait::async_trait; use common::{ error::AppError, storage::{ db::SurrealDbClient, store::StorageManager, types::{ ingestion_payload::IngestionPayload, knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings, text_chunk::TextChunk, text_content::TextContent, StoredObject, }, }, utils::{config::AppConfig, embedding::EmbeddingProvider}, }; use retrieval_pipeline::{reranking::RerankerPool, retrieved_entities_to_json, RetrievedEntity}; use text_splitter::{ChunkCapacity, ChunkConfig, TextSplitter}; use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content}; use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk}; use crate::utils::llm_instructions::get_ingress_analysis_schema; const EMBEDDING_QUERY_CHAR_LIMIT: usize = 12_000; #[async_trait] pub trait PipelineServices: Send + Sync { async fn prepare_text_content( &self, payload: IngestionPayload, ) -> Result; async fn retrieve_similar_entities( &self, content: &TextContent, ) -> Result, AppError>; async fn run_enrichment( &self, content: &TextContent, similar_entities: &[RetrievedEntity], ) -> Result; async fn convert_analysis( &self, content: &TextContent, analysis: &LLMEnrichmentResult, entity_concurrency: usize, ) -> Result<(Vec, Vec), AppError>; async fn prepare_chunks( &self, content: &TextContent, token_range: Range, overlap_tokens: usize, ) -> Result, AppError>; } pub struct DefaultPipelineServices { db: Arc, openai_client: Arc>, config: AppConfig, reranker_pool: Option>, storage: StorageManager, embedding_provider: Arc, } impl DefaultPipelineServices { pub fn new( db: Arc, openai_client: Arc>, config: AppConfig, reranker_pool: Option>, storage: StorageManager, embedding_provider: Arc, ) -> Self { Self { db, openai_client, config, reranker_pool, storage, embedding_provider, } } async fn prepare_llm_request( &self, category: &str, context: Option<&str>, text: &str, similar_entities: &[RetrievedEntity], ) -> Result { let settings = SystemSettings::get_current(&self.db).await?; let entities_json = retrieved_entities_to_json(similar_entities); let user_message = format!( "Category:\n{category}\ncontext:\n{context:?}\nContent:\n{text}\nExisting KnowledgeEntities in database:\n{entities_json}" ); let response_format = ResponseFormat::JsonSchema { json_schema: ResponseFormatJsonSchema { description: Some("Structured analysis of the submitted content".into()), name: "content_analysis".into(), schema: Some(get_ingress_analysis_schema()), strict: Some(true), }, }; let request = CreateChatCompletionRequestArgs::default() .model(&settings.processing_model) .messages([ ChatCompletionRequestSystemMessage::from(settings.ingestion_system_prompt.as_str()) .into(), ChatCompletionRequestUserMessage::from(user_message).into(), ]) .response_format(response_format) .build()?; Ok(request) } async fn perform_analysis( &self, request: CreateChatCompletionRequest, ) -> Result { let response = self.openai_client.chat().create(request).await?; let content = response .choices .first() .and_then(|choice| choice.message.content.as_ref()) .ok_or(AppError::LLMParsing( "No content found in LLM response".into(), ))?; serde_json::from_str::(content).map_err(|e| { AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}")) }) } } #[async_trait] impl PipelineServices for DefaultPipelineServices { async fn prepare_text_content( &self, payload: IngestionPayload, ) -> Result { to_text_content( payload, &self.db, &self.config, &self.openai_client, &self.storage, ) .await } async fn retrieve_similar_entities( &self, content: &TextContent, ) -> Result, AppError> { let truncated_body = truncate_for_embedding(&content.text, EMBEDDING_QUERY_CHAR_LIMIT); let input_text = format!( "content: {}\n[truncated={}], category: {}, user_context: {:?}", truncated_body, truncated_body.len() < content.text.len(), content.category, content.context ); let rerank_lease = match &self.reranker_pool { Some(pool) => pool.checkout().await, None => None, }; let config = retrieval_pipeline::RetrievalConfig::for_search( retrieval_pipeline::SearchTarget::EntitiesOnly, ); match retrieval_pipeline::retrieve_entities( &self.db, &self.openai_client, Some(&*self.embedding_provider), &input_text, &content.user_id, config, rerank_lease, ) .await { Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities), Ok(retrieval_pipeline::StrategyOutput::Search(search)) => { let chunk_count = search.chunks.len(); let entities = search.entities; tracing::debug!( chunk_count, entity_count = entities.len(), "ingestion search results returned entities" ); Ok(entities) } Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError( "Ingestion retrieval should return entities".into(), )), Err(e) => Err(e), } } async fn run_enrichment( &self, content: &TextContent, similar_entities: &[RetrievedEntity], ) -> Result { let request = self .prepare_llm_request( &content.category, content.context.as_deref(), &content.text, similar_entities, ) .await?; self.perform_analysis(request).await } async fn convert_analysis( &self, content: &TextContent, analysis: &LLMEnrichmentResult, entity_concurrency: usize, ) -> Result<(Vec, Vec), AppError> { analysis .to_database_entities( content.id(), &content.user_id, &self.openai_client, &self.db, entity_concurrency, Some(&*self.embedding_provider), ) .await } async fn prepare_chunks( &self, content: &TextContent, token_range: Range, overlap_tokens: usize, ) -> Result, AppError> { let chunk_candidates = prepare_chunks( &content.text, token_range.start, token_range.end, overlap_tokens, )?; let mut chunks = Vec::with_capacity(chunk_candidates.len()); for chunk_text in chunk_candidates { let embedding = self .embedding_provider .embed(&chunk_text) .await .map_err(|e| AppError::InternalError(format!("FastEmbed embedding for chunk failed: {e}")))?; let chunk_struct = TextChunk::new( content.id().to_string(), chunk_text, content.user_id.clone(), ); chunks.push(EmbeddedTextChunk { chunk: chunk_struct, embedding, }); } Ok(chunks) } } fn prepare_chunks( text: &str, min_tokens: usize, max_tokens: usize, overlap_tokens: usize, ) -> Result, AppError> { if min_tokens == 0 || max_tokens == 0 || min_tokens > max_tokens { return Err(AppError::Validation( "invalid chunk token bounds; ensure 0 < min <= max".into(), )); } if overlap_tokens >= min_tokens { return Err(AppError::Validation(format!( "chunk_min_tokens must be greater than the configured overlap of {overlap_tokens}" ))); } let tokenizer = get_tokenizer()?; let chunk_capacity = ChunkCapacity::new(min_tokens) .with_max(max_tokens) .map_err(|e| AppError::Validation(format!("invalid chunk token bounds: {e}")))?; let chunk_config = ChunkConfig::new(chunk_capacity) .with_overlap(overlap_tokens) .map_err(|e| AppError::Validation(format!("invalid chunk overlap: {e}")))? .with_sizer(tokenizer); let splitter = TextSplitter::new(chunk_config); let mut chunks: Vec = splitter.chunks(text).map(str::to_owned).collect(); if chunks.is_empty() { chunks.push(String::new()); } Ok(chunks) } fn get_tokenizer() -> Result<&'static tokenizers::Tokenizer, AppError> { static TOKENIZER: OnceLock> = OnceLock::new(); match TOKENIZER.get_or_init(|| { tokenizers::Tokenizer::from_pretrained("bert-base-cased", None) .map_err(|e| format!("failed to initialize tokenizer: {e}")) }) { Ok(tokenizer) => Ok(tokenizer), Err(err) => Err(AppError::InternalError(err.clone())), } } fn truncate_for_embedding(text: &str, max_chars: usize) -> String { if text.chars().count() <= max_chars { return text.to_string(); } let mut truncated = String::with_capacity(max_chars.saturating_add(3)); for (idx, ch) in text.chars().enumerate() { if idx >= max_chars { break; } truncated.push(ch); } truncated.push('…'); truncated } #[cfg(test)] mod tests { use std::sync::Arc; use anyhow::Context; use async_openai::{config::OpenAIConfig, types::ChatCompletionRequestMessage, Client}; use common::{ storage::{ db::SurrealDbClient, store::StorageManager, types::system_settings::SystemSettingsPatch, }, utils::{ config::{AppConfig, StorageKind}, embedding::EmbeddingProvider, }, }; use uuid::Uuid; use super::DefaultPipelineServices; fn system_prompt_from_request( request: &async_openai::types::CreateChatCompletionRequest, ) -> String { let ChatCompletionRequestMessage::System(system) = &request.messages[0] else { panic!("expected first message to be system"); }; match &system.content { async_openai::types::ChatCompletionRequestSystemMessageContent::Text(text) => { text.clone() } other => panic!("unexpected system message content: {other:?}"), } } #[tokio::test] async fn prepare_llm_request_uses_ingestion_prompt_from_system_settings( ) -> anyhow::Result<()> { const SENTINEL: &str = "ingestion-prompt-sentinel-from-db"; let db = Arc::new( SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()) .await .context("start in-memory db")?, ); db.apply_migrations().await.context("apply migrations")?; SystemSettingsPatch { ingestion_system_prompt: Some(SENTINEL.to_string()), ..Default::default() } .apply(&db) .await .context("patch ingestion prompt")?; let config = AppConfig { storage: StorageKind::Memory, ..Default::default() }; let storage = StorageManager::new(&config).await.context("storage manager")?; let openai_client = Arc::new(Client::with_config(OpenAIConfig::default())); let embedding_provider = Arc::new(EmbeddingProvider::new_hashed(384)?); let services = DefaultPipelineServices::new( db, openai_client, config, None, storage, embedding_provider, ); let request = services .prepare_llm_request("notes", None, "hello world", &[]) .await .context("prepare llm request")?; assert_eq!(system_prompt_from_request(&request), SENTINEL); Ok(()) } }