chore: improve html-router auth, caching, and analytics while centralizing search labels in common.

small fix
This commit is contained in:
Per Stark
2026-05-29 14:42:20 +02:00
parent d3443d4153
commit 2aa92b6ad7
27 changed files with 510 additions and 344 deletions
+15 -7
View File
@@ -115,15 +115,17 @@ impl HtmlState {
user_id.to_string(),
ConversationArchiveCacheEntry {
conversations,
expires_at: now + CONVERSATION_ARCHIVE_CACHE_TTL,
expires_at: now
.checked_add(CONVERSATION_ARCHIVE_CACHE_TTL)
.unwrap_or(now),
},
);
let writes = self
.conversation_archive_cache_writes
.fetch_add(1, Ordering::Relaxed)
+ 1;
if writes % CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL == 0 {
.saturating_add(1);
if writes.is_multiple_of(CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL) {
Self::purge_expired_entries(&mut cache, now);
}
@@ -147,7 +149,7 @@ impl HtmlState {
return;
}
let overflow = cache.len() - CONVERSATION_ARCHIVE_CACHE_MAX_USERS;
let overflow = cache.len().saturating_sub(CONVERSATION_ARCHIVE_CACHE_MAX_USERS);
let mut by_expiry: Vec<(String, Instant)> = cache
.iter()
.map(|(user_id, entry)| (user_id.clone(), entry.expires_at))
@@ -178,6 +180,8 @@ impl crate::middlewares::response_middleware::ProvidesHtmlState for HtmlState {
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
use super::*;
use common::{
storage::types::conversation::SidebarConversation,
@@ -202,8 +206,10 @@ mod tests {
.expect("Failed to create session store"),
);
let mut config = AppConfig::default();
config.storage = StorageKind::Memory;
let config = AppConfig {
storage: StorageKind::Memory,
..Default::default()
};
let storage = StorageManager::new(&config)
.await
@@ -239,7 +245,9 @@ mod tests {
id: "conv-1".to_string(),
title: "A stale chat".to_string(),
}]),
expires_at: Instant::now() - Duration::from_secs(1),
expires_at: Instant::now()
.checked_sub(Duration::from_secs(1))
.unwrap_or_else(Instant::now),
},
);
}
+7 -2
View File
@@ -1,3 +1,9 @@
//! SSR + HTMX HTML router for Minne.
//!
//! Handlers return [`middlewares::response_middleware::TemplateResponse`] values;
//! the template middleware renders them with shared layout context. Route composition
//! and middleware layering are handled by [`router_factory::RouterFactory`].
pub mod html_state;
pub mod middlewares;
pub mod router_factory;
@@ -18,14 +24,13 @@ pub type SessionType = Session<SessionSurrealPool<Any>>;
pub type SessionStoreType = SessionStore<SessionSurrealPool<Any>>;
pub type OpenAIClientType = async_openai::Client<async_openai::config::OpenAIConfig>;
/// Html routes
/// Builds the HTML router with public/protected routes, assets, and middleware.
pub fn html_routes<S>(app_state: &HtmlState) -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
RouterFactory::new(app_state)
.add_public_routes(routes::index::public_router())
.add_public_routes(routes::auth::router())
.with_public_assets("/assets", "assets/")
.add_protected_routes(routes::index::protected_router())
@@ -1,3 +1,5 @@
use std::sync::Arc;
use axum::{
extract::{Request, State},
http::Method,
@@ -10,7 +12,7 @@ use common::storage::{db::ProvidesDb, types::analytics::Analytics};
use crate::SessionType;
/// Middleware to count unique visitors and page loads
/// Middleware to count unique visitors and page loads.
pub async fn analytics_middleware<S>(
State(state): State<S>,
session: SessionType,
@@ -21,17 +23,18 @@ where
S: ProvidesDb + Clone + Send + Sync + 'static,
{
let path = request.uri().path();
// Only count visits/page loads for GET requests to non-asset, non-static paths
if request.method() == Method::GET && !path.starts_with("/assets") && !path.contains('.') {
if !session.get::<bool>("counted_visitor").unwrap_or(false) {
if let Err(e) = Analytics::increment_visitors(state.db()).await {
warn!("failed to increment visitor count: {e}");
}
let is_new_visitor = !session.get::<bool>("counted_visitor").unwrap_or(false);
if is_new_visitor {
session.set("counted_visitor", true);
}
if let Err(e) = Analytics::increment_page_loads(state.db()).await {
warn!("failed to increment page load count: {e}");
}
let db = Arc::clone(state.db());
tokio::spawn(async move {
if let Err(error) = Analytics::record_page_view(&db, is_new_visitor).await {
warn!("failed to record page view: {error}");
}
});
}
next.run(request).await
}
@@ -11,6 +11,7 @@ use crate::AuthSessionType;
use super::response_middleware::TemplateResponse;
#[derive(Debug, Clone)]
/// Authenticated user extracted from request extensions by [`require_auth`].
pub struct RequireUser(pub User);
// Implement FromRequestParts for RequireUser
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;
use axum::{
extract::{Request, State},
@@ -36,6 +37,7 @@ pub enum TemplateKind {
}
#[derive(Clone)]
/// Handler response that the template middleware renders into HTML.
pub struct TemplateResponse {
template_kind: TemplateKind,
context: Value,
@@ -180,6 +182,7 @@ fn context_to_map(
}
}
#[allow(clippy::too_many_lines)]
pub async fn with_template_response<S>(
State(state): State<S>,
HxRequest(is_htmx): HxRequest,
@@ -221,14 +224,15 @@ where
if let Some(cached_archive) =
html_state.get_cached_conversation_archive(user_id).await
{
conversation_archive = cached_archive;
conversation_archive = cached_archive.to_vec();
} else if let Ok(archive) =
Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await
{
let cached = Arc::from(archive);
html_state
.set_cached_conversation_archive(user_id, archive.clone())
.set_cached_conversation_archive(user_id, Arc::clone(&cached))
.await;
conversation_archive = archive;
conversation_archive = cached.to_vec();
}
}
}
@@ -245,8 +249,8 @@ where
};
let context = ContextWrapper {
user_theme: &user_theme,
initial_theme: &initial_theme,
user_theme,
initial_theme,
is_authenticated,
user: current_user.as_ref(),
conversation_archive,
@@ -290,13 +294,13 @@ where
.context
.get_attr("title")
.ok()
.and_then(|v| v.as_str().map(|s| s.to_string()))
.and_then(|v| v.as_str().map(str::to_string))
.unwrap_or_else(|| "Error".to_string());
let description = template_response
.context
.get_attr("description")
.ok()
.and_then(|v| v.as_str().map(|s| s.to_string()))
.and_then(|v| v.as_str().map(str::to_string))
.unwrap_or_else(|| "An error occurred.".to_string());
let trigger_payload = json!({"toast": {"title": title, "description": description, "type": "error"}});
+1
View File
@@ -36,6 +36,7 @@ macro_rules! create_asset_service {
pub type MiddleWareVecType<S> = Vec<Box<dyn FnOnce(Router<S>) -> Router<S> + Send>>;
/// Builder for composing public/protected HTML routes and middleware layers.
pub struct RouterFactory<S> {
app_state: HtmlState,
public_routers: Vec<Router<S>>,
+5 -8
View File
@@ -10,7 +10,7 @@ use crate::{
use common::storage::types::user::User;
#[derive(Deserialize, Serialize)]
pub struct SignupParams {
pub struct SignInParams {
pub email: String,
pub password: String,
pub remember_me: Option<String>,
@@ -20,7 +20,7 @@ pub async fn show_signin_form(
auth: AuthSessionType,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() {
if auth.current_user.is_some() {
return Ok(TemplateResponse::redirect("/"));
}
if boosted {
@@ -37,13 +37,10 @@ pub async fn show_signin_form(
pub async fn authenticate_user(
State(state): State<HtmlState>,
auth: AuthSessionType,
Form(form): Form<SignupParams>,
Form(form): Form<SignInParams>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match User::authenticate(&form.email, &form.password, &state.db).await {
Ok(user) => user,
Err(_) => {
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
}
let Ok(user) = User::authenticate(&form.email, &form.password, &state.db).await else {
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
};
auth.login_user(user.id);
+12 -5
View File
@@ -2,7 +2,7 @@ use axum::{extract::State, response::IntoResponse, Form};
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
use common::storage::types::user::{Theme, User};
use common::{error::AppError, storage::types::user::{Theme, User}};
use crate::{
html_state::HtmlState,
@@ -17,11 +17,18 @@ pub struct Params {
pub timezone: String,
}
fn signup_error_message(err: &AppError) -> &str {
match err {
AppError::Auth(message) if message == "Registration is not allowed" => message,
_ => "Could not create account. Please try again.",
}
}
pub async fn show_signup_form(
auth: AuthSessionType,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() {
if auth.current_user.is_some() {
return Ok(TemplateResponse::redirect("/"));
}
@@ -51,9 +58,9 @@ pub async fn process_signup_and_show_verification(
.await
{
Ok(user) => user,
Err(e) => {
tracing::error!("{:?}", e);
return Ok(TemplateResponse::bad_request(&e.to_string()).into_response());
Err(err) => {
tracing::error!(?err, "signup failed");
return Ok(TemplateResponse::bad_request(signup_error_message(&err)).into_response());
}
};
+3 -78
View File
@@ -1,20 +1,16 @@
use axum::{
extract::{Path, State},
http::HeaderValue,
response::{IntoResponse, Redirect},
response::IntoResponse,
Form,
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::{Deserialize, Serialize};
use surrealdb::{engine::any::Any, Surreal};
use common::{
error::AppError,
storage::types::{
conversation::Conversation,
message::{Message, MessageRole},
user::User,
},
};
@@ -26,75 +22,12 @@ use crate::{
},
};
#[derive(Debug, Deserialize)]
pub struct ChatStartParams {
user_query: String,
llm_response: String,
#[serde(deserialize_with = "deserialize_references")]
references: Vec<String>,
}
// Custom deserializer function
fn deserialize_references<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(serde::de::Error::custom)
}
#[derive(Serialize)]
pub struct ChatPageData {
history: Vec<Message>,
conversation: Option<Conversation>,
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn show_initialized_chat(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<ChatStartParams>,
) -> Result<impl IntoResponse, HtmlError> {
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
let user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
form.user_query,
None,
);
let ai_message = Message::new(
conversation.id.clone(),
MessageRole::AI,
form.llm_response,
Some(form.references),
);
state.db.store_item(conversation.clone()).await?;
state.db.store_item(ai_message.clone()).await?;
state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
let messages = vec![user_message, ai_message];
let mut response = TemplateResponse::new_template(
"chat/base.html",
ChatPageData {
history: messages,
conversation: Some(conversation.clone()),
},
)
.into_response();
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response)
}
pub async fn show_chat_base(
State(_state): State<HtmlState>,
RequireUser(_user): RequireUser,
@@ -131,8 +64,6 @@ pub async fn show_existing_chat(
))
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_user_message(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
@@ -171,11 +102,9 @@ pub async fn new_user_message(
Ok(response)
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_chat_user_message(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
@@ -184,10 +113,6 @@ pub async fn new_chat_user_message(
conversation: Conversation,
}
let Some(user) = auth.current_user else {
return Ok(Redirect::to("/").into_response());
};
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
let user_message = Message::new(
conversation.id.clone(),
@@ -213,7 +138,7 @@ pub async fn new_chat_user_message(
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response.into_response())
Ok(response)
}
#[derive(Deserialize)]
@@ -25,7 +25,7 @@ use retrieval_pipeline::{
};
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use std::sync::Mutex;
use tokio::sync::Mutex;
use tokio::sync::mpsc::channel;
use tracing::{debug, error, info};
@@ -39,7 +39,10 @@ use common::storage::{
},
};
use crate::{html_state::HtmlState, AuthSessionType};
use crate::{
html_state::HtmlState,
middlewares::auth_middleware::RequireUser,
};
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
@@ -61,15 +64,9 @@ fn create_error_stream(message: impl Into<String>) -> EventStream {
async fn get_message_and_user(
db: &SurrealDbClient,
current_user: Option<User>,
user: User,
message_id: &str,
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
let Some(user) = current_user else {
return Err(sse_with_keep_alive(create_error_stream(
"You must be signed in to use this feature",
)));
};
let message = match db.get_item::<Message>(message_id).await {
Ok(Some(message)) => message,
Ok(None) => {
@@ -136,7 +133,7 @@ fn find_message_index(messages: &[Message], message_id: &str) -> Option<usize> {
fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option<Message> {
messages
.iter()
.skip(user_message_index + 1)
.skip(user_message_index.saturating_add(1))
.take_while(|message| message.role != MessageRole::User)
.find(|message| message.role == MessageRole::AI)
.cloned()
@@ -202,11 +199,11 @@ fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
#[allow(clippy::too_many_lines)]
pub async fn get_response_stream(
State(state): State<HtmlState>,
auth: AuthSessionType,
RequireUser(user): RequireUser,
Query(params): Query<QueryParams>,
) -> SseResponse {
let (user_message, user, _conversation, history, existing_ai_response) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
match get_message_and_user(&state.db, user, &params.message_id).await {
Ok((user_message, user, conversation, history, existing_ai_response)) => (
user_message,
user,
@@ -289,7 +286,7 @@ fn build_chat_event_stream(
let _ = tx_storage.send(content.clone()).await;
let display_content = {
let mut state = json_state.lock().expect("json parser mutex poisoned");
let mut state = json_state.lock().await;
state.process_chunk(&content)
};
if !display_content.is_empty() {
@@ -540,6 +537,8 @@ impl StreamParserState {
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::missing_docs_in_private_items)]
use super::*;
use chrono::{Duration as ChronoDuration, Utc};
use common::storage::{
@@ -707,7 +706,7 @@ mod tests {
.expect("failed to store second user message");
let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) =
get_message_and_user(&db, Some(user.clone()), &user_message.id)
get_message_and_user(&db, user.clone(), &user_message.id)
.await
.expect("expected first turn to load");
@@ -717,7 +716,7 @@ mod tests {
assert_eq!(existing_ai_for_first_turn.id, ai_message.id);
let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) =
get_message_and_user(&db, Some(user), &second_user_message.id)
get_message_and_user(&db, user, &second_user_message.id)
.await
.expect("expected second turn to load");
+1 -3
View File
@@ -5,14 +5,13 @@ mod references;
use axum::{
extract::FromRef,
routing::{get, post},
routing::get,
Router,
};
pub use chat_handlers::{
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
reload_sidebar, show_conversation_editing_title,
show_chat_base as show_base, show_existing_chat as show_existing,
show_initialized_chat as show_initialized,
};
use message_response_stream::get_response_stream;
use references::show_reference_tooltip;
@@ -37,7 +36,6 @@ where
get(show_conversation_editing_title).patch(patch_conversation_title),
)
.route("/chat/sidebar", get(reload_sidebar))
.route("/initialized-chat", post(show_initialized))
.route("/chat/response-stream", get(get_response_stream))
.route("/chat/reference/{id}", get(show_reference_tooltip))
}
@@ -51,21 +51,21 @@ impl ReferenceReasonStats {
match reason {
InvalidReferenceReason::Empty => self.empty = self.empty.saturating_add(1),
InvalidReferenceReason::UnsupportedPrefix => {
self.unsupported_prefix = self.unsupported_prefix.saturating_add(1)
self.unsupported_prefix = self.unsupported_prefix.saturating_add(1);
}
InvalidReferenceReason::MalformedUuid => {
self.malformed_uuid = self.malformed_uuid.saturating_add(1)
self.malformed_uuid = self.malformed_uuid.saturating_add(1);
}
InvalidReferenceReason::Duplicate => self.duplicate = self.duplicate.saturating_add(1),
InvalidReferenceReason::NotInContext => {
self.not_in_context = self.not_in_context.saturating_add(1)
self.not_in_context = self.not_in_context.saturating_add(1);
}
InvalidReferenceReason::NotFound => self.not_found = self.not_found.saturating_add(1),
InvalidReferenceReason::WrongUser => {
self.wrong_user = self.wrong_user.saturating_add(1)
self.wrong_user = self.wrong_user.saturating_add(1);
}
InvalidReferenceReason::OverLimit => {
self.over_limit = self.over_limit.saturating_add(1)
self.over_limit = self.over_limit.saturating_add(1);
}
}
}
+5 -18
View File
@@ -15,7 +15,7 @@ use crate::{
response_middleware::{HtmlError, TemplateResponse},
},
utils::text_content_preview::truncate_text_contents,
AuthSessionType,
utils::truncate::with_ellipsis,
};
use common::storage::types::user::DashboardStats;
use common::{
@@ -36,13 +36,9 @@ pub struct IndexPageData {
pub async fn index_handler(
State(state): State<HtmlState>,
auth: AuthSessionType,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let Some(user) = auth.current_user else {
return Ok(TemplateResponse::redirect("/signin"));
};
let (text_contents, stats, active_jobs) = try_join!(
let (text_contents, dashboard_stats, active_jobs) = try_join!(
User::get_latest_text_contents(&user.id, &state.db),
User::get_dashboard_stats(&user.id, &state.db),
User::get_unfinished_ingestion_tasks(&user.id, &state.db)
@@ -54,7 +50,7 @@ pub async fn index_handler(
"dashboard/base.html",
IndexPageData {
text_contents,
stats,
stats: dashboard_stats,
active_jobs,
},
))
@@ -223,7 +219,7 @@ pub async fn show_task_archive(
fn summarize_task_content(task: &IngestionTask) -> (String, String) {
match &task.content {
common::storage::types::ingestion_payload::IngestionPayload::Text { text, .. } => {
("Text".to_string(), truncate_summary(text, 80))
("Text".to_string(), with_ellipsis(text, 80))
}
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
("URL".to_string(), url.clone())
@@ -234,15 +230,6 @@ fn summarize_task_content(task: &IngestionTask) -> (String, String) {
}
}
fn truncate_summary(input: &str, max_chars: usize) -> String {
if input.chars().count() <= max_chars {
input.to_string()
} else {
let truncated: String = input.chars().take(max_chars).collect();
format!("{truncated}")
}
}
pub async fn serve_file(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
+1 -8
View File
@@ -11,20 +11,13 @@ use handlers::{
use crate::html_state::HtmlState;
pub fn public_router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new().route("/", get(index_handler))
}
pub fn protected_router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/", get(index_handler))
.route("/jobs/{job_id}", delete(delete_job))
.route("/jobs/archive", get(show_task_archive))
.route("/active-jobs", get(show_active_jobs))
+9 -17
View File
@@ -33,7 +33,6 @@ use crate::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
AuthSessionType,
};
type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
@@ -73,6 +72,11 @@ pub async fn hide_ingest_form(
))
}
#[derive(Serialize)]
struct NewTasksData {
tasks: Vec<IngestionTask>,
}
#[derive(Debug, TryFromMultipart)]
pub struct IngestionParams {
pub content: Option<String>,
@@ -95,9 +99,9 @@ pub async fn process_ingest_form(
);
}
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let content_bytes = input.content.as_ref().map_or(0, String::len);
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let ctx_len = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
@@ -126,7 +130,7 @@ pub async fn process_ingest_form(
user_id = %user.id,
has_content,
content_bytes,
context_bytes,
ctx_len,
category_bytes,
file_count,
"Received ingest form submission"
@@ -149,11 +153,6 @@ pub async fn process_ingest_form(
let tasks =
IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
#[derive(Serialize)]
struct NewTasksData {
tasks: Vec<IngestionTask>,
}
Ok(
TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks })
.into_response(),
@@ -172,21 +171,14 @@ fn create_error_stream(message: impl Into<String>) -> EventStream {
pub async fn get_task_updates_stream(
State(state): State<HtmlState>,
auth: AuthSessionType,
RequireUser(current_user): RequireUser,
Query(params): Query<QueryParams>,
) -> TaskSse {
let task_id = params.task_id.clone();
let db = Arc::clone(&state.db);
// 1. Check for authenticated user
let Some(current_user) = auth.current_user else {
return sse_with_keep_alive(create_error_stream("User not authenticated"));
};
// 2. Fetch task for initial authorization and to ensure it exists
match db.get_item::<IngestionTask>(&task_id).await {
Ok(Some(task)) => {
// 3. Validate user ownership
if task.user_id != current_user.id {
return sse_with_keep_alive(create_error_stream(
"Access denied: You do not have permission to view updates for this task.",
+11 -137
View File
@@ -1,18 +1,13 @@
use std::{
collections::{HashMap, HashSet},
fmt,
str::FromStr,
};
use std::collections::HashSet;
use axum::{
extract::{Query, State},
response::IntoResponse,
};
use common::storage::types::{text_content::TextContent, user::User, StoredObject};
use common::utils::serde_helpers::deserialize_flexible_id;
use common::storage::types::{text_content::TextContent, user::User};
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
use serde::{de, Deserialize, Deserializer, Serialize};
use surrealdb::RecordId;
use std::{fmt, str::FromStr};
use crate::{
html_state::HtmlState,
@@ -20,7 +15,6 @@ use crate::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
utils::truncate::{first_non_empty_line, truncate_with_ellipsis},
};
/// Serde deserialization decorator to map empty Strings to None,
@@ -37,86 +31,6 @@ where
}
}
fn source_id_suffix(source_id: &str) -> String {
let start = source_id.len().saturating_sub(8);
source_id[start..].to_string()
}
#[derive(Deserialize)]
struct UrlInfoLabel {
#[serde(default)]
title: String,
#[serde(default)]
url: String,
}
#[derive(Deserialize)]
struct FileInfoLabel {
#[serde(default)]
file_name: String,
}
#[derive(Deserialize)]
struct SourceLabelRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
#[serde(default)]
url_info: Option<UrlInfoLabel>,
#[serde(default)]
file_info: Option<FileInfoLabel>,
#[serde(default)]
context: Option<String>,
#[serde(default)]
category: String,
#[serde(default)]
text: String,
}
fn build_source_label(row: &SourceLabelRow) -> String {
const MAX_LABEL_CHARS: usize = 80;
if let Some(url_info) = row.url_info.as_ref() {
let title = url_info.title.trim();
if !title.is_empty() {
return title.to_string();
}
let url = url_info.url.trim();
if !url.is_empty() {
return url.to_string();
}
}
if let Some(file_info) = row.file_info.as_ref() {
let name = file_info.file_name.trim();
if !name.is_empty() {
return name.to_string();
}
}
if let Some(context) = row.context.as_ref() {
let trimmed = context.trim();
if !trimmed.is_empty() {
return truncate_with_ellipsis(trimmed, MAX_LABEL_CHARS);
}
}
if let Some(text_label) = first_non_empty_line(&row.text, MAX_LABEL_CHARS) {
return text_label;
}
let category = row.category.trim();
if !category.is_empty() {
return truncate_with_ellipsis(category, MAX_LABEL_CHARS);
}
format!("Text snippet: {}", source_id_suffix(&row.id))
}
fn fallback_source_label(source_id: &str) -> String {
format!("Text snippet: {}", source_id_suffix(source_id))
}
#[derive(Deserialize)]
pub struct SearchParams {
#[serde(default, deserialize_with = "empty_string_as_none")]
@@ -218,7 +132,7 @@ async fn perform_search(
_ => SearchResult::new(vec![], vec![]),
};
let source_label_map = resolve_source_labels(state, user, &search_result).await?;
let source_label_map = collect_source_label_map(state, user, &search_result).await?;
let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
@@ -227,7 +141,7 @@ async fn perform_search(
let source_label = source_label_map
.get(&chunk_result.chunk.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
.unwrap_or_else(|| TextContent::fallback_source_label(&chunk_result.chunk.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "text_chunk".to_string(),
score: chunk_result.score,
@@ -246,7 +160,9 @@ async fn perform_search(
let source_label = source_label_map
.get(&entity_result.entity.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
.unwrap_or_else(|| {
TextContent::fallback_source_label(&entity_result.entity.source_id)
});
combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(),
score: entity_result.score,
@@ -269,11 +185,11 @@ async fn perform_search(
Ok((combined_results, trimmed_query.to_string()))
}
async fn resolve_source_labels(
async fn collect_source_label_map(
state: &HtmlState,
user: &User,
search_result: &SearchResult,
) -> Result<HashMap<String, String>, HtmlError> {
) -> Result<std::collections::HashMap<String, String>, HtmlError> {
let mut source_ids = HashSet::new();
for chunk_result in &search_result.chunks {
source_ids.insert(chunk_result.chunk.source_id.clone());
@@ -282,47 +198,5 @@ async fn resolve_source_labels(
source_ids.insert(entity_result.entity.source_id.clone());
}
if source_ids.is_empty() {
return Ok(HashMap::new());
}
let record_ids: Vec<RecordId> = source_ids
.iter()
.filter_map(|id| {
if id.contains(':') {
RecordId::from_str(id).ok()
} else {
Some(RecordId::from_table_key(TextContent::table_name(), id))
}
})
.collect();
let mut response = state
.db
.client
.query(
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
)
.bind(("table_name", TextContent::table_name()))
.bind(("user_id", user.id.clone()))
.bind(("record_ids", record_ids))
.await?;
let contents: Vec<SourceLabelRow> = response.take(0)?;
tracing::debug!(
source_id_count = source_ids.len(),
label_row_count = contents.len(),
"Resolved search source labels"
);
let mut labels = HashMap::new();
for content in contents {
let label = build_source_label(&content);
labels.insert(content.id.clone(), label.clone());
labels.insert(
format!("{}:{}", TextContent::table_name(), content.id),
label,
);
}
Ok(labels)
Ok(TextContent::resolve_source_labels(&state.db, &user.id, source_ids).await?)
}
+1
View File
@@ -1,2 +1,3 @@
pub mod pagination;
pub mod text_content_preview;
pub mod truncate;
+4 -16
View File
@@ -1,26 +1,14 @@
use common::storage::types::text_content::TextContent;
use super::truncate::with_ellipsis;
const TEXT_PREVIEW_LENGTH: usize = 50;
fn maybe_truncate(value: &str) -> Option<String> {
for (char_count, (idx, _)) in value.char_indices().enumerate() {
if char_count == TEXT_PREVIEW_LENGTH {
return Some(value[..idx].to_string());
}
}
None
}
pub fn truncate_text_content(mut content: TextContent) -> TextContent {
if let Some(truncated) = maybe_truncate(&content.text) {
content.text = truncated;
}
content.text = with_ellipsis(&content.text, TEXT_PREVIEW_LENGTH);
if let Some(context) = content.context.as_mut() {
if let Some(truncated) = maybe_truncate(context) {
*context = truncated;
}
*context = with_ellipsis(context, TEXT_PREVIEW_LENGTH);
}
content
+61
View File
@@ -0,0 +1,61 @@
const ELLIPSIS: &str = "";
/// Truncates `value` to at most `max_chars` Unicode scalars, appending an ellipsis when shortened.
pub fn with_ellipsis(value: &str, max_chars: usize) -> String {
if max_chars == 0 {
return if value.is_empty() {
String::new()
} else {
ELLIPSIS.to_string()
};
}
let mut end_byte = value.len();
for (count, (idx, _)) in value.char_indices().enumerate() {
if count == max_chars {
end_byte = idx;
break;
}
}
if end_byte == value.len() {
return value.to_string();
}
format!("{}{}", &value[..end_byte], ELLIPSIS)
}
/// Returns the first non-empty line of `text`, truncated with an ellipsis when needed.
pub fn first_non_empty_line(text: &str, max_chars: usize) -> Option<String> {
text.lines().find_map(|line| {
let trimmed = line.trim();
if trimmed.is_empty() {
None
} else {
Some(with_ellipsis(trimmed, max_chars))
}
})
}
#[cfg(test)]
mod tests {
use super::{first_non_empty_line, with_ellipsis};
#[test]
fn leaves_short_strings_unchanged() {
assert_eq!(with_ellipsis("hello", 10), "hello");
}
#[test]
fn truncates_at_char_boundary_with_ellipsis() {
assert_eq!(with_ellipsis("hello world", 5), "hello…");
}
#[test]
fn first_non_empty_line_skips_blank_lines() {
assert_eq!(
first_non_empty_line("\n \nTitle line\nBody", 20),
Some("Title line".to_string())
);
}
}