feat: ingestion task streaming feedback

This commit is contained in:
Per Stark
2025-05-13 21:45:57 +02:00
parent 850878d5c3
commit d504903db3
13 changed files with 271 additions and 88 deletions

View File

@@ -8,8 +8,6 @@ use axum::{
Sse,
},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use composite_retrieval::{
answer_retrieval::{
create_chat_request, create_user_message_with_history, format_entities_json,
@@ -25,7 +23,6 @@ use json_stream_parser::JsonStreamParser;
use minijinja::Value;
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use surrealdb::{engine::any::Any, Surreal};
use tokio::sync::{mpsc::channel, Mutex};
use tracing::{debug, error};
@@ -39,7 +36,7 @@ use common::storage::{
},
};
use crate::html_state::HtmlState;
use crate::{html_state::HtmlState, AuthSessionType};
// Error handling function
fn create_error_stream(
@@ -110,7 +107,8 @@ pub struct QueryParams {
pub async fn get_response_stream(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
auth: AuthSessionType,
// auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
// 1. Authentication and initial data validation

View File

@@ -176,3 +176,24 @@ pub async fn show_content_read_modal(
TextContentReadModalData { user, text_content },
))
}
pub async fn show_recent_content(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let text_contents = User::get_latest_text_contents(&user.id, &state.db).await?;
#[derive(Serialize)]
pub struct RecentTextContentData {
pub user: User,
pub text_contents: Vec<TextContent>,
}
Ok(TemplateResponse::new_template(
"/index/signed_in/recent_content.html",
RecentTextContentData {
user,
text_contents,
},
))
}

View File

@@ -3,7 +3,7 @@ mod handlers;
use axum::{extract::FromRef, routing::get, Router};
use handlers::{
delete_text_content, patch_text_content, show_content_page, show_content_read_modal,
show_text_content_edit_form,
show_recent_content, show_text_content_edit_form,
};
use crate::html_state::HtmlState;
@@ -15,6 +15,7 @@ where
{
Router::new()
.route("/content", get(show_content_page))
.route("/content/recent", get(show_recent_content))
.route("/content/{id}/read", get(show_content_read_modal))
.route(
"/content/{id}",

View File

@@ -1,20 +1,29 @@
use std::{pin::Pin, time::Duration};
use axum::{
extract::State,
response::{Html, IntoResponse},
extract::{Query, State},
response::{
sse::{Event, KeepAlive},
Html, IntoResponse, Sse,
},
};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use futures::{future::try_join_all, TryFutureExt};
use serde::Serialize;
use futures::{future::try_join_all, stream, Stream, StreamExt, TryFutureExt};
use minijinja::{context, Value};
use serde::{Deserialize, Serialize};
use tempfile::NamedTempFile;
use tracing::info;
use tokio::time::sleep;
use tracing::{error, info};
use common::{
error::AppError,
storage::types::{
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
file_info::FileInfo,
ingestion_payload::IngestionPayload,
ingestion_task::{IngestionTask, IngestionTaskStatus},
text_content::TextContent,
user::User,
},
utils::config::AppConfig,
};
use crate::{
@@ -23,7 +32,7 @@ use crate::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
routes::index::handlers::ActiveJobsData,
AuthSessionType,
};
pub async fn show_ingress_form(
@@ -104,22 +113,182 @@ pub async fn process_ingress_form(
let futures: Vec<_> = payloads
.into_iter()
.map(|object| {
IngestionTask::create_and_add_to_db(object.clone(), user.id.clone(), &state.db)
})
.map(|object| IngestionTask::create_and_add_to_db(object, user.id.clone(), &state.db))
.collect();
try_join_all(futures).await?;
let tasks = try_join_all(futures).await?;
// Update the active jobs page with the newly created job
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
// let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
Ok(TemplateResponse::new_partial(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
#[derive(Serialize)]
struct NewTasksData {
user: User,
tasks: Vec<IngestionTask>,
}
Ok(TemplateResponse::new_template(
"index/signed_in/new_task.html",
NewTasksData { user, tasks },
))
}
#[derive(Deserialize)]
pub struct QueryParams {
task_id: String,
}
// pub async fn get_task_updates_stream(
// State(state): State<HtmlState>,
// auth: AuthSessionType,
// Query(params): Query<QueryParams>,
// ) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
// let task_id = params.task_id.clone();
// let db = state.db.clone();
// let stream = async_stream::stream! {
// loop {
// match db.get_item::<IngestionTask>(&task_id).await {
// Ok(Some(_task)) => {
// // For now, just sending a placeholder event
// yield Ok(Event::default().event("status").data("hey"));
// },
// _ => {
// yield Ok(Event::default().event("error").data("Failed to get item"));
// break;
// }
// }
// sleep(Duration::from_secs(5)).await;
// }
// };
// Sse::new(stream.boxed()).keep_alive(
// KeepAlive::new()
// .interval(Duration::from_secs(15))
// .text("keep-alive"),
// )
// }
// Error handling function
fn create_error_stream(
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
}
pub async fn get_task_updates_stream(
State(state): State<HtmlState>,
auth: AuthSessionType,
Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
let task_id = params.task_id.clone();
let db = state.db.clone();
// 1. Check for authenticated user
let current_user = match auth.current_user {
Some(user) => user,
None => {
return Sse::new(create_error_stream(
"User not authenticated. Please log in.",
));
}
};
// 2. Fetch task for initial authorization and to ensure it exists
match db.get_item::<IngestionTask>(&task_id).await {
Ok(Some(task)) => {
// 3. Validate user ownership
if task.user_id != current_user.id {
return Sse::new(create_error_stream(
"Access denied: You do not have permission to view updates for this task.",
));
}
let sse_stream = async_stream::stream! {
let mut consecutive_db_errors = 0;
let max_consecutive_db_errors = 3;
loop {
match db.get_item::<IngestionTask>(&task_id).await {
Ok(Some(updated_task)) => {
consecutive_db_errors = 0; // Reset error count on success
// Format the status message based on IngestionTaskStatus
let status_message = match &updated_task.status {
IngestionTaskStatus::Created => "Created".to_string(),
IngestionTaskStatus::InProgress { attempts, .. } => {
// Following your template's current display
format!("In progress, attempt {}", attempts)
}
IngestionTaskStatus::Completed => "Completed".to_string(),
IngestionTaskStatus::Error(ref err_msg) => {
// Providing a user-friendly error message from the status
format!("Error: {}", err_msg)
}
IngestionTaskStatus::Cancelled => "Cancelled".to_string(),
};
yield Ok(Event::default().event("status").data(status_message));
// Check for terminal states to close the stream
match updated_task.status {
IngestionTaskStatus::Completed |
IngestionTaskStatus::Error(_) |
IngestionTaskStatus::Cancelled => {
// Send a specific event that HTMX uses to close the connection
// Send a event to reload the recent content
// Send a event to remove the loading indicatior
let check_icon = state.templates.render("icons/check_icon.html", &context!{}).unwrap_or("Ok".to_string());
yield Ok(Event::default().event("stop_loading").data(check_icon));
yield Ok(Event::default().event("update_latest_content").data("Update latest content"));
yield Ok(Event::default().event("close_stream").data("Stream complete"));
break; // Exit loop on terminal states
}
_ => {
// Not a terminal state, continue polling
}
}
},
Ok(None) => {
// Task disappeared after initial fetch
yield Ok(Event::default().event("error").data("Task not found during update polling."));
break;
}
Err(db_err) => {
error!("Database error while fetching task '{}': {:?}", task_id, db_err);
consecutive_db_errors += 1;
yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {}).", consecutive_db_errors)));
if consecutive_db_errors >= max_consecutive_db_errors {
error!("Max consecutive DB errors reached for task '{}'. Closing stream.", task_id);
yield Ok(Event::default().event("error").data("Persistent error fetching task updates. Stream closed."));
yield Ok(Event::default().event("close_stream").data("Stream complete"));
break;
}
}
}
sleep(Duration::from_secs(2)).await;
}
};
Sse::new(sse_stream.boxed()).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive-ping"),
)
}
Ok(None) => Sse::new(create_error_stream(format!(
"Task with ID '{}' not found.",
task_id
))),
Err(e) => {
error!(
"Failed to fetch task '{}' for authorization: {:?}",
task_id, e
);
Sse::new(create_error_stream(
"An error occurred while retrieving task details. Please try again later.",
))
}
}
}

View File

@@ -1,7 +1,9 @@
mod handlers;
use axum::{extract::FromRef, routing::get, Router};
use handlers::{hide_ingress_form, process_ingress_form, show_ingress_form};
use handlers::{
get_task_updates_stream, hide_ingress_form, process_ingress_form, show_ingress_form,
};
use crate::html_state::HtmlState;
@@ -15,5 +17,6 @@ where
"/ingress-form",
get(show_ingress_form).post(process_ingress_form),
)
.route("/task/status-stream", get(get_task_updates_stream))
.route("/hide-ingress-form", get(hide_ingress_form))
}