feat: reduced memory usage

This commit is contained in:
Per Stark
2025-01-16 08:29:49 +01:00
parent 2e70bd0636
commit e58ead5cd7
5 changed files with 119 additions and 96 deletions

View File

@@ -78,16 +78,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await?, .await?,
); );
let openai_client = Arc::new(async_openai::Client::new());
let app_state = AppState { let app_state = AppState {
surreal_db_client: surreal_db_client.clone(), surreal_db_client: surreal_db_client.clone(),
openai_client: Arc::new(async_openai::Client::new()),
templates: Arc::new(reloader), templates: Arc::new(reloader),
openai_client: openai_client.clone(),
mailer: Arc::new(Mailer::new( mailer: Arc::new(Mailer::new(
config.smtp_username, config.smtp_username,
config.smtp_relayer, config.smtp_relayer,
config.smtp_password, config.smtp_password,
)?), )?),
job_queue: Arc::new(JobQueue::new(surreal_db_client)), job_queue: Arc::new(JobQueue::new(surreal_db_client, openai_client)),
}; };
let session_config = SessionConfig::default() let session_config = SessionConfig::default()

View File

@@ -38,9 +38,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await?, .await?,
); );
let job_queue = JobQueue::new(surreal_db_client.clone()); let openai_client = Arc::new(async_openai::Client::new());
let content_processor = ContentProcessor::new(surreal_db_client).await?; let job_queue = JobQueue::new(surreal_db_client.clone(), openai_client.clone());
let content_processor = ContentProcessor::new(surreal_db_client, openai_client).await?;
loop { loop {
// First, check for any unfinished jobs // First, check for any unfinished jobs

View File

@@ -21,14 +21,17 @@ use super::analysis::{
pub struct ContentProcessor { pub struct ContentProcessor {
db_client: Arc<SurrealDbClient>, db_client: Arc<SurrealDbClient>,
openai_client: async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
} }
impl ContentProcessor { impl ContentProcessor {
pub async fn new(surreal_db_client: Arc<SurrealDbClient>) -> Result<Self, AppError> { pub async fn new(
surreal_db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<Self, AppError> {
Ok(Self { Ok(Self {
db_client: surreal_db_client, db_client: surreal_db_client,
openai_client: async_openai::Client::new(), openai_client,
}) })
} }

View File

@@ -1,11 +1,8 @@
use chrono::Utc; use chrono::Utc;
use futures::Stream; use futures::Stream;
use std::{ use std::sync::Arc;
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use surrealdb::{opt::PatchOp, Error, Notification}; use surrealdb::{opt::PatchOp, Error, Notification};
use tracing::{error, info}; use tracing::{debug, error, info};
use crate::{ use crate::{
error::AppError, error::AppError,
@@ -22,21 +19,28 @@ use super::{content_processor::ContentProcessor, types::ingress_object::IngressO
pub struct JobQueue { pub struct JobQueue {
pub db: Arc<SurrealDbClient>, pub db: Arc<SurrealDbClient>,
pub openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
} }
pub const MAX_ATTEMPTS: u32 = 3; pub const MAX_ATTEMPTS: u32 = 3;
impl JobQueue { impl JobQueue {
pub fn new(db: Arc<SurrealDbClient>) -> Self { pub fn new(
Self { db } db: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Self {
Self { db, openai_client }
} }
/// Creates a new job and stores it in the database /// Creates a new job and stores it in the database
pub async fn enqueue(&self, content: IngressObject, user_id: String) -> Result<Job, AppError> { pub async fn enqueue(&self, content: IngressObject, user_id: String) -> Result<(), AppError> {
let job = Job::new(content, user_id).await; let job = Job::new(content, user_id).await;
info!("{:?}", job); info!("{:?}", job);
store_item(&self.db, job.clone()).await?;
Ok(job) store_item(&self.db, job).await?;
Ok(())
} }
/// Gets all jobs for a specific user /// Gets all jobs for a specific user
@@ -44,11 +48,12 @@ impl JobQueue {
let jobs: Vec<Job> = self let jobs: Vec<Job> = self
.db .db
.query("SELECT * FROM job WHERE user_id = $user_id ORDER BY created_at DESC") .query("SELECT * FROM job WHERE user_id = $user_id ORDER BY created_at DESC")
.bind(("user_id", user_id.to_string())) .bind(("user_id", user_id.to_owned()))
.await? .await?
.take(0)?; .take(0)?;
info!("{:?}", jobs); debug!("{:?}", jobs);
Ok(jobs) Ok(jobs)
} }
@@ -69,12 +74,8 @@ impl JobQueue {
Ok(()) Ok(())
} }
pub async fn update_status( pub async fn update_status(&self, id: &str, status: JobStatus) -> Result<(), AppError> {
&self, let _job: Option<Job> = self
id: &str,
status: JobStatus,
) -> Result<Option<Job>, AppError> {
let job: Option<Job> = self
.db .db
.update((Job::table_name(), id)) .update((Job::table_name(), id))
.patch(PatchOp::replace("/status", status)) .patch(PatchOp::replace("/status", status))
@@ -84,7 +85,7 @@ impl JobQueue {
)) ))
.await?; .await?;
Ok(job) Ok(())
} }
/// Listen for new jobs /// Listen for new jobs
@@ -137,7 +138,7 @@ impl JobQueue {
) )
.await?; .await?;
let text_content = job.content.to_text_content().await?; let text_content = job.content.to_text_content(&self.openai_client).await?;
match processor.process(&text_content).await { match processor.process(&text_content).await {
Ok(_) => { Ok(_) => {

View File

@@ -1,3 +1,5 @@
use std::{sync::Arc, time::Duration};
use crate::{ use crate::{
error::AppError, error::AppError,
storage::types::{file_info::FileInfo, text_content::TextContent}, storage::types::{file_info::FileInfo, text_content::TextContent},
@@ -9,7 +11,8 @@ use async_openai::types::{
use reqwest; use reqwest;
use scraper::{Html, Selector}; use scraper::{Html, Selector};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tiktoken_rs::o200k_base; use std::fmt::Write;
use tiktoken_rs::{o200k_base, CoreBPE};
use tracing::info; use tracing::info;
/// Knowledge object type, containing the content or reference to it, as well as metadata /// Knowledge object type, containing the content or reference to it, as well as metadata
@@ -43,7 +46,10 @@ impl IngressObject {
/// ///
/// # Returns /// # Returns
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc. /// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
pub async fn to_text_content(&self) -> Result<TextContent, AppError> { pub async fn to_text_content(
&self,
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<TextContent, AppError> {
match self { match self {
IngressObject::Url { IngressObject::Url {
url, url,
@@ -51,7 +57,7 @@ impl IngressObject {
category, category,
user_id, user_id,
} => { } => {
let text = Self::fetch_text_from_url(url).await?; let text = Self::fetch_text_from_url(url, openai_client).await?;
Ok(TextContent::new( Ok(TextContent::new(
text, text,
instructions.into(), instructions.into(),
@@ -90,69 +96,62 @@ impl IngressObject {
} }
} }
/// Fetches and extracts text from a URL. /// Get text from url, will return it as a markdown formatted string
async fn fetch_text_from_url(url: &str) -> Result<String, AppError> { async fn fetch_text_from_url(
let response = reqwest::get(url).await?.text().await?; url: &str,
let document = Html::parse_document(&response); openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<String, AppError> {
// Use a client with timeouts and reuse
let client = reqwest::ClientBuilder::new()
.timeout(Duration::from_secs(30))
.build()?;
let response = client.get(url).send().await?.text().await?;
// Select main content areas first // Preallocate string with capacity
let main_selectors = Selector::parse(concat!( let mut structured_content = String::with_capacity(response.len() / 2);
"article, main, .article-content,", // Common main content classes
".post-content, .entry-content,", // Common blog/article classes let document = Html::parse_document(&response);
"[role='main']" // Accessibility marker let main_selectors = Selector::parse(
)) "article, main, .article-content, .post-content, .entry-content, [role='main']",
)
.unwrap(); .unwrap();
// If no main content found, fallback to body
let content_element = document let content_element = document
.select(&main_selectors) .select(&main_selectors)
.next() .next()
.or_else(|| document.select(&Selector::parse("body").unwrap()).next()) .or_else(|| document.select(&Selector::parse("body").unwrap()).next())
.ok_or(AppError::NotFound("No content found".into()))?; .ok_or(AppError::NotFound("No content found".into()))?;
// Remove unwanted elements but preserve structure // Compile selectors once
// let exclude_selector = Selector::parse(concat!( let heading_selector = Selector::parse("h1, h2, h3").unwrap();
// "script, style, noscript,", let paragraph_selector = Selector::parse("p").unwrap();
// "[class*='window'], [id*='window'],",
// "[class*='env'], [id*='env'],",
// "iframe, nav, footer, .comments,",
// ".advertisement, .social-share"
// ))
// .unwrap();
// Collect structured content // Process content in one pass
let mut structured_content = String::new(); for element in content_element.select(&heading_selector) {
let _ = writeln!(
// Process headings structured_content,
for heading in content_element.select(&Selector::parse("h1, h2, h3").unwrap()) { "<heading>{}</heading>",
structured_content.push_str(&format!( element.text().collect::<String>().trim()
"<heading>{}</heading>\n", );
heading.text().collect::<String>().trim() }
)); for element in content_element.select(&paragraph_selector) {
let _ = writeln!(
structured_content,
"<paragraph>{}</paragraph>",
element.text().collect::<String>().trim()
);
} }
// Process paragraphs
for paragraph in content_element.select(&Selector::parse("p").unwrap()) {
structured_content.push_str(&format!(
"<paragraph>{}</paragraph>\n",
paragraph.text().collect::<String>().trim()
));
}
// Clean up
let content = structured_content let content = structured_content
.replace(|c: char| c.is_control(), " ") .replace(|c: char| c.is_control(), " ")
.replace(" ", " "); .replace(" ", " ");
Self::process_web_content(content, openai_client).await
let processed_content = Self::process_web_content(content.trim().to_string()).await?;
info!("Extracted content from page: {:?}", processed_content);
Ok(processed_content)
} }
pub async fn process_web_content(content: String) -> Result<String, AppError> { pub async fn process_web_content(
let openai_client = async_openai::Client::new(); content: String,
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<String, AppError> {
const MAX_TOKENS: usize = 122000; const MAX_TOKENS: usize = 122000;
const SYSTEM_PROMPT: &str = r#" const SYSTEM_PROMPT: &str = r#"
You are a precise content extractor for web pages. Your task: You are a precise content extractor for web pages. Your task:
@@ -182,25 +181,10 @@ impl IngressObject {
"#; "#;
let bpe = o200k_base()?; let bpe = o200k_base()?;
let token_count = bpe.encode_with_special_tokens(&content).len();
let content = if token_count > MAX_TOKENS { // Process content in chunks if needed
// Split content into structural blocks let truncated_content = if bpe.encode_with_special_tokens(&content).len() > MAX_TOKENS {
let blocks: Vec<&str> = content.split('\n').collect(); Self::truncate_content(&content, MAX_TOKENS, &bpe)?
let mut truncated = String::new();
let mut current_tokens = 0;
// Keep adding blocks until we approach the limit
for block in blocks {
let block_tokens = bpe.encode_with_special_tokens(block).len();
if current_tokens + block_tokens > MAX_TOKENS {
break;
}
truncated.push_str(block);
truncated.push('\n');
current_tokens += block_tokens;
}
truncated
} else { } else {
content content
}; };
@@ -211,7 +195,7 @@ impl IngressObject {
.max_tokens(16200u32) .max_tokens(16200u32)
.messages([ .messages([
ChatCompletionRequestSystemMessage::from(SYSTEM_PROMPT).into(), ChatCompletionRequestSystemMessage::from(SYSTEM_PROMPT).into(),
ChatCompletionRequestUserMessage::from(content).into(), ChatCompletionRequestUserMessage::from(truncated_content).into(),
]) ])
.build()?; .build()?;
@@ -221,10 +205,41 @@ impl IngressObject {
.choices .choices
.first() .first()
.and_then(|choice| choice.message.content.as_ref()) .and_then(|choice| choice.message.content.as_ref())
.map(|content| content.to_string()) .map(|content| content.to_owned())
.ok_or(AppError::LLMParsing("No content in response".into())) .ok_or(AppError::LLMParsing("No content in response".into()))
} }
fn truncate_content(
content: &str,
max_tokens: usize,
tokenizer: &CoreBPE,
) -> Result<String, AppError> {
// Pre-allocate with estimated size
let mut result = String::with_capacity(content.len() / 2);
let mut current_tokens = 0;
// Process content by paragraph to maintain context
for paragraph in content.split("\n\n") {
let tokens = tokenizer.encode_with_special_tokens(paragraph).len();
// Check if adding paragraph exceeds limit
if current_tokens + tokens > max_tokens {
break;
}
result.push_str(paragraph);
result.push_str("\n\n");
current_tokens += tokens;
}
// Ensure we return valid content
if result.is_empty() {
return Err(AppError::Processing("Content exceeds token limit".into()));
}
Ok(result.trim_end().to_string())
}
/// Extracts text from a file based on its MIME type. /// Extracts text from a file based on its MIME type.
async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, AppError> { async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, AppError> {
match file_info.mime_type.as_str() { match file_info.mime_type.as_str() {