use std::{pin::Pin, sync::Arc, time::Duration}; use axum::{ extract::{Query, State}, http::StatusCode, response::{ sse::{Event, KeepAlive, KeepAliveStream}, Sse, }, }; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use futures::{future::try_join_all, stream, Stream, StreamExt, TryFutureExt}; use minijinja::context; use serde::{Deserialize, Serialize}; use tempfile::NamedTempFile; use tokio::time::sleep; use tracing::{error, info}; use common::{ error::AppError, storage::types::{ file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::{IngestionTask, TaskState}, user::User, }, utils::ingest_limits::{validate_ingest_input, IngestValidationError}, }; use crate::{ html_state::HtmlState, middlewares::{ auth_middleware::RequireUser, response_middleware::{TemplateResponse, TemplateResult}, }, }; type EventStream = Pin> + Send>>; type TaskSse = Sse>; fn sse_with_keep_alive(stream: EventStream) -> TaskSse { Sse::new(stream).keep_alive( KeepAlive::new() .interval(Duration::from_secs(15)) .text("keep-alive-ping"), ) } pub async fn show_ingest_form( State(state): State, RequireUser(user): RequireUser, ) -> TemplateResult { #[derive(Serialize)] pub struct ShowIngestFormData { user_categories: Vec, } let user_categories = User::get_user_categories(&user.id, &state.db).await?; Ok(TemplateResponse::new_template( "ingestion_modal.html", ShowIngestFormData { user_categories }, )) } pub async fn hide_ingest_form( RequireUser(_user): RequireUser, ) -> TemplateResult { Ok(TemplateResponse::new_template( "ingestion/add_content_button.html", (), )) } #[derive(Serialize)] struct NewTasksData { tasks: Vec, } #[derive(Debug, TryFromMultipart)] pub struct IngestionParams { pub content: Option, pub context: String, pub category: String, #[form_data(limit = "20000000")] #[form_data(default)] pub files: Vec>, } pub async fn process_ingest_form( State(state): State, RequireUser(user): RequireUser, TypedMultipart(input): TypedMultipart, ) -> TemplateResult { if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() { return Ok(TemplateResponse::bad_request( "You need to either add files or content", )); } let content_bytes = input.content.as_ref().map_or(0, String::len); let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty()); let ctx_len = input.context.len(); let category_bytes = input.category.len(); let file_count = input.files.len(); match validate_ingest_input( &state.config, input.content.as_deref(), &input.context, &input.category, file_count, ) { Ok(()) => {} Err(IngestValidationError::PayloadTooLarge(message)) => { return Ok(TemplateResponse::error( StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large", &message, )); } Err(IngestValidationError::BadRequest(message)) => { return Ok(TemplateResponse::bad_request(&message)); } } info!( user_id = %user.id, has_content, content_bytes, ctx_len, category_bytes, file_count, "Received ingest form submission" ); let file_infos = try_join_all(input.files.into_iter().map(|file| { FileInfo::new_with_storage(file, &state.db, &user.id, &state.storage) .map_err(AppError::from) })) .await?; let payloads = IngestionPayload::create_ingestion_payload( input.content, input.context, input.category, file_infos, user.id.clone(), )?; let tasks = IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?; Ok(TemplateResponse::new_template( "dashboard/current_task.html", NewTasksData { tasks }, )) } #[derive(Deserialize)] pub struct QueryParams { task_id: String, } fn create_error_stream(message: impl Into) -> EventStream { let message = message.into(); stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed() } pub async fn get_task_updates_stream( State(state): State, RequireUser(current_user): RequireUser, Query(params): Query, ) -> TaskSse { let task_id = params.task_id.clone(); let db = Arc::clone(&state.db); match db.get_item::(&task_id).await { Ok(Some(task)) => { if task.user_id != current_user.id { return sse_with_keep_alive(create_error_stream( "Access denied: You do not have permission to view updates for this task.", )); } let sse_stream = async_stream::stream! { let mut consecutive_db_errors: u32 = 0; let max_consecutive_db_errors = 3; loop { match db.get_item::(&task_id).await { Ok(Some(updated_task)) => { consecutive_db_errors = 0; // Reset error count on success let status_message = match updated_task.state { TaskState::Pending => "Pending".to_string(), TaskState::Reserved => format!( "Reserved (attempt {} of {})", updated_task.attempts, updated_task.max_attempts ), TaskState::Processing => format!( "Processing (attempt {} of {})", updated_task.attempts, updated_task.max_attempts ), TaskState::Succeeded => "Completed".to_string(), TaskState::Failed => { let mut base = format!( "Retry scheduled (attempt {} of {})", updated_task.attempts, updated_task.max_attempts ); if let Some(message) = updated_task.error_message.as_ref() { base.push_str(": "); base.push_str(message); } base } TaskState::Cancelled => "Cancelled".to_string(), TaskState::DeadLetter => { let mut base = "Failed permanently".to_string(); if let Some(message) = updated_task.error_message.as_ref() { base.push_str(": "); base.push_str(message); } base } }; yield Ok(Event::default().event("status").data(status_message)); // Check for terminal states to close the stream if updated_task.state.is_terminal() { // Send a specific event that HTMX uses to close the connection // Send a event to reload the recent content // Send a event to remove the loading indicatior let check_icon = state.templates.render("icons/check_icon.html", &context!{}).unwrap_or_else(|_| "Ok".to_string()); yield Ok(Event::default().event("stop_loading").data(check_icon)); yield Ok(Event::default().event("update_latest_content").data("Update latest content")); yield Ok(Event::default().event("close_stream").data("Stream complete")); break; // Exit loop on terminal states } }, Ok(None) => { // Task disappeared after initial fetch yield Ok(Event::default().event("error").data("Task not found during update polling.")); break; } Err(db_err) => { error!("Database error while fetching task '{}': {:?}", task_id, db_err); consecutive_db_errors = consecutive_db_errors.saturating_add(1); yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {consecutive_db_errors})."))); if consecutive_db_errors >= max_consecutive_db_errors { error!("Max consecutive DB errors reached for task '{task_id}'. Closing stream."); yield Ok(Event::default().event("error").data("Persistent error fetching task updates. Stream closed.")); yield Ok(Event::default().event("close_stream").data("Stream complete")); break; } } } sleep(Duration::from_secs(2)).await; } }; sse_with_keep_alive(sse_stream.boxed()) } Ok(None) => sse_with_keep_alive(create_error_stream(format!( "Task with ID '{task_id}' not found." ))), Err(e) => { error!("Failed to fetch task '{task_id}' for authorization: {e:?}"); sse_with_keep_alive(create_error_stream( "An error occurred while retrieving task details. Please try again later.", )) } } }