diff --git a/src/bin/server.rs b/src/bin/server.rs index 31bf083..3b3a479 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -78,16 +78,18 @@ async fn main() -> Result<(), Box> { .await?, ); + let openai_client = Arc::new(async_openai::Client::new()); + let app_state = AppState { surreal_db_client: surreal_db_client.clone(), - openai_client: Arc::new(async_openai::Client::new()), templates: Arc::new(reloader), + openai_client: openai_client.clone(), mailer: Arc::new(Mailer::new( config.smtp_username, config.smtp_relayer, config.smtp_password, )?), - job_queue: Arc::new(JobQueue::new(surreal_db_client)), + job_queue: Arc::new(JobQueue::new(surreal_db_client, openai_client)), }; let session_config = SessionConfig::default() diff --git a/src/bin/worker.rs b/src/bin/worker.rs index 2098829..ab214bd 100644 --- a/src/bin/worker.rs +++ b/src/bin/worker.rs @@ -38,9 +38,11 @@ async fn main() -> Result<(), Box> { .await?, ); - let job_queue = JobQueue::new(surreal_db_client.clone()); + let openai_client = Arc::new(async_openai::Client::new()); - let content_processor = ContentProcessor::new(surreal_db_client).await?; + let job_queue = JobQueue::new(surreal_db_client.clone(), openai_client.clone()); + + let content_processor = ContentProcessor::new(surreal_db_client, openai_client).await?; loop { // First, check for any unfinished jobs diff --git a/src/ingress/content_processor.rs b/src/ingress/content_processor.rs index fbc746f..266162f 100644 --- a/src/ingress/content_processor.rs +++ b/src/ingress/content_processor.rs @@ -21,14 +21,17 @@ use super::analysis::{ pub struct ContentProcessor { db_client: Arc, - openai_client: async_openai::Client, + openai_client: Arc>, } impl ContentProcessor { - pub async fn new(surreal_db_client: Arc) -> Result { + pub async fn new( + surreal_db_client: Arc, + openai_client: Arc>, + ) -> Result { Ok(Self { db_client: surreal_db_client, - openai_client: async_openai::Client::new(), + openai_client, }) } diff --git a/src/ingress/jobqueue.rs b/src/ingress/jobqueue.rs index df8c1c8..5e5bd48 100644 --- a/src/ingress/jobqueue.rs +++ b/src/ingress/jobqueue.rs @@ -1,11 +1,8 @@ use chrono::Utc; use futures::Stream; -use std::{ - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, -}; +use std::sync::Arc; use surrealdb::{opt::PatchOp, Error, Notification}; -use tracing::{error, info}; +use tracing::{debug, error, info}; use crate::{ error::AppError, @@ -22,21 +19,28 @@ use super::{content_processor::ContentProcessor, types::ingress_object::IngressO pub struct JobQueue { pub db: Arc, + pub openai_client: Arc>, } pub const MAX_ATTEMPTS: u32 = 3; impl JobQueue { - pub fn new(db: Arc) -> Self { - Self { db } + pub fn new( + db: Arc, + openai_client: Arc>, + ) -> Self { + Self { db, openai_client } } /// Creates a new job and stores it in the database - pub async fn enqueue(&self, content: IngressObject, user_id: String) -> Result { + pub async fn enqueue(&self, content: IngressObject, user_id: String) -> Result<(), AppError> { let job = Job::new(content, user_id).await; + info!("{:?}", job); - store_item(&self.db, job.clone()).await?; - Ok(job) + + store_item(&self.db, job).await?; + + Ok(()) } /// Gets all jobs for a specific user @@ -44,11 +48,12 @@ impl JobQueue { let jobs: Vec = self .db .query("SELECT * FROM job WHERE user_id = $user_id ORDER BY created_at DESC") - .bind(("user_id", user_id.to_string())) + .bind(("user_id", user_id.to_owned())) .await? .take(0)?; - info!("{:?}", jobs); + debug!("{:?}", jobs); + Ok(jobs) } @@ -69,12 +74,8 @@ impl JobQueue { Ok(()) } - pub async fn update_status( - &self, - id: &str, - status: JobStatus, - ) -> Result, AppError> { - let job: Option = self + pub async fn update_status(&self, id: &str, status: JobStatus) -> Result<(), AppError> { + let _job: Option = self .db .update((Job::table_name(), id)) .patch(PatchOp::replace("/status", status)) @@ -84,7 +85,7 @@ impl JobQueue { )) .await?; - Ok(job) + Ok(()) } /// Listen for new jobs @@ -137,7 +138,7 @@ impl JobQueue { ) .await?; - let text_content = job.content.to_text_content().await?; + let text_content = job.content.to_text_content(&self.openai_client).await?; match processor.process(&text_content).await { Ok(_) => { diff --git a/src/ingress/types/ingress_object.rs b/src/ingress/types/ingress_object.rs index 3ca70fd..cb0c6e5 100644 --- a/src/ingress/types/ingress_object.rs +++ b/src/ingress/types/ingress_object.rs @@ -1,3 +1,5 @@ +use std::{sync::Arc, time::Duration}; + use crate::{ error::AppError, storage::types::{file_info::FileInfo, text_content::TextContent}, @@ -9,7 +11,8 @@ use async_openai::types::{ use reqwest; use scraper::{Html, Selector}; use serde::{Deserialize, Serialize}; -use tiktoken_rs::o200k_base; +use std::fmt::Write; +use tiktoken_rs::{o200k_base, CoreBPE}; use tracing::info; /// Knowledge object type, containing the content or reference to it, as well as metadata @@ -43,7 +46,10 @@ impl IngressObject { /// /// # Returns /// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc. - pub async fn to_text_content(&self) -> Result { + pub async fn to_text_content( + &self, + openai_client: &Arc>, + ) -> Result { match self { IngressObject::Url { url, @@ -51,7 +57,7 @@ impl IngressObject { category, user_id, } => { - let text = Self::fetch_text_from_url(url).await?; + let text = Self::fetch_text_from_url(url, openai_client).await?; Ok(TextContent::new( text, instructions.into(), @@ -90,69 +96,62 @@ impl IngressObject { } } - /// Fetches and extracts text from a URL. - async fn fetch_text_from_url(url: &str) -> Result { - let response = reqwest::get(url).await?.text().await?; - let document = Html::parse_document(&response); + /// Get text from url, will return it as a markdown formatted string + async fn fetch_text_from_url( + url: &str, + openai_client: &Arc>, + ) -> Result { + // Use a client with timeouts and reuse + let client = reqwest::ClientBuilder::new() + .timeout(Duration::from_secs(30)) + .build()?; + let response = client.get(url).send().await?.text().await?; - // Select main content areas first - let main_selectors = Selector::parse(concat!( - "article, main, .article-content,", // Common main content classes - ".post-content, .entry-content,", // Common blog/article classes - "[role='main']" // Accessibility marker - )) + // Preallocate string with capacity + let mut structured_content = String::with_capacity(response.len() / 2); + + let document = Html::parse_document(&response); + let main_selectors = Selector::parse( + "article, main, .article-content, .post-content, .entry-content, [role='main']", + ) .unwrap(); - // If no main content found, fallback to body let content_element = document .select(&main_selectors) .next() .or_else(|| document.select(&Selector::parse("body").unwrap()).next()) .ok_or(AppError::NotFound("No content found".into()))?; - // Remove unwanted elements but preserve structure - // let exclude_selector = Selector::parse(concat!( - // "script, style, noscript,", - // "[class*='window'], [id*='window'],", - // "[class*='env'], [id*='env'],", - // "iframe, nav, footer, .comments,", - // ".advertisement, .social-share" - // )) - // .unwrap(); + // Compile selectors once + let heading_selector = Selector::parse("h1, h2, h3").unwrap(); + let paragraph_selector = Selector::parse("p").unwrap(); - // Collect structured content - let mut structured_content = String::new(); - - // Process headings - for heading in content_element.select(&Selector::parse("h1, h2, h3").unwrap()) { - structured_content.push_str(&format!( - "{}\n", - heading.text().collect::().trim() - )); + // Process content in one pass + for element in content_element.select(&heading_selector) { + let _ = writeln!( + structured_content, + "{}", + element.text().collect::().trim() + ); + } + for element in content_element.select(¶graph_selector) { + let _ = writeln!( + structured_content, + "{}", + element.text().collect::().trim() + ); } - // Process paragraphs - for paragraph in content_element.select(&Selector::parse("p").unwrap()) { - structured_content.push_str(&format!( - "{}\n", - paragraph.text().collect::().trim() - )); - } - - // Clean up let content = structured_content .replace(|c: char| c.is_control(), " ") .replace(" ", " "); - - let processed_content = Self::process_web_content(content.trim().to_string()).await?; - - info!("Extracted content from page: {:?}", processed_content); - - Ok(processed_content) + Self::process_web_content(content, openai_client).await } - pub async fn process_web_content(content: String) -> Result { - let openai_client = async_openai::Client::new(); + pub async fn process_web_content( + content: String, + openai_client: &Arc>, + ) -> Result { const MAX_TOKENS: usize = 122000; const SYSTEM_PROMPT: &str = r#" You are a precise content extractor for web pages. Your task: @@ -182,25 +181,10 @@ impl IngressObject { "#; let bpe = o200k_base()?; - let token_count = bpe.encode_with_special_tokens(&content).len(); - let content = if token_count > MAX_TOKENS { - // Split content into structural blocks - let blocks: Vec<&str> = content.split('\n').collect(); - let mut truncated = String::new(); - let mut current_tokens = 0; - - // Keep adding blocks until we approach the limit - for block in blocks { - let block_tokens = bpe.encode_with_special_tokens(block).len(); - if current_tokens + block_tokens > MAX_TOKENS { - break; - } - truncated.push_str(block); - truncated.push('\n'); - current_tokens += block_tokens; - } - truncated + // Process content in chunks if needed + let truncated_content = if bpe.encode_with_special_tokens(&content).len() > MAX_TOKENS { + Self::truncate_content(&content, MAX_TOKENS, &bpe)? } else { content }; @@ -211,7 +195,7 @@ impl IngressObject { .max_tokens(16200u32) .messages([ ChatCompletionRequestSystemMessage::from(SYSTEM_PROMPT).into(), - ChatCompletionRequestUserMessage::from(content).into(), + ChatCompletionRequestUserMessage::from(truncated_content).into(), ]) .build()?; @@ -221,10 +205,41 @@ impl IngressObject { .choices .first() .and_then(|choice| choice.message.content.as_ref()) - .map(|content| content.to_string()) + .map(|content| content.to_owned()) .ok_or(AppError::LLMParsing("No content in response".into())) } + fn truncate_content( + content: &str, + max_tokens: usize, + tokenizer: &CoreBPE, + ) -> Result { + // Pre-allocate with estimated size + let mut result = String::with_capacity(content.len() / 2); + let mut current_tokens = 0; + + // Process content by paragraph to maintain context + for paragraph in content.split("\n\n") { + let tokens = tokenizer.encode_with_special_tokens(paragraph).len(); + + // Check if adding paragraph exceeds limit + if current_tokens + tokens > max_tokens { + break; + } + + result.push_str(paragraph); + result.push_str("\n\n"); + current_tokens += tokens; + } + + // Ensure we return valid content + if result.is_empty() { + return Err(AppError::Processing("Content exceeds token limit".into())); + } + + Ok(result.trim_end().to_string()) + } + /// Extracts text from a file based on its MIME type. async fn extract_text_from_file(file_info: &FileInfo) -> Result { match file_info.mime_type.as_str() {