clippy: adhere to pedantic clippy, uniform test error handling

This commit is contained in:
Per Stark
2026-05-26 11:43:45 +02:00
parent 6a5d631287
commit 000852c94c
68 changed files with 2468 additions and 2547 deletions
+27 -26
View File
@@ -39,30 +39,33 @@ const CONVERSATION_ARCHIVE_CACHE_TTL: Duration = Duration::from_secs(30);
const CONVERSATION_ARCHIVE_CACHE_MAX_USERS: usize = 1024;
const CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL: usize = 64;
pub struct StateResources {
pub db: Arc<SurrealDbClient>,
pub openai_client: Arc<OpenAIClientType>,
pub session_store: Arc<SessionStoreType>,
pub storage: StorageManager,
pub config: AppConfig,
pub reranker_pool: Option<Arc<RerankerPool>>,
pub embedding_provider: Arc<EmbeddingProvider>,
pub template_engine: Option<Arc<TemplateEngine>>,
}
impl HtmlState {
pub async fn new_with_resources(
db: Arc<SurrealDbClient>,
openai_client: Arc<OpenAIClientType>,
session_store: Arc<SessionStoreType>,
storage: StorageManager,
config: AppConfig,
reranker_pool: Option<Arc<RerankerPool>>,
embedding_provider: Arc<EmbeddingProvider>,
template_engine: Option<Arc<TemplateEngine>>,
) -> Self {
let templates =
template_engine.unwrap_or_else(|| Arc::new(create_template_engine!("templates")));
pub fn new_with_resources(resources: StateResources) -> Self {
let templates = resources
.template_engine
.unwrap_or_else(|| Arc::new(create_template_engine!("templates")));
debug!("Template engine configured for html_router.");
Self {
db,
openai_client,
session_store,
db: resources.db,
openai_client: resources.openai_client,
templates,
config,
storage,
reranker_pool,
embedding_provider,
session_store: resources.session_store,
config: resources.config,
storage: resources.storage,
reranker_pool: resources.reranker_pool,
embedding_provider: resources.embedding_provider,
conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())),
conversation_archive_cache_writes: Arc::new(AtomicUsize::new(0)),
}
@@ -210,18 +213,16 @@ mod tests {
EmbeddingProvider::new_hashed(8).expect("Failed to create embedding provider"),
);
HtmlState::new_with_resources(
HtmlState::new_with_resources(StateResources {
db,
Arc::new(async_openai::Client::new()),
openai_client: Arc::new(async_openai::Client::new()),
session_store,
storage,
config,
None,
reranker_pool: None,
embedding_provider,
None,
)
.await
.expect("Failed to create HtmlState")
template_engine: None,
})
}
#[tokio::test]
+1 -1
View File
@@ -2,6 +2,6 @@ use tower_http::compression::CompressionLayer;
/// Provides a default compression layer that negotiates encoding based on the
/// `Accept-Encoding` header of the incoming request.
pub fn compression_layer() -> CompressionLayer {
pub fn layer() -> CompressionLayer {
CompressionLayer::new()
}
@@ -10,7 +10,7 @@ use axum::{
use axum_htmx::{HxRequest, HX_TRIGGER};
use common::{
error::AppError,
utils::template_engine::{ProvidesTemplateEngine, Value},
utils::template_engine::{ProvidesTemplateEngine, TemplateEngine, Value},
};
use minijinja::context;
use serde::Serialize;
@@ -146,6 +146,40 @@ struct ContextWrapper<'a> {
context: HashMap<String, Value>,
}
const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"];
fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) {
for &header_name in HTMX_HEADERS_TO_FORWARD {
if let Ok(name) = HeaderName::from_bytes(header_name.as_bytes()) {
if let Some(value) = from.get(&name) {
to.insert(name.clone(), value.clone());
}
}
}
}
fn context_to_map(
value: &Value,
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
match value.kind() {
minijinja::value::ValueKind::Map => {
let mut map = HashMap::new();
if let Ok(keys) = value.try_iter() {
for key in keys {
if let Ok(val) = value.get_item(&key) {
map.insert(key.to_string(), val);
}
}
}
Ok(map)
}
minijinja::value::ValueKind::None | minijinja::value::ValueKind::Undefined => {
Ok(HashMap::new())
}
other => Err(other),
}
}
pub async fn with_template_response<S>(
State(state): State<S>,
HxRequest(is_htmx): HxRequest,
@@ -158,14 +192,12 @@ where
let mut user_theme = Theme::System.as_str();
let mut initial_theme = Theme::System.initial_theme();
let mut is_authenticated = false;
let mut current_user_id = None;
let mut current_user = None;
{
if let Some(auth) = req.extensions().get::<AuthSessionType>() {
if let Some(user) = &auth.current_user {
is_authenticated = true;
current_user_id = Some(user.id.clone());
user_theme = user.theme.as_str();
initial_theme = user.theme.initial_theme();
current_user = Some(TemplateUser::from(user));
@@ -175,9 +207,6 @@ where
let response = next.run(req).await;
// Headers to forward from the original response
const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"];
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
let template_engine = state.template_engine();
@@ -187,56 +216,23 @@ where
matches!(&template_response.template_kind, TemplateKind::Full(_));
if should_load_conversation_archive {
if let Some(user_id) = current_user_id {
if let Some(user_id) = current_user.as_ref().map(|u| &u.id) {
let html_state = state.html_state();
if let Some(cached_archive) =
html_state.get_cached_conversation_archive(&user_id).await
html_state.get_cached_conversation_archive(user_id).await
{
conversation_archive = cached_archive;
} else if let Ok(archive) =
Conversation::get_user_sidebar_conversations(&user_id, &html_state.db).await
Conversation::get_user_sidebar_conversations(user_id, &html_state.db).await
{
html_state
.set_cached_conversation_archive(&user_id, archive.clone())
.set_cached_conversation_archive(user_id, archive.clone())
.await;
conversation_archive = archive;
}
}
}
fn context_to_map(
value: &Value,
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
match value.kind() {
minijinja::value::ValueKind::Map => {
let mut map = HashMap::new();
if let Ok(keys) = value.try_iter() {
for key in keys {
if let Ok(val) = value.get_item(&key) {
map.insert(key.to_string(), val);
}
}
}
Ok(map)
}
minijinja::value::ValueKind::None | minijinja::value::ValueKind::Undefined => {
Ok(HashMap::new())
}
other => Err(other),
}
}
// Helper to forward relevant headers
fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) {
for &header_name in HTMX_HEADERS_TO_FORWARD {
if let Ok(name) = HeaderName::from_bytes(header_name.as_bytes()) {
if let Some(value) = from.get(&name) {
to.insert(name.clone(), value.clone());
}
}
}
}
let context_map = match context_to_map(&template_response.context) {
Ok(map) => map,
Err(kind) => {
@@ -290,18 +286,17 @@ where
}
TemplateKind::Error(status) => {
if is_htmx {
// HTMX request: Send 204 + HX-Trigger for toast
let title = template_response
.context
.get_attr("title")
.ok()
.and_then(|v| v.as_str().map(String::from))
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "Error".to_string());
let description = template_response
.context
.get_attr("description")
.ok()
.and_then(|v| v.as_str().map(String::from))
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "An error occurred.".to_string());
let trigger_payload = json!({"toast": {"title": title, "description": description, "type": "error"}});
@@ -312,14 +307,12 @@ where
});
(StatusCode::NO_CONTENT, [(HX_TRIGGER, trigger_value)], "").into_response()
} else {
// Non-HTMX request: Render the full errors/error.html page
match template_engine
.render("errors/error.html", &Value::from_serialize(&context))
{
Ok(html) => (*status, Html(html)).into_response(),
Err(e) => {
error!("Critical: Failed to render 'errors/error.html': {:?}", e);
// Fallback HTML, but use the intended status code
(*status, Html(fallback_error())).into_response()
}
}
+9 -2
View File
@@ -9,7 +9,7 @@ use crate::{
html_state::HtmlState,
middlewares::{
analytics_middleware::analytics_middleware, auth_middleware::require_auth,
compression::compression_layer, response_middleware::with_template_response,
compression, response_middleware::with_template_response,
},
};
@@ -71,6 +71,7 @@ where
}
// Add a serving of assets
#[must_use]
pub fn with_public_assets(mut self, path: &str, directory: &str) -> Self {
self.public_assets_config = Some(AssetsConfig {
path: path.to_string(),
@@ -80,24 +81,28 @@ where
}
// Add a public router that will be merged at the root level
#[must_use]
pub fn add_public_routes(mut self, routes: Router<S>) -> Self {
self.public_routers.push(routes);
self
}
// Add a protected router that will be merged at the root level
#[must_use]
pub fn add_protected_routes(mut self, routes: Router<S>) -> Self {
self.protected_routers.push(routes);
self
}
// Nest a public router under a path prefix
#[must_use]
pub fn nest_public_routes(mut self, path: &str, routes: Router<S>) -> Self {
self.nested_routes.push((path.to_string(), routes));
self
}
// Nest a protected router under a path prefix
#[must_use]
pub fn nest_protected_routes(mut self, path: &str, routes: Router<S>) -> Self {
self.nested_protected_routes
.push((path.to_string(), routes));
@@ -105,6 +110,7 @@ where
}
// Add custom middleware to be applied before the standard ones
#[must_use]
pub fn with_middleware<F>(mut self, middleware_fn: F) -> Self
where
F: FnOnce(Router<S>) -> Router<S> + Send + 'static,
@@ -114,6 +120,7 @@ where
}
/// Enables response compression when building the router.
#[must_use]
pub const fn with_compression(mut self) -> Self {
self.compression_enabled = true;
self
@@ -191,7 +198,7 @@ where
// Apply Global Middleware (Compression)
if self.compression_enabled {
final_router = final_router.layer(compression_layer());
final_router = final_router.layer(compression::layer());
}
final_router
+3 -3
View File
@@ -62,7 +62,7 @@ pub async fn set_api_key(
let api_key = User::set_api_key(&user.id, &state.db).await?;
// Clear the cache so new requests have access to the user with api key
auth.cache_clear_user(user.id.to_string());
auth.cache_clear_user(user.id.clone());
// Render the API key section block
Ok(TemplateResponse::new_partial(
@@ -106,7 +106,7 @@ pub async fn update_timezone(
User::update_timezone(&user.id, &form.timezone, &state.db).await?;
// Clear the cache
auth.cache_clear_user(user.id.to_string());
auth.cache_clear_user(user.id.clone());
let timezones = TZ_VARIANTS
.iter()
@@ -141,7 +141,7 @@ pub async fn update_theme(
User::update_theme(&user.id, &form.theme, &state.db).await?;
// Clear the cache
auth.cache_clear_user(user.id.to_string());
auth.cache_clear_user(user.id.clone());
let theme_options = vec![
Theme::Light.as_str().to_string(),
+7 -12
View File
@@ -1,3 +1,5 @@
use std::sync::Arc;
use async_openai::types::ListModelResponse;
use axum::{
extract::{Query, State},
@@ -37,18 +39,14 @@ pub struct AdminPanelData {
current_section: AdminSection,
}
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum AdminSection {
#[default]
Overview,
Models,
}
impl Default for AdminSection {
fn default() -> Self {
Self::Overview
}
}
#[derive(Deserialize)]
pub struct AdminPanelQuery {
@@ -107,10 +105,7 @@ fn checkbox_to_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
D: serde::Deserializer<'de>,
{
match String::deserialize(deserializer) {
Ok(string) => Ok(string == "on"),
Err(_) => Ok(false),
}
String::deserialize(deserializer).map(|s| s == "on")
}
#[derive(Deserialize)]
@@ -219,8 +214,8 @@ pub async fn update_model_settings(
if reembedding_needed {
info!("Embedding dimensions changed. Spawning background re-embedding task...");
let db_for_task = state.db.clone();
let openai_for_task = state.openai_client.clone();
let db_for_task = Arc::clone(&state.db);
let openai_for_task = Arc::clone(&state.openai_client);
let new_model_for_task = new_settings.embedding_model.clone();
let new_dims_for_task = new_settings.embedding_dimensions;
+2 -2
View File
@@ -11,7 +11,7 @@ use crate::{
};
#[derive(Deserialize, Serialize)]
pub struct SignupParams {
pub struct Params {
pub email: String,
pub password: String,
pub timezone: String,
@@ -39,7 +39,7 @@ pub async fn show_signup_form(
pub async fn process_signup_and_show_verification(
State(state): State<HtmlState>,
auth: AuthSessionType,
Form(form): Form<SignupParams>,
Form(form): Form<Params>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match User::create_new(
form.email,
+30 -27
View File
@@ -49,6 +49,8 @@ pub struct ChatPageData {
conversation: Option<Conversation>,
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn show_initialized_chat(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
@@ -57,14 +59,14 @@ pub async fn show_initialized_chat(
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
let user_message = Message::new(
conversation.id.to_string(),
conversation.id.clone(),
MessageRole::User,
form.user_query,
None,
);
let ai_message = Message::new(
conversation.id.to_string(),
conversation.id.clone(),
MessageRole::AI,
form.llm_response,
Some(form.references),
@@ -86,10 +88,9 @@ pub async fn show_initialized_chat(
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response)
}
@@ -130,12 +131,19 @@ pub async fn show_existing_chat(
))
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_user_message(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let conversation: Conversation = state
.db
.get_item(&conversation_id)
@@ -150,33 +158,34 @@ pub async fn new_user_message(
state.db.store_item(user_message.clone()).await?;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let mut response = TemplateResponse::new_template(
"chat/streaming_response.html",
SSEResponseInitData { user_message },
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response)
}
/// # Panics
/// Panics if the HX-Push header value cannot be parsed.
pub async fn new_chat_user_message(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
conversation: Conversation,
}
let Some(user) = auth.current_user else {
return Ok(Redirect::to("/").into_response());
};
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
@@ -191,11 +200,6 @@ pub async fn new_chat_user_message(
state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
conversation: Conversation,
}
let mut response = TemplateResponse::new_template(
"chat/new_chat_first_response.html",
SSEResponseInitData {
@@ -205,10 +209,9 @@ pub async fn new_chat_user_message(
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
response.headers_mut().insert("HX-Push", header_value);
}
Ok(response.into_response())
}
@@ -53,26 +53,22 @@ fn sse_with_keep_alive(stream: EventStream) -> SseResponse {
)
}
// Error handling function
fn create_error_stream(message: impl Into<String>) -> EventStream {
let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
}
// Helper function to get message and user
async fn get_message_and_user(
db: &SurrealDbClient,
current_user: Option<User>,
message_id: &str,
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
// Check authentication
let Some(user) = current_user else {
return Err(sse_with_keep_alive(create_error_stream(
"You must be signed in to use this feature",
)));
};
// Retrieve message
let message = match db.get_item::<Message>(message_id).await {
Ok(Some(message)) => message,
Ok(None) => {
@@ -88,7 +84,6 @@ async fn get_message_and_user(
}
};
// Get conversation history
let (conversation, history) =
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
{
@@ -209,7 +204,6 @@ pub async fn get_response_stream(
auth: AuthSessionType,
Query(params): Query<QueryParams>,
) -> SseResponse {
// 1. Authentication and initial data validation
let (user_message, user, _conversation, history, existing_ai_response) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
Ok((user_message, user, conversation, history, existing_ai_response)) => (
@@ -226,9 +220,123 @@ pub async fn get_response_stream(
return create_replayed_response_stream(&state, existing_ai_message);
}
// 2. Retrieve knowledge entities
let (request, allowed_reference_ids) = match prepare_chat_request(&state, &user_message, &user, &history).await {
Ok(result) => result,
Err(sse) => return sse,
};
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
}
};
build_chat_event_stream(state, openai_stream, &user_message, user.id.clone(), allowed_reference_ids)
}
fn build_chat_event_stream(
state: HtmlState,
openai_stream: impl Stream<Item = Result<async_openai::types::CreateChatCompletionStreamResponse, async_openai::error::OpenAIError>> + Send + 'static,
user_message: &Message,
user_id: String,
allowed_reference_ids: Vec<String>,
) -> SseResponse {
let (tx, rx) = channel::<String>(1000);
let (tx_final, mut rx_final) = channel::<Message>(1);
spawn_storage_task(Arc::clone(&state.db), rx, tx_final, user_message, user_id, allowed_reference_ids);
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx.clone();
let json_state = Arc::clone(&json_state);
stream! {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
let _ = tx_storage.send(content.clone()).await;
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
}
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {e}")));
}
}
}
})
.flatten()
.chain(stream::once(async move {
#[derive(Serialize)]
struct LocalReferenceData {
message: Message,
}
if let Some(message) = rx_final.recv().await {
if message
.references
.as_ref()
.is_some_and(std::vec::Vec::is_empty)
{
return Ok(Event::default().event("empty"));
}
match state.templates.render(
"chat/reference_list.html",
&Value::from_serialize(LocalReferenceData { message }),
) {
Ok(html) => Ok(Event::default().event("references").data(html)),
Err(_) => Ok(Event::default()
.event("error")
.data("Failed to render references")),
}
} else {
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}))
.boxed();
sse_with_keep_alive(event_stream)
}
async fn prepare_chat_request(
state: &HtmlState,
user_message: &Message,
user: &User,
history: &[Message],
) -> Result<
(async_openai::types::CreateChatCompletionRequest, Vec<String>),
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
> {
let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await),
Some(pool) => pool.checkout().await,
None => None,
};
@@ -248,59 +356,49 @@ pub async fn get_response_stream(
{
Ok(result) => result,
Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge"));
return Err(Sse::new(create_error_stream("Failed to retrieve knowledge")));
}
};
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
// 3. Create the OpenAI request with appropriate context format
let context_json = match &retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
let context_json = match retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(entities)
}
retrieval_pipeline::StrategyOutput::Search(search_result) => {
// For chat, use chunks from the search result
chunks_to_chat_context(&search_result.chunks)
}
};
let formatted_user_message =
create_user_message_with_history(&context_json, &history, &user_message.content);
create_user_message_with_history(&context_json, history, &user_message.content);
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings"));
return Err(Sse::new(create_error_stream("Failed to retrieve system settings")));
};
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
return sse_with_keep_alive(create_error_stream("Failed to create chat request"));
return Err(Sse::new(create_error_stream("Failed to create chat request")));
};
// 4. Set up the OpenAI stream
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
}
};
Ok((request, allowed_reference_ids))
}
// 5. Create channel for collecting complete response
let (tx, mut rx) = channel::<String>(1000);
let tx_clone = tx.clone();
let (tx_final, mut rx_final) = channel::<Message>(1);
fn spawn_storage_task(
db_client: Arc<SurrealDbClient>,
mut rx: tokio::sync::mpsc::Receiver<String>,
tx_final: tokio::sync::mpsc::Sender<Message>,
user_message: &Message,
user_id: String,
allowed_reference_ids: Vec<String>,
) {
let conversation_id = user_message.conversation_id.clone();
// 6. Set up the collection task for DB storage
let db_client = Arc::clone(&state.db);
let user_id = user.id.clone();
let allowed_reference_ids = allowed_reference_ids.clone();
tokio::spawn(async move {
drop(tx); // Close sender when no longer needed
// Collect full response
let mut full_json = String::new();
while let Some(chunk) = rx.recv().await {
full_json.push_str(&chunk);
}
// Try to extract structured data
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let raw_references = extract_reference_strings(&response);
let answer = response.answer;
@@ -347,7 +445,7 @@ pub async fn get_response_stream(
);
let ai_message = Message::new(
user_message.conversation_id,
conversation_id,
MessageRole::AI,
answer,
Some(initial_validation.valid_refs),
@@ -362,104 +460,11 @@ pub async fn get_response_stream(
} else {
error!("Failed to parse LLM response as structured format");
// Fallback - store raw response
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
full_json,
None,
);
let ai_message = Message::new(conversation_id, MessageRole::AI, full_json, None);
let _ = db_client.store_item(ai_message).await;
}
});
// Create a shared state for tracking the JSON parsing
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
// 7. Create the response event stream
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx_clone.clone();
let json_state = Arc::clone(&json_state);
stream! {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
// Always send raw content to storage
let _ = tx_storage.send(content.clone()).await;
// Process through JSON parser
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
// If display_content is empty, don't yield anything
}
// If content is empty, don't yield anything
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {e}")));
}
}
}
})
.flatten()
.chain(stream::once(async move {
if let Some(message) = rx_final.recv().await {
// Don't send any event if references is empty
if message
.references
.as_ref()
.is_some_and(std::vec::Vec::is_empty)
{
return Ok(Event::default().event("empty")); // This event won't be sent
}
// Render template with references
match state.templates.render(
"chat/reference_list.html",
&Value::from_serialize(ReferenceData { message }),
) {
Ok(html) => {
// Return the rendered HTML
Ok(Event::default().event("references").data(html))
}
Err(_) => {
// Handle template rendering error
Ok(Event::default()
.event("error")
.data("Failed to render references"))
}
}
} else {
// Handle case where no references were received
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}));
sse_with_keep_alive(event_stream.boxed())
}
struct StreamParserState {
@@ -478,23 +483,18 @@ impl StreamParserState {
}
fn process_chunk(&mut self, chunk: &str) -> String {
// Feed all characters into the parser
for c in chunk.chars() {
let _ = self.parser.add_char(c);
}
// Get the current state of the JSON
let json = self.parser.get_result();
// Check if we're in the answer field
if let Some(obj) = json.as_object() {
if let Some(answer) = obj.get("answer") {
self.in_answer_field = true;
// Get current answer content
let current_content = answer.as_str().unwrap_or_default().to_string();
// Calculate difference to send only new content
if current_content.len() > self.last_answer_content.len() {
let new_content = current_content[self.last_answer_content.len()..].to_string();
self.last_answer_content = current_content;
@@ -503,7 +503,6 @@ impl StreamParserState {
}
}
// No new content to return
String::new()
}
}
+6 -5
View File
@@ -10,8 +10,9 @@ use axum::{
};
pub use chat_handlers::{
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat,
show_initialized_chat,
reload_sidebar, show_conversation_editing_title,
show_chat_base as show_base, show_existing_chat as show_existing,
show_initialized_chat as show_initialized,
};
use message_response_stream::get_response_stream;
use references::show_reference_tooltip;
@@ -24,10 +25,10 @@ where
HtmlState: FromRef<S>,
{
Router::new()
.route("/chat", get(show_chat_base).post(new_chat_user_message))
.route("/chat", get(show_base).post(new_chat_user_message))
.route(
"/chat/{id}",
get(show_existing_chat)
get(show_existing)
.post(new_user_message)
.delete(delete_conversation),
)
@@ -36,7 +37,7 @@ where
get(show_conversation_editing_title).patch(patch_conversation_title),
)
.route("/chat/sidebar", get(reload_sidebar))
.route("/initialized-chat", post(show_initialized_chat))
.route("/initialized-chat", post(show_initialized))
.route("/chat/response-stream", get(get_response_stream))
.route("/chat/reference/{id}", get(show_reference_tooltip))
}
+5 -4
View File
@@ -102,13 +102,13 @@ pub async fn show_text_content_edit_form(
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
#[derive(Serialize)]
pub struct TextContentEditModal {
pub text_content: TextContent,
}
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
"content/edit_text_content_modal.html",
TextContentEditModal { text_content },
@@ -214,13 +214,14 @@ pub async fn show_content_read_modal(
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// Get and validate the text content
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
#[derive(Serialize)]
pub struct TextContentReadModalData {
pub text_content: TextContent,
}
// Get and validate the text content
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
"content/read_content_modal.html",
TextContentReadModalData { text_content },
+5 -7
View File
@@ -226,7 +226,7 @@ fn summarize_task_content(task: &IngestionTask) -> (String, String) {
("Text".to_string(), truncate_summary(text, 80))
}
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
("URL".to_string(), url.to_string())
("URL".to_string(), url.clone())
}
common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => {
("File".to_string(), file_info.file_name.clone())
@@ -248,18 +248,16 @@ pub async fn serve_file(
RequireUser(user): RequireUser,
Path(file_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let file_info = match FileInfo::get_by_id(&file_id, &state.db).await {
Ok(info) => info,
_ => return Ok(TemplateResponse::not_found().into_response()),
let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else {
return Ok(TemplateResponse::not_found().into_response());
};
if file_info.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response());
}
let stream = match state.storage.get_stream(&file_info.path).await {
Ok(s) => s,
Err(_) => return Ok(TemplateResponse::server_error().into_response()),
let Ok(stream) = state.storage.get_stream(&file_info.path).await else {
return Ok(TemplateResponse::server_error().into_response());
};
let body = Body::from_stream(stream);
+6 -6
View File
@@ -1,4 +1,4 @@
use std::{pin::Pin, time::Duration};
use std::{pin::Pin, sync::Arc, time::Duration};
use axum::{
extract::{Query, State},
@@ -51,13 +51,13 @@ pub async fn show_ingest_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
#[derive(Serialize)]
pub struct ShowIngestFormData {
user_categories: Vec<String>,
}
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
"ingestion_modal.html",
ShowIngestFormData { user_categories },
@@ -180,7 +180,7 @@ pub async fn get_task_updates_stream(
Query(params): Query<QueryParams>,
) -> TaskSse {
let task_id = params.task_id.clone();
let db = state.db.clone();
let db = Arc::clone(&state.db);
// 1. Check for authenticated user
let Some(current_user) = auth.current_user else {
@@ -198,7 +198,7 @@ pub async fn get_task_updates_stream(
}
let sse_stream = async_stream::stream! {
let mut consecutive_db_errors = 0;
let mut consecutive_db_errors: u32 = 0;
let max_consecutive_db_errors = 3;
loop {
@@ -263,7 +263,7 @@ pub async fn get_task_updates_stream(
}
Err(db_err) => {
error!("Database error while fetching task '{}': {:?}", task_id, db_err);
consecutive_db_errors += 1;
consecutive_db_errors = consecutive_db_errors.saturating_add(1);
yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {consecutive_db_errors}).")));
if consecutive_db_errors >= max_consecutive_db_errors {
+21 -20
View File
@@ -39,7 +39,7 @@ use url::form_urlencoded;
const KNOWLEDGE_ENTITIES_PER_PAGE: usize = 12;
const RELATIONSHIP_TYPE_OPTIONS: &[&str] = &["RelatedTo", "RelevantTo", "SimilarTo", "References"];
const DEFAULT_RELATIONSHIP_TYPE: &str = RELATIONSHIP_TYPE_OPTIONS[0];
const DEFAULT_RELATIONSHIP_TYPE: &str = "RelatedTo";
const MAX_RELATIONSHIP_SUGGESTIONS: usize = 10;
const SUGGESTION_MIN_SCORE: f32 = 0.5;
@@ -61,15 +61,15 @@ fn canonicalize_relationship_type(value: &str) -> String {
let key: String = trimmed
.chars()
.filter(|c| c.is_ascii_alphanumeric())
.flat_map(|c| c.to_lowercase())
.filter(char::is_ascii_alphanumeric)
.flat_map(char::to_lowercase)
.collect();
for option in RELATIONSHIP_TYPE_OPTIONS {
let option_key: String = option
.chars()
.filter(|c| c.is_ascii_alphanumeric())
.flat_map(|c| c.to_lowercase())
.filter(char::is_ascii_alphanumeric)
.flat_map(char::to_lowercase)
.collect();
if option_key == key {
return (*option).to_string();
@@ -141,7 +141,7 @@ pub async fn show_new_knowledge_entity_form(
) -> Result<impl IntoResponse, HtmlError> {
let entity_types: Vec<String> = KnowledgeEntityType::variants()
.iter()
.map(|&s| s.to_owned())
.map(ToString::to_string)
.collect();
let existing_entities = User::get_knowledge_entities(&user.id, &state.db).await?;
@@ -278,7 +278,7 @@ pub async fn suggest_knowledge_relationships(
if !query_parts.is_empty() {
let query = query_parts.join(" ");
let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await),
Some(pool) => pool.checkout().await,
None => None,
};
@@ -406,9 +406,10 @@ fn build_relationship_table_data(
.map(|relationship| {
let relationship_type_label =
canonicalize_relationship_type(&relationship.metadata.relationship_type);
*frequency
let count = frequency
.entry(relationship_type_label.clone())
.or_insert(0) += 1;
.or_insert(0);
*count = count.saturating_add(1);
RelationshipTableRow {
relationship,
relationship_type_label,
@@ -417,9 +418,7 @@ fn build_relationship_table_data(
.collect();
let default_relationship_type = frequency
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(label, _)| label)
.unwrap_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string());
.max_by_key(|(_, count)| *count).map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label);
RelationshipTableData {
entities,
@@ -800,8 +799,10 @@ pub async fn get_knowledge_graph_json(
for rel in &relationships {
if entity_ids.contains(&rel.in_) && entity_ids.contains(&rel.out) {
// undirected counting for degree
*degree_count.entry(rel.in_.clone()).or_insert(0) += 1;
*degree_count.entry(rel.out.clone()).or_insert(0) += 1;
let count = degree_count.entry(rel.in_.clone()).or_insert(0);
*count = count.saturating_add(1);
let count = degree_count.entry(rel.out.clone()).or_insert(0);
*count = count.saturating_add(1);
links.push(GraphLink {
source: rel.out.clone(),
target: rel.in_.clone(),
@@ -836,11 +837,11 @@ fn normalize_filter(input: Option<String>) -> Option<String> {
fn trim_matching_quotes(value: &str) -> &str {
let bytes = value.as_bytes();
if bytes.len() >= 2 {
let first = bytes[0];
let last = bytes[bytes.len() - 1];
if (first == b'"' && last == b'"') || (first == b'\'' && last == b'\'') {
return &value[1..value.len() - 1];
if let (Some(&first), Some(&last)) = (bytes.first(), bytes.last()) {
if bytes.len() >= 2
&& ((first == b'"' && last == b'"') || (first == b'\'' && last == b'\''))
{
return &value[1..value.len().saturating_sub(1)];
}
}
value
@@ -860,7 +861,7 @@ pub async fn show_edit_knowledge_entity_form(
// Get entity types
let entity_types: Vec<String> = KnowledgeEntityType::variants()
.iter()
.map(|&s| s.to_owned())
.map(ToString::to_string)
.collect();
// Get the entity and validate ownership
+163 -154
View File
@@ -11,6 +11,7 @@ use axum::{
use common::storage::types::{
serde_helpers::deserialize_flexible_id,
text_content::TextContent,
user::User,
StoredObject,
};
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
@@ -46,13 +47,11 @@ fn source_id_suffix(source_id: &str) -> String {
fn truncate_label(value: &str, max_chars: usize) -> String {
let mut end = None;
let mut count = 0;
for (idx, _) in value.char_indices() {
for (count, (idx, _)) in value.char_indices().enumerate() {
if count == max_chars {
end = Some(idx);
break;
}
count += 1;
}
match end {
@@ -174,165 +173,31 @@ struct KnowledgeEntityForTemplate {
score: f32,
}
#[derive(Serialize)]
struct SearchResultForTemplate {
result_type: String,
score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
text_chunk: Option<TextChunkForTemplate>,
#[serde(skip_serializing_if = "Option::is_none")]
knowledge_entity: Option<KnowledgeEntityForTemplate>,
}
#[derive(Serialize)]
pub struct AnswerData {
search_result: Vec<SearchResultForTemplate>,
query_param: String,
}
pub async fn search_result_handler(
State(state): State<HtmlState>,
Query(params): Query<SearchParams>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
struct SearchResultForTemplate {
result_type: String,
score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
text_chunk: Option<TextChunkForTemplate>,
#[serde(skip_serializing_if = "Option::is_none")]
knowledge_entity: Option<KnowledgeEntityForTemplate>,
}
#[derive(Serialize)]
pub struct AnswerData {
search_result: Vec<SearchResultForTemplate>,
query_param: String,
}
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
params.query
{
let trimmed_query = actual_query.trim();
if trimmed_query.is_empty() {
(Vec::<SearchResultForTemplate>::new(), String::new())
} else {
// Use retrieval pipeline Search strategy
let config = RetrievalConfig::for_search(SearchTarget::Both);
// Checkout a reranker lease if pool is available
let reranker_lease = match &state.reranker_pool {
Some(pool) => Some(pool.checkout().await),
None => None,
};
let result = retrieval_pipeline::pipeline::run_pipeline(
&state.db,
&state.openai_client,
Some(&state.embedding_provider),
trimmed_query,
&user.id,
config,
reranker_lease,
)
.await?;
let search_result = match result {
StrategyOutput::Search(sr) => sr,
_ => SearchResult::new(vec![], vec![]),
};
let mut source_ids = HashSet::new();
for chunk_result in &search_result.chunks {
source_ids.insert(chunk_result.chunk.source_id.clone());
}
for entity_result in &search_result.entities {
source_ids.insert(entity_result.entity.source_id.clone());
}
let source_label_map = if source_ids.is_empty() {
HashMap::new()
} else {
let record_ids: Vec<RecordId> = source_ids
.iter()
.filter_map(|id| {
if id.contains(':') {
RecordId::from_str(id).ok()
} else {
Some(RecordId::from_table_key(TextContent::table_name(), id))
}
})
.collect();
let mut response = state
.db
.client
.query(
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
)
.bind(("table_name", TextContent::table_name()))
.bind(("user_id", user.id.clone()))
.bind(("record_ids", record_ids))
.await?;
let contents: Vec<SourceLabelRow> = response.take(0)?;
tracing::debug!(
source_id_count = source_ids.len(),
label_row_count = contents.len(),
"Resolved search source labels"
);
let mut labels = HashMap::new();
for content in contents {
let label = build_source_label(&content);
labels.insert(content.id.clone(), label.clone());
labels.insert(
format!("{}:{}", TextContent::table_name(), content.id),
label,
);
}
labels
};
let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
// Add chunk results
for chunk_result in search_result.chunks {
let source_label = source_label_map
.get(&chunk_result.chunk.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "text_chunk".to_string(),
score: chunk_result.score,
text_chunk: Some(TextChunkForTemplate {
id: chunk_result.chunk.id,
source_id: chunk_result.chunk.source_id,
source_label,
chunk: chunk_result.chunk.chunk,
score: chunk_result.score,
}),
knowledge_entity: None,
});
}
// Add entity results
for entity_result in search_result.entities {
let source_label = source_label_map
.get(&entity_result.entity.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(),
score: entity_result.score,
text_chunk: None,
knowledge_entity: Some(KnowledgeEntityForTemplate {
id: entity_result.entity.id,
name: entity_result.entity.name,
description: entity_result.entity.description,
entity_type: format!("{:?}", entity_result.entity.entity_type),
source_id: entity_result.entity.source_id,
source_label,
score: entity_result.score,
}),
});
}
// Sort by score descending
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
// Limit results
const TOTAL_LIMIT: usize = 10;
combined_results.truncate(TOTAL_LIMIT);
(combined_results, trimmed_query.to_string())
}
perform_search(&state, &user, actual_query).await?
} else {
(Vec::<SearchResultForTemplate>::new(), String::new())
};
@@ -345,3 +210,147 @@ pub async fn search_result_handler(
},
))
}
async fn perform_search(
state: &HtmlState,
user: &User,
query: String,
) -> Result<(Vec<SearchResultForTemplate>, String), HtmlError> {
const TOTAL_LIMIT: usize = 10;
let trimmed_query = query.trim();
if trimmed_query.is_empty() {
return Ok((Vec::new(), String::new()));
}
let config = RetrievalConfig::for_search(SearchTarget::Both);
let reranker_lease = match &state.reranker_pool {
Some(pool) => pool.checkout().await,
None => None,
};
let params = retrieval_pipeline::pipeline::StrategyParams {
db_client: &state.db,
openai_client: &state.openai_client,
embedding_provider: Some(&state.embedding_provider),
input_text: trimmed_query,
user_id: &user.id,
config,
reranker: reranker_lease,
};
let result = retrieval_pipeline::pipeline::execute(params).await?;
let search_result = match result {
StrategyOutput::Search(sr) => sr,
_ => SearchResult::new(vec![], vec![]),
};
let source_label_map = resolve_source_labels(state, user, &search_result).await?;
let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
for chunk_result in search_result.chunks {
let source_label = source_label_map
.get(&chunk_result.chunk.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "text_chunk".to_string(),
score: chunk_result.score,
text_chunk: Some(TextChunkForTemplate {
id: chunk_result.chunk.id,
source_id: chunk_result.chunk.source_id,
source_label,
chunk: chunk_result.chunk.chunk,
score: chunk_result.score,
}),
knowledge_entity: None,
});
}
for entity_result in search_result.entities {
let source_label = source_label_map
.get(&entity_result.entity.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(),
score: entity_result.score,
text_chunk: None,
knowledge_entity: Some(KnowledgeEntityForTemplate {
id: entity_result.entity.id,
name: entity_result.entity.name,
description: entity_result.entity.description,
entity_type: format!("{:?}", entity_result.entity.entity_type),
source_id: entity_result.entity.source_id,
source_label,
score: entity_result.score,
}),
});
}
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
combined_results.truncate(TOTAL_LIMIT);
Ok((combined_results, trimmed_query.to_string()))
}
async fn resolve_source_labels(
state: &HtmlState,
user: &User,
search_result: &SearchResult,
) -> Result<HashMap<String, String>, HtmlError> {
let mut source_ids = HashSet::new();
for chunk_result in &search_result.chunks {
source_ids.insert(chunk_result.chunk.source_id.clone());
}
for entity_result in &search_result.entities {
source_ids.insert(entity_result.entity.source_id.clone());
}
if source_ids.is_empty() {
return Ok(HashMap::new());
}
let record_ids: Vec<RecordId> = source_ids
.iter()
.filter_map(|id| {
if id.contains(':') {
RecordId::from_str(id).ok()
} else {
Some(RecordId::from_table_key(TextContent::table_name(), id))
}
})
.collect();
let mut response = state
.db
.client
.query(
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
)
.bind(("table_name", TextContent::table_name()))
.bind(("user_id", user.id.clone()))
.bind(("record_ids", record_ids))
.await?;
let contents: Vec<SourceLabelRow> = response.take(0)?;
tracing::debug!(
source_id_count = source_ids.len(),
label_row_count = contents.len(),
"Resolved search source labels"
);
let mut labels = HashMap::new();
for content in contents {
let label = build_source_label(&content);
labels.insert(content.id.clone(), label.clone());
labels.insert(
format!("{}:{}", TextContent::table_name(), content.id),
label,
);
}
Ok(labels)
}
+5 -2
View File
@@ -1,7 +1,10 @@
mod handlers;
use axum::{extract::FromRef, routing::get, Router};
pub use handlers::{search_result_handler, SearchParams};
#[allow(clippy::module_name_repetitions)]
pub use handlers::{
search_result_handler as result_handler, SearchParams as SearchQueryParams,
};
use crate::html_state::HtmlState;
@@ -10,5 +13,5 @@ where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new().route("/search", get(search_result_handler))
Router::new().route("/search", get(result_handler))
}
+8 -8
View File
@@ -31,8 +31,8 @@ impl Pagination {
} else {
0
};
let start_index = if page_len == 0 { 0 } else { offset + 1 };
let end_index = if page_len == 0 { 0 } else { offset + page_len };
let start_index = if page_len == 0 { 0 } else { offset.saturating_add(1) };
let end_index = if page_len == 0 { 0 } else { offset.saturating_add(page_len) };
Self {
current_page,
@@ -42,12 +42,12 @@ impl Pagination {
has_previous,
has_next,
previous_page: if has_previous {
Some(current_page - 1)
Some(current_page.saturating_sub(1))
} else {
None
},
next_page: if has_next {
Some(current_page + 1)
Some(current_page.saturating_add(1))
} else {
None
},
@@ -68,7 +68,7 @@ pub fn paginate_items<T>(
let total_pages = if total_items == 0 {
0
} else {
((total_items - 1) / per_page) + 1
total_items.saturating_sub(1).checked_div(per_page).unwrap_or(0).saturating_add(1)
};
let mut current_page = requested_page.unwrap_or(1);
@@ -84,7 +84,7 @@ pub fn paginate_items<T>(
let offset = if total_pages == 0 {
0
} else {
per_page.saturating_mul(current_page - 1)
per_page.saturating_mul(current_page.saturating_sub(1))
};
let page_items: Vec<T> = items.into_iter().skip(offset).take(per_page).collect();
@@ -136,8 +136,8 @@ mod tests {
assert_eq!(page, vec![5]);
assert_eq!(meta.current_page, 3);
assert_eq!(meta.total_pages, 3);
assert_eq!(meta.has_next, false);
assert_eq!(meta.has_previous, true);
assert!(!meta.has_next, "expected no next page");
assert!(meta.has_previous, "expected previous page");
assert_eq!(meta.start_index, 5);
assert_eq!(meta.end_index, 5);
}