mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-30 18:11:34 +02:00
chore: improve html-router auth, caching, and analytics while centralizing search labels in common.
small fix
This commit is contained in:
@@ -10,7 +10,7 @@ use crate::{
|
||||
use common::storage::types::user::User;
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct SignupParams {
|
||||
pub struct SignInParams {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
pub remember_me: Option<String>,
|
||||
@@ -20,7 +20,7 @@ pub async fn show_signin_form(
|
||||
auth: AuthSessionType,
|
||||
HxBoosted(boosted): HxBoosted,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
if auth.is_authenticated() {
|
||||
if auth.current_user.is_some() {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
if boosted {
|
||||
@@ -37,13 +37,10 @@ pub async fn show_signin_form(
|
||||
pub async fn authenticate_user(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
Form(form): Form<SignupParams>,
|
||||
Form(form): Form<SignInParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match User::authenticate(&form.email, &form.password, &state.db).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => {
|
||||
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
|
||||
}
|
||||
let Ok(user) = User::authenticate(&form.email, &form.password, &state.db).await else {
|
||||
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
|
||||
};
|
||||
|
||||
auth.login_user(user.id);
|
||||
|
||||
@@ -2,7 +2,7 @@ use axum::{extract::State, response::IntoResponse, Form};
|
||||
use axum_htmx::HxBoosted;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use common::storage::types::user::{Theme, User};
|
||||
use common::{error::AppError, storage::types::user::{Theme, User}};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
@@ -17,11 +17,18 @@ pub struct Params {
|
||||
pub timezone: String,
|
||||
}
|
||||
|
||||
fn signup_error_message(err: &AppError) -> &str {
|
||||
match err {
|
||||
AppError::Auth(message) if message == "Registration is not allowed" => message,
|
||||
_ => "Could not create account. Please try again.",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn show_signup_form(
|
||||
auth: AuthSessionType,
|
||||
HxBoosted(boosted): HxBoosted,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
if auth.is_authenticated() {
|
||||
if auth.current_user.is_some() {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
@@ -51,9 +58,9 @@ pub async fn process_signup_and_show_verification(
|
||||
.await
|
||||
{
|
||||
Ok(user) => user,
|
||||
Err(e) => {
|
||||
tracing::error!("{:?}", e);
|
||||
return Ok(TemplateResponse::bad_request(&e.to_string()).into_response());
|
||||
Err(err) => {
|
||||
tracing::error!(?err, "signup failed");
|
||||
return Ok(TemplateResponse::bad_request(signup_error_message(&err)).into_response());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::HeaderValue,
|
||||
response::{IntoResponse, Redirect},
|
||||
response::IntoResponse,
|
||||
Form,
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
conversation::Conversation,
|
||||
message::{Message, MessageRole},
|
||||
user::User,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -26,75 +22,12 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatStartParams {
|
||||
user_query: String,
|
||||
llm_response: String,
|
||||
#[serde(deserialize_with = "deserialize_references")]
|
||||
references: Vec<String>,
|
||||
}
|
||||
|
||||
// Custom deserializer function
|
||||
fn deserialize_references<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
serde_json::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ChatPageData {
|
||||
history: Vec<Message>,
|
||||
conversation: Option<Conversation>,
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn show_initialized_chat(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<ChatStartParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
|
||||
|
||||
let user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
form.user_query,
|
||||
None,
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::AI,
|
||||
form.llm_response,
|
||||
Some(form.references),
|
||||
);
|
||||
|
||||
state.db.store_item(conversation.clone()).await?;
|
||||
state.db.store_item(ai_message.clone()).await?;
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
let messages = vec![user_message, ai_message];
|
||||
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/base.html",
|
||||
ChatPageData {
|
||||
history: messages,
|
||||
conversation: Some(conversation.clone()),
|
||||
},
|
||||
)
|
||||
.into_response();
|
||||
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn show_chat_base(
|
||||
State(_state): State<HtmlState>,
|
||||
RequireUser(_user): RequireUser,
|
||||
@@ -131,8 +64,6 @@ pub async fn show_existing_chat(
|
||||
))
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn new_user_message(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<HtmlState>,
|
||||
@@ -171,11 +102,9 @@ pub async fn new_user_message(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn new_chat_user_message(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
#[derive(Serialize)]
|
||||
@@ -184,10 +113,6 @@ pub async fn new_chat_user_message(
|
||||
conversation: Conversation,
|
||||
}
|
||||
|
||||
let Some(user) = auth.current_user else {
|
||||
return Ok(Redirect::to("/").into_response());
|
||||
};
|
||||
|
||||
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
|
||||
let user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
@@ -213,7 +138,7 @@ pub async fn new_chat_user_message(
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response.into_response())
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
||||
@@ -25,7 +25,7 @@ use retrieval_pipeline::{
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use std::sync::Mutex;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc::channel;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
@@ -39,7 +39,10 @@ use common::storage::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, AuthSessionType};
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::auth_middleware::RequireUser,
|
||||
};
|
||||
|
||||
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
|
||||
|
||||
@@ -61,15 +64,9 @@ fn create_error_stream(message: impl Into<String>) -> EventStream {
|
||||
|
||||
async fn get_message_and_user(
|
||||
db: &SurrealDbClient,
|
||||
current_user: Option<User>,
|
||||
user: User,
|
||||
message_id: &str,
|
||||
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
|
||||
let Some(user) = current_user else {
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)));
|
||||
};
|
||||
|
||||
let message = match db.get_item::<Message>(message_id).await {
|
||||
Ok(Some(message)) => message,
|
||||
Ok(None) => {
|
||||
@@ -136,7 +133,7 @@ fn find_message_index(messages: &[Message], message_id: &str) -> Option<usize> {
|
||||
fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option<Message> {
|
||||
messages
|
||||
.iter()
|
||||
.skip(user_message_index + 1)
|
||||
.skip(user_message_index.saturating_add(1))
|
||||
.take_while(|message| message.role != MessageRole::User)
|
||||
.find(|message| message.role == MessageRole::AI)
|
||||
.cloned()
|
||||
@@ -202,11 +199,11 @@ fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn get_response_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
RequireUser(user): RequireUser,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> SseResponse {
|
||||
let (user_message, user, _conversation, history, existing_ai_response) =
|
||||
match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await {
|
||||
match get_message_and_user(&state.db, user, ¶ms.message_id).await {
|
||||
Ok((user_message, user, conversation, history, existing_ai_response)) => (
|
||||
user_message,
|
||||
user,
|
||||
@@ -289,7 +286,7 @@ fn build_chat_event_stream(
|
||||
let _ = tx_storage.send(content.clone()).await;
|
||||
|
||||
let display_content = {
|
||||
let mut state = json_state.lock().expect("json parser mutex poisoned");
|
||||
let mut state = json_state.lock().await;
|
||||
state.process_chunk(&content)
|
||||
};
|
||||
if !display_content.is_empty() {
|
||||
@@ -540,6 +537,8 @@ impl StreamParserState {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::missing_docs_in_private_items)]
|
||||
|
||||
use super::*;
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
use common::storage::{
|
||||
@@ -707,7 +706,7 @@ mod tests {
|
||||
.expect("failed to store second user message");
|
||||
|
||||
let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) =
|
||||
get_message_and_user(&db, Some(user.clone()), &user_message.id)
|
||||
get_message_and_user(&db, user.clone(), &user_message.id)
|
||||
.await
|
||||
.expect("expected first turn to load");
|
||||
|
||||
@@ -717,7 +716,7 @@ mod tests {
|
||||
assert_eq!(existing_ai_for_first_turn.id, ai_message.id);
|
||||
|
||||
let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) =
|
||||
get_message_and_user(&db, Some(user), &second_user_message.id)
|
||||
get_message_and_user(&db, user, &second_user_message.id)
|
||||
.await
|
||||
.expect("expected second turn to load");
|
||||
|
||||
|
||||
@@ -5,14 +5,13 @@ mod references;
|
||||
|
||||
use axum::{
|
||||
extract::FromRef,
|
||||
routing::{get, post},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
pub use chat_handlers::{
|
||||
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
|
||||
reload_sidebar, show_conversation_editing_title,
|
||||
show_chat_base as show_base, show_existing_chat as show_existing,
|
||||
show_initialized_chat as show_initialized,
|
||||
};
|
||||
use message_response_stream::get_response_stream;
|
||||
use references::show_reference_tooltip;
|
||||
@@ -37,7 +36,6 @@ where
|
||||
get(show_conversation_editing_title).patch(patch_conversation_title),
|
||||
)
|
||||
.route("/chat/sidebar", get(reload_sidebar))
|
||||
.route("/initialized-chat", post(show_initialized))
|
||||
.route("/chat/response-stream", get(get_response_stream))
|
||||
.route("/chat/reference/{id}", get(show_reference_tooltip))
|
||||
}
|
||||
|
||||
@@ -51,21 +51,21 @@ impl ReferenceReasonStats {
|
||||
match reason {
|
||||
InvalidReferenceReason::Empty => self.empty = self.empty.saturating_add(1),
|
||||
InvalidReferenceReason::UnsupportedPrefix => {
|
||||
self.unsupported_prefix = self.unsupported_prefix.saturating_add(1)
|
||||
self.unsupported_prefix = self.unsupported_prefix.saturating_add(1);
|
||||
}
|
||||
InvalidReferenceReason::MalformedUuid => {
|
||||
self.malformed_uuid = self.malformed_uuid.saturating_add(1)
|
||||
self.malformed_uuid = self.malformed_uuid.saturating_add(1);
|
||||
}
|
||||
InvalidReferenceReason::Duplicate => self.duplicate = self.duplicate.saturating_add(1),
|
||||
InvalidReferenceReason::NotInContext => {
|
||||
self.not_in_context = self.not_in_context.saturating_add(1)
|
||||
self.not_in_context = self.not_in_context.saturating_add(1);
|
||||
}
|
||||
InvalidReferenceReason::NotFound => self.not_found = self.not_found.saturating_add(1),
|
||||
InvalidReferenceReason::WrongUser => {
|
||||
self.wrong_user = self.wrong_user.saturating_add(1)
|
||||
self.wrong_user = self.wrong_user.saturating_add(1);
|
||||
}
|
||||
InvalidReferenceReason::OverLimit => {
|
||||
self.over_limit = self.over_limit.saturating_add(1)
|
||||
self.over_limit = self.over_limit.saturating_add(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::{
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
utils::text_content_preview::truncate_text_contents,
|
||||
AuthSessionType,
|
||||
utils::truncate::with_ellipsis,
|
||||
};
|
||||
use common::storage::types::user::DashboardStats;
|
||||
use common::{
|
||||
@@ -36,13 +36,9 @@ pub struct IndexPageData {
|
||||
|
||||
pub async fn index_handler(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let Some(user) = auth.current_user else {
|
||||
return Ok(TemplateResponse::redirect("/signin"));
|
||||
};
|
||||
|
||||
let (text_contents, stats, active_jobs) = try_join!(
|
||||
let (text_contents, dashboard_stats, active_jobs) = try_join!(
|
||||
User::get_latest_text_contents(&user.id, &state.db),
|
||||
User::get_dashboard_stats(&user.id, &state.db),
|
||||
User::get_unfinished_ingestion_tasks(&user.id, &state.db)
|
||||
@@ -54,7 +50,7 @@ pub async fn index_handler(
|
||||
"dashboard/base.html",
|
||||
IndexPageData {
|
||||
text_contents,
|
||||
stats,
|
||||
stats: dashboard_stats,
|
||||
active_jobs,
|
||||
},
|
||||
))
|
||||
@@ -223,7 +219,7 @@ pub async fn show_task_archive(
|
||||
fn summarize_task_content(task: &IngestionTask) -> (String, String) {
|
||||
match &task.content {
|
||||
common::storage::types::ingestion_payload::IngestionPayload::Text { text, .. } => {
|
||||
("Text".to_string(), truncate_summary(text, 80))
|
||||
("Text".to_string(), with_ellipsis(text, 80))
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
|
||||
("URL".to_string(), url.clone())
|
||||
@@ -234,15 +230,6 @@ fn summarize_task_content(task: &IngestionTask) -> (String, String) {
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_summary(input: &str, max_chars: usize) -> String {
|
||||
if input.chars().count() <= max_chars {
|
||||
input.to_string()
|
||||
} else {
|
||||
let truncated: String = input.chars().take(max_chars).collect();
|
||||
format!("{truncated}…")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_file(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
|
||||
@@ -11,20 +11,13 @@ use handlers::{
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
pub fn public_router<S>() -> Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new().route("/", get(index_handler))
|
||||
}
|
||||
|
||||
pub fn protected_router<S>() -> Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route("/", get(index_handler))
|
||||
.route("/jobs/{job_id}", delete(delete_job))
|
||||
.route("/jobs/archive", get(show_task_archive))
|
||||
.route("/active-jobs", get(show_active_jobs))
|
||||
|
||||
@@ -33,7 +33,6 @@ use crate::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
AuthSessionType,
|
||||
};
|
||||
|
||||
type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
|
||||
@@ -73,6 +72,11 @@ pub async fn hide_ingest_form(
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct NewTasksData {
|
||||
tasks: Vec<IngestionTask>,
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
pub struct IngestionParams {
|
||||
pub content: Option<String>,
|
||||
@@ -95,9 +99,9 @@ pub async fn process_ingest_form(
|
||||
);
|
||||
}
|
||||
|
||||
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
|
||||
let content_bytes = input.content.as_ref().map_or(0, String::len);
|
||||
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
|
||||
let context_bytes = input.context.len();
|
||||
let ctx_len = input.context.len();
|
||||
let category_bytes = input.category.len();
|
||||
let file_count = input.files.len();
|
||||
|
||||
@@ -126,7 +130,7 @@ pub async fn process_ingest_form(
|
||||
user_id = %user.id,
|
||||
has_content,
|
||||
content_bytes,
|
||||
context_bytes,
|
||||
ctx_len,
|
||||
category_bytes,
|
||||
file_count,
|
||||
"Received ingest form submission"
|
||||
@@ -149,11 +153,6 @@ pub async fn process_ingest_form(
|
||||
let tasks =
|
||||
IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct NewTasksData {
|
||||
tasks: Vec<IngestionTask>,
|
||||
}
|
||||
|
||||
Ok(
|
||||
TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks })
|
||||
.into_response(),
|
||||
@@ -172,21 +171,14 @@ fn create_error_stream(message: impl Into<String>) -> EventStream {
|
||||
|
||||
pub async fn get_task_updates_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
RequireUser(current_user): RequireUser,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> TaskSse {
|
||||
let task_id = params.task_id.clone();
|
||||
let db = Arc::clone(&state.db);
|
||||
|
||||
// 1. Check for authenticated user
|
||||
let Some(current_user) = auth.current_user else {
|
||||
return sse_with_keep_alive(create_error_stream("User not authenticated"));
|
||||
};
|
||||
|
||||
// 2. Fetch task for initial authorization and to ensure it exists
|
||||
match db.get_item::<IngestionTask>(&task_id).await {
|
||||
Ok(Some(task)) => {
|
||||
// 3. Validate user ownership
|
||||
if task.user_id != current_user.id {
|
||||
return sse_with_keep_alive(create_error_stream(
|
||||
"Access denied: You do not have permission to view updates for this task.",
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fmt,
|
||||
str::FromStr,
|
||||
};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use common::storage::types::{text_content::TextContent, user::User, StoredObject};
|
||||
use common::utils::serde_helpers::deserialize_flexible_id;
|
||||
use common::storage::types::{text_content::TextContent, user::User};
|
||||
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
|
||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||
use surrealdb::RecordId;
|
||||
use std::{fmt, str::FromStr};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
@@ -20,7 +15,6 @@ use crate::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
utils::truncate::{first_non_empty_line, truncate_with_ellipsis},
|
||||
};
|
||||
|
||||
/// Serde deserialization decorator to map empty Strings to None,
|
||||
@@ -37,86 +31,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn source_id_suffix(source_id: &str) -> String {
|
||||
let start = source_id.len().saturating_sub(8);
|
||||
source_id[start..].to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UrlInfoLabel {
|
||||
#[serde(default)]
|
||||
title: String,
|
||||
#[serde(default)]
|
||||
url: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct FileInfoLabel {
|
||||
#[serde(default)]
|
||||
file_name: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SourceLabelRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
#[serde(default)]
|
||||
url_info: Option<UrlInfoLabel>,
|
||||
#[serde(default)]
|
||||
file_info: Option<FileInfoLabel>,
|
||||
#[serde(default)]
|
||||
context: Option<String>,
|
||||
#[serde(default)]
|
||||
category: String,
|
||||
#[serde(default)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
fn build_source_label(row: &SourceLabelRow) -> String {
|
||||
const MAX_LABEL_CHARS: usize = 80;
|
||||
|
||||
if let Some(url_info) = row.url_info.as_ref() {
|
||||
let title = url_info.title.trim();
|
||||
if !title.is_empty() {
|
||||
return title.to_string();
|
||||
}
|
||||
|
||||
let url = url_info.url.trim();
|
||||
if !url.is_empty() {
|
||||
return url.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(file_info) = row.file_info.as_ref() {
|
||||
let name = file_info.file_name.trim();
|
||||
if !name.is_empty() {
|
||||
return name.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(context) = row.context.as_ref() {
|
||||
let trimmed = context.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return truncate_with_ellipsis(trimmed, MAX_LABEL_CHARS);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(text_label) = first_non_empty_line(&row.text, MAX_LABEL_CHARS) {
|
||||
return text_label;
|
||||
}
|
||||
|
||||
let category = row.category.trim();
|
||||
if !category.is_empty() {
|
||||
return truncate_with_ellipsis(category, MAX_LABEL_CHARS);
|
||||
}
|
||||
|
||||
format!("Text snippet: {}", source_id_suffix(&row.id))
|
||||
}
|
||||
|
||||
fn fallback_source_label(source_id: &str) -> String {
|
||||
format!("Text snippet: {}", source_id_suffix(source_id))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SearchParams {
|
||||
#[serde(default, deserialize_with = "empty_string_as_none")]
|
||||
@@ -218,7 +132,7 @@ async fn perform_search(
|
||||
_ => SearchResult::new(vec![], vec![]),
|
||||
};
|
||||
|
||||
let source_label_map = resolve_source_labels(state, user, &search_result).await?;
|
||||
let source_label_map = collect_source_label_map(state, user, &search_result).await?;
|
||||
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
|
||||
@@ -227,7 +141,7 @@ async fn perform_search(
|
||||
let source_label = source_label_map
|
||||
.get(&chunk_result.chunk.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
|
||||
.unwrap_or_else(|| TextContent::fallback_source_label(&chunk_result.chunk.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "text_chunk".to_string(),
|
||||
score: chunk_result.score,
|
||||
@@ -246,7 +160,9 @@ async fn perform_search(
|
||||
let source_label = source_label_map
|
||||
.get(&entity_result.entity.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
|
||||
.unwrap_or_else(|| {
|
||||
TextContent::fallback_source_label(&entity_result.entity.source_id)
|
||||
});
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "knowledge_entity".to_string(),
|
||||
score: entity_result.score,
|
||||
@@ -269,11 +185,11 @@ async fn perform_search(
|
||||
Ok((combined_results, trimmed_query.to_string()))
|
||||
}
|
||||
|
||||
async fn resolve_source_labels(
|
||||
async fn collect_source_label_map(
|
||||
state: &HtmlState,
|
||||
user: &User,
|
||||
search_result: &SearchResult,
|
||||
) -> Result<HashMap<String, String>, HtmlError> {
|
||||
) -> Result<std::collections::HashMap<String, String>, HtmlError> {
|
||||
let mut source_ids = HashSet::new();
|
||||
for chunk_result in &search_result.chunks {
|
||||
source_ids.insert(chunk_result.chunk.source_id.clone());
|
||||
@@ -282,47 +198,5 @@ async fn resolve_source_labels(
|
||||
source_ids.insert(entity_result.entity.source_id.clone());
|
||||
}
|
||||
|
||||
if source_ids.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let record_ids: Vec<RecordId> = source_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
if id.contains(':') {
|
||||
RecordId::from_str(id).ok()
|
||||
} else {
|
||||
Some(RecordId::from_table_key(TextContent::table_name(), id))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let mut response = state
|
||||
.db
|
||||
.client
|
||||
.query(
|
||||
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
|
||||
)
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.bind(("user_id", user.id.clone()))
|
||||
.bind(("record_ids", record_ids))
|
||||
.await?;
|
||||
let contents: Vec<SourceLabelRow> = response.take(0)?;
|
||||
|
||||
tracing::debug!(
|
||||
source_id_count = source_ids.len(),
|
||||
label_row_count = contents.len(),
|
||||
"Resolved search source labels"
|
||||
);
|
||||
|
||||
let mut labels = HashMap::new();
|
||||
for content in contents {
|
||||
let label = build_source_label(&content);
|
||||
labels.insert(content.id.clone(), label.clone());
|
||||
labels.insert(
|
||||
format!("{}:{}", TextContent::table_name(), content.id),
|
||||
label,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(labels)
|
||||
Ok(TextContent::resolve_source_labels(&state.db, &user.id, source_ids).await?)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user