diff --git a/Cargo.lock b/Cargo.lock index ed92636..f2c421f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,12 +138,6 @@ dependencies = [ "libc", ] -[[package]] -name = "anstyle" -version = "1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" - [[package]] name = "any_ascii" version = "0.3.2" @@ -167,6 +161,7 @@ dependencies = [ "futures", "serde", "tempfile", + "thiserror", "tokio", "tracing", ] @@ -1050,9 +1045,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-openai", - "async-stream", "axum", - "axum-htmx", "axum_session", "axum_session_auth", "axum_session_surreal", @@ -1061,31 +1054,41 @@ dependencies = [ "chrono-tz", "config", "futures", - "json-stream-parser", "lettre", "mime", "mime_guess", "minijinja", "minijinja-autoreload", - "minijinja-contrib", - "mockall", - "plotly", "reqwest", "serde", "serde_json", "sha2", "surrealdb", "tempfile", - "text-splitter", "thiserror", "tokio", - "tower-http", "tracing", - "tracing-subscriber", "url", "uuid", ] +[[package]] +name = "composite-retrieval" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-openai", + "axum", + "common", + "futures", + "serde", + "serde_json", + "surrealdb", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1097,9 +1100,9 @@ dependencies = [ [[package]] name = "config" -version = "0.15.4" +version = "0.15.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d84f8d224ac58107d53d3ec2b9ad39fd8c8c4e285d3c9cb35485ffd2ca88cb3" +checksum = "fb07d21d12f9f0bc5e7c3e97ccc78b2341b9b4a4604eac3ed7c1d0d6e2c3b23e" dependencies = [ "async-trait", "convert_case", @@ -1110,7 +1113,7 @@ dependencies = [ "serde", "serde_json", "toml", - "winnow", + "winnow 0.7.3", "yaml-rust2", ] @@ -1448,12 +1451,6 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" -[[package]] -name = "downcast" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" - [[package]] name = "dtoa" version = "1.0.9" @@ -1576,7 +1573,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" dependencies = [ "futures-core", - "nom", + "nom 7.1.3", "pin-project-lite", ] @@ -1627,6 +1624,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1661,12 +1664,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "fragile" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" - [[package]] name = "fsevent-sys" version = "4.1.0" @@ -1974,14 +1971,17 @@ name = "hashbrown" version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +dependencies = [ + "foldhash", +] [[package]] name = "hashlink" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown 0.14.5", + "hashbrown 0.15.0", ] [[package]] @@ -2052,6 +2052,7 @@ dependencies = [ "axum_typed_multipart", "chrono-tz", "common", + "composite-retrieval", "futures", "json-stream-parser", "minijinja", @@ -2062,6 +2063,7 @@ dependencies = [ "serde_json", "surrealdb", "tempfile", + "thiserror", "tokio", "tower-http", "tracing", @@ -2435,6 +2437,8 @@ dependencies = [ "axum", "chrono", "common", + "composite-retrieval", + "futures", "reqwest", "scraper", "serde", @@ -2443,6 +2447,7 @@ dependencies = [ "tiktoken-rs", "tokio", "tracing", + "uuid", ] [[package]] @@ -2625,9 +2630,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lettre" -version = "0.11.11" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab4c9a167ff73df98a5ecc07e8bf5ce90b583665da3d1762eb1f775ad4d0d6f5" +checksum = "5d476fe7a4a798f392ce34947aa7d53d981127e37523c5251da3c927f7fa901f" dependencies = [ "base64 0.22.1", "chumsky", @@ -2640,16 +2645,12 @@ dependencies = [ "idna", "mime", "native-tls", - "nom", + "nom 8.0.0", "percent-encoding", "quoted_printable", - "rustls", - "rustls-pemfile", - "rustls-pki-types", "socket2", "tokio", "url", - "webpki-roots", ] [[package]] @@ -2737,45 +2738,18 @@ dependencies = [ "anyhow", "api-router", "async-openai", - "async-stream", "axum", - "axum-htmx", - "axum_session", - "axum_session_auth", - "axum_session_surreal", - "axum_typed_multipart", - "chrono", - "chrono-tz", "common", - "config", "futures", "html-router", "ingestion-pipeline", - "json-stream-parser", - "lettre", - "mime", - "mime_guess", - "minijinja", - "minijinja-autoreload", - "minijinja-contrib", - "mockall", - "plotly", - "reqwest", - "scraper", "serde", "serde_json", - "sha2", "surrealdb", - "tempfile", - "text-splitter", "thiserror", - "tiktoken-rs", "tokio", - "tower-http", "tracing", "tracing-subscriber", - "url", - "uuid", ] [[package]] @@ -2959,32 +2933,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "mockall" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c28b3fb6d753d28c20e826cd46ee611fda1cf3cde03a443a974043247c065a" -dependencies = [ - "cfg-if", - "downcast", - "fragile", - "mockall_derive", - "predicates", - "predicates-tree", -] - -[[package]] -name = "mockall_derive" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "341014e7f530314e9a1fdbc7400b244efea7122662c96bfa248c31da5bfb2020" -dependencies = [ - "cfg-if", - "proc-macro2", - "quote", - "syn 2.0.87", -] - [[package]] name = "multer" version = "3.1.0" @@ -3091,6 +3039,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + [[package]] name = "nonempty" version = "0.7.0" @@ -3611,32 +3568,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" -[[package]] -name = "predicates" -version = "3.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97" -dependencies = [ - "anstyle", - "predicates-core", -] - -[[package]] -name = "predicates-core" -version = "1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931" - -[[package]] -name = "predicates-tree" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13" -dependencies = [ - "predicates-core", - "termtree", -] - [[package]] name = "proc-macro-crate" version = "3.2.0" @@ -4004,7 +3935,7 @@ dependencies = [ "futures-core", "futures-timer", "mime", - "nom", + "nom 7.1.3", "pin-project-lite", "reqwest", "thiserror", @@ -4090,7 +4021,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93f9a866e2e00a7a1fb27e46e9e324a6f7c0e7edc4543cae1d38f4e4a100c610" dependencies = [ "memchr", - "nom", + "nom 7.1.3", "serde", ] @@ -5110,12 +5041,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "termtree" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" - [[package]] name = "text-splitter" version = "0.18.1" @@ -5381,7 +5306,7 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "winnow", + "winnow 0.6.20", ] [[package]] @@ -6036,6 +5961,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1" +dependencies = [ + "memchr", +] + [[package]] name = "write16" version = "1.0.0" @@ -6084,9 +6018,9 @@ checksum = "791978798f0597cfc70478424c2b4fdc2b7a8024aaff78497ef00f24ef674193" [[package]] name = "yaml-rust2" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a1a1c0bc9823338a3bdf8c61f994f23ac004c6fa32c08cd152984499b445e8d" +checksum = "232bdb534d65520716bef0bbb205ff8f2db72d807b19c0bc3020853b92a0cd4b" dependencies = [ "arraydeque", "encoding_rs", diff --git a/Cargo.toml b/Cargo.toml index f192162..9ab4d4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,11 @@ members = [ "crates/main", "crates/common", - "crates/api-router" -, "crates/html-router", "crates/ingestion-pipeline"] + "crates/api-router", + "crates/html-router", + "crates/ingestion-pipeline", + "crates/composite-retrieval" +] resolver = "2" [workspace.dependencies] @@ -14,3 +17,6 @@ serde_json = "1.0.128" thiserror = "1.0.63" anyhow = "1.0.94" tracing = "0.1.40" +surrealdb = "2.0.4" +futures = "0.3.31" +async-openai = "0.24.1" diff --git a/crates/api-router/Cargo.toml b/crates/api-router/Cargo.toml index 6ae0a93..902bdaa 100644 --- a/crates/api-router/Cargo.toml +++ b/crates/api-router/Cargo.toml @@ -10,6 +10,7 @@ serde = { workspace = true } axum = { workspace = true } tracing = { workspace = true } anyhow = { workspace = true } +thiserror = { workspace = true } tempfile = "3.12.0" futures = "0.3.31" diff --git a/crates/api-router/src/error.rs b/crates/api-router/src/error.rs new file mode 100644 index 0000000..f67ead0 --- /dev/null +++ b/crates/api-router/src/error.rs @@ -0,0 +1,168 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use common::error::AppError; +use serde::Serialize; +use thiserror::Error; + +#[derive(Error, Debug, Serialize, Clone)] +pub enum ApiError { + #[error("Internal server error")] + InternalError(String), + + #[error("Validation error: {0}")] + ValidationError(String), + + #[error("Not found: {0}")] + NotFound(String), + + #[error("Unauthorized: {0}")] + Unauthorized(String), +} + +impl From for ApiError { + fn from(err: AppError) -> Self { + match err { + AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => { + tracing::error!("Internal error: {:?}", err); + ApiError::InternalError("Internal server error".to_string()) + } + AppError::NotFound(msg) => ApiError::NotFound(msg), + AppError::Validation(msg) => ApiError::ValidationError(msg), + AppError::Auth(msg) => ApiError::Unauthorized(msg), + _ => ApiError::InternalError("Internal server error".to_string()), + } + } +} +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + let (status, error_response) = match self { + ApiError::InternalError(message) => ( + StatusCode::INTERNAL_SERVER_ERROR, + ErrorResponse { + error: message, + status: "error".to_string(), + }, + ), + ApiError::ValidationError(message) => ( + StatusCode::BAD_REQUEST, + ErrorResponse { + error: message, + status: "error".to_string(), + }, + ), + ApiError::NotFound(message) => ( + StatusCode::NOT_FOUND, + ErrorResponse { + error: message, + status: "error".to_string(), + }, + ), + ApiError::Unauthorized(message) => ( + StatusCode::UNAUTHORIZED, + ErrorResponse { + error: message, + status: "error".to_string(), + }, + ), + }; + + (status, Json(error_response)).into_response() + } +} + +#[derive(Serialize, Debug)] +struct ErrorResponse { + error: String, + status: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use common::error::AppError; + use std::fmt::Debug; + + // Helper to check status code + fn assert_status_code(response: T, expected_status: StatusCode) { + let response = response.into_response(); + assert_eq!(response.status(), expected_status); + } + + #[test] + fn test_app_error_to_api_error_conversion() { + // Test NotFound error conversion + let not_found = AppError::NotFound("resource not found".to_string()); + let api_error = ApiError::from(not_found); + assert!(matches!(api_error, ApiError::NotFound(msg) if msg == "resource not found")); + + // Test Validation error conversion + let validation = AppError::Validation("invalid input".to_string()); + let api_error = ApiError::from(validation); + assert!(matches!(api_error, ApiError::ValidationError(msg) if msg == "invalid input")); + + // Test Auth error conversion + let auth = AppError::Auth("unauthorized".to_string()); + let api_error = ApiError::from(auth); + assert!(matches!(api_error, ApiError::Unauthorized(msg) if msg == "unauthorized")); + + // Test for internal errors - create a mock error that doesn't require surrealdb + let internal_error = + AppError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io error")); + let api_error = ApiError::from(internal_error); + assert!(matches!(api_error, ApiError::InternalError(_))); + } + + #[test] + fn test_api_error_response_status_codes() { + // Test internal error status + let error = ApiError::InternalError("server error".to_string()); + assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR); + + // Test not found status + let error = ApiError::NotFound("not found".to_string()); + assert_status_code(error, StatusCode::NOT_FOUND); + + // Test validation error status + let error = ApiError::ValidationError("invalid input".to_string()); + assert_status_code(error, StatusCode::BAD_REQUEST); + + // Test unauthorized status + let error = ApiError::Unauthorized("not allowed".to_string()); + assert_status_code(error, StatusCode::UNAUTHORIZED); + } + + // Alternative approach that doesn't try to parse the response body + #[test] + fn test_error_messages() { + // For validation errors + let message = "invalid data format"; + let error = ApiError::ValidationError(message.to_string()); + + // Check that the error itself contains the message + assert_eq!(error.to_string(), format!("Validation error: {}", message)); + + // For not found errors + let message = "user not found"; + let error = ApiError::NotFound(message.to_string()); + assert_eq!(error.to_string(), format!("Not found: {}", message)); + } + + // Alternative approach for internal error test + #[test] + fn test_internal_error_sanitization() { + // Create a sensitive error message + let sensitive_info = "db password incorrect"; + + // Create ApiError with sensitive info + let api_error = ApiError::InternalError(sensitive_info.to_string()); + + // Check the error message is correctly set + assert_eq!(api_error.to_string(), "Internal server error"); + + // Also verify correct status code + assert_status_code(api_error, StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/crates/api-router/src/lib.rs b/crates/api-router/src/lib.rs index d31ac06..ed13579 100644 --- a/crates/api-router/src/lib.rs +++ b/crates/api-router/src/lib.rs @@ -9,6 +9,7 @@ use middleware_api_auth::api_auth; use routes::ingress::ingest_data; pub mod api_state; +pub mod error; mod middleware_api_auth; mod routes; diff --git a/crates/api-router/src/middleware_api_auth.rs b/crates/api-router/src/middleware_api_auth.rs index 628e666..29e6bee 100644 --- a/crates/api-router/src/middleware_api_auth.rs +++ b/crates/api-router/src/middleware_api_auth.rs @@ -4,9 +4,9 @@ use axum::{ response::Response, }; -use common::{error::ApiError, storage::types::user::User}; +use common::storage::types::user::User; -use crate::api_state::ApiState; +use crate::{api_state::ApiState, error::ApiError}; pub async fn api_auth( State(state): State, diff --git a/crates/api-router/src/routes/ingress.rs b/crates/api-router/src/routes/ingress.rs index 6fde00e..9f6ba94 100644 --- a/crates/api-router/src/routes/ingress.rs +++ b/crates/api-router/src/routes/ingress.rs @@ -1,7 +1,7 @@ use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use common::{ - error::{ApiError, AppError}, + error::AppError, storage::types::{ file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, user::User, @@ -11,7 +11,7 @@ use futures::{future::try_join_all, TryFutureExt}; use tempfile::NamedTempFile; use tracing::info; -use crate::api_state::ApiState; +use crate::{api_state::ApiState, error::ApiError}; #[derive(Debug, TryFromMultipart)] pub struct IngestParams { diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index f7c6ff6..287d723 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -12,10 +12,10 @@ tracing = { workspace = true } anyhow = { workspace = true } thiserror = { workspace = true } serde_json = { workspace = true } +surrealdb = { workspace = true } +async-openai = { workspace = true } +futures = { workspace = true } -async-openai = "0.24.1" -async-stream = "0.3.6" -axum-htmx = "0.6.0" axum_session = "0.14.4" axum_session_auth = "0.14.1" axum_session_surreal = "0.2.1" @@ -23,23 +23,14 @@ axum_typed_multipart = "0.12.1" chrono = { version = "0.4.39", features = ["serde"] } chrono-tz = "0.10.1" config = "0.15.4" -futures = "0.3.31" -json-stream-parser = "0.1.4" -lettre = { version = "0.11.11", features = ["rustls-tls"] } +lettre = { version = "0.11.11", features = [] } mime = "0.3.17" mime_guess = "2.0.5" minijinja = { version = "2.5.0", features = ["loader", "multi_template"] } minijinja-autoreload = "2.5.0" -minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] } -mockall = "0.13.0" -plotly = "0.12.1" reqwest = {version = "0.12.12", features = ["charset", "json"]} sha2 = "0.10.8" -surrealdb = "2.0.4" tempfile = "3.12.0" -text-splitter = "0.18.1" -tower-http = { version = "0.6.2", features = ["fs"] } -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } url = { version = "2.5.2", features = ["serde"] } uuid = { version = "1.10.0", features = ["v4", "serde"] } diff --git a/crates/common/src/error.rs b/crates/common/src/error.rs index ff4b7ec..630957e 100644 --- a/crates/common/src/error.rs +++ b/crates/common/src/error.rs @@ -1,15 +1,4 @@ -use std::sync::Arc; - use async_openai::error::OpenAIError; -use axum::{ - http::StatusCode, - response::{Html, IntoResponse, Response}, - Json, -}; -use minijinja::context; -use minijinja_autoreload::AutoReloader; -use serde::Serialize; -use serde_json::json; use thiserror::Error; use tokio::task::JoinError; @@ -49,206 +38,3 @@ pub enum AppError { #[error("Ingress Processing error: {0}")] Processing(String), } - -// API-specific errors -#[derive(Debug, Serialize)] -pub enum ApiError { - InternalError(String), - ValidationError(String), - NotFound(String), - Unauthorized(String), -} - -impl From for ApiError { - fn from(err: AppError) -> Self { - match err { - AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => { - tracing::error!("Internal error: {:?}", err); - ApiError::InternalError("Internal server error".to_string()) - } - AppError::NotFound(msg) => ApiError::NotFound(msg), - AppError::Validation(msg) => ApiError::ValidationError(msg), - AppError::Auth(msg) => ApiError::Unauthorized(msg), - _ => ApiError::InternalError("Internal server error".to_string()), - } - } -} - -impl IntoResponse for ApiError { - fn into_response(self) -> Response { - let (status, body) = match self { - ApiError::InternalError(message) => ( - StatusCode::INTERNAL_SERVER_ERROR, - json!({ - "error": message, - "status": "error" - }), - ), - ApiError::ValidationError(message) => ( - StatusCode::BAD_REQUEST, - json!({ - "error": message, - "status": "error" - }), - ), - ApiError::NotFound(message) => ( - StatusCode::NOT_FOUND, - json!({ - "error": message, - "status": "error" - }), - ), - ApiError::Unauthorized(message) => ( - StatusCode::UNAUTHORIZED, - json!({ - "error": message, - "status": "error" - }), - ), // ... other matches - }; - (status, Json(body)).into_response() - } -} - -pub type TemplateResult = Result; - -// Helper trait for converting to HtmlError with templates -pub trait IntoHtmlError { - fn with_template(self, templates: Arc) -> HtmlError; -} -// // Implement for AppError -impl IntoHtmlError for AppError { - fn with_template(self, templates: Arc) -> HtmlError { - HtmlError::new(self, templates) - } -} -// // Implement for minijinja::Error directly -impl IntoHtmlError for minijinja::Error { - fn with_template(self, templates: Arc) -> HtmlError { - HtmlError::from_template_error(self, templates) - } -} - -#[derive(Clone)] -pub struct ErrorContext { - #[allow(dead_code)] - templates: Arc, -} - -impl ErrorContext { - pub fn new(templates: Arc) -> Self { - Self { templates } - } -} - -pub enum HtmlError { - ServerError(Arc), - NotFound(Arc), - Unauthorized(Arc), - BadRequest(String, Arc), - Template(String, Arc), -} - -impl HtmlError { - pub fn new(error: AppError, templates: Arc) -> Self { - match error { - AppError::NotFound(_msg) => HtmlError::NotFound(templates), - AppError::Auth(_msg) => HtmlError::Unauthorized(templates), - AppError::Validation(msg) => HtmlError::BadRequest(msg, templates), - _ => { - tracing::error!("Internal error: {:?}", error); - HtmlError::ServerError(templates) - } - } - } - - pub fn from_template_error(error: minijinja::Error, templates: Arc) -> Self { - tracing::error!("Template error: {:?}", error); - HtmlError::Template(error.to_string(), templates) - } -} - -impl IntoResponse for HtmlError { - fn into_response(self) -> Response { - let (status, context, templates) = match self { - HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => ( - StatusCode::INTERNAL_SERVER_ERROR, - context! { - status_code => 500, - title => "Internal Server Error", - error => "Internal Server Error", - description => "Something went wrong on our end." - }, - templates, - ), - HtmlError::NotFound(templates) => ( - StatusCode::NOT_FOUND, - context! { - status_code => 404, - title => "Page Not Found", - error => "Not Found", - description => "The page you're looking for doesn't exist or was removed." - }, - templates, - ), - HtmlError::Unauthorized(templates) => ( - StatusCode::UNAUTHORIZED, - context! { - status_code => 401, - title => "Unauthorized", - error => "Access Denied", - description => "You need to be logged in to access this page." - }, - templates, - ), - HtmlError::BadRequest(msg, templates) => ( - StatusCode::BAD_REQUEST, - context! { - status_code => 400, - title => "Bad Request", - error => "Bad Request", - description => msg - }, - templates, - ), - }; - - let html = match templates.acquire_env() { - Ok(env) => match env.get_template("errors/error.html") { - Ok(tmpl) => match tmpl.render(context) { - Ok(output) => output, - Err(e) => { - tracing::error!("Template render error: {:?}", e); - Self::fallback_html() - } - }, - Err(e) => { - tracing::error!("Template get error: {:?}", e); - Self::fallback_html() - } - }, - Err(e) => { - tracing::error!("Environment acquire error: {:?}", e); - Self::fallback_html() - } - }; - - (status, Html(html)).into_response() - } -} - -impl HtmlError { - fn fallback_html() -> String { - r#" - - -
-

Error

-

Sorry, something went wrong displaying this page.

-
- - - "# - .to_string() - } -} diff --git a/crates/common/src/ingress/analysis/mod.rs b/crates/common/src/ingress/analysis/mod.rs deleted file mode 100644 index 5a2916f..0000000 --- a/crates/common/src/ingress/analysis/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod ingress_analyser; -pub mod prompt; -pub mod types; diff --git a/crates/common/src/ingress/analysis/types/mod.rs b/crates/common/src/ingress/analysis/types/mod.rs deleted file mode 100644 index f8def07..0000000 --- a/crates/common/src/ingress/analysis/types/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod graph_mapper; -pub mod llm_analysis_result; diff --git a/crates/common/src/ingress/content_processor.rs b/crates/common/src/ingress/content_processor.rs deleted file mode 100644 index 8b13789..0000000 --- a/crates/common/src/ingress/content_processor.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/common/src/ingress/mod.rs b/crates/common/src/ingress/mod.rs deleted file mode 100644 index e2a6c38..0000000 --- a/crates/common/src/ingress/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod analysis; -pub mod content_processor; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index adec445..d04944f 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -1,5 +1,3 @@ pub mod error; -pub mod ingress; -pub mod retrieval; pub mod storage; pub mod utils; diff --git a/crates/composite-retrieval/Cargo.toml b/crates/composite-retrieval/Cargo.toml new file mode 100644 index 0000000..29ac9fb --- /dev/null +++ b/crates/composite-retrieval/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "composite-retrieval" +version = "0.1.0" +edition = "2021" + +[dependencies] +# Workspace dependencies +tokio = { workspace = true } +serde = { workspace = true } +axum = { workspace = true } +tracing = { workspace = true } +anyhow = { workspace = true } +thiserror = { workspace = true } +serde_json = { workspace = true } +surrealdb = { workspace = true } +futures = { workspace = true } +async-openai = { workspace = true } + +common = { path = "../common" } diff --git a/crates/common/src/retrieval/query_helper.rs b/crates/composite-retrieval/src/answer_retrieval.rs similarity index 69% rename from crates/common/src/retrieval/query_helper.rs rename to crates/composite-retrieval/src/answer_retrieval.rs index 56f95ab..671defa 100644 --- a/crates/common/src/retrieval/query_helper.rs +++ b/crates/composite-retrieval/src/answer_retrieval.rs @@ -10,13 +10,14 @@ use serde::Deserialize; use serde_json::{json, Value}; use tracing::debug; -use crate::{ +use common::{ error::AppError, - retrieval::combined_knowledge_entity_retrieval, storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}, }; -use super::query_helper_prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT}; +use crate::retrieve_entities; + +use super::answer_retrieval_helper::{get_query_response_schema, QUERY_SYSTEM_PROMPT}; #[derive(Debug, Deserialize)] pub struct Reference { @@ -31,47 +32,6 @@ pub struct LLMResponseFormat { pub references: Vec, } -// /// Orchestrator function that takes a query and clients and returns a answer with references -// /// -// /// # Arguments -// /// * `surreal_db_client` - Client for interacting with SurrealDn -// /// * `openai_client` - Client for interacting with openai -// /// * `query` - The query -// /// -// /// # Returns -// /// * `Result<(String, Vec, ApiError)` - Will return the answer, and the list of references or Error -// pub async fn get_answer_with_references( -// surreal_db_client: &SurrealDbClient, -// openai_client: &async_openai::Client, -// query: &str, -// ) -> Result<(String, Vec), ApiError> { -// let entities = -// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?; - -// // Format entities and create message -// let entities_json = format_entities_json(&entities); -// let user_message = create_user_message(&entities_json, query); - -// // Create and send request -// let request = create_chat_request(user_message)?; -// let response = openai_client -// .chat() -// .create(request) -// .await -// .map_err(|e| ApiError::QueryError(e.to_string()))?; - -// // Process response -// let answer = process_llm_response(response).await?; - -// let references: Vec = answer -// .references -// .into_iter() -// .map(|reference| reference.reference) -// .collect(); - -// Ok((answer.answer, references)) -// } - /// Orchestrates query processing and returns an answer with references /// /// Takes a query and uses the provided clients to generate an answer with supporting references. @@ -98,9 +58,7 @@ pub async fn get_answer_with_references( query: &str, user_id: &str, ) -> Result { - let entities = - combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id) - .await?; + let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?; let entities_json = format_entities_json(&entities); debug!("{:?}", entities_json); diff --git a/crates/common/src/retrieval/query_helper_prompt.rs b/crates/composite-retrieval/src/answer_retrieval_helper.rs similarity index 100% rename from crates/common/src/retrieval/query_helper_prompt.rs rename to crates/composite-retrieval/src/answer_retrieval_helper.rs diff --git a/crates/common/src/retrieval/graph.rs b/crates/composite-retrieval/src/graph.rs similarity index 95% rename from crates/common/src/retrieval/graph.rs rename to crates/composite-retrieval/src/graph.rs index 99388b4..81a9712 100644 --- a/crates/common/src/retrieval/graph.rs +++ b/crates/composite-retrieval/src/graph.rs @@ -1,7 +1,7 @@ use surrealdb::Error; use tracing::debug; -use crate::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}; +use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}; /// Retrieves database entries that match a specific source identifier. /// diff --git a/crates/common/src/retrieval/mod.rs b/crates/composite-retrieval/src/lib.rs similarity index 91% rename from crates/common/src/retrieval/mod.rs rename to crates/composite-retrieval/src/lib.rs index 863ab49..735579d 100644 --- a/crates/common/src/retrieval/mod.rs +++ b/crates/composite-retrieval/src/lib.rs @@ -1,21 +1,19 @@ +pub mod answer_retrieval; +pub mod answer_retrieval_helper; pub mod graph; -pub mod query_helper; -pub mod query_helper_prompt; pub mod vector; -use crate::{ +use common::{ error::AppError, - retrieval::{ - graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}, - vector::find_items_by_vector_similarity, - }, storage::{ db::SurrealDbClient, types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, }, }; use futures::future::{try_join, try_join_all}; +use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}; use std::collections::HashMap; +use vector::find_items_by_vector_similarity; /// Performs a comprehensive knowledge entity retrieval using multiple search strategies /// to find the most relevant entities for a given query. @@ -38,7 +36,7 @@ use std::collections::HashMap; /// # Returns /// * `Result, AppError>` - A deduplicated vector of relevant /// knowledge entities, or an error if the retrieval process fails -pub async fn combined_knowledge_entity_retrieval( +pub async fn retrieve_entities( db_client: &SurrealDbClient, openai_client: &async_openai::Client, query: &str, diff --git a/crates/common/src/retrieval/vector.rs b/crates/composite-retrieval/src/vector.rs similarity index 96% rename from crates/common/src/retrieval/vector.rs rename to crates/composite-retrieval/src/vector.rs index 8b4ea90..a965b7a 100644 --- a/crates/common/src/retrieval/vector.rs +++ b/crates/composite-retrieval/src/vector.rs @@ -1,6 +1,6 @@ use surrealdb::{engine::any::Any, Surreal}; -use crate::{error::AppError, utils::embedding::generate_embedding}; +use common::{error::AppError, utils::embedding::generate_embedding}; /// Compares vectors and retrieves a number of items from the specified table. /// diff --git a/crates/html-router/Cargo.toml b/crates/html-router/Cargo.toml index 697b1d9..a37724a 100644 --- a/crates/html-router/Cargo.toml +++ b/crates/html-router/Cargo.toml @@ -10,7 +10,8 @@ serde = { workspace = true } axum = { workspace = true } tracing = { workspace = true } serde_json = { workspace = true } - +async-openai = { workspace = true } +thiserror = { workspace = true } axum-htmx = "0.6.0" axum_session = "0.14.4" @@ -28,8 +29,6 @@ plotly = "0.12.1" surrealdb = "2.0.4" tower-http = { version = "0.6.2", features = ["fs"] } chrono-tz = "0.10.1" -async-openai = "0.24.1" - - common = { path = "../common" } +composite-retrieval = { path = "../composite-retrieval" } diff --git a/crates/html-router/src/error.rs b/crates/html-router/src/error.rs new file mode 100644 index 0000000..f0d08f3 --- /dev/null +++ b/crates/html-router/src/error.rs @@ -0,0 +1,139 @@ +// use axum::{ +// http::StatusCode, +// response::{Html, IntoResponse, Response}, +// }; +// use common::error::AppError; +// use minijinja::context; +// use minijinja_autoreload::AutoReloader; +// use std::sync::Arc; + +// pub type TemplateResult = Result; + +// // Helper trait for converting to HtmlError with templates +// pub trait IntoHtmlError { +// fn with_template(self, templates: Arc) -> HtmlError; +// } +// // // Implement for AppError +// impl IntoHtmlError for AppError { +// fn with_template(self, templates: Arc) -> HtmlError { +// HtmlError::new(self, templates) +// } +// } +// // // Implement for minijinja::Error directly +// impl IntoHtmlError for minijinja::Error { +// fn with_template(self, templates: Arc) -> HtmlError { +// HtmlError::from_template_error(self, templates) +// } +// } + +// pub enum HtmlError { +// ServerError(Arc), +// NotFound(Arc), +// Unauthorized(Arc), +// BadRequest(String, Arc), +// Template(String, Arc), +// } + +// impl HtmlError { +// pub fn new(error: AppError, templates: Arc) -> Self { +// match error { +// AppError::NotFound(_msg) => HtmlError::NotFound(templates), +// AppError::Auth(_msg) => HtmlError::Unauthorized(templates), +// AppError::Validation(msg) => HtmlError::BadRequest(msg, templates), +// _ => { +// tracing::error!("Internal error: {:?}", error); +// HtmlError::ServerError(templates) +// } +// } +// } + +// pub fn from_template_error(error: minijinja::Error, templates: Arc) -> Self { +// tracing::error!("Template error: {:?}", error); +// HtmlError::Template(error.to_string(), templates) +// } +// } + +// impl IntoResponse for HtmlError { +// fn into_response(self) -> Response { +// let (status, context, templates) = match self { +// HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => ( +// StatusCode::INTERNAL_SERVER_ERROR, +// context! { +// status_code => 500, +// title => "Internal Server Error", +// error => "Internal Server Error", +// description => "Something went wrong on our end." +// }, +// templates, +// ), +// HtmlError::NotFound(templates) => ( +// StatusCode::NOT_FOUND, +// context! { +// status_code => 404, +// title => "Page Not Found", +// error => "Not Found", +// description => "The page you're looking for doesn't exist or was removed." +// }, +// templates, +// ), +// HtmlError::Unauthorized(templates) => ( +// StatusCode::UNAUTHORIZED, +// context! { +// status_code => 401, +// title => "Unauthorized", +// error => "Access Denied", +// description => "You need to be logged in to access this page." +// }, +// templates, +// ), +// HtmlError::BadRequest(msg, templates) => ( +// StatusCode::BAD_REQUEST, +// context! { +// status_code => 400, +// title => "Bad Request", +// error => "Bad Request", +// description => msg +// }, +// templates, +// ), +// }; + +// let html = match templates.acquire_env() { +// Ok(env) => match env.get_template("errors/error.html") { +// Ok(tmpl) => match tmpl.render(context) { +// Ok(output) => output, +// Err(e) => { +// tracing::error!("Template render error: {:?}", e); +// Self::fallback_html() +// } +// }, +// Err(e) => { +// tracing::error!("Template get error: {:?}", e); +// Self::fallback_html() +// } +// }, +// Err(e) => { +// tracing::error!("Environment acquire error: {:?}", e); +// Self::fallback_html() +// } +// }; + +// (status, Html(html)).into_response() +// } +// } + +// impl HtmlError { +// fn fallback_html() -> String { +// r#" +// +// +//
+//

Error

+//

Sorry, something went wrong displaying this page.

+//
+// +// +// "# +// .to_string() +// } +// } diff --git a/crates/html-router/src/lib.rs b/crates/html-router/src/lib.rs index 28c47f4..c3ccc4c 100644 --- a/crates/html-router/src/lib.rs +++ b/crates/html-router/src/lib.rs @@ -1,19 +1,23 @@ +pub mod error; pub mod html_state; mod middleware_analytics; +mod middleware_auth; mod routes; +mod template_response; use axum::{ extract::FromRef, - middleware::from_fn_with_state, + middleware::{from_fn_with_state, map_response_with_state}, routing::{delete, get, patch, post}, Router, }; -use axum_session::SessionLayer; -use axum_session_auth::{AuthConfig, AuthSessionLayer}; +use axum_session::{Session, SessionLayer}; +use axum_session_auth::{AuthConfig, AuthSession, AuthSessionLayer}; use axum_session_surreal::SessionSurrealPool; use common::storage::types::user::User; use html_state::HtmlState; use middleware_analytics::analytics_middleware; +use middleware_auth::require_auth; use routes::{ account::{delete_account, set_api_key, show_account_page, update_timezone}, admin_panel::{show_admin_panel, toggle_registration_status}, @@ -39,26 +43,40 @@ use routes::{ signup::{process_signup_and_show_verification, show_signup_form}, }; use surrealdb::{engine::any::Any, Surreal}; +use template_response::with_template_response; use tower_http::services::ServeDir; -/// Router for HTML endpoints +pub type AuthSessionType = AuthSession, Surreal>; +pub type SessionType = Session>; + +/// Html routes pub fn html_routes(app_state: &HtmlState) -> Router where S: Clone + Send + Sync + 'static, HtmlState: FromRef, { - Router::new() + // Public routes - no auth required + let public_routes = Router::new() .route("/", get(index_handler)) .route("/gdpr/accept", post(accept_gdpr)) .route("/gdpr/deny", post(deny_gdpr)) - .route("/search", get(search_result_handler)) + .route("/signout", get(sign_out_user)) + .route("/signin", get(show_signin_form).post(authenticate_user)) + .route( + "/signup", + get(show_signup_form).post(process_signup_and_show_verification), + ) + .route("/documentation", get(show_documentation_index)) + .route("/documentation/privacy-policy", get(show_privacy_policy)) + .route("/documentation/get-started", get(show_get_started)) + .route("/documentation/mobile-friendly", get(show_mobile_friendly)) + .nest_service("/assets", ServeDir::new("assets/")); + + // Protected routes - auth required + let protected_routes = Router::new() .route("/chat", get(show_chat_base).post(new_chat_user_message)) .route("/initialized-chat", post(show_initialized_chat)) .route("/chat/:id", get(show_existing_chat).post(new_user_message)) - .route("/chat/response-stream", get(get_response_stream)) - .route("/knowledge/:id", get(show_reference_tooltip)) - .route("/signout", get(sign_out_user)) - .route("/signin", get(show_signin_form).post(authenticate_user)) .route( "/ingress-form", get(show_ingress_form).post(process_ingress_form), @@ -72,6 +90,9 @@ where "/content/:id", get(show_text_content_edit_form).patch(patch_text_content), ) + .route("/search", get(search_result_handler)) + .route("/chat/response-stream", get(get_response_stream)) + .route("/knowledge/:id", get(show_reference_tooltip)) .route("/knowledge", get(show_knowledge_page)) .route( "/knowledge-entity/:id", @@ -90,16 +111,17 @@ where .route("/set-api-key", post(set_api_key)) .route("/update-timezone", patch(update_timezone)) .route("/delete-account", delete(delete_account)) - .route( - "/signup", - get(show_signup_form).post(process_signup_and_show_verification), - ) - .route("/documentation", get(show_documentation_index)) - .route("/documentation/privacy-policy", get(show_privacy_policy)) - .route("/documentation/get-started", get(show_get_started)) - .route("/documentation/mobile-friendly", get(show_mobile_friendly)) - .nest_service("/assets", ServeDir::new("assets/")) + .route_layer(from_fn_with_state(app_state.clone(), require_auth)); + + // Combine routes and add common middleware + Router::new() + .merge(public_routes) + .merge(protected_routes) .layer(from_fn_with_state(app_state.clone(), analytics_middleware)) + .layer(map_response_with_state( + app_state.clone(), + with_template_response, + )) .layer( AuthSessionLayer::, Surreal>::new(Some( app_state.db.client.clone(), diff --git a/crates/html-router/src/middleware_auth.rs b/crates/html-router/src/middleware_auth.rs new file mode 100644 index 0000000..ae5d6fa --- /dev/null +++ b/crates/html-router/src/middleware_auth.rs @@ -0,0 +1,48 @@ +use axum::{ + async_trait, + extract::{FromRequestParts, Request}, + http::request::Parts, + middleware::Next, + response::{IntoResponse, Response}, +}; +use common::storage::types::user::User; + +use crate::{template_response::TemplateResponse, AuthSessionType}; + +#[derive(Debug, Clone)] +pub struct RequireUser(pub User); + +// Implement FromRequestParts for RequireUser +#[async_trait] +impl FromRequestParts for RequireUser +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .map(RequireUser) + .ok_or_else(|| TemplateResponse::redirect("/signin").into_response()) + } +} + +// Auth middleware that adds the user to extensions +pub async fn require_auth(auth: AuthSessionType, mut request: Request, next: Next) -> Response { + // Check if user is authenticated + match auth.current_user { + Some(user) => { + // Add user to request extensions + request.extensions_mut().insert(user); + // Continue to the handler + next.run(request).await + } + None => { + // Redirect to login + TemplateResponse::redirect("/signin").into_response() + } + } +} diff --git a/crates/html-router/src/routes/account.rs b/crates/html-router/src/routes/account.rs index 65d73b8..2e038ef 100644 --- a/crates/html-router/src/routes/account.rs +++ b/crates/html-router/src/routes/account.rs @@ -1,65 +1,42 @@ -use axum::{ - extract::State, - http::{StatusCode, Uri}, - response::{IntoResponse, Redirect}, - Form, -}; -use axum_htmx::HxRedirect; -use axum_session_auth::AuthSession; -use axum_session_surreal::SessionSurrealPool; +use axum::{extract::State, response::IntoResponse, Form}; use chrono_tz::TZ_VARIANTS; -use surrealdb::{engine::any::Any, Surreal}; +use serde::{Deserialize, Serialize}; -use common::{ - error::{AppError, HtmlError}, - storage::types::user::User, +use crate::{ + middleware_auth::RequireUser, + template_response::{HtmlError, TemplateResponse}, + AuthSessionType, }; +use common::storage::types::user::User; -use crate::{html_state::HtmlState, page_data}; +use crate::html_state::HtmlState; -use super::{render_block, render_template}; - -page_data!(AccountData, "auth/account_settings.html", { +#[derive(Serialize)] +pub struct AccountPageData { user: User, - timezones: Vec -}); + timezones: Vec, +} pub async fn show_account_page( - State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, ) -> Result { - // Early return if the user is not authenticated - let user = match auth.current_user { - Some(user) => user, - None => return Ok(Redirect::to("/").into_response()), - }; - let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect(); - let output = render_template( - AccountData::template_name(), - AccountData { user, timezones }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + Ok(TemplateResponse::new_template( + "auth/account_settings.html", + AccountPageData { user, timezones }, + )) } pub async fn set_api_key( State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, + auth: AuthSessionType, ) -> Result { - // Early return if the user is not authenticated - let user = match &auth.current_user { - Some(user) => user, - None => return Ok(Redirect::to("/").into_response()), - }; - // Generate and set the API key - let api_key = User::set_api_key(&user.id, &state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + let api_key = User::set_api_key(&user.id, &state.db).await?; + // Clear the cache so new requests have access to the user with api key auth.cache_clear_user(user.id.to_string()); // Update the user's API key @@ -69,40 +46,28 @@ pub async fn set_api_key( }; // Render the API key section block - let output = render_block( - AccountData::template_name(), + Ok(TemplateResponse::new_partial( + "auth/account_settings.html", "api_key_section", - AccountData { + AccountPageData { user: updated_user, timezones: vec![], }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } pub async fn delete_account( State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, + auth: AuthSessionType, ) -> Result { - // Early return if the user is not authenticated - let user = match &auth.current_user { - Some(user) => user, - None => return Ok(Redirect::to("/").into_response()), - }; - - state - .db - .delete_item::(&user.id) - .await - .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?; + state.db.delete_item::(&user.id).await?; auth.logout_user(); auth.session.destroy(); - Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response()) + Ok(TemplateResponse::redirect("/")) } #[derive(Deserialize)] @@ -112,18 +77,13 @@ pub struct UpdateTimezoneForm { pub async fn update_timezone( State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, + auth: AuthSessionType, Form(form): Form, ) -> Result { - let user = match &auth.current_user { - Some(user) => user, - None => return Ok(Redirect::to("/").into_response()), - }; - - User::update_timezone(&user.id, &form.timezone, &state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + User::update_timezone(&user.id, &form.timezone, &state.db).await?; + // Clear the cache auth.cache_clear_user(user.id.to_string()); // Update the user's API key @@ -135,15 +95,12 @@ pub async fn update_timezone( let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect(); // Render the API key section block - let output = render_block( - AccountData::template_name(), + Ok(TemplateResponse::new_partial( + "auth/account_settings.html", "timezone_section", - AccountData { + AccountPageData { user: updated_user, timezones, }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } diff --git a/crates/html-router/src/routes/admin_panel.rs b/crates/html-router/src/routes/admin_panel.rs index f6b42f8..40462b8 100644 --- a/crates/html-router/src/routes/admin_panel.rs +++ b/crates/html-router/src/routes/admin_panel.rs @@ -1,62 +1,39 @@ -use axum::{ - extract::State, - response::{IntoResponse, Redirect}, - Form, +use axum::{extract::State, response::IntoResponse, Form}; +use serde::{Deserialize, Serialize}; + +use crate::{ + middleware_auth::RequireUser, + template_response::{HtmlError, TemplateResponse}, }; -use axum_session_auth::AuthSession; -use axum_session_surreal::SessionSurrealPool; -use surrealdb::{engine::any::Any, Surreal}; +use common::storage::types::{analytics::Analytics, system_settings::SystemSettings, user::User}; -use common::{ - error::HtmlError, - storage::types::{analytics::Analytics, system_settings::SystemSettings, user::User}, -}; +use crate::html_state::HtmlState; -use crate::{html_state::HtmlState, page_data}; - -use super::{render_block, render_template}; - -page_data!(AdminPanelData, "auth/admin_panel.html", { +#[derive(Serialize)] +pub struct AdminPanelData { user: User, settings: SystemSettings, analytics: Analytics, users: i64, -}); +} pub async fn show_admin_panel( State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, ) -> Result { - // Early return if the user is not authenticated and admin - let user = match auth.current_user { - Some(user) if user.admin => user, - _ => return Ok(Redirect::to("/").into_response()), - }; + let settings = SystemSettings::get_current(&state.db).await?; + let analytics = Analytics::get_current(&state.db).await?; + let users_count = Analytics::get_users_amount(&state.db).await?; - let settings = SystemSettings::get_current(&state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - - let analytics = Analytics::get_current(&state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - - let users_count = Analytics::get_users_amount(&state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - - let output = render_template( - AdminPanelData::template_name(), + Ok(TemplateResponse::new_template( + "auth/admin_panel.html", AdminPanelData { user, settings, analytics, users: users_count, }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } fn checkbox_to_bool<'de, D>(deserializer: D) -> Result @@ -83,36 +60,28 @@ pub struct RegistrationToggleData { pub async fn toggle_registration_status( State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, Form(input): Form, ) -> Result { - // Early return if the user is not authenticated and admin - let _user = match auth.current_user { - Some(user) if user.admin => user, - _ => return Ok(Redirect::to("/").into_response()), + // Early return if the user is not admin + if !user.admin { + return Ok(TemplateResponse::redirect("/")); }; - let current_settings = SystemSettings::get_current(&state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + let current_settings = SystemSettings::get_current(&state.db).await?; let new_settings = SystemSettings { registrations_enabled: input.registration_open, ..current_settings.clone() }; - SystemSettings::update(&state.db, new_settings.clone()) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + SystemSettings::update(&state.db, new_settings.clone()).await?; - let output = render_block( - AdminPanelData::template_name(), + Ok(TemplateResponse::new_partial( + "auth/admin_panel.html", "registration_status_input", RegistrationToggleData { settings: new_settings, }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } diff --git a/crates/html-router/src/routes/chat/message_response_stream.rs b/crates/html-router/src/routes/chat/message_response_stream.rs index b6ab5b0..1551e03 100644 --- a/crates/html-router/src/routes/chat/message_response_stream.rs +++ b/crates/html-router/src/routes/chat/message_response_stream.rs @@ -10,6 +10,12 @@ use axum::{ }; use axum_session_auth::AuthSession; use axum_session_surreal::SessionSurrealPool; +use composite_retrieval::{ + answer_retrieval::{ + create_chat_request, create_user_message, format_entities_json, LLMResponseFormat, + }, + retrieve_entities, +}; use futures::{ stream::{self, once}, Stream, StreamExt, TryStreamExt, @@ -21,19 +27,11 @@ use surrealdb::{engine::any::Any, Surreal}; use tokio::sync::{mpsc::channel, Mutex}; use tracing::{error, info}; -use common::{ - retrieval::{ - combined_knowledge_entity_retrieval, - query_helper::{ - create_chat_request, create_user_message, format_entities_json, LLMResponseFormat, - }, - }, - storage::{ - db::SurrealDbClient, - types::{ - message::{Message, MessageRole}, - user::User, - }, +use common::storage::{ + db::SurrealDbClient, + types::{ + message::{Message, MessageRole}, + user::User, }, }; @@ -100,7 +98,7 @@ pub async fn get_response_stream( }; // 2. Retrieve knowledge entities - let entities = match combined_knowledge_entity_retrieval( + let entities = match retrieve_entities( &state.db, &state.openai_client, &user_message.content, diff --git a/crates/html-router/src/routes/chat/mod.rs b/crates/html-router/src/routes/chat/mod.rs index f8c56aa..ee01bed 100644 --- a/crates/html-router/src/routes/chat/mod.rs +++ b/crates/html-router/src/routes/chat/mod.rs @@ -12,8 +12,9 @@ use axum_session_surreal::SessionSurrealPool; use surrealdb::{engine::any::Any, Surreal}; use tracing::info; +use crate::routes::HtmlError; use common::{ - error::{AppError, HtmlError}, + error::AppError, storage::types::{ conversation::Conversation, message::{Message, MessageRole}, diff --git a/crates/html-router/src/routes/chat/references.rs b/crates/html-router/src/routes/chat/references.rs index 6f5695c..229fe02 100644 --- a/crates/html-router/src/routes/chat/references.rs +++ b/crates/html-router/src/routes/chat/references.rs @@ -8,8 +8,9 @@ use serde::Serialize; use surrealdb::{engine::any::Any, Surreal}; use tracing::info; +use crate::routes::HtmlError; use common::{ - error::{AppError, HtmlError}, + error::AppError, storage::types::{knowledge_entity::KnowledgeEntity, user::User}, }; diff --git a/crates/html-router/src/routes/content/mod.rs b/crates/html-router/src/routes/content/mod.rs index 85eebce..e5a4d77 100644 --- a/crates/html-router/src/routes/content/mod.rs +++ b/crates/html-router/src/routes/content/mod.rs @@ -6,12 +6,9 @@ use axum_session_auth::AuthSession; use axum_session_surreal::SessionSurrealPool; use surrealdb::{engine::any::Any, Surreal}; -use common::{ - error::HtmlError, - storage::types::{text_content::TextContent, user::User}, -}; +use common::storage::types::{text_content::TextContent, user::User}; -use crate::{html_state::HtmlState, page_data}; +use crate::{error::HtmlError, html_state::HtmlState, page_data}; use super::render_template; diff --git a/crates/html-router/src/routes/documentation/mod.rs b/crates/html-router/src/routes/documentation/mod.rs index 5b246c1..17f1edb 100644 --- a/crates/html-router/src/routes/documentation/mod.rs +++ b/crates/html-router/src/routes/documentation/mod.rs @@ -1,78 +1,54 @@ -use axum::{extract::State, response::IntoResponse}; -use axum_session_auth::AuthSession; -use axum_session_surreal::SessionSurrealPool; -use surrealdb::{engine::any::Any, Surreal}; +use axum::response::IntoResponse; +use common::storage::types::user::User; +use serde::Serialize; -use common::{error::HtmlError, storage::types::user::User}; +use crate::template_response::{HtmlError, TemplateResponse}; +use crate::AuthSessionType; -use crate::{html_state::HtmlState, page_data}; - -use super::render_template; - -page_data!(DocumentationData, "do_not_use_this", { +#[derive(Serialize)] +pub struct DocumentationPageData { user: Option, - current_path: String -}); - -pub async fn show_privacy_policy( - State(state): State, - auth: AuthSession, Surreal>, -) -> Result { - let output = render_template( - "documentation/privacy.html", - DocumentationData { - user: auth.current_user, - current_path: "/privacy_policy".to_string(), - }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + current_path: String, } -pub async fn show_get_started( - State(state): State, - auth: AuthSession, Surreal>, -) -> Result { - let output = render_template( +pub async fn show_privacy_policy(auth: AuthSessionType) -> Result { + Ok(TemplateResponse::new_template( + "documentation/privacy.html", + DocumentationPageData { + user: auth.current_user, + current_path: "/privacy-policy".to_string(), + }, + )) +} + +pub async fn show_get_started(auth: AuthSessionType) -> Result { + Ok(TemplateResponse::new_template( "documentation/get_started.html", - DocumentationData { + DocumentationPageData { user: auth.current_user, current_path: "/get-started".to_string(), }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } -pub async fn show_mobile_friendly( - State(state): State, - auth: AuthSession, Surreal>, -) -> Result { - let output = render_template( + +pub async fn show_mobile_friendly(auth: AuthSessionType) -> Result { + Ok(TemplateResponse::new_template( "documentation/mobile_friendly.html", - DocumentationData { + DocumentationPageData { user: auth.current_user, current_path: "/mobile-friendly".to_string(), }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } pub async fn show_documentation_index( - State(state): State, - auth: AuthSession, Surreal>, + auth: AuthSessionType, ) -> Result { - let output = render_template( + Ok(TemplateResponse::new_template( "documentation/index.html", - DocumentationData { + DocumentationPageData { user: auth.current_user, current_path: "/index".to_string(), }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } diff --git a/crates/html-router/src/routes/gdpr.rs b/crates/html-router/src/routes/gdpr.rs index 253b831..6703d3f 100644 --- a/crates/html-router/src/routes/gdpr.rs +++ b/crates/html-router/src/routes/gdpr.rs @@ -1,22 +1,15 @@ use axum::response::{Html, IntoResponse}; -use axum_session::Session; -use axum_session_surreal::SessionSurrealPool; -use surrealdb::engine::any::Any; -use common::error::HtmlError; +use crate::SessionType; -pub async fn accept_gdpr( - session: Session>, -) -> Result { +pub async fn accept_gdpr(session: SessionType) -> impl IntoResponse { session.set("gdpr_accepted", true); - Ok(Html("").into_response()) + Html("").into_response() } -pub async fn deny_gdpr( - session: Session>, -) -> Result { +pub async fn deny_gdpr(session: SessionType) -> impl IntoResponse { session.set("gdpr_accepted", true); - Ok(Html("").into_response()) + Html("").into_response() } diff --git a/crates/html-router/src/routes/index.rs b/crates/html-router/src/routes/index.rs index bb0b669..28708a3 100644 --- a/crates/html-router/src/routes/index.rs +++ b/crates/html-router/src/routes/index.rs @@ -1,16 +1,18 @@ use axum::{ + debug_handler, extract::{Path, State}, - response::{IntoResponse, Redirect}, + response::IntoResponse, }; -use axum_session::Session; -use axum_session_auth::AuthSession; -use axum_session_surreal::SessionSurrealPool; -use surrealdb::{engine::any::Any, Surreal}; +use serde::Serialize; use tokio::join; -use tracing::info; +use crate::{ + middleware_auth::RequireUser, + template_response::{HtmlError, TemplateResponse}, + AuthSessionType, SessionType, +}; use common::{ - error::{AppError, HtmlError}, + error::AppError, storage::types::{ file_info::FileInfo, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, @@ -18,67 +20,51 @@ use common::{ }, }; -use crate::{html_state::HtmlState, page_data, routes::render_template}; +use crate::html_state::HtmlState; -use super::render_block; - -page_data!(IndexData, "index/index.html", { +#[derive(Serialize)] +pub struct IndexPageData { gdpr_accepted: bool, user: Option, latest_text_contents: Vec, - active_jobs: Vec -}); + active_jobs: Vec, +} pub async fn index_handler( State(state): State, - auth: AuthSession, Surreal>, - session: Session>, + auth: AuthSessionType, + session: SessionType, ) -> Result { - info!("Displaying index page"); - let gdpr_accepted = auth.current_user.is_some() | session.get("gdpr_accepted").unwrap_or(false); let active_jobs = match auth.current_user.is_some() { true => { User::get_unfinished_ingestion_tasks(&auth.current_user.clone().unwrap().id, &state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))? + .await? } false => vec![], }; let latest_text_contents = match auth.current_user.clone().is_some() { - true => User::get_latest_text_contents( - auth.current_user.clone().unwrap().id.as_str(), - &state.db, - ) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?, + true => { + User::get_latest_text_contents( + auth.current_user.clone().unwrap().id.as_str(), + &state.db, + ) + .await? + } false => vec![], }; - // let latest_knowledge_entities = match auth.current_user.is_some() { - // true => User::get_latest_knowledge_entities( - // auth.current_user.clone().unwrap().id.as_str(), - // &state.db, - // ) - // .await - // .map_err(|e| HtmlError::new(e, state.templates.clone()))?, - // false => vec![], - // }; - - let output = render_template( - IndexData::template_name(), - IndexData { + Ok(TemplateResponse::new_template( + "index/index.html", + IndexPageData { gdpr_accepted, user: auth.current_user, latest_text_contents, active_jobs, }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } #[derive(Serialize)] @@ -87,21 +73,17 @@ pub struct LatestTextContentData { user: User, } +#[debug_handler] pub async fn delete_text_content( State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, Path(id): Path, ) -> Result { - let user = match &auth.current_user { - Some(user) => user, - None => return Ok(Redirect::to("/").into_response()), - }; - // Get and validate TextContent - let text_content = get_and_validate_text_content(&state, &id, user).await?; + let text_content = get_and_validate_text_content(&state, &id, &user).await?; // Perform concurrent deletions - let deletion_tasks = join!( + join!( async { if let Some(file_info) = text_content.file_info { FileInfo::delete_by_id(&file_info.id, &state.db).await @@ -115,33 +97,17 @@ pub async fn delete_text_content( KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db) ); - // Handle potential errors from concurrent operations - match deletion_tasks { - (Ok(_), Ok(_), Ok(_), Ok(_), Ok(_)) => (), - _ => { - return Err(HtmlError::new( - AppError::Processing("Failed to delete one or more items".to_string()), - state.templates.clone(), - )) - } - } - // Render updated content - let latest_text_contents = User::get_latest_text_contents(&user.id, &state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + let latest_text_contents = User::get_latest_text_contents(&user.id, &state.db).await?; - let output = render_block( + Ok(TemplateResponse::new_partial( "index/signed_in/recent_content.html", "latest_content_section", LatestTextContentData { - user: user.clone(), + user: user.to_owned(), latest_text_contents, }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } // Helper function to get and validate text content @@ -149,23 +115,16 @@ async fn get_and_validate_text_content( state: &HtmlState, id: &str, user: &User, -) -> Result { +) -> Result { let text_content = state .db .get_item::(id) - .await - .map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))? - .ok_or_else(|| { - HtmlError::new( - AppError::NotFound("No item found".to_string()), - state.templates.clone(), - ) - })?; + .await? + .ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?; if text_content.user_id != user.id { - return Err(HtmlError::new( - AppError::Auth("You are not the owner of that content".to_string()), - state.templates.clone(), + return Err(AppError::Auth( + "You are not the owner of that content".to_string(), )); } @@ -180,57 +139,35 @@ pub struct ActiveJobsData { pub async fn delete_job( State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, Path(id): Path, ) -> Result { - let user = match auth.current_user { - Some(user) => user, - None => return Ok(Redirect::to("/signin").into_response()), - }; + User::validate_and_delete_job(&id, &user.id, &state.db).await?; - User::validate_and_delete_job(&id, &user.id, &state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; + let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?; - let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - - let output = render_block( + Ok(TemplateResponse::new_partial( "index/signed_in/active_jobs.html", "active_jobs_section", ActiveJobsData { user: user.clone(), active_jobs, }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } pub async fn show_active_jobs( State(state): State, - auth: AuthSession, Surreal>, + RequireUser(user): RequireUser, ) -> Result { - let user = match auth.current_user { - Some(user) => user, - None => return Ok(Redirect::to("/signin").into_response()), - }; + let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?; - let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db) - .await - .map_err(|e| HtmlError::new(e, state.templates.clone()))?; - - let output = render_block( + Ok(TemplateResponse::new_partial( "index/signed_in/active_jobs.html", "active_jobs_section", ActiveJobsData { user: user.clone(), active_jobs, }, - state.templates.clone(), - )?; - - Ok(output.into_response()) + )) } diff --git a/crates/html-router/src/routes/ingress_form.rs b/crates/html-router/src/routes/ingress_form.rs index 01d4693..fafb305 100644 --- a/crates/html-router/src/routes/ingress_form.rs +++ b/crates/html-router/src/routes/ingress_form.rs @@ -11,7 +11,7 @@ use tempfile::NamedTempFile; use tracing::info; use common::{ - error::{AppError, HtmlError, IntoHtmlError}, + error::AppError, storage::types::{ file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, user::User, @@ -19,6 +19,7 @@ use common::{ }; use crate::{ + error::{HtmlError, IntoHtmlError}, html_state::HtmlState, page_data, routes::{index::ActiveJobsData, render_block}, diff --git a/crates/html-router/src/routes/knowledge/mod.rs b/crates/html-router/src/routes/knowledge/mod.rs index 40f4293..35f1228 100644 --- a/crates/html-router/src/routes/knowledge/mod.rs +++ b/crates/html-router/src/routes/knowledge/mod.rs @@ -14,7 +14,7 @@ use surrealdb::{engine::any::Any, Surreal}; use tracing::info; use common::{ - error::{AppError, HtmlError}, + error::AppError, storage::types::{ knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, knowledge_relationship::KnowledgeRelationship, @@ -22,7 +22,7 @@ use common::{ }, }; -use crate::{html_state::HtmlState, page_data, routes::render_template}; +use crate::{error::HtmlError, html_state::HtmlState, page_data, routes::render_template}; page_data!(KnowledgeBaseData, "knowledge/base.html", { entities: Vec, diff --git a/crates/html-router/src/routes/mod.rs b/crates/html-router/src/routes/mod.rs index 8a27a41..46aab45 100644 --- a/crates/html-router/src/routes/mod.rs +++ b/crates/html-router/src/routes/mod.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use axum::response::Html; use minijinja_autoreload::AutoReloader; -use common::error::{HtmlError, IntoHtmlError}; +use crate::error::{HtmlError, IntoHtmlError}; pub mod account; pub mod admin_panel; @@ -19,73 +19,73 @@ pub mod signin; pub mod signout; pub mod signup; -pub trait PageData { - fn template_name() -> &'static str; -} +// pub trait PageData { +// fn template_name() -> &'static str; +// } -// Helper function for render_template -pub fn render_template( - template_name: &str, - context: T, - templates: Arc, -) -> Result, HtmlError> -where - T: serde::Serialize, -{ - let env = templates - .acquire_env() - .map_err(|e| e.with_template(templates.clone()))?; - let tmpl = env - .get_template(template_name) - .map_err(|e| e.with_template(templates.clone()))?; - let context = minijinja::Value::from_serialize(&context); - let output = tmpl - .render(context) - .map_err(|e| e.with_template(templates.clone()))?; - Ok(Html(output)) -} +// // Helper function for render_template +// pub fn render_template( +// template_name: &str, +// context: T, +// templates: Arc, +// ) -> Result, HtmlError> +// where +// T: serde::Serialize, +// { +// let env = templates +// .acquire_env() +// .map_err(|e| e.with_template(templates.clone()))?; +// let tmpl = env +// .get_template(template_name) +// .map_err(|e| e.with_template(templates.clone()))?; +// let context = minijinja::Value::from_serialize(&context); +// let output = tmpl +// .render(context) +// .map_err(|e| e.with_template(templates.clone()))?; +// Ok(Html(output)) +// } -pub fn render_block( - template_name: &str, - block: &str, - context: T, - templates: Arc, -) -> Result, HtmlError> -where - T: serde::Serialize, -{ - let env = templates - .acquire_env() - .map_err(|e| e.with_template(templates.clone()))?; - let tmpl = env - .get_template(template_name) - .map_err(|e| e.with_template(templates.clone()))?; +// pub fn render_block( +// template_name: &str, +// block: &str, +// context: T, +// templates: Arc, +// ) -> Result, HtmlError> +// where +// T: serde::Serialize, +// { +// let env = templates +// .acquire_env() +// .map_err(|e| e.with_template(templates.clone()))?; +// let tmpl = env +// .get_template(template_name) +// .map_err(|e| e.with_template(templates.clone()))?; - let context = minijinja::Value::from_serialize(&context); - let output = tmpl - .eval_to_state(context) - .map_err(|e| e.with_template(templates.clone()))? - .render_block(block) - .map_err(|e| e.with_template(templates.clone()))?; +// let context = minijinja::Value::from_serialize(&context); +// let output = tmpl +// .eval_to_state(context) +// .map_err(|e| e.with_template(templates.clone()))? +// .render_block(block) +// .map_err(|e| e.with_template(templates.clone()))?; - Ok(output.into()) -} +// Ok(output.into()) +// } -#[macro_export] -macro_rules! page_data { - ($name:ident, $template_name:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*$(,)?}) => { - use serde::{Serialize, Deserialize}; - use $crate::routes::PageData; +// #[macro_export] +// macro_rules! page_data { +// ($name:ident, $template_name:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*$(,)?}) => { +// use serde::{Serialize, Deserialize}; +// use $crate::routes::PageData; - #[derive(Debug, Deserialize, Serialize)] - pub struct $name { - $($(#[$attr])* pub $field: $ty),* - } +// #[derive(Debug, Deserialize, Serialize)] +// pub struct $name { +// $($(#[$attr])* pub $field: $ty),* +// } - impl PageData for $name { - fn template_name() -> &'static str { - $template_name - } - } - }; -} +// impl PageData for $name { +// fn template_name() -> &'static str { +// $template_name +// } +// } +// }; +// } diff --git a/crates/html-router/src/routes/search_result.rs b/crates/html-router/src/routes/search_result.rs index 8625715..cced2d3 100644 --- a/crates/html-router/src/routes/search_result.rs +++ b/crates/html-router/src/routes/search_result.rs @@ -8,7 +8,8 @@ use serde::{Deserialize, Serialize}; use surrealdb::{engine::any::Any, Surreal}; use tracing::info; -use common::{error::HtmlError, storage::types::user::User}; +use crate::routes::HtmlError; +use common::storage::types::user::User; use crate::{html_state::HtmlState, routes::render_template}; #[derive(Deserialize)] diff --git a/crates/html-router/src/routes/signin.rs b/crates/html-router/src/routes/signin.rs index 24c8cdf..b2eadec 100644 --- a/crates/html-router/src/routes/signin.rs +++ b/crates/html-router/src/routes/signin.rs @@ -1,19 +1,17 @@ use axum::{ extract::State, - http::{StatusCode, Uri}, - response::{Html, IntoResponse, Redirect}, + response::{Html, IntoResponse}, Form, }; -use axum_htmx::{HxBoosted, HxRedirect}; -use axum_session_auth::AuthSession; -use axum_session_surreal::SessionSurrealPool; -use surrealdb::{engine::any::Any, Surreal}; +use axum_htmx::HxBoosted; +use serde::{Deserialize, Serialize}; -use common::{error::HtmlError, storage::types::user::User}; - -use crate::{html_state::HtmlState, page_data}; - -use super::{render_block, render_template}; +use crate::{ + html_state::HtmlState, + template_response::{HtmlError, TemplateResponse}, + AuthSessionType, +}; +use common::storage::types::user::User; #[derive(Deserialize, Serialize)] pub struct SignupParams { @@ -22,36 +20,26 @@ pub struct SignupParams { pub remember_me: Option, } -page_data!(ShowSignInForm, "auth/signin_form.html", {}); - pub async fn show_signin_form( - State(state): State, - auth: AuthSession, Surreal>, + auth: AuthSessionType, HxBoosted(boosted): HxBoosted, ) -> Result { if auth.is_authenticated() { - return Ok(Redirect::to("/").into_response()); + return Ok(TemplateResponse::redirect("/")); } - let output = match boosted { - true => render_block( - ShowSignInForm::template_name(), + match boosted { + true => Ok(TemplateResponse::new_partial( + "auth/signin_form.html", "body", - ShowSignInForm {}, - state.templates.clone(), - )?, - false => render_template( - ShowSignInForm::template_name(), - ShowSignInForm {}, - state.templates.clone(), - )?, - }; - - Ok(output.into_response()) + {}, + )), + false => Ok(TemplateResponse::new_template("auth/signin_form.html", {})), + } } pub async fn authenticate_user( State(state): State, - auth: AuthSession, Surreal>, + auth: AuthSessionType, Form(form): Form, ) -> Result { let user = match User::authenticate(form.email, form.password, &state.db).await { @@ -67,5 +55,5 @@ pub async fn authenticate_user( auth.remember_user(true); } - Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response()) + Ok(TemplateResponse::redirect("/").into_response()) } diff --git a/crates/html-router/src/routes/signout.rs b/crates/html-router/src/routes/signout.rs index 147bd94..1b54f50 100644 --- a/crates/html-router/src/routes/signout.rs +++ b/crates/html-router/src/routes/signout.rs @@ -1,18 +1,16 @@ -use axum::response::{IntoResponse, Redirect}; -use axum_session_auth::AuthSession; -use axum_session_surreal::SessionSurrealPool; -use surrealdb::{engine::any::Any, Surreal}; +use axum::response::IntoResponse; -use common::{error::ApiError, storage::types::user::User}; +use crate::{ + template_response::{HtmlError, TemplateResponse}, + AuthSessionType, +}; -pub async fn sign_out_user( - auth: AuthSession, Surreal>, -) -> Result { +pub async fn sign_out_user(auth: AuthSessionType) -> Result { if !auth.is_authenticated() { - return Ok(Redirect::to("/").into_response()); + return Ok(TemplateResponse::redirect("/")); } auth.logout_user(); - Ok(Redirect::to("/").into_response()) + Ok(TemplateResponse::redirect("/")) } diff --git a/crates/html-router/src/routes/signup.rs b/crates/html-router/src/routes/signup.rs index 67993b4..68cb359 100644 --- a/crates/html-router/src/routes/signup.rs +++ b/crates/html-router/src/routes/signup.rs @@ -1,20 +1,18 @@ use axum::{ extract::State, - http::{StatusCode, Uri}, - response::{Html, IntoResponse, Redirect}, + response::{Html, IntoResponse}, Form, }; -use axum_htmx::{HxBoosted, HxRedirect}; -use axum_session_auth::AuthSession; -use axum_session_surreal::SessionSurrealPool; +use axum_htmx::HxBoosted; use serde::{Deserialize, Serialize}; -use surrealdb::{engine::any::Any, Surreal}; -use common::{error::HtmlError, storage::types::user::User}; +use common::storage::types::user::User; -use crate::html_state::HtmlState; - -use super::{render_block, render_template}; +use crate::{ + html_state::HtmlState, + template_response::{HtmlError, TemplateResponse}, + AuthSessionType, +}; #[derive(Deserialize, Serialize)] pub struct SignupParams { @@ -24,24 +22,26 @@ pub struct SignupParams { } pub async fn show_signup_form( - State(state): State, - auth: AuthSession, Surreal>, + auth: AuthSessionType, HxBoosted(boosted): HxBoosted, ) -> Result { if auth.is_authenticated() { - return Ok(Redirect::to("/").into_response()); + return Ok(TemplateResponse::redirect("/")); } - let output = match boosted { - true => render_block("auth/signup_form.html", "body", {}, state.templates.clone())?, - false => render_template("auth/signup_form.html", {}, state.templates.clone())?, - }; - Ok(output.into_response()) + match boosted { + true => Ok(TemplateResponse::new_partial( + "auth/signup_form.html", + "body", + {}, + )), + false => Ok(TemplateResponse::new_template("auth/signup_form.html", {})), + } } pub async fn process_signup_and_show_verification( State(state): State, - auth: AuthSession, Surreal>, + auth: AuthSessionType, Form(form): Form, ) -> Result { let user = match User::create_new(form.email, form.password, &state.db, form.timezone).await { @@ -54,5 +54,5 @@ pub async fn process_signup_and_show_verification( auth.login_user(user.id); - Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response()) + Ok(TemplateResponse::redirect("/").into_response()) } diff --git a/crates/html-router/src/template_response.rs b/crates/html-router/src/template_response.rs new file mode 100644 index 0000000..1bfe055 --- /dev/null +++ b/crates/html-router/src/template_response.rs @@ -0,0 +1,288 @@ +use axum::{ + extract::State, + http::StatusCode, + response::{Html, IntoResponse, Response}, + Extension, +}; +use common::error::AppError; +use minijinja::{context, Value}; +use minijinja_autoreload::AutoReloader; +use serde::Serialize; +use std::sync::Arc; + +use crate::{html_state::HtmlState, AuthSessionType}; + +// Enum for template types +#[derive(Clone)] +pub enum TemplateKind { + Full(String), // Full page template + Partial(String, String), // Template name, block name + Error(StatusCode), // Error template with status code + Redirect(axum::response::Redirect), // Redirect +} + +#[derive(Clone)] +pub struct TemplateResponse { + template_kind: TemplateKind, + context: Value, +} + +impl TemplateResponse { + pub fn new_template(name: impl Into, context: T) -> Self { + Self { + template_kind: TemplateKind::Full(name.into()), + context: Value::from_serialize(&context), + } + } + + pub fn new_partial( + template: impl Into, + block: impl Into, + context: T, + ) -> Self { + Self { + template_kind: TemplateKind::Partial(template.into(), block.into()), + context: Value::from_serialize(&context), + } + } + + pub fn error(status: StatusCode, title: &str, error: &str, description: &str) -> Self { + let ctx = context! { + status_code => status.as_u16(), + title => title, + error => error, + description => description + }; + + Self { + template_kind: TemplateKind::Error(status), + context: ctx, + } + } + + // Convenience methods for common errors + pub fn not_found() -> Self { + Self::error( + StatusCode::NOT_FOUND, + "Page Not Found", + "Not Found", + "The page you're looking for doesn't exist or was removed.", + ) + } + + pub fn server_error() -> Self { + Self::error( + StatusCode::INTERNAL_SERVER_ERROR, + "Internal Server Error", + "Internal Server Error", + "Something went wrong on our end.", + ) + } + + pub fn unauthorized() -> Self { + Self::error( + StatusCode::UNAUTHORIZED, + "Unauthorized", + "Access Denied", + "You need to be logged in to access this page.", + ) + } + + pub fn bad_request(message: &str) -> Self { + Self::error( + StatusCode::BAD_REQUEST, + "Bad Request", + "Bad Request", + message, + ) + } + + pub fn redirect(path: impl AsRef) -> Self { + let redirect_response = axum::response::Redirect::to(path.as_ref()); + + Self { + template_kind: TemplateKind::Redirect(redirect_response), + context: Value::from_serialize(&()), + } + } +} + +impl IntoResponse for TemplateResponse { + fn into_response(self) -> Response { + Extension(self).into_response() + } +} + +// Wrapper to avoid recursion +struct TemplateStateWrapper { + state: HtmlState, + auth: AuthSessionType, + template_response: TemplateResponse, +} + +impl IntoResponse for TemplateStateWrapper { + fn into_response(self) -> Response { + let templates = self.state.templates; + + match &self.template_response.template_kind { + TemplateKind::Full(name) => { + render_template(name, self.template_response.context, templates) + } + TemplateKind::Partial(name, block) => { + render_block(name, block, self.template_response.context, templates) + } + TemplateKind::Error(status) => { + let html = match try_render_template( + "errors/error.html", + self.template_response.context, + templates, + ) { + Ok(html_string) => Html(html_string), + Err(_) => fallback_error(), + }; + (*status, html).into_response() + } + TemplateKind::Redirect(redirect) => redirect.clone().into_response(), + } + } +} + +// Helper functions for rendering with error handling +fn render_template(name: &str, context: Value, templates: Arc) -> Response { + match try_render_template(name, context, templates.clone()) { + Ok(html) => Html(html).into_response(), + Err(_) => fallback_error().into_response(), + } +} + +fn render_block(name: &str, block: &str, context: Value, templates: Arc) -> Response { + match try_render_block(name, block, context, templates.clone()) { + Ok(html) => Html(html).into_response(), + Err(_) => fallback_error().into_response(), + } +} + +fn try_render_template( + template_name: &str, + context: Value, + templates: Arc, +) -> Result { + let env = templates.acquire_env().map_err(|e| { + tracing::error!("Environment error: {:?}", e); + () + })?; + + let tmpl = env.get_template(template_name).map_err(|e| { + tracing::error!("Template error: {:?}", e); + () + })?; + + tmpl.render(context).map_err(|e| { + tracing::error!("Render error: {:?}", e); + () + }) +} + +fn try_render_block( + template_name: &str, + block: &str, + context: Value, + templates: Arc, +) -> Result { + let env = templates.acquire_env().map_err(|e| { + tracing::error!("Environment error: {:?}", e); + () + })?; + + let tmpl = env.get_template(template_name).map_err(|e| { + tracing::error!("Template error: {:?}", e); + () + })?; + + let mut state = tmpl.eval_to_state(context).map_err(|e| { + tracing::error!("Eval error: {:?}", e); + () + })?; + + state.render_block(block).map_err(|e| { + tracing::error!("Block render error: {:?}", e); + () + }) +} + +fn fallback_error() -> Html { + Html( + r#" + + +
+

Error

+

Sorry, something went wrong displaying this page.

+
+ + + "# + .to_string(), + ) +} + +pub async fn with_template_response( + State(state): State, + auth: AuthSessionType, + response: Response, +) -> Response { + // Clone the TemplateResponse from extensions + let template_response = response.extensions().get::().cloned(); + + if let Some(template_response) = template_response { + TemplateStateWrapper { + state, + auth, + template_response, + } + .into_response() + } else { + response + } +} + +// Define HtmlError +pub enum HtmlError { + AppError(AppError), + TemplateError(String), +} + +// Conversion from AppError to HtmlError +impl From for HtmlError { + fn from(err: AppError) -> Self { + HtmlError::AppError(err) + } +} + +// Conversion for database error to HtmlError +impl From for HtmlError { + fn from(err: surrealdb::Error) -> Self { + HtmlError::AppError(AppError::from(err)) + } +} + +// Now implement IntoResponse for HtmlError +impl IntoResponse for HtmlError { + fn into_response(self) -> Response { + match self { + HtmlError::AppError(err) => { + let template_response = match err { + AppError::NotFound(_) => TemplateResponse::not_found(), + AppError::Auth(_) => TemplateResponse::unauthorized(), + AppError::Validation(msg) => TemplateResponse::bad_request(&msg), + _ => { + tracing::error!("Internal error: {:?}", err); + TemplateResponse::server_error() + } + }; + template_response.into_response() + } + HtmlError::TemplateError(_) => TemplateResponse::server_error().into_response(), + } + } +} diff --git a/crates/ingestion-pipeline/Cargo.toml b/crates/ingestion-pipeline/Cargo.toml index 2f6e913..32d6850 100644 --- a/crates/ingestion-pipeline/Cargo.toml +++ b/crates/ingestion-pipeline/Cargo.toml @@ -10,12 +10,15 @@ serde = { workspace = true } axum = { workspace = true } tracing = { workspace = true } serde_json = { workspace = true } +futures = { workspace = true } +async-openai = { workspace = true } -async-openai = "0.24.1" tiktoken-rs = "0.6.0" reqwest = {version = "0.12.12", features = ["charset", "json"]} scraper = "0.22.0" chrono = { version = "0.4.39", features = ["serde"] } text-splitter = "0.18.1" +uuid = { version = "1.10.0", features = ["v4", "serde"] } common = { path = "../common" } +composite-retrieval = { path = "../composite-retrieval" } diff --git a/crates/common/src/ingress/analysis/ingress_analyser.rs b/crates/ingestion-pipeline/src/enricher.rs similarity index 82% rename from crates/common/src/ingress/analysis/ingress_analyser.rs rename to crates/ingestion-pipeline/src/enricher.rs index 4a3cb37..b362a88 100644 --- a/crates/common/src/ingress/analysis/ingress_analyser.rs +++ b/crates/ingestion-pipeline/src/enricher.rs @@ -1,9 +1,5 @@ -use crate::{ - error::AppError, - ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE}, - retrieval::combined_knowledge_entity_retrieval, - storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}, -}; +use std::sync::Arc; + use async_openai::{ error::OpenAIError, types::{ @@ -12,20 +8,28 @@ use async_openai::{ ResponseFormatJsonSchema, }, }; +use common::{ + error::AppError, + storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity}, +}; +use composite_retrieval::retrieve_entities; use serde_json::json; use tracing::debug; -use super::types::llm_analysis_result::LLMGraphAnalysisResult; +use crate::{ + types::llm_enrichment_result::LLMEnrichmentResult, + utils::llm_instructions::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE}, +}; -pub struct IngressAnalyzer<'a> { - db_client: &'a SurrealDbClient, - openai_client: &'a async_openai::Client, +pub struct IngestionEnricher { + db_client: Arc, + openai_client: Arc>, } -impl<'a> IngressAnalyzer<'a> { +impl IngestionEnricher { pub fn new( - db_client: &'a SurrealDbClient, - openai_client: &'a async_openai::Client, + db_client: Arc, + openai_client: Arc>, ) -> Self { Self { db_client, @@ -39,7 +43,7 @@ impl<'a> IngressAnalyzer<'a> { instructions: &str, text: &str, user_id: &str, - ) -> Result { + ) -> Result { let similar_entities = self .find_similar_entities(category, instructions, text, user_id) .await?; @@ -60,13 +64,7 @@ impl<'a> IngressAnalyzer<'a> { text, category, instructions ); - combined_knowledge_entity_retrieval( - self.db_client, - self.openai_client, - &input_text, - user_id, - ) - .await + retrieve_entities(&self.db_client, &self.openai_client, &input_text, user_id).await } fn prepare_llm_request( @@ -120,7 +118,7 @@ impl<'a> IngressAnalyzer<'a> { async fn perform_analysis( &self, request: CreateChatCompletionRequest, - ) -> Result { + ) -> Result { let response = self.openai_client.chat().create(request).await?; debug!("Received LLM response: {:?}", response); diff --git a/crates/ingestion-pipeline/src/lib.rs b/crates/ingestion-pipeline/src/lib.rs index 6fca8de..60cf9ea 100644 --- a/crates/ingestion-pipeline/src/lib.rs +++ b/crates/ingestion-pipeline/src/lib.rs @@ -1,2 +1,4 @@ +pub mod enricher; pub mod pipeline; pub mod types; +pub mod utils; diff --git a/crates/ingestion-pipeline/src/pipeline.rs b/crates/ingestion-pipeline/src/pipeline.rs index 713f386..66f69dd 100644 --- a/crates/ingestion-pipeline/src/pipeline.rs +++ b/crates/ingestion-pipeline/src/pipeline.rs @@ -19,12 +19,11 @@ use common::{ utils::embedding::generate_embedding, }; -use common::ingress::analysis::{ - ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult, +use crate::{ + enricher::IngestionEnricher, + types::{llm_enrichment_result::LLMEnrichmentResult, to_text_content}, }; -use crate::types::to_text_content; - pub struct IngestionPipeline { db: Arc, openai_client: Arc>, @@ -109,8 +108,8 @@ impl IngestionPipeline { async fn perform_semantic_analysis( &self, content: &TextContent, - ) -> Result { - let analyser = IngressAnalyzer::new(&self.db, &self.openai_client); + ) -> Result { + let analyser = IngestionEnricher::new(self.db.clone(), self.openai_client.clone()); analyser .analyze_content( &content.category, diff --git a/crates/common/src/ingress/analysis/types/llm_analysis_result.rs b/crates/ingestion-pipeline/src/types/llm_enrichment_result.rs similarity index 97% rename from crates/common/src/ingress/analysis/types/llm_analysis_result.rs rename to crates/ingestion-pipeline/src/types/llm_enrichment_result.rs index ec62225..8b845d7 100644 --- a/crates/common/src/ingress/analysis/types/llm_analysis_result.rs +++ b/crates/ingestion-pipeline/src/types/llm_enrichment_result.rs @@ -4,7 +4,7 @@ use chrono::Utc; use serde::{Deserialize, Serialize}; use tokio::task; -use crate::{ +use common::{ error::AppError, storage::types::{ knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, @@ -14,7 +14,7 @@ use crate::{ }; use futures::future::try_join_all; -use super::graph_mapper::GraphMapper; // For future parallelization +use crate::utils::GraphMapper; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct LLMKnowledgeEntity { @@ -35,12 +35,12 @@ pub struct LLMRelationship { /// Represents the entire graph analysis result from the LLM. #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct LLMGraphAnalysisResult { +pub struct LLMEnrichmentResult { pub knowledge_entities: Vec, pub relationships: Vec, } -impl LLMGraphAnalysisResult { +impl LLMEnrichmentResult { /// Converts the LLM graph analysis result into database entities and relationships. /// /// # Arguments diff --git a/crates/ingestion-pipeline/src/types/mod.rs b/crates/ingestion-pipeline/src/types/mod.rs index 2ee6e7a..95b2826 100644 --- a/crates/ingestion-pipeline/src/types/mod.rs +++ b/crates/ingestion-pipeline/src/types/mod.rs @@ -1,3 +1,5 @@ +pub mod llm_enrichment_result; + use std::{sync::Arc, time::Duration}; use async_openai::types::{ diff --git a/crates/common/src/ingress/analysis/prompt.rs b/crates/ingestion-pipeline/src/utils/llm_instructions.rs similarity index 100% rename from crates/common/src/ingress/analysis/prompt.rs rename to crates/ingestion-pipeline/src/utils/llm_instructions.rs diff --git a/crates/common/src/ingress/analysis/types/graph_mapper.rs b/crates/ingestion-pipeline/src/utils/mod.rs similarity index 97% rename from crates/common/src/ingress/analysis/types/graph_mapper.rs rename to crates/ingestion-pipeline/src/utils/mod.rs index 08d235f..210dfba 100644 --- a/crates/common/src/ingress/analysis/types/graph_mapper.rs +++ b/crates/ingestion-pipeline/src/utils/mod.rs @@ -1,3 +1,5 @@ +pub mod llm_instructions; + use std::collections::HashMap; use uuid::Uuid; diff --git a/crates/main/Cargo.toml b/crates/main/Cargo.toml index a86316d..5542f39 100644 --- a/crates/main/Cargo.toml +++ b/crates/main/Cargo.toml @@ -11,40 +11,12 @@ thiserror = { workspace = true } anyhow = { workspace = true } tracing = { workspace = true } axum = { workspace = true } +surrealdb = { workspace = true } +futures = { workspace = true } +async-openai = { workspace = true } -async-openai = "0.24.1" -async-stream = "0.3.6" -axum-htmx = "0.6.0" -axum_session = "0.14.4" -axum_session_auth = "0.14.1" -axum_session_surreal = "0.2.1" -axum_typed_multipart = "0.12.1" -chrono = { version = "0.4.39", features = ["serde"] } -chrono-tz = "0.10.1" -config = "0.15.4" -futures = "0.3.31" -json-stream-parser = "0.1.4" -lettre = { version = "0.11.11", features = ["rustls-tls"] } -mime = "0.3.17" -mime_guess = "2.0.5" -minijinja = { version = "2.5.0", features = ["loader", "multi_template"] } -minijinja-autoreload = "2.5.0" -minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] } -mockall = "0.13.0" -plotly = "0.12.1" -reqwest = {version = "0.12.12", features = ["charset", "json"]} -scraper = "0.22.0" -sha2 = "0.10.8" -surrealdb = "2.0.4" -tempfile = "3.12.0" -text-splitter = "0.18.1" -tiktoken-rs = "0.6.0" -tower-http = { version = "0.6.2", features = ["fs"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -url = { version = "2.5.2", features = ["serde"] } -uuid = { version = "1.10.0", features = ["v4", "serde"] } -# Reference to api-router ingestion-pipeline = { path = "../ingestion-pipeline" } api-router = { path = "../api-router" } html-router = { path = "../html-router" }