From a090a8c76e7b55f7b1d992c26f1a2571a2cdecc3 Mon Sep 17 00:00:00 2001 From: Per Stark Date: Tue, 9 Dec 2025 20:35:42 +0100 Subject: [PATCH] retrieval simplfied --- common/src/error.rs | 1 + common/src/lib.rs | 2 + common/src/storage/db.rs | 2 + common/src/storage/indexes.rs | 33 +- common/src/storage/store.rs | 2 +- common/src/storage/types/analytics.rs | 3 +- common/src/storage/types/file_info.rs | 40 ++- common/src/storage/types/ingestion_payload.rs | 7 + common/src/storage/types/ingestion_task.rs | 9 + common/src/storage/types/knowledge_entity.rs | 11 + .../types/knowledge_entity_embedding.rs | 4 +- .../storage/types/knowledge_relationship.rs | 12 +- common/src/storage/types/message.rs | 3 +- common/src/storage/types/mod.rs | 8 +- common/src/storage/types/text_chunk.rs | 24 +- .../src/storage/types/text_chunk_embedding.rs | 9 +- common/src/storage/types/text_content.rs | 1 + common/src/storage/types/user.rs | 40 ++- common/src/utils/config.rs | 9 + common/src/utils/embedding.rs | 64 ++-- common/src/utils/template_engine.rs | 1 + evaluations/src/args.rs | 19 +- evaluations/src/corpus/mod.rs | 12 +- evaluations/src/corpus/orchestrator.rs | 22 +- evaluations/src/corpus/store.rs | 13 +- evaluations/src/datasets/mod.rs | 9 +- evaluations/src/db_helpers.rs | 15 +- evaluations/src/namespace.rs | 3 +- evaluations/src/pipeline/context.rs | 2 +- evaluations/src/pipeline/mod.rs | 4 +- .../src/pipeline/stages/prepare_corpus.rs | 2 +- .../src/pipeline/stages/prepare_namespace.rs | 2 +- .../src/pipeline/stages/run_queries.rs | 8 +- evaluations/src/slice.rs | 4 +- html-router/src/html_state.rs | 2 +- .../routes/chat/message_response_stream.rs | 25 +- ingestion-pipeline/src/lib.rs | 5 + ingestion-pipeline/src/pipeline/config.rs | 11 +- .../src/pipeline/enrichment_result.rs | 19 +- ingestion-pipeline/src/pipeline/mod.rs | 31 +- ingestion-pipeline/src/pipeline/services.rs | 6 +- ingestion-pipeline/src/pipeline/stages/mod.rs | 61 ++-- .../src/utils/file_text_extraction.rs | 2 +- ingestion-pipeline/src/utils/pdf_ingestion.rs | 12 +- .../src/utils/url_text_retrieval.rs | 15 +- main/src/main.rs | 1 - main/src/worker.rs | 2 +- retrieval-pipeline/src/answer_retrieval.rs | 18 ++ retrieval-pipeline/src/fts.rs | 268 ---------------- retrieval-pipeline/src/graph.rs | 197 +----------- retrieval-pipeline/src/lib.rs | 148 ++------- retrieval-pipeline/src/pipeline/config.rs | 25 +- retrieval-pipeline/src/pipeline/mod.rs | 96 +----- retrieval-pipeline/src/pipeline/stages/mod.rs | 293 ++---------------- retrieval-pipeline/src/pipeline/strategies.rs | 40 +-- 55 files changed, 469 insertions(+), 1208 deletions(-) delete mode 100644 retrieval-pipeline/src/fts.rs diff --git a/common/src/error.rs b/common/src/error.rs index 430b504..873921a 100644 --- a/common/src/error.rs +++ b/common/src/error.rs @@ -5,6 +5,7 @@ use tokio::task::JoinError; use crate::storage::types::file_info::FileError; // Core internal errors +#[allow(clippy::module_name_repetitions)] #[derive(Error, Debug)] pub enum AppError { #[error("Database error: {0}")] diff --git a/common/src/lib.rs b/common/src/lib.rs index d04944f..ae9d7b2 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,3 +1,5 @@ +#![allow(clippy::doc_markdown)] +//! Shared utilities and storage helpers for the workspace crates. pub mod error; pub mod storage; pub mod utils; diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index 9fa05f4..83d1496 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -13,12 +13,14 @@ use surrealdb::{ use surrealdb_migrations::MigrationRunner; use tracing::debug; +/// Embedded SurrealDB migration directory packaged with the crate. static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/"); #[derive(Clone)] pub struct SurrealDbClient { pub client: Surreal, } +#[allow(clippy::module_name_repetitions)] pub trait ProvidesDb { fn db(&self) -> &Arc; } diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs index 0be82ce..bf95cf4 100644 --- a/common/src/storage/indexes.rs +++ b/common/src/storage/indexes.rs @@ -1,3 +1,13 @@ +#![allow( + clippy::missing_docs_in_private_items, + clippy::module_name_repetitions, + clippy::items_after_statements, + clippy::arithmetic_side_effects, + clippy::cast_precision_loss, + clippy::redundant_closure_for_method_calls, + clippy::single_match_else, + clippy::uninlined_format_args +)] use std::time::Duration; use anyhow::{Context, Result}; @@ -234,12 +244,25 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> { analyzer = FTS_ANALYZER_NAME ); - db.client + let res = db + .client .query(fallback_query) .await - .context("creating fallback FTS analyzer")? - .check() - .context("failed to create fallback FTS analyzer")?; + .context("creating fallback FTS analyzer")?; + + if let Err(err) = res.check() { + warn!( + error = %err, + "Fallback analyzer creation failed; FTS will run without snowball/ascii analyzer ({})", + FTS_ANALYZER_NAME + ); + return Err(err).context("failed to create fallback FTS analyzer"); + } + + warn!( + "Snowball analyzer unavailable; using fallback analyzer ({}) with lowercase+ascii only", + FTS_ANALYZER_NAME + ); Ok(()) } @@ -466,7 +489,7 @@ async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result { let rows: Vec = response .take(0) .context("failed to deserialize count() response")?; - Ok(rows.first().map(|r| r.count).unwrap_or(0)) + Ok(rows.first().map_or(0, |r| r.count)) } async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result { diff --git a/common/src/storage/store.rs b/common/src/storage/store.rs index 9038ecb..6ae5457 100644 --- a/common/src/storage/store.rs +++ b/common/src/storage/store.rs @@ -183,7 +183,7 @@ impl StorageManager { while current.starts_with(base) && current.as_path() != base.as_path() { match tokio::fs::remove_dir(¤t).await { - Ok(_) => {} + Ok(()) => {} Err(err) => match err.kind() { ErrorKind::NotFound => {} ErrorKind::DirectoryNotEmpty => break, diff --git a/common/src/storage/types/analytics.rs b/common/src/storage/types/analytics.rs index 712e89a..a841a2a 100644 --- a/common/src/storage/types/analytics.rs +++ b/common/src/storage/types/analytics.rs @@ -71,6 +71,7 @@ impl Analytics { // We need to use a direct query for COUNT aggregation #[derive(Debug, Deserialize)] struct CountResult { + /// Total user count. count: i64, } @@ -81,7 +82,7 @@ impl Analytics { .await? .take(0)?; - Ok(result.map(|r| r.count).unwrap_or(0)) + Ok(result.map_or(0, |r| r.count)) } } diff --git a/common/src/storage/types/file_info.rs b/common/src/storage/types/file_info.rs index 73371d9..4e8f0d9 100644 --- a/common/src/storage/types/file_info.rs +++ b/common/src/storage/types/file_info.rs @@ -3,12 +3,10 @@ use bytes; use mime_guess::from_path; use object_store::Error as ObjectStoreError; use sha2::{Digest, Sha256}; -use std::{ - io::{BufReader, Read}, - path::Path, -}; +use std::{io::{BufReader, Read}, path::Path}; use tempfile::NamedTempFile; use thiserror::Error; +use tokio::task; use tracing::info; use uuid::Uuid; @@ -71,21 +69,29 @@ impl FileInfo { /// /// # Returns /// * `Result` - The SHA256 hash as a hex string or an error. + #[allow(clippy::indexing_slicing)] async fn get_sha(file: &NamedTempFile) -> Result { - let mut reader = BufReader::new(file.as_file()); - let mut hasher = Sha256::new(); - let mut buffer = [0u8; 8192]; // 8KB buffer + let mut file_clone = file.as_file().try_clone()?; - loop { - let n = reader.read(&mut buffer)?; - if n == 0 { - break; + let digest = task::spawn_blocking(move || -> Result<_, std::io::Error> { + let mut reader = BufReader::new(&mut file_clone); + let mut hasher = Sha256::new(); + let mut buffer = [0u8; 8192]; // 8KB buffer + + loop { + let n = reader.read(&mut buffer)?; + if n == 0 { + break; + } + hasher.update(&buffer[..n]); } - hasher.update(&buffer[..n]); - } - let digest = hasher.finalize(); - Ok(format!("{:x}", digest)) + Ok::<_, std::io::Error>(hasher.finalize()) + }) + .await + .map_err(std::io::Error::other)??; + + Ok(format!("{digest:x}")) } /// Sanitizes the file name to prevent security vulnerabilities like directory traversal. @@ -103,7 +109,7 @@ impl FileInfo { } }) .collect(); - format!("{}{}", sanitized_name, ext) + format!("{sanitized_name}{ext}") } else { // No extension file_name @@ -292,7 +298,7 @@ impl FileInfo { storage: &StorageManager, ) -> Result { // Logical object location relative to the store root - let location = format!("{}/{}/{}", user_id, uuid, file_name); + let location = format!("{user_id}/{uuid}/{file_name}"); info!("Persisting to object location: {}", location); let bytes = tokio::fs::read(file.path()).await?; diff --git a/common/src/storage/types/ingestion_payload.rs b/common/src/storage/types/ingestion_payload.rs index a1b19fd..828a805 100644 --- a/common/src/storage/types/ingestion_payload.rs +++ b/common/src/storage/types/ingestion_payload.rs @@ -1,3 +1,9 @@ +#![allow( + clippy::result_large_err, + clippy::needless_pass_by_value, + clippy::implicit_clone, + clippy::semicolon_if_nothing_returned +)] use crate::{error::AppError, storage::types::file_info::FileInfo}; use serde::{Deserialize, Serialize}; use tracing::info; @@ -38,6 +44,7 @@ impl IngestionPayload { /// # Returns /// * `Result, AppError>` - On success, returns a vector of ingress objects /// (one per file/content type). On failure, returns an `AppError`. + #[allow(clippy::similar_names)] pub fn create_ingestion_payload( content: Option, context: String, diff --git a/common/src/storage/types/ingestion_task.rs b/common/src/storage/types/ingestion_task.rs index fa1723c..26d5520 100644 --- a/common/src/storage/types/ingestion_task.rs +++ b/common/src/storage/types/ingestion_task.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::cast_possible_wrap, + clippy::items_after_statements, + clippy::arithmetic_side_effects, + clippy::cast_sign_loss, + clippy::missing_docs_in_private_items, + clippy::trivially_copy_pass_by_ref, + clippy::expect_used +)] use std::time::Duration; use chrono::Duration as ChronoDuration; diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 2a13901..9205792 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -1,3 +1,14 @@ +#![allow( + clippy::missing_docs_in_private_items, + clippy::module_name_repetitions, + clippy::match_same_arms, + clippy::format_push_string, + clippy::uninlined_format_args, + clippy::explicit_iter_loop, + clippy::items_after_statements, + clippy::get_first, + clippy::redundant_closure_for_method_calls +)] use std::collections::HashMap; use crate::{ diff --git a/common/src/storage/types/knowledge_entity_embedding.rs b/common/src/storage/types/knowledge_entity_embedding.rs index ad4ccfa..6e92f62 100644 --- a/common/src/storage/types/knowledge_entity_embedding.rs +++ b/common/src/storage/types/knowledge_entity_embedding.rs @@ -72,7 +72,7 @@ impl KnowledgeEntityEmbedding { return Ok(HashMap::new()); } - let ids_list: Vec = entity_ids.iter().cloned().collect(); + let ids_list: Vec = entity_ids.to_vec(); let query = format!( "SELECT * FROM {} WHERE entity_id INSIDE $entity_ids", @@ -110,6 +110,7 @@ impl KnowledgeEntityEmbedding { } /// Delete embeddings by source_id (via joining to knowledge_entity table) + #[allow(clippy::items_after_statements)] pub async fn delete_by_source_id( source_id: &str, db: &SurrealDbClient, @@ -121,6 +122,7 @@ impl KnowledgeEntityEmbedding { .bind(("source_id", source_id.to_owned())) .await .map_err(AppError::Database)?; + #[allow(clippy::missing_docs_in_private_items)] #[derive(Deserialize)] struct IdRow { id: RecordId, diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index 7df01b3..6b41c4d 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -65,8 +65,7 @@ impl KnowledgeRelationship { db_client: &SurrealDbClient, ) -> Result<(), AppError> { let query = format!( - "DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'", - source_id + "DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'" ); db_client.query(query).await?; @@ -81,15 +80,14 @@ impl KnowledgeRelationship { ) -> Result<(), AppError> { let mut authorized_result = db_client .query(format!( - "SELECT * FROM relates_to WHERE id = relates_to:`{}` AND metadata.user_id = '{}'", - id, user_id + "SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'" )) .await?; let authorized: Vec = authorized_result.take(0).unwrap_or_default(); if authorized.is_empty() { let mut exists_result = db_client - .query(format!("SELECT * FROM relates_to:`{}`", id)) + .query(format!("SELECT * FROM relates_to:`{id}`")) .await?; let existing: Option = exists_result.take(0)?; @@ -98,11 +96,11 @@ impl KnowledgeRelationship { "Not authorized to delete relationship".into(), )) } else { - Err(AppError::NotFound(format!("Relationship {} not found", id))) + Err(AppError::NotFound(format!("Relationship {id} not found"))) } } else { db_client - .query(format!("DELETE relates_to:`{}`", id)) + .query(format!("DELETE relates_to:`{id}`")) .await?; Ok(()) } diff --git a/common/src/storage/types/message.rs b/common/src/storage/types/message.rs index 3ddf937..4070f79 100644 --- a/common/src/storage/types/message.rs +++ b/common/src/storage/types/message.rs @@ -1,3 +1,4 @@ +#![allow(clippy::module_name_repetitions)] use uuid::Uuid; use crate::stored_object; @@ -56,7 +57,7 @@ impl fmt::Display for Message { pub fn format_history(history: &[Message]) -> String { history .iter() - .map(|msg| format!("{}", msg)) + .map(|msg| format!("{msg}")) .collect::>() .join("\n") } diff --git a/common/src/storage/types/mod.rs b/common/src/storage/types/mod.rs index 4f053ee..8434254 100644 --- a/common/src/storage/types/mod.rs +++ b/common/src/storage/types/mod.rs @@ -1,3 +1,4 @@ +#![allow(clippy::unsafe_derive_deserialize)] use serde::{Deserialize, Serialize}; pub mod analytics; pub mod conversation; @@ -23,7 +24,7 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> { #[macro_export] macro_rules! stored_object { - ($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => { + ($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => { use serde::{Deserialize, Deserializer, Serialize}; use surrealdb::sql::Thing; use $crate::storage::types::StoredObject; @@ -87,6 +88,7 @@ macro_rules! stored_object { } #[allow(dead_code)] + #[allow(clippy::ref_option)] fn serialize_option_datetime( date: &Option>, serializer: S, @@ -102,6 +104,7 @@ macro_rules! stored_object { } #[allow(dead_code)] + #[allow(clippy::ref_option)] fn deserialize_option_datetime<'de, D>( deserializer: D, ) -> Result>, D::Error> @@ -113,6 +116,7 @@ macro_rules! stored_object { } + $(#[$struct_attr])* #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct $name { #[serde(deserialize_with = "deserialize_flexible_id")] @@ -121,7 +125,7 @@ macro_rules! stored_object { pub created_at: DateTime, #[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)] pub updated_at: DateTime, - $( $(#[$attr])* pub $field: $ty),* + $( $(#[$field_attr])* pub $field: $ty),* } impl StoredObject for $name { diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 219220f..faf1bd8 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -1,4 +1,6 @@ +#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)] use std::collections::HashMap; +use std::fmt::Write; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; @@ -18,6 +20,7 @@ stored_object!(TextChunk, "text_chunk", { }); /// Search result including hydrated chunk. +#[allow(clippy::module_name_repetitions)] #[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)] pub struct TextChunkSearchResult { pub chunk: TextChunk, @@ -98,6 +101,7 @@ impl TextChunk { db: &SurrealDbClient, user_id: &str, ) -> Result, AppError> { + #[allow(clippy::missing_docs_in_private_items)] #[derive(Deserialize)] struct Row { chunk_id: TextChunk, @@ -160,6 +164,8 @@ impl TextChunk { score: f32, } + let limit = i64::try_from(take).unwrap_or(i64::MAX); + let sql = format!( r#" SELECT @@ -183,7 +189,7 @@ impl TextChunk { .query(&sql) .bind(("terms", terms.to_owned())) .bind(("user_id", user_id.to_owned())) - .bind(("limit", take as i64)) + .bind(("limit", limit)) .await .map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?; @@ -245,7 +251,7 @@ impl TextChunk { // Generate all new embeddings in memory let mut new_embeddings: HashMap, String, String)> = HashMap::new(); info!("Generating new embeddings for all chunks..."); - for chunk in all_chunks.iter() { + for chunk in &all_chunks { let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3); let embedding = Retry::spawn(retry_strategy, || { @@ -283,12 +289,13 @@ impl TextChunk { "[{}]", embedding .iter() - .map(|f| f.to_string()) + .map(ToString::to_string) .collect::>() .join(",") ); // Use the chunk id as the embedding record id to keep a 1:1 mapping - transaction_query.push_str(&format!( + write!( + &mut transaction_query, "UPSERT type::thing('text_chunk_embedding', '{id}') SET \ chunk_id = type::thing('text_chunk', '{id}'), \ source_id = '{source_id}', \ @@ -300,13 +307,16 @@ impl TextChunk { embedding = embedding_str, user_id = user_id, source_id = source_id - )); + ) + .map_err(|e| AppError::InternalError(e.to_string()))?; } - transaction_query.push_str(&format!( + write!( + &mut transaction_query, "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", new_dimensions - )); + ) + .map_err(|e| AppError::InternalError(e.to_string()))?; transaction_query.push_str("COMMIT TRANSACTION;"); diff --git a/common/src/storage/types/text_chunk_embedding.rs b/common/src/storage/types/text_chunk_embedding.rs index 771b9ca..734eadd 100644 --- a/common/src/storage/types/text_chunk_embedding.rs +++ b/common/src/storage/types/text_chunk_embedding.rs @@ -110,6 +110,11 @@ impl TextChunkEmbedding { source_id: &str, db: &SurrealDbClient, ) -> Result<(), AppError> { + #[allow(clippy::missing_docs_in_private_items)] + #[derive(Deserialize)] + struct IdRow { + id: RecordId, + } let ids_query = format!( "SELECT id FROM {} WHERE source_id = $source_id", TextChunk::table_name() @@ -120,10 +125,6 @@ impl TextChunkEmbedding { .bind(("source_id", source_id.to_owned())) .await .map_err(AppError::Database)?; - #[derive(Deserialize)] - struct IdRow { - id: RecordId, - } let ids: Vec = res.take(0).map_err(AppError::Database)?; if ids.is_empty() { diff --git a/common/src/storage/types/text_content.rs b/common/src/storage/types/text_content.rs index ef3f9aa..fe11014 100644 --- a/common/src/storage/types/text_content.rs +++ b/common/src/storage/types/text_content.rs @@ -5,6 +5,7 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use super::file_info::FileInfo; +#[allow(clippy::module_name_repetitions)] #[derive(Debug, Deserialize, Serialize)] pub struct TextContentSearchResult { #[serde(deserialize_with = "deserialize_flexible_id")] diff --git a/common/src/storage/types/user.rs b/common/src/storage/types/user.rs index 4d1f121..5e0633a 100644 --- a/common/src/storage/types/user.rs +++ b/common/src/storage/types/user.rs @@ -1,4 +1,5 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; +use anyhow::anyhow; use async_trait::async_trait; use axum_session_auth::Authentication; use chrono_tz::Tz; @@ -17,12 +18,16 @@ use super::{ use chrono::Duration; use futures::try_join; +/// Result row for returning user category. #[derive(Deserialize)] pub struct CategoryResponse { + /// Category name tied to the user. category: String, } -stored_object!(User, "user", { +stored_object!( + #[allow(clippy::unsafe_derive_deserialize)] + User, "user", { email: String, password: String, anonymous: bool, @@ -35,11 +40,11 @@ stored_object!(User, "user", { #[async_trait] impl Authentication> for User { async fn load_user(userid: String, db: Option<&Surreal>) -> Result { - let db = db.unwrap(); + let db = db.ok_or_else(|| anyhow!("Database handle missing"))?; Ok(db .select((Self::table_name(), userid.as_str())) .await? - .unwrap()) + .ok_or_else(|| anyhow!("User {userid} not found"))?) } fn is_authenticated(&self) -> bool { @@ -55,14 +60,14 @@ impl Authentication> for User { } } +/// Ensures a timezone string parses, defaulting to UTC when invalid. fn validate_timezone(input: &str) -> String { - match input.parse::() { - Ok(_) => input.to_owned(), - Err(_) => { - tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input); - "UTC".to_owned() - } + if input.parse::().is_ok() { + return input.to_owned(); } + + tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input); + "UTC".to_owned() } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -77,12 +82,15 @@ pub struct DashboardStats { pub new_text_chunks_week: i64, } +/// Helper for aggregating `SurrealDB` count responses. #[derive(Deserialize)] struct CountResult { + /// Row count returned by the query. count: i64, } impl User { + /// Counts all objects of a given type belonging to the user. async fn count_total( db: &SurrealDbClient, user_id: &str, @@ -94,9 +102,10 @@ impl User { .bind(("user_id", user_id.to_string())) .await? .take(0)?; - Ok(result.map(|r| r.count).unwrap_or(0)) + Ok(result.map_or(0, |r| r.count)) } + /// Counts objects of a given type created after a specific timestamp. async fn count_since( db: &SurrealDbClient, user_id: &str, @@ -112,14 +121,16 @@ impl User { .bind(("since", surrealdb::Datetime::from(since))) .await? .take(0)?; - Ok(result.map(|r| r.count).unwrap_or(0)) + Ok(result.map_or(0, |r| r.count)) } pub async fn get_dashboard_stats( user_id: &str, db: &SurrealDbClient, ) -> Result { - let since = chrono::Utc::now() - Duration::days(7); + let since = chrono::Utc::now() + .checked_sub_signed(Duration::days(7)) + .unwrap_or_else(chrono::Utc::now); let ( total_documents, @@ -261,7 +272,7 @@ impl User { pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result { // Generate a secure random API key - let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", "")); + let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace('-', "")); // Update the user record with the new API key let user: Option = db @@ -341,6 +352,7 @@ impl User { ) -> Result, AppError> { #[derive(Deserialize)] struct EntityTypeResponse { + /// Raw entity type value from the database. entity_type: String, } @@ -358,7 +370,7 @@ impl User { .into_iter() .map(|item| { let normalized = KnowledgeEntityType::from(item.entity_type); - format!("{:?}", normalized) + format!("{normalized:?}") }) .collect(); diff --git a/common/src/utils/config.rs b/common/src/utils/config.rs index 01b08ce..b1329a2 100644 --- a/common/src/utils/config.rs +++ b/common/src/utils/config.rs @@ -9,6 +9,7 @@ pub enum StorageKind { Memory, } +/// Default storage backend when none is configured. fn default_storage_kind() -> StorageKind { StorageKind::Local } @@ -23,10 +24,13 @@ pub enum PdfIngestMode { LlmFirst, } +/// Default PDF ingestion mode when unset. fn default_pdf_ingest_mode() -> PdfIngestMode { PdfIngestMode::LlmFirst } +/// Application configuration loaded from files and environment variables. +#[allow(clippy::module_name_repetitions)] #[derive(Clone, Deserialize, Debug)] pub struct AppConfig { pub openai_api_key: String, @@ -58,14 +62,17 @@ pub struct AppConfig { pub retrieval_strategy: Option, } +/// Default data directory for persisted assets. fn default_data_dir() -> String { "./data".to_string() } +/// Default base URL used for OpenAI-compatible APIs. fn default_base_url() -> String { "https://api.openai.com/v1".to_string() } +/// Whether reranking is enabled by default. fn default_reranking_enabled() -> bool { false } @@ -124,6 +131,8 @@ impl Default for AppConfig { } } +/// Loads the application configuration from the environment and optional config file. +#[allow(clippy::module_name_repetitions)] pub fn get_config() -> Result { ensure_ort_path(); diff --git a/common/src/utils/embedding.rs b/common/src/utils/embedding.rs index 88a8a5b..b7813b8 100644 --- a/common/src/utils/embedding.rs +++ b/common/src/utils/embedding.rs @@ -16,19 +16,16 @@ use crate::{ storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, }; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Supported embedding backends. +#[allow(clippy::module_name_repetitions)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum EmbeddingBackend { OpenAI, + #[default] FastEmbed, Hashed, } -impl Default for EmbeddingBackend { - fn default() -> Self { - Self::FastEmbed - } -} - impl std::str::FromStr for EmbeddingBackend { type Err = anyhow::Error; @@ -44,24 +41,38 @@ impl std::str::FromStr for EmbeddingBackend { } } +/// Wrapper around the chosen embedding backend. +#[allow(clippy::module_name_repetitions)] #[derive(Clone)] pub struct EmbeddingProvider { + /// Concrete backend implementation. inner: EmbeddingInner, } +/// Concrete embedding implementations. #[derive(Clone)] enum EmbeddingInner { + /// Uses an `OpenAI`-compatible API. OpenAI { + /// Client used to issue embedding requests. client: Arc>, + /// Model identifier for the API. model: String, + /// Expected output dimensions. dimensions: u32, }, + /// Generates deterministic hashed embeddings without external calls. Hashed { + /// Output vector length. dimension: usize, }, + /// Uses `FastEmbed` running locally. FastEmbed { + /// Shared `FastEmbed` model. model: Arc>, + /// Model metadata used for info logging. model_name: EmbeddingModel, + /// Output vector length. dimension: usize, }, } @@ -77,8 +88,9 @@ impl EmbeddingProvider { pub fn dimension(&self) -> usize { match &self.inner { - EmbeddingInner::Hashed { dimension } => *dimension, - EmbeddingInner::FastEmbed { dimension, .. } => *dimension, + EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => { + *dimension + } EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize, } } @@ -172,12 +184,12 @@ impl EmbeddingProvider { } } - pub async fn new_openai( + pub fn new_openai( client: Arc>, model: String, dimensions: u32, ) -> Result { - Ok(EmbeddingProvider { + Ok(Self { inner: EmbeddingInner::OpenAI { client, model, @@ -226,6 +238,7 @@ impl EmbeddingProvider { } // Helper functions for hashed embeddings +/// Generates a hashed embedding vector without external dependencies. fn hashed_embedding(text: &str, dimension: usize) -> Vec { let dim = dimension.max(1); let mut vector = vec![0.0f32; dim]; @@ -233,15 +246,11 @@ fn hashed_embedding(text: &str, dimension: usize) -> Vec { return vector; } - let mut token_count = 0f32; for token in tokens(text) { - token_count += 1.0; let idx = bucket(&token, dim); - vector[idx] += 1.0; - } - - if token_count == 0.0 { - return vector; + if let Some(slot) = vector.get_mut(idx) { + *slot += 1.0; + } } let norm = vector.iter().map(|v| v * v).sum::().sqrt(); @@ -254,16 +263,22 @@ fn hashed_embedding(text: &str, dimension: usize) -> Vec { vector } +/// Tokenizes the text into alphanumeric lowercase tokens. fn tokens(text: &str) -> impl Iterator + '_ { text.split(|c: char| !c.is_ascii_alphanumeric()) .filter(|token| !token.is_empty()) - .map(|token| token.to_ascii_lowercase()) + .map(str::to_ascii_lowercase) } +/// Buckets a token into the hashed embedding vector. +#[allow(clippy::arithmetic_side_effects)] fn bucket(token: &str, dimension: usize) -> usize { + let safe_dimension = dimension.max(1); let mut hasher = DefaultHasher::new(); token.hash(&mut hasher); - (hasher.finish() as usize) % dimension + usize::try_from(hasher.finish()) + .unwrap_or_default() + % safe_dimension } // Backward compatibility function @@ -274,15 +289,15 @@ pub async fn generate_embedding_with_provider( provider.embed(input).await.map_err(AppError::from) } -/// Generates an embedding vector for the given input text using OpenAI's embedding model. +/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model. /// /// This function takes a text input and converts it into a numerical vector representation (embedding) -/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity +/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity /// comparisons, vector search, and other natural language processing tasks. /// /// # Arguments /// -/// * `client`: The OpenAI client instance used to make API requests. +/// * `client`: The `OpenAI` client instance used to make API requests. /// * `input`: The text string to generate embeddings for. /// /// # Returns @@ -294,9 +309,10 @@ pub async fn generate_embedding_with_provider( /// # Errors /// /// This function can return a `AppError` in the following cases: -/// * If the OpenAI API request fails +/// * If the `OpenAI` API request fails /// * If the request building fails /// * If no embedding data is received in the response +#[allow(clippy::module_name_repetitions)] pub async fn generate_embedding( client: &async_openai::Client, input: &str, diff --git a/common/src/utils/template_engine.rs b/common/src/utils/template_engine.rs index 7c3a612..634e928 100644 --- a/common/src/utils/template_engine.rs +++ b/common/src/utils/template_engine.rs @@ -4,6 +4,7 @@ pub use minijinja_contrib; pub use minijinja_embed; use std::sync::Arc; +#[allow(clippy::module_name_repetitions)] pub trait ProvidesTemplateEngine { fn template_engine(&self) -> &Arc; } diff --git a/evaluations/src/args.rs b/evaluations/src/args.rs index 9f6838c..8bf6fea 100644 --- a/evaluations/src/args.rs +++ b/evaluations/src/args.rs @@ -28,19 +28,14 @@ fn default_ingestion_cache_dir() -> PathBuf { pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025; -#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)] #[value(rename_all = "lowercase")] pub enum EmbeddingBackend { Hashed, + #[default] FastEmbed, } -impl Default for EmbeddingBackend { - fn default() -> Self { - Self::FastEmbed - } -} - impl std::fmt::Display for EmbeddingBackend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -109,7 +104,7 @@ pub struct RetrievalSettings { pub require_verified_chunks: bool, /// Select the retrieval pipeline strategy - #[arg(long, default_value_t = RetrievalStrategy::Initial)] + #[arg(long, default_value_t = RetrievalStrategy::Default)] pub strategy: RetrievalStrategy, } @@ -130,7 +125,7 @@ impl Default for RetrievalSettings { chunk_rrf_use_vector: None, chunk_rrf_use_fts: None, require_verified_chunks: true, - strategy: RetrievalStrategy::Initial, + strategy: RetrievalStrategy::Default, } } } @@ -378,11 +373,7 @@ impl Config { self.summary_sample = self.sample.max(1); // Handle retrieval settings - if self.llm_mode { - self.retrieval.require_verified_chunks = false; - } else { - self.retrieval.require_verified_chunks = true; - } + self.retrieval.require_verified_chunks = !self.llm_mode; if self.dataset == DatasetKind::Beir { self.negative_multiplier = 9.0; diff --git a/evaluations/src/corpus/mod.rs b/evaluations/src/corpus/mod.rs index 3726307..a64ad0c 100644 --- a/evaluations/src/corpus/mod.rs +++ b/evaluations/src/corpus/mod.rs @@ -14,13 +14,13 @@ pub use store::{ }; pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig { - let mut tuning = ingestion_pipeline::IngestionTuning::default(); - tuning.chunk_min_tokens = config.ingest.ingest_chunk_min_tokens; - tuning.chunk_max_tokens = config.ingest.ingest_chunk_max_tokens; - tuning.chunk_overlap_tokens = config.ingest.ingest_chunk_overlap_tokens; - ingestion_pipeline::IngestionConfig { - tuning, + tuning: ingestion_pipeline::IngestionTuning { + chunk_min_tokens: config.ingest.ingest_chunk_min_tokens, + chunk_max_tokens: config.ingest.ingest_chunk_max_tokens, + chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens, + ..Default::default() + }, chunk_only: config.ingest.ingest_chunks_only, } } diff --git a/evaluations/src/corpus/orchestrator.rs b/evaluations/src/corpus/orchestrator.rs index 93656e8..66004f7 100644 --- a/evaluations/src/corpus/orchestrator.rs +++ b/evaluations/src/corpus/orchestrator.rs @@ -106,6 +106,7 @@ struct IngestionStats { negative_ingested: usize, } +#[allow(clippy::too_many_arguments)] pub async fn ensure_corpus( dataset: &ConvertedDataset, slice: &ResolvedSlice<'_>, @@ -337,11 +338,9 @@ pub async fn ensure_corpus( }); } - for record in &mut records { - if let Some(ref mut entry) = record { - if entry.dirty { - store.persist(&entry.shard)?; - } + for entry in records.iter_mut().flatten() { + if entry.dirty { + store.persist(&entry.shard)?; } } @@ -403,6 +402,7 @@ pub async fn ensure_corpus( Ok(handle) } +#[allow(clippy::too_many_arguments)] async fn ingest_paragraph_batch( dataset: &ConvertedDataset, targets: &[IngestRequest<'_>], @@ -430,8 +430,10 @@ async fn ingest_paragraph_batch( .await .context("applying migrations for ingestion")?; - let mut app_config = AppConfig::default(); - app_config.storage = StorageKind::Memory; + let app_config = AppConfig { + storage: StorageKind::Memory, + ..Default::default() + }; let backend: DynStore = Arc::new(InMemory::new()); let storage = StorageManager::with_backend(backend, StorageKind::Memory); @@ -444,8 +446,7 @@ async fn ingest_paragraph_batch( storage, embedding.clone(), pipeline_config, - ) - .await?; + )?; let pipeline = Arc::new(pipeline); let mut shards = Vec::with_capacity(targets.len()); @@ -454,7 +455,7 @@ async fn ingest_paragraph_batch( info!( batch = batch_index, batch_size = batch.len(), - total_batches = (targets.len() + batch_size - 1) / batch_size, + total_batches = targets.len().div_ceil(batch_size), "Ingesting paragraph batch" ); let model_clone = embedding_model.clone(); @@ -486,6 +487,7 @@ async fn ingest_paragraph_batch( Ok(shards) } +#[allow(clippy::too_many_arguments)] async fn ingest_single_paragraph( pipeline: Arc, request: IngestRequest<'_>, diff --git a/evaluations/src/corpus/store.rs b/evaluations/src/corpus/store.rs index 14061b0..0bd7f02 100644 --- a/evaluations/src/corpus/store.rs +++ b/evaluations/src/corpus/store.rs @@ -481,6 +481,7 @@ impl ParagraphShardStore { } impl ParagraphShard { + #[allow(clippy::too_many_arguments)] pub fn new( paragraph: &ConvertedParagraph, shard_path: String, @@ -674,10 +675,8 @@ async fn execute_batched_inserts( let slice = &batches[start..group_end]; let mut query = db.client.query("BEGIN TRANSACTION;"); - let mut bind_index = 0usize; - for batch in slice { + for (bind_index, batch) in slice.iter().enumerate() { let name = format!("{prefix}{bind_index}"); - bind_index += 1; query = query .query(format!("{} ${};", statement.as_ref(), name)) .bind((name, batch.items.clone())); @@ -702,7 +701,7 @@ async fn execute_batched_inserts( pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> { let batches = build_manifest_batches(manifest).context("preparing manifest batches")?; - let result = (|| async { + let result = async { execute_batched_inserts( db, format!("INSERT INTO {}", TextContent::table_name()), @@ -752,7 +751,7 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife .await?; Ok(()) - })() + } .await; if result.is_err() { @@ -778,7 +777,6 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife #[cfg(test)] mod tests { use super::*; - use crate::db_helpers::change_embedding_length_in_hnsw_indexes; use chrono::Utc; use common::storage::types::knowledge_entity::KnowledgeEntityType; use uuid::Uuid; @@ -905,9 +903,6 @@ mod tests { db.apply_migrations() .await .expect("apply migrations for memory db"); - change_embedding_length_in_hnsw_indexes(&db, 3) - .await - .expect("set embedding index dimension for test"); let manifest = build_manifest(); seed_manifest_into_db(&db, &manifest) diff --git a/evaluations/src/datasets/mod.rs b/evaluations/src/datasets/mod.rs index 108c36f..353b567 100644 --- a/evaluations/src/datasets/mod.rs +++ b/evaluations/src/datasets/mod.rs @@ -245,8 +245,9 @@ fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> { catalog.dataset(kind.id()) } -#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)] pub enum DatasetKind { + #[default] SquadV2, NaturalQuestions, Beir, @@ -368,12 +369,6 @@ impl std::fmt::Display for DatasetKind { } } -impl Default for DatasetKind { - fn default() -> Self { - Self::SquadV2 - } -} - impl FromStr for DatasetKind { type Err = anyhow::Error; diff --git a/evaluations/src/db_helpers.rs b/evaluations/src/db_helpers.rs index 2f56937..c703631 100644 --- a/evaluations/src/db_helpers.rs +++ b/evaluations/src/db_helpers.rs @@ -36,13 +36,14 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s Ok(()) } -// Test helper to force index dimension change -pub async fn change_embedding_length_in_hnsw_indexes( - db: &SurrealDbClient, - dimension: usize, -) -> Result<()> { - recreate_indexes(db, dimension).await -} +// // Test helper to force index dimension change +// #[allow(dead_code)] +// pub async fn change_embedding_length_in_hnsw_indexes( +// db: &SurrealDbClient, +// dimension: usize, +// ) -> Result<()> { +// recreate_indexes(db, dimension).await +// } #[cfg(test)] mod tests { diff --git a/evaluations/src/namespace.rs b/evaluations/src/namespace.rs index 8eaa856..07c0f99 100644 --- a/evaluations/src/namespace.rs +++ b/evaluations/src/namespace.rs @@ -86,6 +86,7 @@ pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result { } /// Determine if we can reuse an existing namespace based on cached state. +#[allow(clippy::too_many_arguments)] pub(crate) async fn can_reuse_namespace( db: &SurrealDbClient, descriptor: &snapshot::Descriptor, @@ -213,7 +214,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result { timezone: "UTC".to_string(), }; - if let Some(existing) = db.get_item::(&user.get_id()).await? { + if let Some(existing) = db.get_item::(user.get_id()).await? { return Ok(existing); } diff --git a/evaluations/src/pipeline/context.rs b/evaluations/src/pipeline/context.rs index 99c0eef..390a724 100644 --- a/evaluations/src/pipeline/context.rs +++ b/evaluations/src/pipeline/context.rs @@ -154,7 +154,7 @@ impl<'a> EvaluationContext<'a> { } pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) { - let elapsed = duration.as_millis() as u128; + let elapsed = duration.as_millis(); match stage { EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed, EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed, diff --git a/evaluations/src/pipeline/mod.rs b/evaluations/src/pipeline/mod.rs index f563346..6980557 100644 --- a/evaluations/src/pipeline/mod.rs +++ b/evaluations/src/pipeline/mod.rs @@ -21,9 +21,7 @@ pub async fn run_evaluation( let machine = stages::prepare_namespace(machine, &mut ctx).await?; let machine = stages::run_queries(machine, &mut ctx).await?; let machine = stages::summarize(machine, &mut ctx).await?; - let machine = stages::finalize(machine, &mut ctx).await?; - - drop(machine); + let _ = stages::finalize(machine, &mut ctx).await?; Ok(ctx.into_summary()) } diff --git a/evaluations/src/pipeline/stages/prepare_corpus.rs b/evaluations/src/pipeline/stages/prepare_corpus.rs index 102a3b3..cd73690 100644 --- a/evaluations/src/pipeline/stages/prepare_corpus.rs +++ b/evaluations/src/pipeline/stages/prepare_corpus.rs @@ -113,7 +113,7 @@ pub(crate) async fn prepare_corpus( .metadata .ingestion_fingerprint .clone(); - let ingestion_duration_ms = ingestion_timer.elapsed().as_millis() as u128; + let ingestion_duration_ms = ingestion_timer.elapsed().as_millis(); info!( cache = %corpus_handle.path.display(), reused_ingestion = corpus_handle.reused_ingestion, diff --git a/evaluations/src/pipeline/stages/prepare_namespace.rs b/evaluations/src/pipeline/stages/prepare_namespace.rs index ec87f75..a12c81f 100644 --- a/evaluations/src/pipeline/stages/prepare_namespace.rs +++ b/evaluations/src/pipeline/stages/prepare_namespace.rs @@ -119,7 +119,7 @@ pub(crate) async fn prepare_namespace( corpus::seed_manifest_into_db(ctx.db(), &manifest_for_seed) .await .context("seeding ingestion corpus from manifest")?; - namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128); + namespace_seed_ms = Some(seed_start.elapsed().as_millis()); // Recreate indexes AFTER data is loaded (correct bulk loading pattern) if indexes_disabled { diff --git a/evaluations/src/pipeline/stages/run_queries.rs b/evaluations/src/pipeline/stages/run_queries.rs index 37edefc..bb25523 100644 --- a/evaluations/src/pipeline/stages/run_queries.rs +++ b/evaluations/src/pipeline/stages/run_queries.rs @@ -50,8 +50,10 @@ pub(crate) async fn run_queries( None }; - let mut retrieval_config = RetrievalConfig::default(); - retrieval_config.strategy = config.retrieval.strategy; + let mut retrieval_config = RetrievalConfig { + strategy: config.retrieval.strategy, + ..Default::default() + }; retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top; if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top { retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top; @@ -213,7 +215,7 @@ pub(crate) async fn run_queries( .with_context(|| format!("running pipeline for question {}", question_id))?; (outcome.results, None, outcome.stage_timings) }; - let query_latency = query_start.elapsed().as_millis() as u128; + let query_latency = query_start.elapsed().as_millis(); let candidates = adapt_strategy_output(result_output); let mut retrieved = Vec::new(); diff --git a/evaluations/src/slice.rs b/evaluations/src/slice.rs index e21f5ce..41ce029 100644 --- a/evaluations/src/slice.rs +++ b/evaluations/src/slice.rs @@ -436,8 +436,8 @@ pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result( - dataset: &'a ConvertedDataset, +fn load_explicit_slice( + dataset: &ConvertedDataset, index: &DatasetIndex, config: &SliceConfig<'_>, slice_arg: &str, diff --git a/html-router/src/html_state.rs b/html-router/src/html_state.rs index c488738..615b834 100644 --- a/html-router/src/html_state.rs +++ b/html-router/src/html_state.rs @@ -46,7 +46,7 @@ impl HtmlState { .retrieval_strategy .as_deref() .and_then(|value| value.parse().ok()) - .unwrap_or(RetrievalStrategy::Initial) + .unwrap_or(RetrievalStrategy::Default) } } impl ProvidesDb for HtmlState { diff --git a/html-router/src/routes/chat/message_response_stream.rs b/html-router/src/routes/chat/message_response_stream.rs index 8fa37dc..5aec618 100644 --- a/html-router/src/routes/chat/message_response_stream.rs +++ b/html-router/src/routes/chat/message_response_stream.rs @@ -15,7 +15,10 @@ use futures::{ use json_stream_parser::JsonStreamParser; use minijinja::Value; use retrieval_pipeline::{ - answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat}, + answer_retrieval::{ + chunks_to_chat_context, create_chat_request, create_user_message_with_history, + LLMResponseFormat, + }, retrieved_entities_to_json, }; use serde::{Deserialize, Serialize}; @@ -126,7 +129,7 @@ pub async fn get_response_stream( let strategy = state.retrieval_strategy(); let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy); - let entities = match retrieval_pipeline::retrieve_entities( + let retrieval_result = match retrieval_pipeline::retrieve_entities( &state.db, &state.openai_client, &user_message.content, @@ -136,19 +139,21 @@ pub async fn get_response_stream( ) .await { - Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => entities, - Ok(retrieval_pipeline::StrategyOutput::Chunks(_chunks)) => { - return Sse::new(create_error_stream("Chat retrieval currently only supports Entity-based strategies (Initial). Revised strategy returns Chunks which are not yet supported by this handler.")); - } + Ok(result) => result, Err(_e) => { - return Sse::new(create_error_stream("Failed to retrieve knowledge entities")); + return Sse::new(create_error_stream("Failed to retrieve knowledge")); } }; - // 3. Create the OpenAI request - let entities_json = retrieved_entities_to_json(&entities); + // 3. Create the OpenAI request with appropriate context format + let context_json = match retrieval_result { + retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks), + retrieval_pipeline::StrategyOutput::Entities(entities) => { + retrieved_entities_to_json(&entities) + } + }; let formatted_user_message = - create_user_message_with_history(&entities_json, &history, &user_message.content); + create_user_message_with_history(&context_json, &history, &user_message.content); let settings = match SystemSettings::get_current(&state.db).await { Ok(s) => s, Err(_) => { diff --git a/ingestion-pipeline/src/lib.rs b/ingestion-pipeline/src/lib.rs index 671f2ae..c7200b2 100644 --- a/ingestion-pipeline/src/lib.rs +++ b/ingestion-pipeline/src/lib.rs @@ -1,3 +1,8 @@ +#![allow( + clippy::missing_docs_in_private_items, + clippy::result_large_err +)] + pub mod pipeline; pub mod utils; diff --git a/ingestion-pipeline/src/pipeline/config.rs b/ingestion-pipeline/src/pipeline/config.rs index b0a9df8..3e8d23a 100644 --- a/ingestion-pipeline/src/pipeline/config.rs +++ b/ingestion-pipeline/src/pipeline/config.rs @@ -31,17 +31,8 @@ impl Default for IngestionTuning { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IngestionConfig { pub tuning: IngestionTuning, pub chunk_only: bool, } - -impl Default for IngestionConfig { - fn default() -> Self { - Self { - tuning: IngestionTuning::default(), - chunk_only: false, - } - } -} diff --git a/ingestion-pipeline/src/pipeline/enrichment_result.rs b/ingestion-pipeline/src/pipeline/enrichment_result.rs index 3c6193c..ed1f6c9 100644 --- a/ingestion-pipeline/src/pipeline/enrichment_result.rs +++ b/ingestion-pipeline/src/pipeline/enrichment_result.rs @@ -52,7 +52,7 @@ impl LLMEnrichmentResult { entity_concurrency: usize, embedding_provider: Option<&EmbeddingProvider>, ) -> Result<(Vec, Vec), AppError> { - let mapper = Arc::new(self.create_mapper()?); + let mapper = Arc::new(self.create_mapper()); let entities = self .process_entities( @@ -66,21 +66,22 @@ impl LLMEnrichmentResult { ) .await?; - let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?; + let relationships = self.process_relationships(source_id, user_id, mapper.as_ref())?; Ok((entities, relationships)) } - fn create_mapper(&self) -> Result { + fn create_mapper(&self) -> GraphMapper { let mut mapper = GraphMapper::new(); for entity in &self.knowledge_entities { mapper.assign_id(&entity.key); } - Ok(mapper) + mapper } + #[allow(clippy::too_many_arguments)] async fn process_entities( &self, source_id: &str, @@ -91,7 +92,7 @@ impl LLMEnrichmentResult { entity_concurrency: usize, embedding_provider: Option<&EmbeddingProvider>, ) -> Result, AppError> { - stream::iter(self.knowledge_entities.iter().cloned().map(|entity| { + stream::iter(self.knowledge_entities.clone().into_iter().map(|entity| { let mapper = Arc::clone(&mapper); let openai_client = openai_client.clone(); let source_id = source_id.to_string(); @@ -120,7 +121,7 @@ impl LLMEnrichmentResult { &self, source_id: &str, user_id: &str, - mapper: Arc, + mapper: &GraphMapper, ) -> Result, AppError> { self.relationships .iter() @@ -170,9 +171,9 @@ async fn create_single_entity( id: assigned_id, created_at: now, updated_at: now, - name: llm_entity.name.to_string(), - description: llm_entity.description.to_string(), - entity_type: KnowledgeEntityType::from(llm_entity.entity_type.to_string()), + name: llm_entity.name.clone(), + description: llm_entity.description.clone(), + entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()), source_id: source_id.to_string(), metadata: None, user_id: user_id.into(), diff --git a/ingestion-pipeline/src/pipeline/mod.rs b/ingestion-pipeline/src/pipeline/mod.rs index e642343..5294909 100644 --- a/ingestion-pipeline/src/pipeline/mod.rs +++ b/ingestion-pipeline/src/pipeline/mod.rs @@ -8,6 +8,7 @@ mod state; pub use config::{IngestionConfig, IngestionTuning}; pub use enrichment_result::{LLMEnrichmentResult, LLMKnowledgeEntity, LLMRelationship}; +#[allow(clippy::module_name_repetitions)] pub use services::{DefaultPipelineServices, PipelineServices}; use std::{ @@ -37,6 +38,7 @@ use self::{ state::ready, }; +#[allow(clippy::module_name_repetitions)] pub struct IngestionPipeline { db: Arc, pipeline_config: IngestionConfig, @@ -44,7 +46,7 @@ pub struct IngestionPipeline { } impl IngestionPipeline { - pub async fn new( + pub fn new( db: Arc, openai_client: Arc>, config: AppConfig, @@ -61,10 +63,9 @@ impl IngestionPipeline { embedding_provider, IngestionConfig::default(), ) - .await } - pub async fn new_with_config( + pub fn new_with_config( db: Arc, openai_client: Arc>, config: AppConfig, @@ -74,9 +75,9 @@ impl IngestionPipeline { pipeline_config: IngestionConfig, ) -> Result { let services = DefaultPipelineServices::new( - db.clone(), - openai_client.clone(), - config.clone(), + Arc::clone(&db), + openai_client, + config, reranker_pool, storage, embedding_provider, @@ -181,11 +182,17 @@ impl IngestionPipeline { .saturating_sub(1) .min(tuning.retry_backoff_cap_exponent); let multiplier = 2_u64.pow(capped_attempt); - let delay = tuning.retry_base_delay_secs * multiplier; + let delay = tuning + .retry_base_delay_secs + .saturating_mul(multiplier); Duration::from_secs(delay.min(tuning.retry_max_delay_secs)) } + fn duration_millis(duration: Duration) -> u64 { + u64::try_from(duration.as_millis()).unwrap_or(u64::MAX) + } + #[tracing::instrument( skip_all, fields(task_id = %task.id, attempt = task.attempts, user_id = %task.user_id) @@ -231,14 +238,14 @@ impl IngestionPipeline { let persist_duration = stage_start.elapsed(); let total_duration = pipeline_started.elapsed(); - let prepare_ms = prepare_duration.as_millis() as u64; - let retrieve_ms = retrieve_duration.as_millis() as u64; - let enrich_ms = enrich_duration.as_millis() as u64; - let persist_ms = persist_duration.as_millis() as u64; + let prepare_ms = Self::duration_millis(prepare_duration); + let retrieve_ms = Self::duration_millis(retrieve_duration); + let enrich_ms = Self::duration_millis(enrich_duration); + let persist_ms = Self::duration_millis(persist_duration); info!( task_id = %ctx.task_id, attempt = ctx.attempt, - total_ms = total_duration.as_millis() as u64, + total_ms = Self::duration_millis(total_duration), prepare_ms, retrieve_ms, enrich_ms, diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index 8d6ca97..e550322 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -228,7 +228,7 @@ impl PipelineServices for DefaultPipelineServices { ) -> Result<(Vec, Vec), AppError> { analysis .to_database_entities( - &content.get_id(), + content.get_id(), &content.user_id, &self.openai_client, &self.db, @@ -327,13 +327,13 @@ fn truncate_for_embedding(text: &str, max_chars: usize) -> String { return text.to_string(); } - let mut truncated = String::with_capacity(max_chars + 3); + let mut truncated = String::with_capacity(max_chars.saturating_add(3)); for (idx, ch) in text.chars().enumerate() { if idx >= max_chars { break; } truncated.push(ch); } - truncated.push_str("…"); + truncated.push('…'); truncated } diff --git a/ingestion-pipeline/src/pipeline/stages/mod.rs b/ingestion-pipeline/src/pipeline/stages/mod.rs index 73316b9..ffe6ba8 100644 --- a/ingestion-pipeline/src/pipeline/stages/mod.rs +++ b/ingestion-pipeline/src/pipeline/stages/mod.rs @@ -20,6 +20,22 @@ use super::{ state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved}, }; +const STORE_RELATIONSHIPS: &str = r" + BEGIN TRANSACTION; + LET $relationships = $relationships; + + FOR $relationship IN $relationships { + LET $in_node = type::thing('knowledge_entity', $relationship.in); + LET $out_node = type::thing('knowledge_entity', $relationship.out); + RELATE $in_node->relates_to->$out_node CONTENT { + id: type::thing('relates_to', $relationship.id), + metadata: $relationship.metadata + }; + }; + + COMMIT TRANSACTION; +"; + #[instrument( level = "trace", skip_all, @@ -40,8 +56,7 @@ pub async fn prepare_content( let context_len = text_content .context .as_ref() - .map(|c| c.chars().count()) - .unwrap_or(0); + .map_or(0, |c| c.chars().count()); tracing::info!( task_id = %ctx.task_id, @@ -65,7 +80,7 @@ pub async fn prepare_content( machine .prepare() - .map_err(|(_, guard)| map_guard_error("prepare", guard)) + .map_err(|(_, guard)| map_guard_error("prepare", &guard)) } #[instrument( @@ -80,7 +95,7 @@ pub async fn retrieve_related( if ctx.pipeline_config.chunk_only { return machine .retrieve() - .map_err(|(_, guard)| map_guard_error("retrieve", guard)); + .map_err(|(_, guard)| map_guard_error("retrieve", &guard)); } let content = ctx.text_content()?; @@ -97,7 +112,7 @@ pub async fn retrieve_related( machine .retrieve() - .map_err(|(_, guard)| map_guard_error("retrieve", guard)) + .map_err(|(_, guard)| map_guard_error("retrieve", &guard)) } #[instrument( @@ -116,7 +131,7 @@ pub async fn enrich( }); return machine .enrich() - .map_err(|(_, guard)| map_guard_error("enrich", guard)); + .map_err(|(_, guard)| map_guard_error("enrich", &guard)); } let content = ctx.text_content()?; @@ -137,7 +152,7 @@ pub async fn enrich( machine .enrich() - .map_err(|(_, guard)| map_guard_error("enrich", guard)) + .map_err(|(_, guard)| map_guard_error("enrich", &guard)) } #[instrument( @@ -182,10 +197,10 @@ pub async fn persist( machine .persist() - .map_err(|(_, guard)| map_guard_error("persist", guard)) + .map_err(|(_, guard)| map_guard_error("persist", &guard)) } -fn map_guard_error(event: &str, guard: GuardError) -> AppError { +fn map_guard_error(event: &str, guard: &GuardError) -> AppError { AppError::InternalError(format!( "invalid ingestion pipeline transition during {event}: {guard:?}" )) @@ -206,43 +221,31 @@ async fn store_graph_entities( return Ok(()); } - const STORE_RELATIONSHIPS: &str = r" - BEGIN TRANSACTION; - LET $relationships = $relationships; - - FOR $relationship IN $relationships { - LET $in_node = type::thing('knowledge_entity', $relationship.in); - LET $out_node = type::thing('knowledge_entity', $relationship.out); - RELATE $in_node->relates_to->$out_node CONTENT { - id: type::thing('relates_to', $relationship.id), - metadata: $relationship.metadata - }; - }; - - COMMIT TRANSACTION; - "; - let relationships = Arc::new(relationships); let mut backoff_ms = tuning.graph_initial_backoff_ms; + let last_attempt = tuning.graph_store_attempts.saturating_sub(1); for attempt in 0..tuning.graph_store_attempts { let result = db .client .query(STORE_RELATIONSHIPS) - .bind(("relationships", relationships.clone())) + .bind(("relationships", Arc::clone(&relationships))) .await; match result { Ok(_) => return Ok(()), Err(err) => { - if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts { + if is_retryable_conflict(&err) && attempt < last_attempt { + let next_attempt = attempt.saturating_add(1); warn!( - attempt = attempt + 1, + attempt = next_attempt, "Transient SurrealDB conflict while storing graph data; retrying" ); sleep(Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms); + backoff_ms = backoff_ms + .saturating_mul(2) + .min(tuning.graph_max_backoff_ms); continue; } diff --git a/ingestion-pipeline/src/utils/file_text_extraction.rs b/ingestion-pipeline/src/utils/file_text_extraction.rs index 06b3dab..1ae88c8 100644 --- a/ingestion-pipeline/src/utils/file_text_extraction.rs +++ b/ingestion-pipeline/src/utils/file_text_extraction.rs @@ -65,7 +65,7 @@ fn infer_extension(file_info: &FileInfo) -> Option { Path::new(&file_info.path) .extension() .and_then(|ext| ext.to_str()) - .map(|ext| ext.to_string()) + .map(std::string::ToString::to_string) } pub async fn extract_text_from_file( diff --git a/ingestion-pipeline/src/utils/pdf_ingestion.rs b/ingestion-pipeline/src/utils/pdf_ingestion.rs index c4b7a27..5be106c 100644 --- a/ingestion-pipeline/src/utils/pdf_ingestion.rs +++ b/ingestion-pipeline/src/utils/pdf_ingestion.rs @@ -116,6 +116,7 @@ async fn load_page_numbers(pdf_bytes: Vec) -> Result, AppError> { } /// Uses the existing headless Chrome dependency to rasterize the requested PDF pages into PNGs. +#[allow(clippy::too_many_lines)] async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result>, AppError> { let file_url = url::Url::from_file_path(file_path) .map_err(|()| AppError::Processing("Unable to construct PDF file URL".into()))?; @@ -148,7 +149,7 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result loaded = true; break; } - if attempt + 1 < NAVIGATION_RETRY_ATTEMPTS { + if attempt < NAVIGATION_RETRY_ATTEMPTS.saturating_sub(1) { sleep(Duration::from_millis(NAVIGATION_RETRY_INTERVAL_MS)).await; } } @@ -172,7 +173,7 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result break; } Ok(None) => { - if attempt + 1 < CANVAS_VIEWPORT_ATTEMPTS { + if attempt < CANVAS_VIEWPORT_ATTEMPTS.saturating_sub(1) { tokio::time::sleep(Duration::from_millis(CANVAS_VIEWPORT_WAIT_MS)).await; } } @@ -260,6 +261,7 @@ fn create_browser() -> Result { } /// Sends one or more rendered pages to the configured multimodal model and stitches the resulting Markdown chunks together. +#[allow(clippy::too_many_lines)] async fn vision_markdown( rendered_pages: Vec>, db: &SurrealDbClient, @@ -303,10 +305,11 @@ async fn vision_markdown( let mut batch_markdown: Option = None; + let last_attempt = MAX_VISION_ATTEMPTS.saturating_sub(1); for attempt in 0..MAX_VISION_ATTEMPTS { let prompt_text = prompt_for_attempt(attempt, prompt); - let mut content_parts = Vec::with_capacity(encoded_images.len() + 1); + let mut content_parts = Vec::with_capacity(encoded_images.len().saturating_add(1)); content_parts.push( ChatCompletionRequestMessageContentPartTextArgs::default() .text(prompt_text) @@ -375,7 +378,7 @@ async fn vision_markdown( batch = batch_idx, attempt, "Vision model returned low quality response" ); - if attempt + 1 == MAX_VISION_ATTEMPTS { + if attempt == last_attempt { return Err(AppError::Processing( "Vision model failed to transcribe PDF page contents".into(), )); @@ -400,6 +403,7 @@ async fn vision_markdown( } /// Heuristic that determines whether the fast-path text looks like well-formed prose. +#[allow(clippy::cast_precision_loss)] fn looks_good_enough(text: &str) -> bool { if text.len() < FAST_PATH_MIN_LEN { return false; diff --git a/ingestion-pipeline/src/utils/url_text_retrieval.rs b/ingestion-pipeline/src/utils/url_text_retrieval.rs index ca61273..a76bfca 100644 --- a/ingestion-pipeline/src/utils/url_text_retrieval.rs +++ b/ingestion-pipeline/src/utils/url_text_retrieval.rs @@ -50,7 +50,7 @@ pub async fn extract_text_from_url( )?; let mut tmp_file = NamedTempFile::new()?; - let temp_path_str = format!("{:?}", tmp_file.path()); + let temp_path_str = tmp_file.path().display().to_string(); tmp_file.write_all(&screenshot)?; tmp_file.as_file().sync_all()?; @@ -108,14 +108,11 @@ fn ensure_ingestion_url_allowed(url: &url::Url) -> Result { } } - let host = match url.host_str() { - Some(host) => host, - None => { - warn!(%url, "Rejected ingestion URL missing host"); - return Err(AppError::Validation( - "URL is missing a host component".to_string(), - )); - } + let Some(host) = url.host_str() else { + warn!(%url, "Rejected ingestion URL missing host"); + return Err(AppError::Validation( + "URL is missing a host component".to_string(), + )); }; if host.eq_ignore_ascii_case("localhost") { diff --git a/main/src/main.rs b/main/src/main.rs index 7db36ca..76be6a9 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -138,7 +138,6 @@ async fn main() -> Result<(), Box> { storage.clone(), embedding_provider, ) - .await .unwrap(), ); diff --git a/main/src/worker.rs b/main/src/worker.rs index e9d13f8..fc1f29a 100644 --- a/main/src/worker.rs +++ b/main/src/worker.rs @@ -53,7 +53,7 @@ async fn main() -> Result<(), Box> { storage, embedding_provider, ) - .await?, + ?, ); run_worker_loop(db, ingestion_pipeline).await diff --git a/retrieval-pipeline/src/answer_retrieval.rs b/retrieval-pipeline/src/answer_retrieval.rs index e0f0108..0a2ff9c 100644 --- a/retrieval-pipeline/src/answer_retrieval.rs +++ b/retrieval-pipeline/src/answer_retrieval.rs @@ -51,6 +51,24 @@ pub fn create_user_message(entities_json: &Value, query: &str) -> String { ) } +/// Convert chunk-based retrieval results to JSON format for LLM context +pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value { + fn round_score(value: f32) -> f64 { + (f64::from(value) * 1000.0).round() / 1000.0 + } + + serde_json::json!(chunks + .iter() + .map(|chunk| { + serde_json::json!({ + "content": chunk.chunk.chunk, + "source_id": chunk.chunk.source_id, + "score": round_score(chunk.score), + }) + }) + .collect::>()) +} + pub fn create_user_message_with_history( entities_json: &Value, history: &[Message], diff --git a/retrieval-pipeline/src/fts.rs b/retrieval-pipeline/src/fts.rs deleted file mode 100644 index f439d3b..0000000 --- a/retrieval-pipeline/src/fts.rs +++ /dev/null @@ -1,268 +0,0 @@ -use std::collections::HashMap; - -use serde::Deserialize; -use tracing::debug; - -use common::{ - error::AppError, - storage::{db::SurrealDbClient, types::StoredObject}, -}; - -use crate::scoring::Scored; -use common::storage::types::file_info::deserialize_flexible_id; -use surrealdb::sql::Thing; - -#[derive(Debug, Deserialize)] -struct FtsScoreRow { - #[serde(deserialize_with = "deserialize_flexible_id")] - id: String, - fts_score: Option, -} - -/// Executes a full-text search query against SurrealDB and returns scored results. -/// -/// The function expects FTS indexes to exist for the provided table. Currently supports -/// `knowledge_entity` (name + description) and `text_chunk` (chunk). -pub async fn find_items_by_fts( - take: usize, - query: &str, - db_client: &SurrealDbClient, - table: &str, - user_id: &str, -) -> Result>, AppError> -where - T: for<'de> serde::Deserialize<'de> + StoredObject, -{ - let (filter_clause, score_clause) = match table { - "knowledge_entity" => ( - "(name @0@ $terms OR description @1@ $terms)", - "(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \ - (IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)", - ), - "text_chunk" => ( - "(chunk @0@ $terms)", - "IF search::score(0) != NONE THEN search::score(0) ELSE 0 END", - ), - _ => { - return Err(AppError::Validation(format!( - "FTS not configured for table '{table}'" - ))) - } - }; - - let sql = format!( - "SELECT id, {score_clause} AS fts_score \ - FROM {table} \ - WHERE {filter_clause} \ - AND user_id = $user_id \ - ORDER BY fts_score DESC \ - LIMIT $limit", - table = table, - filter_clause = filter_clause, - score_clause = score_clause - ); - - debug!( - table = table, - limit = take, - "Executing FTS query with filter clause: {}", - filter_clause - ); - - let mut response = db_client - .query(sql) - .bind(("terms", query.to_owned())) - .bind(("user_id", user_id.to_owned())) - .bind(("limit", take as i64)) - .await?; - - let score_rows: Vec = response.take(0)?; - - if score_rows.is_empty() { - return Ok(Vec::new()); - } - - let ids: Vec = score_rows.iter().map(|row| row.id.clone()).collect(); - let thing_ids: Vec = ids - .iter() - .map(|id| Thing::from((table, id.as_str()))) - .collect(); - - let mut items_response = db_client - .query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id") - .bind(("table", table.to_owned())) - .bind(("things", thing_ids.clone())) - .bind(("user_id", user_id.to_owned())) - .await?; - - let items: Vec = items_response.take(0)?; - - let mut item_map: HashMap = items - .into_iter() - .map(|item| (item.get_id().to_owned(), item)) - .collect(); - - let mut results = Vec::with_capacity(score_rows.len()); - for row in score_rows { - if let Some(item) = item_map.remove(&row.id) { - let score = row.fts_score.unwrap_or_default(); - results.push(Scored::new(item).with_fts_score(score)); - } - } - - Ok(results) -} - -#[cfg(test)] -mod tests { - use super::*; - use common::storage::indexes::ensure_runtime_indexes; - use common::storage::types::{ - knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, - text_chunk::TextChunk, - StoredObject, - }; - use uuid::Uuid; - - #[tokio::test] - async fn fts_preserves_single_field_score_for_name() { - let namespace = "fts_test_ns"; - let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("failed to create in-memory surreal"); - - db.apply_migrations() - .await - .expect("failed to apply migrations"); - ensure_runtime_indexes(&db, 1536) - .await - .expect("failed to build runtime indexes"); - - let user_id = "user_fts"; - let entity = KnowledgeEntity::new( - "source_a".into(), - "Rustacean handbook".into(), - "completely unrelated description".into(), - KnowledgeEntityType::Document, - None, - user_id.into(), - ); - - db.store_item(entity.clone()) - .await - .expect("failed to insert entity"); - - db.rebuild_indexes() - .await - .expect("failed to rebuild indexes"); - - let results = find_items_by_fts::( - 5, - "rustacean", - &db, - KnowledgeEntity::table_name(), - user_id, - ) - .await - .expect("fts query failed"); - - assert!(!results.is_empty(), "expected at least one FTS result"); - assert!( - results[0].scores.fts.is_some(), - "expected an FTS score when only the name matched" - ); - } - - #[tokio::test] - async fn fts_preserves_single_field_score_for_description() { - let namespace = "fts_test_ns_desc"; - let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("failed to create in-memory surreal"); - - db.apply_migrations() - .await - .expect("failed to apply migrations"); - ensure_runtime_indexes(&db, 1536) - .await - .expect("failed to build runtime indexes"); - - let user_id = "user_fts_desc"; - let entity = KnowledgeEntity::new( - "source_b".into(), - "neutral name".into(), - "Detailed notes about async runtimes".into(), - KnowledgeEntityType::Document, - None, - user_id.into(), - ); - - db.store_item(entity.clone()) - .await - .expect("failed to insert entity"); - - db.rebuild_indexes() - .await - .expect("failed to rebuild indexes"); - - let results = find_items_by_fts::( - 5, - "async", - &db, - KnowledgeEntity::table_name(), - user_id, - ) - .await - .expect("fts query failed"); - - assert!(!results.is_empty(), "expected at least one FTS result"); - assert!( - results[0].scores.fts.is_some(), - "expected an FTS score when only the description matched" - ); - } - - #[tokio::test] - async fn fts_preserves_scores_for_text_chunks() { - let namespace = "fts_test_ns_chunks"; - let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("failed to create in-memory surreal"); - - db.apply_migrations() - .await - .expect("failed to apply migrations"); - ensure_runtime_indexes(&db, 1536) - .await - .expect("failed to build runtime indexes"); - - let user_id = "user_fts_chunk"; - let chunk = TextChunk::new( - "source_chunk".into(), - "GraphQL documentation reference".into(), - user_id.into(), - ); - - TextChunk::store_with_embedding(chunk.clone(), vec![0.0; 1536], &db) - .await - .expect("failed to insert chunk"); - - db.rebuild_indexes() - .await - .expect("failed to rebuild indexes"); - - let results = - find_items_by_fts::(5, "graphql", &db, TextChunk::table_name(), user_id) - .await - .expect("fts query failed"); - - assert!(!results.is_empty(), "expected at least one FTS result"); - assert!( - results[0].scores.fts.is_some(), - "expected an FTS score when chunk field matched" - ); - } -} diff --git a/retrieval-pipeline/src/graph.rs b/retrieval-pipeline/src/graph.rs index 494404d..0c7cc5a 100644 --- a/retrieval-pipeline/src/graph.rs +++ b/retrieval-pipeline/src/graph.rs @@ -10,54 +10,17 @@ use common::storage::{ }, }; -/// Retrieves database entries that match a specific source identifier. +/// Find entities related to the given entity via graph relationships. /// -/// This function queries the database for all records in a specified table that have -/// a matching `source_id` field. It's commonly used to find related entities or -/// track the origin of database entries. +/// Queries the `relates_to` edge table for all relationships involving the entity, +/// then fetches and returns the neighboring entities. /// /// # Arguments -/// -/// * `source_id` - The identifier to search for in the database -/// * `table_name` - The name of the table to search in -/// * `db_client` - The `SurrealDB` client instance for database operations -/// -/// # Type Parameters -/// -/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize` -/// -/// # Returns -/// -/// Returns a `Result` containing either: -/// * `Ok(Vec)` - A vector of matching records deserialized into type `T` -/// * `Err(Error)` - An error if the database query fails -/// -/// # Errors -/// -/// This function will return a `Error` if: -/// * The database query fails to execute -/// * The results cannot be deserialized into type `T` -pub async fn find_entities_by_source_ids( - source_ids: Vec, - table_name: &str, - user_id: &str, - db: &SurrealDbClient, -) -> Result, Error> -where - T: for<'de> serde::Deserialize<'de>, -{ - let query = - "SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id"; +/// * `db` - Database client +/// * `entity_id` - ID of the entity to find neighbors for +/// * `user_id` - User ID for access control +/// * `limit` - Maximum number of neighbors to return - db.query(query) - .bind(("table", table_name.to_owned())) - .bind(("source_ids", source_ids)) - .bind(("user_id", user_id.to_owned())) - .await? - .take(0) -} - -/// Find entities by their relationship to the id pub async fn find_entities_by_relationship_by_id( db: &SurrealDbClient, entity_id: &str, @@ -153,154 +116,8 @@ mod tests { use super::*; use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use common::storage::types::knowledge_relationship::KnowledgeRelationship; - use common::storage::types::StoredObject; use uuid::Uuid; - #[tokio::test] - async fn test_find_entities_by_source_ids() { - // Setup in-memory database for testing - let namespace = "test_ns"; - let database = &Uuid::new_v4().to_string(); - let db = SurrealDbClient::memory(namespace, database) - .await - .expect("Failed to start in-memory surrealdb"); - - // Create some test entities with different source_ids - let source_id1 = "source123".to_string(); - let source_id2 = "source456".to_string(); - let source_id3 = "source789".to_string(); - - let entity_type = KnowledgeEntityType::Document; - let user_id = "user123".to_string(); - - // Entity with source_id1 - let entity1 = KnowledgeEntity::new( - source_id1.clone(), - "Entity 1".to_string(), - "Description 1".to_string(), - entity_type.clone(), - None, - user_id.clone(), - ); - - // Entity with source_id2 - let entity2 = KnowledgeEntity::new( - source_id2.clone(), - "Entity 2".to_string(), - "Description 2".to_string(), - entity_type.clone(), - None, - user_id.clone(), - ); - - // Another entity with source_id1 - let entity3 = KnowledgeEntity::new( - source_id1.clone(), - "Entity 3".to_string(), - "Description 3".to_string(), - entity_type.clone(), - None, - user_id.clone(), - ); - - // Entity with source_id3 - let entity4 = KnowledgeEntity::new( - source_id3.clone(), - "Entity 4".to_string(), - "Description 4".to_string(), - entity_type.clone(), - None, - user_id.clone(), - ); - - // Store all entities - db.store_item(entity1.clone()) - .await - .expect("Failed to store entity 1"); - db.store_item(entity2.clone()) - .await - .expect("Failed to store entity 2"); - db.store_item(entity3.clone()) - .await - .expect("Failed to store entity 3"); - db.store_item(entity4.clone()) - .await - .expect("Failed to store entity 4"); - - // Test finding entities by multiple source_ids - let source_ids = vec![source_id1.clone(), source_id2.clone()]; - let found_entities: Vec = - find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db) - .await - .expect("Failed to find entities by source_ids"); - - // Should find 3 entities (2 with source_id1, 1 with source_id2) - assert_eq!( - found_entities.len(), - 3, - "Should find 3 entities with the specified source_ids" - ); - - // Check that entities with source_id1 and source_id2 are found - let found_source_ids: Vec = - found_entities.iter().map(|e| e.source_id.clone()).collect(); - assert!( - found_source_ids.contains(&source_id1), - "Should find entities with source_id1" - ); - assert!( - found_source_ids.contains(&source_id2), - "Should find entities with source_id2" - ); - assert!( - !found_source_ids.contains(&source_id3), - "Should not find entities with source_id3" - ); - - // Test finding entities by a single source_id - let single_source_id = vec![source_id1.clone()]; - let found_entities: Vec = find_entities_by_source_ids( - single_source_id, - KnowledgeEntity::table_name(), - &user_id, - &db, - ) - .await - .expect("Failed to find entities by single source_id"); - - // Should find 2 entities with source_id1 - assert_eq!( - found_entities.len(), - 2, - "Should find 2 entities with source_id1" - ); - - // Check that all found entities have source_id1 - for entity in found_entities { - assert_eq!( - entity.source_id, source_id1, - "All found entities should have source_id1" - ); - } - - // Test finding entities with non-existent source_id - let non_existent_source_id = vec!["non_existent_source".to_string()]; - let found_entities: Vec = find_entities_by_source_ids( - non_existent_source_id, - KnowledgeEntity::table_name(), - &user_id, - &db, - ) - .await - .expect("Failed to find entities by non-existent source_id"); - - // Should find 0 entities - assert_eq!( - found_entities.len(), - 0, - "Should find 0 entities with non-existent source_id" - ); - } #[tokio::test] async fn test_find_entities_by_relationship_by_id() { diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index 063f6e0..5b1ee34 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -1,6 +1,6 @@ pub mod answer_retrieval; pub mod answer_retrieval_helper; -pub mod fts; + pub mod graph; pub mod pipeline; pub mod reranking; @@ -70,11 +70,7 @@ mod tests { use super::*; use async_openai::Client; use common::storage::indexes::ensure_runtime_indexes; - use common::storage::types::{ - knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, - knowledge_relationship::KnowledgeRelationship, - text_chunk::TextChunk, - }; + use common::storage::types::text_chunk::TextChunk; use pipeline::{RetrievalConfig, RetrievalStrategy}; use uuid::Uuid; @@ -82,14 +78,6 @@ mod tests { vec![0.9, 0.1, 0.0] } - fn entity_embedding_high() -> Vec { - vec![0.8, 0.2, 0.0] - } - - fn entity_embedding_low() -> Vec { - vec![0.1, 0.9, 0.0] - } - fn chunk_embedding_primary() -> Vec { vec![0.85, 0.15, 0.0] } @@ -113,41 +101,19 @@ mod tests { .await .expect("failed to build runtime indexes"); - db.query( - "BEGIN TRANSACTION; - REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding; - DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 3; - REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding; - DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 3; - COMMIT TRANSACTION;", - ) - .await - .expect("Failed to configure indices"); - db } #[tokio::test] - async fn test_retrieve_entities_with_embedding_basic_flow() { + async fn test_default_strategy_retrieves_chunks() { let db = setup_test_db().await; let user_id = "test_user"; - let entity = KnowledgeEntity::new( - "source_1".into(), - "Rust async guide".into(), - "Detailed notes about async runtimes".into(), - KnowledgeEntityType::Document, - None, - user_id.into(), - ); let chunk = TextChunk::new( - entity.source_id.clone(), + "source_1".into(), "Tokio uses cooperative scheduling for fairness.".into(), user_id.into(), ); - KnowledgeEntity::store_with_embedding(entity.clone(), entity_embedding_high(), &db) - .await - .expect("Failed to store entity"); TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db) .await .expect("Failed to store chunk"); @@ -164,64 +130,32 @@ mod tests { None, ) .await - .expect("Hybrid retrieval failed"); + .expect("Default strategy retrieval failed"); - let entities = match results { - StrategyOutput::Entities(items) => items, - other => panic!("expected entity results, got {:?}", other), + let chunks = match results { + StrategyOutput::Chunks(items) => items, + other => panic!("expected chunk results, got {:?}", other), }; + assert!(!chunks.is_empty(), "Expected at least one retrieval result"); assert!( - !entities.is_empty(), - "Expected at least one retrieval result" - ); - let top = &entities[0]; - assert!( - top.entity.name.contains("Rust"), - "Expected Rust entity to be ranked first" - ); - assert!( - !top.chunks.is_empty(), - "Expected Rust entity to include supporting chunks" + chunks[0].chunk.chunk.contains("Tokio"), + "Expected chunk about Tokio" ); } #[tokio::test] - async fn test_graph_relationship_enriches_results() { + async fn test_default_strategy_returns_chunks_from_multiple_sources() { let db = setup_test_db().await; - let user_id = "graph_user"; - - let primary = KnowledgeEntity::new( - "primary_source".into(), - "Async Rust patterns".into(), - "Explores async runtimes and scheduling strategies.".into(), - KnowledgeEntityType::Document, - None, - user_id.into(), - ); - let neighbor = KnowledgeEntity::new( - "neighbor_source".into(), - "Tokio Scheduler Deep Dive".into(), - "Details on Tokio's cooperative scheduler.".into(), - KnowledgeEntityType::Document, - None, - user_id.into(), - ); - - KnowledgeEntity::store_with_embedding(primary.clone(), entity_embedding_high(), &db) - .await - .expect("Failed to store primary entity"); - KnowledgeEntity::store_with_embedding(neighbor.clone(), entity_embedding_low(), &db) - .await - .expect("Failed to store neighbor entity"); + let user_id = "multi_source_user"; let primary_chunk = TextChunk::new( - primary.source_id.clone(), + "primary_source".into(), "Rust async tasks use Tokio's cooperative scheduler.".into(), user_id.into(), ); - let neighbor_chunk = TextChunk::new( - neighbor.source_id.clone(), + let secondary_chunk = TextChunk::new( + "secondary_source".into(), "Tokio's scheduler manages task fairness across executors.".into(), user_id.into(), ); @@ -229,23 +163,11 @@ mod tests { TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db) .await .expect("Failed to store primary chunk"); - TextChunk::store_with_embedding(neighbor_chunk, chunk_embedding_secondary(), &db) + TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db) .await - .expect("Failed to store neighbor chunk"); + .expect("Failed to store secondary chunk"); let openai_client = Client::new(); - let relationship = KnowledgeRelationship::new( - primary.id.clone(), - neighbor.id.clone(), - user_id.into(), - "relationship_source".into(), - "references".into(), - ); - relationship - .store_relationship(&db) - .await - .expect("Failed to store relationship"); - let results = pipeline::run_pipeline_with_embedding( &db, &openai_client, @@ -257,35 +179,23 @@ mod tests { None, ) .await - .expect("Hybrid retrieval failed"); + .expect("Default strategy retrieval failed"); - let entities = match results { - StrategyOutput::Entities(items) => items, - other => panic!("expected entity results, got {:?}", other), + let chunks = match results { + StrategyOutput::Chunks(items) => items, + other => panic!("expected chunk results, got {:?}", other), }; - let mut neighbor_entry = None; - for entity in &entities { - if entity.entity.id == neighbor.id { - neighbor_entry = Some(entity.clone()); - } - } - - println!("{:?}", entities); - - let neighbor_entry = - neighbor_entry.expect("Graph-enriched neighbor should appear in results"); - + assert!(chunks.len() >= 2, "Expected chunks from multiple sources"); assert!( - neighbor_entry.score > 0.2, - "Graph-enriched entity should have a meaningful fused score" + chunks.iter().any(|c| c.chunk.source_id == "primary_source"), + "Should include primary source chunk" ); assert!( - neighbor_entry - .chunks + chunks .iter() - .all(|chunk| chunk.chunk.source_id == neighbor.source_id), - "Neighbor entity should surface its own supporting chunks" + .any(|c| c.chunk.source_id == "secondary_source"), + "Should include secondary source chunk" ); } @@ -311,7 +221,7 @@ mod tests { .await .expect("Failed to store chunk two"); - let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised); + let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default); let openai_client = Client::new(); let results = pipeline::run_pipeline_with_embedding( &db, diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index 42f3c50..684d41b 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -6,15 +6,17 @@ use crate::scoring::FusionWeights; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)] #[serde(rename_all = "snake_case")] pub enum RetrievalStrategy { - Initial, - Revised, + /// Primary hybrid chunk retrieval for search/chat (formerly Revised) + Default, + /// Entity retrieval for suggesting relationships when creating manual entities RelationshipSuggestion, + /// Entity retrieval for context during content ingestion Ingestion, } impl Default for RetrievalStrategy { fn default() -> Self { - Self::Initial + Self::Default } } @@ -23,8 +25,16 @@ impl std::str::FromStr for RetrievalStrategy { fn from_str(value: &str) -> Result { match value.to_ascii_lowercase().as_str() { - "initial" => Ok(Self::Initial), - "revised" => Ok(Self::Revised), + "default" => Ok(Self::Default), + // Backward compatibility: treat "initial" and "revised" as "default" + "initial" | "revised" => { + tracing::warn!( + "Retrieval strategy '{}' is deprecated. Use 'default' instead. \ + The 'initial' strategy has been removed in favor of the simpler hybrid chunk retrieval.", + value + ); + Ok(Self::Default) + } "relationship_suggestion" => Ok(Self::RelationshipSuggestion), "ingestion" => Ok(Self::Ingestion), other => Err(format!("unknown retrieval strategy '{other}'")), @@ -35,8 +45,7 @@ impl std::str::FromStr for RetrievalStrategy { impl fmt::Display for RetrievalStrategy { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let label = match self { - RetrievalStrategy::Initial => "initial", - RetrievalStrategy::Revised => "revised", + RetrievalStrategy::Default => "default", RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion", RetrievalStrategy::Ingestion => "ingestion", }; @@ -136,7 +145,7 @@ pub struct RetrievalConfig { impl RetrievalConfig { pub fn new(tuning: RetrievalTuning) -> Self { Self { - strategy: RetrievalStrategy::Initial, + strategy: RetrievalStrategy::Default, tuning, } } diff --git a/retrieval-pipeline/src/pipeline/mod.rs b/retrieval-pipeline/src/pipeline/mod.rs index 7a0d026..9537f7a 100644 --- a/retrieval-pipeline/src/pipeline/mod.rs +++ b/retrieval-pipeline/src/pipeline/mod.rs @@ -17,9 +17,7 @@ use std::time::{Duration, Instant}; use tracing::info; use stages::PipelineContext; -use strategies::{ - IngestionDriver, InitialStrategyDriver, RelationshipSuggestionDriver, RevisedStrategyDriver, -}; +use strategies::{DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver}; // Export StrategyOutput publicly from this module // (it's defined in lib.rs but we re-export it here) @@ -132,25 +130,8 @@ pub async fn run_pipeline( ); match config.strategy { - RetrievalStrategy::Initial => { - let driver = InitialStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - None, - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - Ok(StrategyOutput::Entities(run.results)) - } - RetrievalStrategy::Revised => { - let driver = RevisedStrategyDriver::new(); + RetrievalStrategy::Default => { + let driver = DefaultStrategyDriver::new(); let run = execute_strategy( driver, db_client, @@ -214,25 +195,8 @@ pub async fn run_pipeline_with_embedding( reranker: Option, ) -> Result { match config.strategy { - RetrievalStrategy::Initial => { - let driver = InitialStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - Ok(StrategyOutput::Entities(run.results)) - } - RetrievalStrategy::Revised => { - let driver = RevisedStrategyDriver::new(); + RetrievalStrategy::Default => { + let driver = DefaultStrategyDriver::new(); let run = execute_strategy( driver, db_client, @@ -301,29 +265,8 @@ pub async fn run_pipeline_with_embedding_with_metrics( reranker: Option, ) -> Result, AppError> { match config.strategy { - RetrievalStrategy::Initial => { - let driver = InitialStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - false, - ) - .await?; - Ok(PipelineRunOutput { - results: StrategyOutput::Entities(run.results), - diagnostics: run.diagnostics, - stage_timings: run.stage_timings, - }) - } - RetrievalStrategy::Revised => { - let driver = RevisedStrategyDriver::new(); + RetrievalStrategy::Default => { + let driver = DefaultStrategyDriver::new(); let run = execute_strategy( driver, db_client, @@ -361,29 +304,8 @@ pub async fn run_pipeline_with_embedding_with_diagnostics( reranker: Option, ) -> Result, AppError> { match config.strategy { - RetrievalStrategy::Initial => { - let driver = InitialStrategyDriver::new(); - let run = execute_strategy( - driver, - db_client, - openai_client, - embedding_provider, - Some(query_embedding), - input_text, - user_id, - config, - reranker, - true, - ) - .await?; - Ok(PipelineRunOutput { - results: StrategyOutput::Entities(run.results), - diagnostics: run.diagnostics, - stage_timings: run.stage_timings, - }) - } - RetrievalStrategy::Revised => { - let driver = RevisedStrategyDriver::new(); + RetrievalStrategy::Default => { + let driver = DefaultStrategyDriver::new(); let run = execute_strategy( driver, db_client, diff --git a/retrieval-pipeline/src/pipeline/stages/mod.rs b/retrieval-pipeline/src/pipeline/stages/mod.rs index 587ec52..3dcc394 100644 --- a/retrieval-pipeline/src/pipeline/stages/mod.rs +++ b/retrieval-pipeline/src/pipeline/stages/mod.rs @@ -12,13 +12,13 @@ use fastembed::RerankResult; use futures::{stream::FuturesUnordered, StreamExt}; use std::{ cmp::Ordering, - collections::{HashMap, HashSet}, + collections::HashMap, }; use tracing::{debug, instrument, warn}; use crate::{ - fts::find_items_by_fts, - graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}, + + graph::find_entities_by_relationship_by_id, reranking::RerankerLease, scoring::{ clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, reciprocal_rank_fusion, @@ -45,7 +45,6 @@ pub struct PipelineContext<'a> { pub config: RetrievalConfig, pub query_embedding: Option>, pub entity_candidates: HashMap>, - pub chunk_candidates: HashMap>, pub filtered_entities: Vec>, pub chunk_values: Vec>, pub revised_chunk_values: Vec>, @@ -75,7 +74,6 @@ impl<'a> PipelineContext<'a> { config, query_embedding: None, entity_candidates: HashMap::new(), - chunk_candidates: HashMap::new(), filtered_entities: Vec::new(), chunk_values: Vec::new(), revised_chunk_values: Vec::new(), @@ -209,20 +207,6 @@ impl PipelineStage for GraphExpansionStage { } } -#[derive(Debug, Clone, Copy)] -pub struct ChunkAttachStage; - -#[async_trait] -impl PipelineStage for ChunkAttachStage { - fn kind(&self) -> StageKind { - StageKind::ChunkAttach - } - - async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - attach_chunks(ctx).await - } -} - #[derive(Debug, Clone, Copy)] pub struct RerankStage; @@ -324,75 +308,68 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App let weights = FusionWeights::default(); - let (vector_entity_results, vector_chunk_results, mut fts_entities, mut fts_chunks) = tokio::try_join!( + let (vector_entity_results, fts_entity_results) = tokio::try_join!( KnowledgeEntity::vector_search( tuning.entity_vector_take, - embedding.clone(), - ctx.db_client, - &ctx.user_id, - ), - TextChunk::vector_search( - tuning.chunk_vector_take, embedding, ctx.db_client, &ctx.user_id, ), - find_items_by_fts( - tuning.entity_fts_take, - &ctx.input_text, + KnowledgeEntity::search( ctx.db_client, - "knowledge_entity", + &ctx.input_text, &ctx.user_id, - ), - find_items_by_fts( - tuning.chunk_fts_take, - &ctx.input_text, - ctx.db_client, - "text_chunk", - &ctx.user_id - ), + tuning.entity_fts_take, + ) )?; + #[allow(clippy::useless_conversion)] let vector_entities: Vec> = vector_entity_results .into_iter() .map(|row| Scored::new(row.entity).with_vector_score(row.score)) .collect(); - let vector_chunks: Vec> = vector_chunk_results + + let mut fts_entities: Vec> = fts_entity_results .into_iter() - .map(|row| Scored::new(row.chunk).with_vector_score(row.score)) + .map(|res| { + let entity = KnowledgeEntity { + id: res.id, + created_at: res.created_at, + updated_at: res.updated_at, + source_id: res.source_id, + name: res.name, + description: res.description, + entity_type: res.entity_type, + metadata: res.metadata, + user_id: res.user_id, + }; + Scored::new(entity).with_fts_score(res.score) + }) .collect(); debug!( vector_entities = vector_entities.len(), - vector_chunks = vector_chunks.len(), fts_entities = fts_entities.len(), - fts_chunks = fts_chunks.len(), "Hybrid retrieval initial candidate counts" ); if ctx.diagnostics_enabled() { ctx.record_collect_candidates(CollectCandidatesStats { vector_entity_candidates: vector_entities.len(), - vector_chunk_candidates: vector_chunks.len(), + vector_chunk_candidates: 0, fts_entity_candidates: fts_entities.len(), - fts_chunk_candidates: fts_chunks.len(), - vector_chunk_scores: sample_scores(&vector_chunks, |chunk| { - chunk.scores.vector.unwrap_or(0.0) - }), - fts_chunk_scores: sample_scores(&fts_chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)), + fts_chunk_candidates: 0, + vector_chunk_scores: Vec::new(), + fts_chunk_scores: Vec::new(), }); } normalize_fts_scores(&mut fts_entities); - normalize_fts_scores(&mut fts_chunks); merge_scored_by_id(&mut ctx.entity_candidates, vector_entities); merge_scored_by_id(&mut ctx.entity_candidates, fts_entities); - merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks); - merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks); apply_fusion(&mut ctx.entity_candidates, weights); - apply_fusion(&mut ctx.chunk_candidates, weights); Ok(()) } @@ -467,82 +444,6 @@ pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> Ok(()) } -#[instrument(level = "trace", skip_all)] -pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { - debug!("Attaching chunks to surviving entities"); - let tuning = &ctx.config.tuning; - let weights = FusionWeights::default(); - - let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates); - let chunk_candidates_before = ctx.chunk_candidates.len(); - let chunk_sources_considered = chunk_by_source.len(); - - backfill_entities_from_chunks( - &mut ctx.entity_candidates, - &chunk_by_source, - ctx.db_client, - &ctx.user_id, - weights, - ) - .await?; - - boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights); - - let mut entity_results: Vec> = - ctx.entity_candidates.values().cloned().collect(); - sort_by_fused_desc(&mut entity_results); - - let mut filtered_entities: Vec> = entity_results - .iter() - .filter(|candidate| candidate.fused >= tuning.score_threshold) - .cloned() - .collect(); - - if filtered_entities.len() < tuning.fallback_min_results { - filtered_entities = entity_results - .into_iter() - .take(tuning.fallback_min_results) - .collect(); - } - - ctx.filtered_entities = filtered_entities; - - let mut chunk_results: Vec> = - ctx.chunk_candidates.values().cloned().collect(); - sort_by_fused_desc(&mut chunk_results); - - let mut chunk_by_id: HashMap> = HashMap::new(); - for chunk in chunk_results { - chunk_by_id.insert(chunk.item.id.clone(), chunk); - } - - enrich_chunks_from_entities( - &mut chunk_by_id, - &ctx.filtered_entities, - ctx.db_client, - &ctx.user_id, - weights, - ) - .await?; - - let mut chunk_values: Vec> = chunk_by_id.into_values().collect(); - sort_by_fused_desc(&mut chunk_values); - - if ctx.diagnostics_enabled() { - ctx.record_chunk_enrichment(ChunkEnrichmentStats { - filtered_entity_count: ctx.filtered_entities.len(), - fallback_min_results: tuning.fallback_min_results, - chunk_sources_considered, - chunk_candidates_before_enrichment: chunk_candidates_before, - chunk_candidates_after_enrichment: chunk_values.len(), - top_chunk_scores: sample_scores(&chunk_values, |chunk| chunk.fused), - }); - } - - ctx.chunk_values = chunk_values; - - Ok(()) -} #[instrument(level = "trace", skip_all)] pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { @@ -960,142 +861,6 @@ where } } -fn group_chunks_by_source( - chunks: &HashMap>, -) -> HashMap>> { - let mut by_source: HashMap>> = HashMap::new(); - - for chunk in chunks.values() { - by_source - .entry(chunk.item.source_id.clone()) - .or_default() - .push(chunk.clone()); - } - by_source -} - -async fn backfill_entities_from_chunks( - entity_candidates: &mut HashMap>, - chunk_by_source: &HashMap>>, - db_client: &SurrealDbClient, - user_id: &str, - weights: FusionWeights, -) -> Result<(), AppError> { - let mut missing_sources = Vec::new(); - - for source_id in chunk_by_source.keys() { - if !entity_candidates - .values() - .any(|entity| entity.item.source_id == *source_id) - { - missing_sources.push(source_id.clone()); - } - } - - if missing_sources.is_empty() { - return Ok(()); - } - - let related_entities: Vec = find_entities_by_source_ids( - missing_sources.clone(), - "knowledge_entity", - user_id, - db_client, - ) - .await - .unwrap_or_default(); - - if related_entities.is_empty() { - warn!("expected related entities for missing chunk sources, but none were found"); - } - - for entity in related_entities { - if let Some(chunks) = chunk_by_source.get(&entity.source_id) { - let best_chunk_score = chunks - .iter() - .map(|chunk| chunk.fused) - .fold(0.0f32, f32::max); - - let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score); - let fused = fuse_scores(&scored.scores, weights); - scored.update_fused(fused); - entity_candidates.insert(entity.id.clone(), scored); - } - } - - Ok(()) -} - -fn boost_entities_with_chunks( - entity_candidates: &mut HashMap>, - chunk_by_source: &HashMap>>, - weights: FusionWeights, -) { - for entity in entity_candidates.values_mut() { - if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) { - let best_chunk_score = chunks - .iter() - .map(|chunk| chunk.fused) - .fold(0.0f32, f32::max); - - if best_chunk_score > 0.0 { - let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score); - entity.scores.vector = Some(boosted); - let fused = fuse_scores(&entity.scores, weights); - entity.update_fused(fused); - } - } - } -} - -async fn enrich_chunks_from_entities( - chunk_candidates: &mut HashMap>, - entities: &[Scored], - db_client: &SurrealDbClient, - user_id: &str, - weights: FusionWeights, -) -> Result<(), AppError> { - let mut source_ids: HashSet = HashSet::new(); - for entity in entities { - source_ids.insert(entity.item.source_id.clone()); - } - - if source_ids.is_empty() { - return Ok(()); - } - - let chunks = find_entities_by_source_ids::( - source_ids.into_iter().collect(), - "text_chunk", - user_id, - db_client, - ) - .await?; - - let mut entity_score_lookup: HashMap = HashMap::new(); - for entity in entities { - entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused); - } - - for chunk in chunks { - let entry = chunk_candidates - .entry(chunk.id.clone()) - .or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0)); - - let entity_score = entity_score_lookup - .get(&chunk.source_id) - .copied() - .unwrap_or(0.0); - - entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8)); - let fused = fuse_scores(&entry.scores, weights); - entry.update_fused(fused); - entry.item = chunk; - } - - Ok(()) -} - fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec { if ctx.filtered_entities.is_empty() { return Vec::new(); diff --git a/retrieval-pipeline/src/pipeline/strategies.rs b/retrieval-pipeline/src/pipeline/strategies.rs index 35d6f31..3318872 100644 --- a/retrieval-pipeline/src/pipeline/strategies.rs +++ b/retrieval-pipeline/src/pipeline/strategies.rs @@ -1,50 +1,24 @@ use super::{ stages::{ - AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage, - ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, - RerankStage, + AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage, + CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage, }, BoxedStage, StrategyDriver, }; use crate::{RetrievedChunk, RetrievedEntity}; use common::error::AppError; -pub struct InitialStrategyDriver; -impl InitialStrategyDriver { + +pub struct DefaultStrategyDriver; + +impl DefaultStrategyDriver { pub fn new() -> Self { Self } } -impl StrategyDriver for InitialStrategyDriver { - type Output = Vec; - - fn stages(&self) -> Vec { - vec![ - Box::new(EmbedStage), - Box::new(CollectCandidatesStage), - Box::new(GraphExpansionStage), - Box::new(ChunkAttachStage), - Box::new(RerankStage), - Box::new(AssembleEntitiesStage), - ] - } - - fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result { - Ok(ctx.take_entity_results()) - } -} - -pub struct RevisedStrategyDriver; - -impl RevisedStrategyDriver { - pub fn new() -> Self { - Self - } -} - -impl StrategyDriver for RevisedStrategyDriver { +impl StrategyDriver for DefaultStrategyDriver { type Output = Vec; fn stages(&self) -> Vec {