mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-31 03:40:38 +02:00
280 lines
10 KiB
Rust
280 lines
10 KiB
Rust
use std::{pin::Pin, sync::Arc, time::Duration};
|
|
|
|
use axum::{
|
|
extract::{Query, State},
|
|
http::StatusCode,
|
|
response::{
|
|
sse::{Event, KeepAlive, KeepAliveStream},
|
|
Sse,
|
|
},
|
|
};
|
|
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
|
use futures::{future::try_join_all, stream, Stream, StreamExt, TryFutureExt};
|
|
use minijinja::context;
|
|
use serde::{Deserialize, Serialize};
|
|
use tempfile::NamedTempFile;
|
|
use tokio::time::sleep;
|
|
use tracing::{error, info};
|
|
|
|
use common::{
|
|
error::AppError,
|
|
storage::types::{
|
|
file_info::FileInfo,
|
|
ingestion_payload::IngestionPayload,
|
|
ingestion_task::{IngestionTask, TaskState},
|
|
user::User,
|
|
},
|
|
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
|
|
};
|
|
|
|
use crate::{
|
|
html_state::HtmlState,
|
|
middlewares::{
|
|
auth_middleware::RequireUser,
|
|
response_middleware::{TemplateResponse, TemplateResult},
|
|
},
|
|
};
|
|
|
|
type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
|
|
type TaskSse = Sse<KeepAliveStream<EventStream>>;
|
|
|
|
fn sse_with_keep_alive(stream: EventStream) -> TaskSse {
|
|
Sse::new(stream).keep_alive(
|
|
KeepAlive::new()
|
|
.interval(Duration::from_secs(15))
|
|
.text("keep-alive-ping"),
|
|
)
|
|
}
|
|
|
|
pub async fn show_ingest_form(
|
|
State(state): State<HtmlState>,
|
|
RequireUser(user): RequireUser,
|
|
) -> TemplateResult {
|
|
#[derive(Serialize)]
|
|
pub struct ShowIngestFormData {
|
|
user_categories: Vec<String>,
|
|
}
|
|
|
|
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
|
|
|
|
Ok(TemplateResponse::new_template(
|
|
"ingestion_modal.html",
|
|
ShowIngestFormData { user_categories },
|
|
))
|
|
}
|
|
|
|
pub async fn hide_ingest_form(
|
|
RequireUser(_user): RequireUser,
|
|
) -> TemplateResult {
|
|
Ok(TemplateResponse::new_template(
|
|
"ingestion/add_content_button.html",
|
|
(),
|
|
))
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct NewTasksData {
|
|
tasks: Vec<IngestionTask>,
|
|
}
|
|
|
|
#[derive(Debug, TryFromMultipart)]
|
|
pub struct IngestionParams {
|
|
pub content: Option<String>,
|
|
pub context: String,
|
|
pub category: String,
|
|
#[form_data(limit = "20000000")]
|
|
#[form_data(default)]
|
|
pub files: Vec<FieldData<NamedTempFile>>,
|
|
}
|
|
|
|
pub async fn process_ingest_form(
|
|
State(state): State<HtmlState>,
|
|
RequireUser(user): RequireUser,
|
|
TypedMultipart(input): TypedMultipart<IngestionParams>,
|
|
) -> TemplateResult {
|
|
if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() {
|
|
return Ok(TemplateResponse::bad_request(
|
|
"You need to either add files or content",
|
|
));
|
|
}
|
|
|
|
let content_bytes = input.content.as_ref().map_or(0, String::len);
|
|
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
|
|
let ctx_len = input.context.len();
|
|
let category_bytes = input.category.len();
|
|
let file_count = input.files.len();
|
|
|
|
match validate_ingest_input(
|
|
&state.config,
|
|
input.content.as_deref(),
|
|
&input.context,
|
|
&input.category,
|
|
file_count,
|
|
) {
|
|
Ok(()) => {}
|
|
Err(IngestValidationError::PayloadTooLarge(message)) => {
|
|
return Ok(TemplateResponse::error(
|
|
StatusCode::PAYLOAD_TOO_LARGE,
|
|
"Payload Too Large",
|
|
&message,
|
|
));
|
|
}
|
|
Err(IngestValidationError::BadRequest(message)) => {
|
|
return Ok(TemplateResponse::bad_request(&message));
|
|
}
|
|
}
|
|
|
|
info!(
|
|
user_id = %user.id,
|
|
has_content,
|
|
content_bytes,
|
|
ctx_len,
|
|
category_bytes,
|
|
file_count,
|
|
"Received ingest form submission"
|
|
);
|
|
|
|
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
|
FileInfo::new_with_storage(file, &state.db, &user.id, &state.storage)
|
|
.map_err(AppError::from)
|
|
}))
|
|
.await?;
|
|
|
|
let payloads = IngestionPayload::create_ingestion_payload(
|
|
input.content,
|
|
input.context,
|
|
input.category,
|
|
file_infos,
|
|
user.id.clone(),
|
|
)?;
|
|
|
|
let tasks =
|
|
IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
|
|
|
|
Ok(TemplateResponse::new_template(
|
|
"dashboard/current_task.html",
|
|
NewTasksData { tasks },
|
|
))
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct QueryParams {
|
|
task_id: String,
|
|
}
|
|
|
|
fn create_error_stream(message: impl Into<String>) -> EventStream {
|
|
let message = message.into();
|
|
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
|
|
}
|
|
|
|
pub async fn get_task_updates_stream(
|
|
State(state): State<HtmlState>,
|
|
RequireUser(current_user): RequireUser,
|
|
Query(params): Query<QueryParams>,
|
|
) -> TaskSse {
|
|
let task_id = params.task_id.clone();
|
|
let db = Arc::clone(&state.db);
|
|
|
|
match db.get_item::<IngestionTask>(&task_id).await {
|
|
Ok(Some(task)) => {
|
|
if task.user_id != current_user.id {
|
|
return sse_with_keep_alive(create_error_stream(
|
|
"Access denied: You do not have permission to view updates for this task.",
|
|
));
|
|
}
|
|
|
|
let sse_stream = async_stream::stream! {
|
|
let mut consecutive_db_errors: u32 = 0;
|
|
let max_consecutive_db_errors = 3;
|
|
|
|
loop {
|
|
match db.get_item::<IngestionTask>(&task_id).await {
|
|
Ok(Some(updated_task)) => {
|
|
consecutive_db_errors = 0; // Reset error count on success
|
|
|
|
let status_message = match updated_task.state {
|
|
TaskState::Pending => "Pending".to_string(),
|
|
TaskState::Reserved => format!(
|
|
"Reserved (attempt {} of {})",
|
|
updated_task.attempts,
|
|
updated_task.max_attempts
|
|
),
|
|
TaskState::Processing => format!(
|
|
"Processing (attempt {} of {})",
|
|
updated_task.attempts,
|
|
updated_task.max_attempts
|
|
),
|
|
TaskState::Succeeded => "Completed".to_string(),
|
|
TaskState::Failed => {
|
|
let mut base = format!(
|
|
"Retry scheduled (attempt {} of {})",
|
|
updated_task.attempts,
|
|
updated_task.max_attempts
|
|
);
|
|
if let Some(message) = updated_task.error_message.as_ref() {
|
|
base.push_str(": ");
|
|
base.push_str(message);
|
|
}
|
|
base
|
|
}
|
|
TaskState::Cancelled => "Cancelled".to_string(),
|
|
TaskState::DeadLetter => {
|
|
let mut base = "Failed permanently".to_string();
|
|
if let Some(message) = updated_task.error_message.as_ref() {
|
|
base.push_str(": ");
|
|
base.push_str(message);
|
|
}
|
|
base
|
|
}
|
|
};
|
|
|
|
yield Ok(Event::default().event("status").data(status_message));
|
|
|
|
// Check for terminal states to close the stream
|
|
if updated_task.state.is_terminal() {
|
|
// Send a specific event that HTMX uses to close the connection
|
|
// Send a event to reload the recent content
|
|
// Send a event to remove the loading indicatior
|
|
let check_icon = state.templates.render("icons/check_icon.html", &context!{}).unwrap_or_else(|_| "Ok".to_string());
|
|
yield Ok(Event::default().event("stop_loading").data(check_icon));
|
|
yield Ok(Event::default().event("update_latest_content").data("Update latest content"));
|
|
yield Ok(Event::default().event("close_stream").data("Stream complete"));
|
|
break; // Exit loop on terminal states
|
|
}
|
|
},
|
|
Ok(None) => {
|
|
// Task disappeared after initial fetch
|
|
yield Ok(Event::default().event("error").data("Task not found during update polling."));
|
|
break;
|
|
}
|
|
Err(db_err) => {
|
|
error!("Database error while fetching task '{}': {:?}", task_id, db_err);
|
|
consecutive_db_errors = consecutive_db_errors.saturating_add(1);
|
|
yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {consecutive_db_errors}).")));
|
|
|
|
if consecutive_db_errors >= max_consecutive_db_errors {
|
|
error!("Max consecutive DB errors reached for task '{task_id}'. Closing stream.");
|
|
yield Ok(Event::default().event("error").data("Persistent error fetching task updates. Stream closed."));
|
|
yield Ok(Event::default().event("close_stream").data("Stream complete"));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
sleep(Duration::from_secs(2)).await;
|
|
}
|
|
};
|
|
|
|
sse_with_keep_alive(sse_stream.boxed())
|
|
}
|
|
Ok(None) => sse_with_keep_alive(create_error_stream(format!(
|
|
"Task with ID '{task_id}' not found."
|
|
))),
|
|
Err(e) => {
|
|
error!("Failed to fetch task '{task_id}' for authorization: {e:?}");
|
|
sse_with_keep_alive(create_error_stream(
|
|
"An error occurred while retrieving task details. Please try again later.",
|
|
))
|
|
}
|
|
}
|
|
}
|