wip: heavy refactoring html routers

This commit is contained in:
Per Stark
2025-03-08 15:47:44 +01:00
parent 1a641db503
commit 60a0d621e1
50 changed files with 1130 additions and 987 deletions

196
Cargo.lock generated
View File

@@ -138,12 +138,6 @@ dependencies = [
"libc",
]
[[package]]
name = "anstyle"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
[[package]]
name = "any_ascii"
version = "0.3.2"
@@ -167,6 +161,7 @@ dependencies = [
"futures",
"serde",
"tempfile",
"thiserror",
"tokio",
"tracing",
]
@@ -1050,9 +1045,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"axum",
"axum-htmx",
"axum_session",
"axum_session_auth",
"axum_session_surreal",
@@ -1061,31 +1054,41 @@ dependencies = [
"chrono-tz",
"config",
"futures",
"json-stream-parser",
"lettre",
"mime",
"mime_guess",
"minijinja",
"minijinja-autoreload",
"minijinja-contrib",
"mockall",
"plotly",
"reqwest",
"serde",
"serde_json",
"sha2",
"surrealdb",
"tempfile",
"text-splitter",
"thiserror",
"tokio",
"tower-http",
"tracing",
"tracing-subscriber",
"url",
"uuid",
]
[[package]]
name = "composite-retrieval"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"axum",
"common",
"futures",
"serde",
"serde_json",
"surrealdb",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "concurrent-queue"
version = "2.5.0"
@@ -1097,9 +1100,9 @@ dependencies = [
[[package]]
name = "config"
version = "0.15.4"
version = "0.15.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d84f8d224ac58107d53d3ec2b9ad39fd8c8c4e285d3c9cb35485ffd2ca88cb3"
checksum = "fb07d21d12f9f0bc5e7c3e97ccc78b2341b9b4a4604eac3ed7c1d0d6e2c3b23e"
dependencies = [
"async-trait",
"convert_case",
@@ -1110,7 +1113,7 @@ dependencies = [
"serde",
"serde_json",
"toml",
"winnow",
"winnow 0.7.3",
"yaml-rust2",
]
@@ -1448,12 +1451,6 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
[[package]]
name = "downcast"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
[[package]]
name = "dtoa"
version = "1.0.9"
@@ -1576,7 +1573,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"nom 7.1.3",
"pin-project-lite",
]
@@ -1627,6 +1624,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
[[package]]
name = "foreign-types"
version = "0.3.2"
@@ -1661,12 +1664,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "fragile"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
[[package]]
name = "fsevent-sys"
version = "4.1.0"
@@ -1974,14 +1971,17 @@ name = "hashbrown"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb"
dependencies = [
"foldhash",
]
[[package]]
name = "hashlink"
version = "0.9.1"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af"
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
dependencies = [
"hashbrown 0.14.5",
"hashbrown 0.15.0",
]
[[package]]
@@ -2052,6 +2052,7 @@ dependencies = [
"axum_typed_multipart",
"chrono-tz",
"common",
"composite-retrieval",
"futures",
"json-stream-parser",
"minijinja",
@@ -2062,6 +2063,7 @@ dependencies = [
"serde_json",
"surrealdb",
"tempfile",
"thiserror",
"tokio",
"tower-http",
"tracing",
@@ -2435,6 +2437,8 @@ dependencies = [
"axum",
"chrono",
"common",
"composite-retrieval",
"futures",
"reqwest",
"scraper",
"serde",
@@ -2443,6 +2447,7 @@ dependencies = [
"tiktoken-rs",
"tokio",
"tracing",
"uuid",
]
[[package]]
@@ -2625,9 +2630,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lettre"
version = "0.11.11"
version = "0.11.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab4c9a167ff73df98a5ecc07e8bf5ce90b583665da3d1762eb1f775ad4d0d6f5"
checksum = "5d476fe7a4a798f392ce34947aa7d53d981127e37523c5251da3c927f7fa901f"
dependencies = [
"base64 0.22.1",
"chumsky",
@@ -2640,16 +2645,12 @@ dependencies = [
"idna",
"mime",
"native-tls",
"nom",
"nom 8.0.0",
"percent-encoding",
"quoted_printable",
"rustls",
"rustls-pemfile",
"rustls-pki-types",
"socket2",
"tokio",
"url",
"webpki-roots",
]
[[package]]
@@ -2737,45 +2738,18 @@ dependencies = [
"anyhow",
"api-router",
"async-openai",
"async-stream",
"axum",
"axum-htmx",
"axum_session",
"axum_session_auth",
"axum_session_surreal",
"axum_typed_multipart",
"chrono",
"chrono-tz",
"common",
"config",
"futures",
"html-router",
"ingestion-pipeline",
"json-stream-parser",
"lettre",
"mime",
"mime_guess",
"minijinja",
"minijinja-autoreload",
"minijinja-contrib",
"mockall",
"plotly",
"reqwest",
"scraper",
"serde",
"serde_json",
"sha2",
"surrealdb",
"tempfile",
"text-splitter",
"thiserror",
"tiktoken-rs",
"tokio",
"tower-http",
"tracing",
"tracing-subscriber",
"url",
"uuid",
]
[[package]]
@@ -2959,32 +2933,6 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "mockall"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4c28b3fb6d753d28c20e826cd46ee611fda1cf3cde03a443a974043247c065a"
dependencies = [
"cfg-if",
"downcast",
"fragile",
"mockall_derive",
"predicates",
"predicates-tree",
]
[[package]]
name = "mockall_derive"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "341014e7f530314e9a1fdbc7400b244efea7122662c96bfa248c31da5bfb2020"
dependencies = [
"cfg-if",
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]]
name = "multer"
version = "3.1.0"
@@ -3091,6 +3039,15 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nom"
version = "8.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405"
dependencies = [
"memchr",
]
[[package]]
name = "nonempty"
version = "0.7.0"
@@ -3611,32 +3568,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]]
name = "predicates"
version = "3.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97"
dependencies = [
"anstyle",
"predicates-core",
]
[[package]]
name = "predicates-core"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931"
[[package]]
name = "predicates-tree"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13"
dependencies = [
"predicates-core",
"termtree",
]
[[package]]
name = "proc-macro-crate"
version = "3.2.0"
@@ -4004,7 +3935,7 @@ dependencies = [
"futures-core",
"futures-timer",
"mime",
"nom",
"nom 7.1.3",
"pin-project-lite",
"reqwest",
"thiserror",
@@ -4090,7 +4021,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f9a866e2e00a7a1fb27e46e9e324a6f7c0e7edc4543cae1d38f4e4a100c610"
dependencies = [
"memchr",
"nom",
"nom 7.1.3",
"serde",
]
@@ -5110,12 +5041,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "termtree"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
[[package]]
name = "text-splitter"
version = "0.18.1"
@@ -5381,7 +5306,7 @@ dependencies = [
"serde",
"serde_spanned",
"toml_datetime",
"winnow",
"winnow 0.6.20",
]
[[package]]
@@ -6036,6 +5961,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "winnow"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1"
dependencies = [
"memchr",
]
[[package]]
name = "write16"
version = "1.0.0"
@@ -6084,9 +6018,9 @@ checksum = "791978798f0597cfc70478424c2b4fdc2b7a8024aaff78497ef00f24ef674193"
[[package]]
name = "yaml-rust2"
version = "0.9.0"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a1a1c0bc9823338a3bdf8c61f994f23ac004c6fa32c08cd152984499b445e8d"
checksum = "232bdb534d65520716bef0bbb205ff8f2db72d807b19c0bc3020853b92a0cd4b"
dependencies = [
"arraydeque",
"encoding_rs",

View File

@@ -2,8 +2,11 @@
members = [
"crates/main",
"crates/common",
"crates/api-router"
, "crates/html-router", "crates/ingestion-pipeline"]
"crates/api-router",
"crates/html-router",
"crates/ingestion-pipeline",
"crates/composite-retrieval"
]
resolver = "2"
[workspace.dependencies]
@@ -14,3 +17,6 @@ serde_json = "1.0.128"
thiserror = "1.0.63"
anyhow = "1.0.94"
tracing = "0.1.40"
surrealdb = "2.0.4"
futures = "0.3.31"
async-openai = "0.24.1"

View File

@@ -10,6 +10,7 @@ serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
tempfile = "3.12.0"
futures = "0.3.31"

View File

@@ -0,0 +1,168 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use common::error::AppError;
use serde::Serialize;
use thiserror::Error;
#[derive(Error, Debug, Serialize, Clone)]
pub enum ApiError {
#[error("Internal server error")]
InternalError(String),
#[error("Validation error: {0}")]
ValidationError(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
}
impl From<AppError> for ApiError {
fn from(err: AppError) -> Self {
match err {
AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => {
tracing::error!("Internal error: {:?}", err);
ApiError::InternalError("Internal server error".to_string())
}
AppError::NotFound(msg) => ApiError::NotFound(msg),
AppError::Validation(msg) => ApiError::ValidationError(msg),
AppError::Auth(msg) => ApiError::Unauthorized(msg),
_ => ApiError::InternalError("Internal server error".to_string()),
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let (status, error_response) = match self {
ApiError::InternalError(message) => (
StatusCode::INTERNAL_SERVER_ERROR,
ErrorResponse {
error: message,
status: "error".to_string(),
},
),
ApiError::ValidationError(message) => (
StatusCode::BAD_REQUEST,
ErrorResponse {
error: message,
status: "error".to_string(),
},
),
ApiError::NotFound(message) => (
StatusCode::NOT_FOUND,
ErrorResponse {
error: message,
status: "error".to_string(),
},
),
ApiError::Unauthorized(message) => (
StatusCode::UNAUTHORIZED,
ErrorResponse {
error: message,
status: "error".to_string(),
},
),
};
(status, Json(error_response)).into_response()
}
}
#[derive(Serialize, Debug)]
struct ErrorResponse {
error: String,
status: String,
}
#[cfg(test)]
mod tests {
use super::*;
use common::error::AppError;
use std::fmt::Debug;
// Helper to check status code
fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) {
let response = response.into_response();
assert_eq!(response.status(), expected_status);
}
#[test]
fn test_app_error_to_api_error_conversion() {
// Test NotFound error conversion
let not_found = AppError::NotFound("resource not found".to_string());
let api_error = ApiError::from(not_found);
assert!(matches!(api_error, ApiError::NotFound(msg) if msg == "resource not found"));
// Test Validation error conversion
let validation = AppError::Validation("invalid input".to_string());
let api_error = ApiError::from(validation);
assert!(matches!(api_error, ApiError::ValidationError(msg) if msg == "invalid input"));
// Test Auth error conversion
let auth = AppError::Auth("unauthorized".to_string());
let api_error = ApiError::from(auth);
assert!(matches!(api_error, ApiError::Unauthorized(msg) if msg == "unauthorized"));
// Test for internal errors - create a mock error that doesn't require surrealdb
let internal_error =
AppError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io error"));
let api_error = ApiError::from(internal_error);
assert!(matches!(api_error, ApiError::InternalError(_)));
}
#[test]
fn test_api_error_response_status_codes() {
// Test internal error status
let error = ApiError::InternalError("server error".to_string());
assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR);
// Test not found status
let error = ApiError::NotFound("not found".to_string());
assert_status_code(error, StatusCode::NOT_FOUND);
// Test validation error status
let error = ApiError::ValidationError("invalid input".to_string());
assert_status_code(error, StatusCode::BAD_REQUEST);
// Test unauthorized status
let error = ApiError::Unauthorized("not allowed".to_string());
assert_status_code(error, StatusCode::UNAUTHORIZED);
}
// Alternative approach that doesn't try to parse the response body
#[test]
fn test_error_messages() {
// For validation errors
let message = "invalid data format";
let error = ApiError::ValidationError(message.to_string());
// Check that the error itself contains the message
assert_eq!(error.to_string(), format!("Validation error: {}", message));
// For not found errors
let message = "user not found";
let error = ApiError::NotFound(message.to_string());
assert_eq!(error.to_string(), format!("Not found: {}", message));
}
// Alternative approach for internal error test
#[test]
fn test_internal_error_sanitization() {
// Create a sensitive error message
let sensitive_info = "db password incorrect";
// Create ApiError with sensitive info
let api_error = ApiError::InternalError(sensitive_info.to_string());
// Check the error message is correctly set
assert_eq!(api_error.to_string(), "Internal server error");
// Also verify correct status code
assert_status_code(api_error, StatusCode::INTERNAL_SERVER_ERROR);
}
}

View File

@@ -9,6 +9,7 @@ use middleware_api_auth::api_auth;
use routes::ingress::ingest_data;
pub mod api_state;
pub mod error;
mod middleware_api_auth;
mod routes;

View File

@@ -4,9 +4,9 @@ use axum::{
response::Response,
};
use common::{error::ApiError, storage::types::user::User};
use common::storage::types::user::User;
use crate::api_state::ApiState;
use crate::{api_state::ApiState, error::ApiError};
pub async fn api_auth(
State(state): State<ApiState>,

View File

@@ -1,7 +1,7 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use common::{
error::{ApiError, AppError},
error::AppError,
storage::types::{
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
user::User,
@@ -11,7 +11,7 @@ use futures::{future::try_join_all, TryFutureExt};
use tempfile::NamedTempFile;
use tracing::info;
use crate::api_state::ApiState;
use crate::{api_state::ApiState, error::ApiError};
#[derive(Debug, TryFromMultipart)]
pub struct IngestParams {

View File

@@ -12,10 +12,10 @@ tracing = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde_json = { workspace = true }
surrealdb = { workspace = true }
async-openai = { workspace = true }
futures = { workspace = true }
async-openai = "0.24.1"
async-stream = "0.3.6"
axum-htmx = "0.6.0"
axum_session = "0.14.4"
axum_session_auth = "0.14.1"
axum_session_surreal = "0.2.1"
@@ -23,23 +23,14 @@ axum_typed_multipart = "0.12.1"
chrono = { version = "0.4.39", features = ["serde"] }
chrono-tz = "0.10.1"
config = "0.15.4"
futures = "0.3.31"
json-stream-parser = "0.1.4"
lettre = { version = "0.11.11", features = ["rustls-tls"] }
lettre = { version = "0.11.11", features = [] }
mime = "0.3.17"
mime_guess = "2.0.5"
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
minijinja-autoreload = "2.5.0"
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
mockall = "0.13.0"
plotly = "0.12.1"
reqwest = {version = "0.12.12", features = ["charset", "json"]}
sha2 = "0.10.8"
surrealdb = "2.0.4"
tempfile = "3.12.0"
text-splitter = "0.18.1"
tower-http = { version = "0.6.2", features = ["fs"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
url = { version = "2.5.2", features = ["serde"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }

View File

@@ -1,15 +1,4 @@
use std::sync::Arc;
use async_openai::error::OpenAIError;
use axum::{
http::StatusCode,
response::{Html, IntoResponse, Response},
Json,
};
use minijinja::context;
use minijinja_autoreload::AutoReloader;
use serde::Serialize;
use serde_json::json;
use thiserror::Error;
use tokio::task::JoinError;
@@ -49,206 +38,3 @@ pub enum AppError {
#[error("Ingress Processing error: {0}")]
Processing(String),
}
// API-specific errors
#[derive(Debug, Serialize)]
pub enum ApiError {
InternalError(String),
ValidationError(String),
NotFound(String),
Unauthorized(String),
}
impl From<AppError> for ApiError {
fn from(err: AppError) -> Self {
match err {
AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => {
tracing::error!("Internal error: {:?}", err);
ApiError::InternalError("Internal server error".to_string())
}
AppError::NotFound(msg) => ApiError::NotFound(msg),
AppError::Validation(msg) => ApiError::ValidationError(msg),
AppError::Auth(msg) => ApiError::Unauthorized(msg),
_ => ApiError::InternalError("Internal server error".to_string()),
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let (status, body) = match self {
ApiError::InternalError(message) => (
StatusCode::INTERNAL_SERVER_ERROR,
json!({
"error": message,
"status": "error"
}),
),
ApiError::ValidationError(message) => (
StatusCode::BAD_REQUEST,
json!({
"error": message,
"status": "error"
}),
),
ApiError::NotFound(message) => (
StatusCode::NOT_FOUND,
json!({
"error": message,
"status": "error"
}),
),
ApiError::Unauthorized(message) => (
StatusCode::UNAUTHORIZED,
json!({
"error": message,
"status": "error"
}),
), // ... other matches
};
(status, Json(body)).into_response()
}
}
pub type TemplateResult<T> = Result<T, HtmlError>;
// Helper trait for converting to HtmlError with templates
pub trait IntoHtmlError {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError;
}
// // Implement for AppError
impl IntoHtmlError for AppError {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
HtmlError::new(self, templates)
}
}
// // Implement for minijinja::Error directly
impl IntoHtmlError for minijinja::Error {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
HtmlError::from_template_error(self, templates)
}
}
#[derive(Clone)]
pub struct ErrorContext {
#[allow(dead_code)]
templates: Arc<AutoReloader>,
}
impl ErrorContext {
pub fn new(templates: Arc<AutoReloader>) -> Self {
Self { templates }
}
}
pub enum HtmlError {
ServerError(Arc<AutoReloader>),
NotFound(Arc<AutoReloader>),
Unauthorized(Arc<AutoReloader>),
BadRequest(String, Arc<AutoReloader>),
Template(String, Arc<AutoReloader>),
}
impl HtmlError {
pub fn new(error: AppError, templates: Arc<AutoReloader>) -> Self {
match error {
AppError::NotFound(_msg) => HtmlError::NotFound(templates),
AppError::Auth(_msg) => HtmlError::Unauthorized(templates),
AppError::Validation(msg) => HtmlError::BadRequest(msg, templates),
_ => {
tracing::error!("Internal error: {:?}", error);
HtmlError::ServerError(templates)
}
}
}
pub fn from_template_error(error: minijinja::Error, templates: Arc<AutoReloader>) -> Self {
tracing::error!("Template error: {:?}", error);
HtmlError::Template(error.to_string(), templates)
}
}
impl IntoResponse for HtmlError {
fn into_response(self) -> Response {
let (status, context, templates) = match self {
HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => (
StatusCode::INTERNAL_SERVER_ERROR,
context! {
status_code => 500,
title => "Internal Server Error",
error => "Internal Server Error",
description => "Something went wrong on our end."
},
templates,
),
HtmlError::NotFound(templates) => (
StatusCode::NOT_FOUND,
context! {
status_code => 404,
title => "Page Not Found",
error => "Not Found",
description => "The page you're looking for doesn't exist or was removed."
},
templates,
),
HtmlError::Unauthorized(templates) => (
StatusCode::UNAUTHORIZED,
context! {
status_code => 401,
title => "Unauthorized",
error => "Access Denied",
description => "You need to be logged in to access this page."
},
templates,
),
HtmlError::BadRequest(msg, templates) => (
StatusCode::BAD_REQUEST,
context! {
status_code => 400,
title => "Bad Request",
error => "Bad Request",
description => msg
},
templates,
),
};
let html = match templates.acquire_env() {
Ok(env) => match env.get_template("errors/error.html") {
Ok(tmpl) => match tmpl.render(context) {
Ok(output) => output,
Err(e) => {
tracing::error!("Template render error: {:?}", e);
Self::fallback_html()
}
},
Err(e) => {
tracing::error!("Template get error: {:?}", e);
Self::fallback_html()
}
},
Err(e) => {
tracing::error!("Environment acquire error: {:?}", e);
Self::fallback_html()
}
};
(status, Html(html)).into_response()
}
}
impl HtmlError {
fn fallback_html() -> String {
r#"
<html>
<body>
<div class="container mx-auto p-4">
<h1 class="text-4xl text-error">Error</h1>
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
</div>
</body>
</html>
"#
.to_string()
}
}

View File

@@ -1,3 +0,0 @@
pub mod ingress_analyser;
pub mod prompt;
pub mod types;

View File

@@ -1,2 +0,0 @@
pub mod graph_mapper;
pub mod llm_analysis_result;

View File

@@ -1,2 +0,0 @@
pub mod analysis;
pub mod content_processor;

View File

@@ -1,5 +1,3 @@
pub mod error;
pub mod ingress;
pub mod retrieval;
pub mod storage;
pub mod utils;

View File

@@ -0,0 +1,19 @@
[package]
name = "composite-retrieval"
version = "0.1.0"
edition = "2021"
[dependencies]
# Workspace dependencies
tokio = { workspace = true }
serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde_json = { workspace = true }
surrealdb = { workspace = true }
futures = { workspace = true }
async-openai = { workspace = true }
common = { path = "../common" }

View File

@@ -10,13 +10,14 @@ use serde::Deserialize;
use serde_json::{json, Value};
use tracing::debug;
use crate::{
use common::{
error::AppError,
retrieval::combined_knowledge_entity_retrieval,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use super::query_helper_prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT};
use crate::retrieve_entities;
use super::answer_retrieval_helper::{get_query_response_schema, QUERY_SYSTEM_PROMPT};
#[derive(Debug, Deserialize)]
pub struct Reference {
@@ -31,47 +32,6 @@ pub struct LLMResponseFormat {
pub references: Vec<Reference>,
}
// /// Orchestrator function that takes a query and clients and returns a answer with references
// ///
// /// # Arguments
// /// * `surreal_db_client` - Client for interacting with SurrealDn
// /// * `openai_client` - Client for interacting with openai
// /// * `query` - The query
// ///
// /// # Returns
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
// pub async fn get_answer_with_references(
// surreal_db_client: &SurrealDbClient,
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
// query: &str,
// ) -> Result<(String, Vec<String>), ApiError> {
// let entities =
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
// // Format entities and create message
// let entities_json = format_entities_json(&entities);
// let user_message = create_user_message(&entities_json, query);
// // Create and send request
// let request = create_chat_request(user_message)?;
// let response = openai_client
// .chat()
// .create(request)
// .await
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
// // Process response
// let answer = process_llm_response(response).await?;
// let references: Vec<String> = answer
// .references
// .into_iter()
// .map(|reference| reference.reference)
// .collect();
// Ok((answer.answer, references))
// }
/// Orchestrates query processing and returns an answer with references
///
/// Takes a query and uses the provided clients to generate an answer with supporting references.
@@ -98,9 +58,7 @@ pub async fn get_answer_with_references(
query: &str,
user_id: &str,
) -> Result<Answer, AppError> {
let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
.await?;
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
let entities_json = format_entities_json(&entities);
debug!("{:?}", entities_json);

View File

@@ -1,7 +1,7 @@
use surrealdb::Error;
use tracing::debug;
use crate::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
/// Retrieves database entries that match a specific source identifier.
///

View File

@@ -1,21 +1,19 @@
pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod graph;
pub mod query_helper;
pub mod query_helper_prompt;
pub mod vector;
use crate::{
use common::{
error::AppError,
retrieval::{
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
vector::find_items_by_vector_similarity,
},
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
},
};
use futures::future::{try_join, try_join_all};
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
use std::collections::HashMap;
use vector::find_items_by_vector_similarity;
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
/// to find the most relevant entities for a given query.
@@ -38,7 +36,7 @@ use std::collections::HashMap;
/// # Returns
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
/// knowledge entities, or an error if the retrieval process fails
pub async fn combined_knowledge_entity_retrieval(
pub async fn retrieve_entities(
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,

View File

@@ -1,6 +1,6 @@
use surrealdb::{engine::any::Any, Surreal};
use crate::{error::AppError, utils::embedding::generate_embedding};
use common::{error::AppError, utils::embedding::generate_embedding};
/// Compares vectors and retrieves a number of items from the specified table.
///

View File

@@ -10,7 +10,8 @@ serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
serde_json = { workspace = true }
async-openai = { workspace = true }
thiserror = { workspace = true }
axum-htmx = "0.6.0"
axum_session = "0.14.4"
@@ -28,8 +29,6 @@ plotly = "0.12.1"
surrealdb = "2.0.4"
tower-http = { version = "0.6.2", features = ["fs"] }
chrono-tz = "0.10.1"
async-openai = "0.24.1"
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }

View File

@@ -0,0 +1,139 @@
// use axum::{
// http::StatusCode,
// response::{Html, IntoResponse, Response},
// };
// use common::error::AppError;
// use minijinja::context;
// use minijinja_autoreload::AutoReloader;
// use std::sync::Arc;
// pub type TemplateResult<T> = Result<T, HtmlError>;
// // Helper trait for converting to HtmlError with templates
// pub trait IntoHtmlError {
// fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError;
// }
// // // Implement for AppError
// impl IntoHtmlError for AppError {
// fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
// HtmlError::new(self, templates)
// }
// }
// // // Implement for minijinja::Error directly
// impl IntoHtmlError for minijinja::Error {
// fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
// HtmlError::from_template_error(self, templates)
// }
// }
// pub enum HtmlError {
// ServerError(Arc<AutoReloader>),
// NotFound(Arc<AutoReloader>),
// Unauthorized(Arc<AutoReloader>),
// BadRequest(String, Arc<AutoReloader>),
// Template(String, Arc<AutoReloader>),
// }
// impl HtmlError {
// pub fn new(error: AppError, templates: Arc<AutoReloader>) -> Self {
// match error {
// AppError::NotFound(_msg) => HtmlError::NotFound(templates),
// AppError::Auth(_msg) => HtmlError::Unauthorized(templates),
// AppError::Validation(msg) => HtmlError::BadRequest(msg, templates),
// _ => {
// tracing::error!("Internal error: {:?}", error);
// HtmlError::ServerError(templates)
// }
// }
// }
// pub fn from_template_error(error: minijinja::Error, templates: Arc<AutoReloader>) -> Self {
// tracing::error!("Template error: {:?}", error);
// HtmlError::Template(error.to_string(), templates)
// }
// }
// impl IntoResponse for HtmlError {
// fn into_response(self) -> Response {
// let (status, context, templates) = match self {
// HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => (
// StatusCode::INTERNAL_SERVER_ERROR,
// context! {
// status_code => 500,
// title => "Internal Server Error",
// error => "Internal Server Error",
// description => "Something went wrong on our end."
// },
// templates,
// ),
// HtmlError::NotFound(templates) => (
// StatusCode::NOT_FOUND,
// context! {
// status_code => 404,
// title => "Page Not Found",
// error => "Not Found",
// description => "The page you're looking for doesn't exist or was removed."
// },
// templates,
// ),
// HtmlError::Unauthorized(templates) => (
// StatusCode::UNAUTHORIZED,
// context! {
// status_code => 401,
// title => "Unauthorized",
// error => "Access Denied",
// description => "You need to be logged in to access this page."
// },
// templates,
// ),
// HtmlError::BadRequest(msg, templates) => (
// StatusCode::BAD_REQUEST,
// context! {
// status_code => 400,
// title => "Bad Request",
// error => "Bad Request",
// description => msg
// },
// templates,
// ),
// };
// let html = match templates.acquire_env() {
// Ok(env) => match env.get_template("errors/error.html") {
// Ok(tmpl) => match tmpl.render(context) {
// Ok(output) => output,
// Err(e) => {
// tracing::error!("Template render error: {:?}", e);
// Self::fallback_html()
// }
// },
// Err(e) => {
// tracing::error!("Template get error: {:?}", e);
// Self::fallback_html()
// }
// },
// Err(e) => {
// tracing::error!("Environment acquire error: {:?}", e);
// Self::fallback_html()
// }
// };
// (status, Html(html)).into_response()
// }
// }
// impl HtmlError {
// fn fallback_html() -> String {
// r#"
// <html>
// <body>
// <div class="container mx-auto p-4">
// <h1 class="text-4xl text-error">Error</h1>
// <p class="mt-4">Sorry, something went wrong displaying this page.</p>
// </div>
// </body>
// </html>
// "#
// .to_string()
// }
// }

View File

@@ -1,19 +1,23 @@
pub mod error;
pub mod html_state;
mod middleware_analytics;
mod middleware_auth;
mod routes;
mod template_response;
use axum::{
extract::FromRef,
middleware::from_fn_with_state,
middleware::{from_fn_with_state, map_response_with_state},
routing::{delete, get, patch, post},
Router,
};
use axum_session::SessionLayer;
use axum_session_auth::{AuthConfig, AuthSessionLayer};
use axum_session::{Session, SessionLayer};
use axum_session_auth::{AuthConfig, AuthSession, AuthSessionLayer};
use axum_session_surreal::SessionSurrealPool;
use common::storage::types::user::User;
use html_state::HtmlState;
use middleware_analytics::analytics_middleware;
use middleware_auth::require_auth;
use routes::{
account::{delete_account, set_api_key, show_account_page, update_timezone},
admin_panel::{show_admin_panel, toggle_registration_status},
@@ -39,26 +43,40 @@ use routes::{
signup::{process_signup_and_show_verification, show_signup_form},
};
use surrealdb::{engine::any::Any, Surreal};
use template_response::with_template_response;
use tower_http::services::ServeDir;
/// Router for HTML endpoints
pub type AuthSessionType = AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>;
pub type SessionType = Session<SessionSurrealPool<Any>>;
/// Html routes
pub fn html_routes<S>(app_state: &HtmlState) -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
// Public routes - no auth required
let public_routes = Router::new()
.route("/", get(index_handler))
.route("/gdpr/accept", post(accept_gdpr))
.route("/gdpr/deny", post(deny_gdpr))
.route("/search", get(search_result_handler))
.route("/signout", get(sign_out_user))
.route("/signin", get(show_signin_form).post(authenticate_user))
.route(
"/signup",
get(show_signup_form).post(process_signup_and_show_verification),
)
.route("/documentation", get(show_documentation_index))
.route("/documentation/privacy-policy", get(show_privacy_policy))
.route("/documentation/get-started", get(show_get_started))
.route("/documentation/mobile-friendly", get(show_mobile_friendly))
.nest_service("/assets", ServeDir::new("assets/"));
// Protected routes - auth required
let protected_routes = Router::new()
.route("/chat", get(show_chat_base).post(new_chat_user_message))
.route("/initialized-chat", post(show_initialized_chat))
.route("/chat/:id", get(show_existing_chat).post(new_user_message))
.route("/chat/response-stream", get(get_response_stream))
.route("/knowledge/:id", get(show_reference_tooltip))
.route("/signout", get(sign_out_user))
.route("/signin", get(show_signin_form).post(authenticate_user))
.route(
"/ingress-form",
get(show_ingress_form).post(process_ingress_form),
@@ -72,6 +90,9 @@ where
"/content/:id",
get(show_text_content_edit_form).patch(patch_text_content),
)
.route("/search", get(search_result_handler))
.route("/chat/response-stream", get(get_response_stream))
.route("/knowledge/:id", get(show_reference_tooltip))
.route("/knowledge", get(show_knowledge_page))
.route(
"/knowledge-entity/:id",
@@ -90,16 +111,17 @@ where
.route("/set-api-key", post(set_api_key))
.route("/update-timezone", patch(update_timezone))
.route("/delete-account", delete(delete_account))
.route(
"/signup",
get(show_signup_form).post(process_signup_and_show_verification),
)
.route("/documentation", get(show_documentation_index))
.route("/documentation/privacy-policy", get(show_privacy_policy))
.route("/documentation/get-started", get(show_get_started))
.route("/documentation/mobile-friendly", get(show_mobile_friendly))
.nest_service("/assets", ServeDir::new("assets/"))
.route_layer(from_fn_with_state(app_state.clone(), require_auth));
// Combine routes and add common middleware
Router::new()
.merge(public_routes)
.merge(protected_routes)
.layer(from_fn_with_state(app_state.clone(), analytics_middleware))
.layer(map_response_with_state(
app_state.clone(),
with_template_response,
))
.layer(
AuthSessionLayer::<User, String, SessionSurrealPool<Any>, Surreal<Any>>::new(Some(
app_state.db.client.clone(),

View File

@@ -0,0 +1,48 @@
use axum::{
async_trait,
extract::{FromRequestParts, Request},
http::request::Parts,
middleware::Next,
response::{IntoResponse, Response},
};
use common::storage::types::user::User;
use crate::{template_response::TemplateResponse, AuthSessionType};
#[derive(Debug, Clone)]
pub struct RequireUser(pub User);
// Implement FromRequestParts for RequireUser
#[async_trait]
impl<S> FromRequestParts<S> for RequireUser
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<User>()
.cloned()
.map(RequireUser)
.ok_or_else(|| TemplateResponse::redirect("/signin").into_response())
}
}
// Auth middleware that adds the user to extensions
pub async fn require_auth(auth: AuthSessionType, mut request: Request, next: Next) -> Response {
// Check if user is authenticated
match auth.current_user {
Some(user) => {
// Add user to request extensions
request.extensions_mut().insert(user);
// Continue to the handler
next.run(request).await
}
None => {
// Redirect to login
TemplateResponse::redirect("/signin").into_response()
}
}
}

View File

@@ -1,65 +1,42 @@
use axum::{
extract::State,
http::{StatusCode, Uri},
response::{IntoResponse, Redirect},
Form,
};
use axum_htmx::HxRedirect;
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use axum::{extract::State, response::IntoResponse, Form};
use chrono_tz::TZ_VARIANTS;
use surrealdb::{engine::any::Any, Surreal};
use serde::{Deserialize, Serialize};
use common::{
error::{AppError, HtmlError},
storage::types::user::User,
use crate::{
middleware_auth::RequireUser,
template_response::{HtmlError, TemplateResponse},
AuthSessionType,
};
use common::storage::types::user::User;
use crate::{html_state::HtmlState, page_data};
use crate::html_state::HtmlState;
use super::{render_block, render_template};
page_data!(AccountData, "auth/account_settings.html", {
#[derive(Serialize)]
pub struct AccountPageData {
user: User,
timezones: Vec<String>
});
timezones: Vec<String>,
}
pub async fn show_account_page(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
let output = render_template(
AccountData::template_name(),
AccountData { user, timezones },
state.templates.clone(),
)?;
Ok(output.into_response())
Ok(TemplateResponse::new_template(
"auth/account_settings.html",
AccountPageData { user, timezones },
))
}
pub async fn set_api_key(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
// Generate and set the API key
let api_key = User::set_api_key(&user.id, &state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let api_key = User::set_api_key(&user.id, &state.db).await?;
// Clear the cache so new requests have access to the user with api key
auth.cache_clear_user(user.id.to_string());
// Update the user's API key
@@ -69,40 +46,28 @@ pub async fn set_api_key(
};
// Render the API key section block
let output = render_block(
AccountData::template_name(),
Ok(TemplateResponse::new_partial(
"auth/account_settings.html",
"api_key_section",
AccountData {
AccountPageData {
user: updated_user,
timezones: vec![],
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}
pub async fn delete_account(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
state
.db
.delete_item::<User>(&user.id)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
state.db.delete_item::<User>(&user.id).await?;
auth.logout_user();
auth.session.destroy();
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
Ok(TemplateResponse::redirect("/"))
}
#[derive(Deserialize)]
@@ -112,18 +77,13 @@ pub struct UpdateTimezoneForm {
pub async fn update_timezone(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
Form(form): Form<UpdateTimezoneForm>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
User::update_timezone(&user.id, &form.timezone, &state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
User::update_timezone(&user.id, &form.timezone, &state.db).await?;
// Clear the cache
auth.cache_clear_user(user.id.to_string());
// Update the user's API key
@@ -135,15 +95,12 @@ pub async fn update_timezone(
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
// Render the API key section block
let output = render_block(
AccountData::template_name(),
Ok(TemplateResponse::new_partial(
"auth/account_settings.html",
"timezone_section",
AccountData {
AccountPageData {
user: updated_user,
timezones,
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}

View File

@@ -1,62 +1,39 @@
use axum::{
extract::State,
response::{IntoResponse, Redirect},
Form,
use axum::{extract::State, response::IntoResponse, Form};
use serde::{Deserialize, Serialize};
use crate::{
middleware_auth::RequireUser,
template_response::{HtmlError, TemplateResponse},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use common::storage::types::{analytics::Analytics, system_settings::SystemSettings, user::User};
use common::{
error::HtmlError,
storage::types::{analytics::Analytics, system_settings::SystemSettings, user::User},
};
use crate::html_state::HtmlState;
use crate::{html_state::HtmlState, page_data};
use super::{render_block, render_template};
page_data!(AdminPanelData, "auth/admin_panel.html", {
#[derive(Serialize)]
pub struct AdminPanelData {
user: User,
settings: SystemSettings,
analytics: Analytics,
users: i64,
});
}
pub async fn show_admin_panel(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated and admin
let user = match auth.current_user {
Some(user) if user.admin => user,
_ => return Ok(Redirect::to("/").into_response()),
};
let settings = SystemSettings::get_current(&state.db).await?;
let analytics = Analytics::get_current(&state.db).await?;
let users_count = Analytics::get_users_amount(&state.db).await?;
let settings = SystemSettings::get_current(&state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let analytics = Analytics::get_current(&state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let users_count = Analytics::get_users_amount(&state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
AdminPanelData::template_name(),
Ok(TemplateResponse::new_template(
"auth/admin_panel.html",
AdminPanelData {
user,
settings,
analytics,
users: users_count,
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}
fn checkbox_to_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
@@ -83,36 +60,28 @@ pub struct RegistrationToggleData {
pub async fn toggle_registration_status(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
Form(input): Form<RegistrationToggleInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated and admin
let _user = match auth.current_user {
Some(user) if user.admin => user,
_ => return Ok(Redirect::to("/").into_response()),
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
};
let current_settings = SystemSettings::get_current(&state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {
registrations_enabled: input.registration_open,
..current_settings.clone()
};
SystemSettings::update(&state.db, new_settings.clone())
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
SystemSettings::update(&state.db, new_settings.clone()).await?;
let output = render_block(
AdminPanelData::template_name(),
Ok(TemplateResponse::new_partial(
"auth/admin_panel.html",
"registration_status_input",
RegistrationToggleData {
settings: new_settings,
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}

View File

@@ -10,6 +10,12 @@ use axum::{
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use composite_retrieval::{
answer_retrieval::{
create_chat_request, create_user_message, format_entities_json, LLMResponseFormat,
},
retrieve_entities,
};
use futures::{
stream::{self, once},
Stream, StreamExt, TryStreamExt,
@@ -21,19 +27,11 @@ use surrealdb::{engine::any::Any, Surreal};
use tokio::sync::{mpsc::channel, Mutex};
use tracing::{error, info};
use common::{
retrieval::{
combined_knowledge_entity_retrieval,
query_helper::{
create_chat_request, create_user_message, format_entities_json, LLMResponseFormat,
},
},
storage::{
db::SurrealDbClient,
types::{
message::{Message, MessageRole},
user::User,
},
use common::storage::{
db::SurrealDbClient,
types::{
message::{Message, MessageRole},
user::User,
},
};
@@ -100,7 +98,7 @@ pub async fn get_response_stream(
};
// 2. Retrieve knowledge entities
let entities = match combined_knowledge_entity_retrieval(
let entities = match retrieve_entities(
&state.db,
&state.openai_client,
&user_message.content,

View File

@@ -12,8 +12,9 @@ use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use crate::routes::HtmlError;
use common::{
error::{AppError, HtmlError},
error::AppError,
storage::types::{
conversation::Conversation,
message::{Message, MessageRole},

View File

@@ -8,8 +8,9 @@ use serde::Serialize;
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use crate::routes::HtmlError;
use common::{
error::{AppError, HtmlError},
error::AppError,
storage::types::{knowledge_entity::KnowledgeEntity, user::User},
};

View File

@@ -6,12 +6,9 @@ use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use common::{
error::HtmlError,
storage::types::{text_content::TextContent, user::User},
};
use common::storage::types::{text_content::TextContent, user::User};
use crate::{html_state::HtmlState, page_data};
use crate::{error::HtmlError, html_state::HtmlState, page_data};
use super::render_template;

View File

@@ -1,78 +1,54 @@
use axum::{extract::State, response::IntoResponse};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use axum::response::IntoResponse;
use common::storage::types::user::User;
use serde::Serialize;
use common::{error::HtmlError, storage::types::user::User};
use crate::template_response::{HtmlError, TemplateResponse};
use crate::AuthSessionType;
use crate::{html_state::HtmlState, page_data};
use super::render_template;
page_data!(DocumentationData, "do_not_use_this", {
#[derive(Serialize)]
pub struct DocumentationPageData {
user: Option<User>,
current_path: String
});
pub async fn show_privacy_policy(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
let output = render_template(
"documentation/privacy.html",
DocumentationData {
user: auth.current_user,
current_path: "/privacy_policy".to_string(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
current_path: String,
}
pub async fn show_get_started(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
let output = render_template(
pub async fn show_privacy_policy(auth: AuthSessionType) -> Result<impl IntoResponse, HtmlError> {
Ok(TemplateResponse::new_template(
"documentation/privacy.html",
DocumentationPageData {
user: auth.current_user,
current_path: "/privacy-policy".to_string(),
},
))
}
pub async fn show_get_started(auth: AuthSessionType) -> Result<impl IntoResponse, HtmlError> {
Ok(TemplateResponse::new_template(
"documentation/get_started.html",
DocumentationData {
DocumentationPageData {
user: auth.current_user,
current_path: "/get-started".to_string(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}
pub async fn show_mobile_friendly(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
let output = render_template(
pub async fn show_mobile_friendly(auth: AuthSessionType) -> Result<impl IntoResponse, HtmlError> {
Ok(TemplateResponse::new_template(
"documentation/mobile_friendly.html",
DocumentationData {
DocumentationPageData {
user: auth.current_user,
current_path: "/mobile-friendly".to_string(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}
pub async fn show_documentation_index(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
let output = render_template(
Ok(TemplateResponse::new_template(
"documentation/index.html",
DocumentationData {
DocumentationPageData {
user: auth.current_user,
current_path: "/index".to_string(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}

View File

@@ -1,22 +1,15 @@
use axum::response::{Html, IntoResponse};
use axum_session::Session;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::engine::any::Any;
use common::error::HtmlError;
use crate::SessionType;
pub async fn accept_gdpr(
session: Session<SessionSurrealPool<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
pub async fn accept_gdpr(session: SessionType) -> impl IntoResponse {
session.set("gdpr_accepted", true);
Ok(Html("").into_response())
Html("").into_response()
}
pub async fn deny_gdpr(
session: Session<SessionSurrealPool<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
pub async fn deny_gdpr(session: SessionType) -> impl IntoResponse {
session.set("gdpr_accepted", true);
Ok(Html("").into_response())
Html("").into_response()
}

View File

@@ -1,16 +1,18 @@
use axum::{
debug_handler,
extract::{Path, State},
response::{IntoResponse, Redirect},
response::IntoResponse,
};
use axum_session::Session;
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use serde::Serialize;
use tokio::join;
use tracing::info;
use crate::{
middleware_auth::RequireUser,
template_response::{HtmlError, TemplateResponse},
AuthSessionType, SessionType,
};
use common::{
error::{AppError, HtmlError},
error::AppError,
storage::types::{
file_info::FileInfo, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
@@ -18,67 +20,51 @@ use common::{
},
};
use crate::{html_state::HtmlState, page_data, routes::render_template};
use crate::html_state::HtmlState;
use super::render_block;
page_data!(IndexData, "index/index.html", {
#[derive(Serialize)]
pub struct IndexPageData {
gdpr_accepted: bool,
user: Option<User>,
latest_text_contents: Vec<TextContent>,
active_jobs: Vec<IngestionTask>
});
active_jobs: Vec<IngestionTask>,
}
pub async fn index_handler(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
session: Session<SessionSurrealPool<Any>>,
auth: AuthSessionType,
session: SessionType,
) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying index page");
let gdpr_accepted = auth.current_user.is_some() | session.get("gdpr_accepted").unwrap_or(false);
let active_jobs = match auth.current_user.is_some() {
true => {
User::get_unfinished_ingestion_tasks(&auth.current_user.clone().unwrap().id, &state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?
.await?
}
false => vec![],
};
let latest_text_contents = match auth.current_user.clone().is_some() {
true => User::get_latest_text_contents(
auth.current_user.clone().unwrap().id.as_str(),
&state.db,
)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?,
true => {
User::get_latest_text_contents(
auth.current_user.clone().unwrap().id.as_str(),
&state.db,
)
.await?
}
false => vec![],
};
// let latest_knowledge_entities = match auth.current_user.is_some() {
// true => User::get_latest_knowledge_entities(
// auth.current_user.clone().unwrap().id.as_str(),
// &state.db,
// )
// .await
// .map_err(|e| HtmlError::new(e, state.templates.clone()))?,
// false => vec![],
// };
let output = render_template(
IndexData::template_name(),
IndexData {
Ok(TemplateResponse::new_template(
"index/index.html",
IndexPageData {
gdpr_accepted,
user: auth.current_user,
latest_text_contents,
active_jobs,
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}
#[derive(Serialize)]
@@ -87,21 +73,17 @@ pub struct LatestTextContentData {
user: User,
}
#[debug_handler]
pub async fn delete_text_content(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
// Get and validate TextContent
let text_content = get_and_validate_text_content(&state, &id, user).await?;
let text_content = get_and_validate_text_content(&state, &id, &user).await?;
// Perform concurrent deletions
let deletion_tasks = join!(
join!(
async {
if let Some(file_info) = text_content.file_info {
FileInfo::delete_by_id(&file_info.id, &state.db).await
@@ -115,33 +97,17 @@ pub async fn delete_text_content(
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db)
);
// Handle potential errors from concurrent operations
match deletion_tasks {
(Ok(_), Ok(_), Ok(_), Ok(_), Ok(_)) => (),
_ => {
return Err(HtmlError::new(
AppError::Processing("Failed to delete one or more items".to_string()),
state.templates.clone(),
))
}
}
// Render updated content
let latest_text_contents = User::get_latest_text_contents(&user.id, &state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let latest_text_contents = User::get_latest_text_contents(&user.id, &state.db).await?;
let output = render_block(
Ok(TemplateResponse::new_partial(
"index/signed_in/recent_content.html",
"latest_content_section",
LatestTextContentData {
user: user.clone(),
user: user.to_owned(),
latest_text_contents,
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}
// Helper function to get and validate text content
@@ -149,23 +115,16 @@ async fn get_and_validate_text_content(
state: &HtmlState,
id: &str,
user: &User,
) -> Result<TextContent, HtmlError> {
) -> Result<TextContent, AppError> {
let text_content = state
.db
.get_item::<TextContent>(id)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?
.ok_or_else(|| {
HtmlError::new(
AppError::NotFound("No item found".to_string()),
state.templates.clone(),
)
})?;
.await?
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
if text_content.user_id != user.id {
return Err(HtmlError::new(
AppError::Auth("You are not the owner of that content".to_string()),
state.templates.clone(),
return Err(AppError::Auth(
"You are not the owner of that content".to_string(),
));
}
@@ -180,57 +139,35 @@ pub struct ActiveJobsData {
pub async fn delete_job(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
User::validate_and_delete_job(&id, &user.id, &state.db).await?;
User::validate_and_delete_job(&id, &user.id, &state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_block(
Ok(TemplateResponse::new_partial(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}
pub async fn show_active_jobs(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_block(
Ok(TemplateResponse::new_partial(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
state.templates.clone(),
)?;
Ok(output.into_response())
))
}

View File

@@ -11,7 +11,7 @@ use tempfile::NamedTempFile;
use tracing::info;
use common::{
error::{AppError, HtmlError, IntoHtmlError},
error::AppError,
storage::types::{
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
user::User,
@@ -19,6 +19,7 @@ use common::{
};
use crate::{
error::{HtmlError, IntoHtmlError},
html_state::HtmlState,
page_data,
routes::{index::ActiveJobsData, render_block},

View File

@@ -14,7 +14,7 @@ use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use common::{
error::{AppError, HtmlError},
error::AppError,
storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
@@ -22,7 +22,7 @@ use common::{
},
};
use crate::{html_state::HtmlState, page_data, routes::render_template};
use crate::{error::HtmlError, html_state::HtmlState, page_data, routes::render_template};
page_data!(KnowledgeBaseData, "knowledge/base.html", {
entities: Vec<KnowledgeEntity>,

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use axum::response::Html;
use minijinja_autoreload::AutoReloader;
use common::error::{HtmlError, IntoHtmlError};
use crate::error::{HtmlError, IntoHtmlError};
pub mod account;
pub mod admin_panel;
@@ -19,73 +19,73 @@ pub mod signin;
pub mod signout;
pub mod signup;
pub trait PageData {
fn template_name() -> &'static str;
}
// pub trait PageData {
// fn template_name() -> &'static str;
// }
// Helper function for render_template
pub fn render_template<T>(
template_name: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, HtmlError>
where
T: serde::Serialize,
{
let env = templates
.acquire_env()
.map_err(|e| e.with_template(templates.clone()))?;
let tmpl = env
.get_template(template_name)
.map_err(|e| e.with_template(templates.clone()))?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl
.render(context)
.map_err(|e| e.with_template(templates.clone()))?;
Ok(Html(output))
}
// // Helper function for render_template
// pub fn render_template<T>(
// template_name: &str,
// context: T,
// templates: Arc<AutoReloader>,
// ) -> Result<Html<String>, HtmlError>
// where
// T: serde::Serialize,
// {
// let env = templates
// .acquire_env()
// .map_err(|e| e.with_template(templates.clone()))?;
// let tmpl = env
// .get_template(template_name)
// .map_err(|e| e.with_template(templates.clone()))?;
// let context = minijinja::Value::from_serialize(&context);
// let output = tmpl
// .render(context)
// .map_err(|e| e.with_template(templates.clone()))?;
// Ok(Html(output))
// }
pub fn render_block<T>(
template_name: &str,
block: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, HtmlError>
where
T: serde::Serialize,
{
let env = templates
.acquire_env()
.map_err(|e| e.with_template(templates.clone()))?;
let tmpl = env
.get_template(template_name)
.map_err(|e| e.with_template(templates.clone()))?;
// pub fn render_block<T>(
// template_name: &str,
// block: &str,
// context: T,
// templates: Arc<AutoReloader>,
// ) -> Result<Html<String>, HtmlError>
// where
// T: serde::Serialize,
// {
// let env = templates
// .acquire_env()
// .map_err(|e| e.with_template(templates.clone()))?;
// let tmpl = env
// .get_template(template_name)
// .map_err(|e| e.with_template(templates.clone()))?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl
.eval_to_state(context)
.map_err(|e| e.with_template(templates.clone()))?
.render_block(block)
.map_err(|e| e.with_template(templates.clone()))?;
// let context = minijinja::Value::from_serialize(&context);
// let output = tmpl
// .eval_to_state(context)
// .map_err(|e| e.with_template(templates.clone()))?
// .render_block(block)
// .map_err(|e| e.with_template(templates.clone()))?;
Ok(output.into())
}
// Ok(output.into())
// }
#[macro_export]
macro_rules! page_data {
($name:ident, $template_name:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*$(,)?}) => {
use serde::{Serialize, Deserialize};
use $crate::routes::PageData;
// #[macro_export]
// macro_rules! page_data {
// ($name:ident, $template_name:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*$(,)?}) => {
// use serde::{Serialize, Deserialize};
// use $crate::routes::PageData;
#[derive(Debug, Deserialize, Serialize)]
pub struct $name {
$($(#[$attr])* pub $field: $ty),*
}
// #[derive(Debug, Deserialize, Serialize)]
// pub struct $name {
// $($(#[$attr])* pub $field: $ty),*
// }
impl PageData for $name {
fn template_name() -> &'static str {
$template_name
}
}
};
}
// impl PageData for $name {
// fn template_name() -> &'static str {
// $template_name
// }
// }
// };
// }

View File

@@ -8,7 +8,8 @@ use serde::{Deserialize, Serialize};
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use common::{error::HtmlError, storage::types::user::User};
use crate::routes::HtmlError;
use common::storage::types::user::User;
use crate::{html_state::HtmlState, routes::render_template};
#[derive(Deserialize)]

View File

@@ -1,19 +1,17 @@
use axum::{
extract::State,
http::{StatusCode, Uri},
response::{Html, IntoResponse, Redirect},
response::{Html, IntoResponse},
Form,
};
use axum_htmx::{HxBoosted, HxRedirect};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
use common::{error::HtmlError, storage::types::user::User};
use crate::{html_state::HtmlState, page_data};
use super::{render_block, render_template};
use crate::{
html_state::HtmlState,
template_response::{HtmlError, TemplateResponse},
AuthSessionType,
};
use common::storage::types::user::User;
#[derive(Deserialize, Serialize)]
pub struct SignupParams {
@@ -22,36 +20,26 @@ pub struct SignupParams {
pub remember_me: Option<String>,
}
page_data!(ShowSignInForm, "auth/signin_form.html", {});
pub async fn show_signin_form(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
auth: AuthSessionType,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
return Ok(TemplateResponse::redirect("/"));
}
let output = match boosted {
true => render_block(
ShowSignInForm::template_name(),
match boosted {
true => Ok(TemplateResponse::new_partial(
"auth/signin_form.html",
"body",
ShowSignInForm {},
state.templates.clone(),
)?,
false => render_template(
ShowSignInForm::template_name(),
ShowSignInForm {},
state.templates.clone(),
)?,
};
Ok(output.into_response())
{},
)),
false => Ok(TemplateResponse::new_template("auth/signin_form.html", {})),
}
}
pub async fn authenticate_user(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
auth: AuthSessionType,
Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match User::authenticate(form.email, form.password, &state.db).await {
@@ -67,5 +55,5 @@ pub async fn authenticate_user(
auth.remember_user(true);
}
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
Ok(TemplateResponse::redirect("/").into_response())
}

View File

@@ -1,18 +1,16 @@
use axum::response::{IntoResponse, Redirect};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use axum::response::IntoResponse;
use common::{error::ApiError, storage::types::user::User};
use crate::{
template_response::{HtmlError, TemplateResponse},
AuthSessionType,
};
pub async fn sign_out_user(
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, ApiError> {
pub async fn sign_out_user(auth: AuthSessionType) -> Result<impl IntoResponse, HtmlError> {
if !auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
return Ok(TemplateResponse::redirect("/"));
}
auth.logout_user();
Ok(Redirect::to("/").into_response())
Ok(TemplateResponse::redirect("/"))
}

View File

@@ -1,20 +1,18 @@
use axum::{
extract::State,
http::{StatusCode, Uri},
response::{Html, IntoResponse, Redirect},
response::{Html, IntoResponse},
Form,
};
use axum_htmx::{HxBoosted, HxRedirect};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
use surrealdb::{engine::any::Any, Surreal};
use common::{error::HtmlError, storage::types::user::User};
use common::storage::types::user::User;
use crate::html_state::HtmlState;
use super::{render_block, render_template};
use crate::{
html_state::HtmlState,
template_response::{HtmlError, TemplateResponse},
AuthSessionType,
};
#[derive(Deserialize, Serialize)]
pub struct SignupParams {
@@ -24,24 +22,26 @@ pub struct SignupParams {
}
pub async fn show_signup_form(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
auth: AuthSessionType,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
return Ok(TemplateResponse::redirect("/"));
}
let output = match boosted {
true => render_block("auth/signup_form.html", "body", {}, state.templates.clone())?,
false => render_template("auth/signup_form.html", {}, state.templates.clone())?,
};
Ok(output.into_response())
match boosted {
true => Ok(TemplateResponse::new_partial(
"auth/signup_form.html",
"body",
{},
)),
false => Ok(TemplateResponse::new_template("auth/signup_form.html", {})),
}
}
pub async fn process_signup_and_show_verification(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
auth: AuthSessionType,
Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match User::create_new(form.email, form.password, &state.db, form.timezone).await {
@@ -54,5 +54,5 @@ pub async fn process_signup_and_show_verification(
auth.login_user(user.id);
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
Ok(TemplateResponse::redirect("/").into_response())
}

View File

@@ -0,0 +1,288 @@
use axum::{
extract::State,
http::StatusCode,
response::{Html, IntoResponse, Response},
Extension,
};
use common::error::AppError;
use minijinja::{context, Value};
use minijinja_autoreload::AutoReloader;
use serde::Serialize;
use std::sync::Arc;
use crate::{html_state::HtmlState, AuthSessionType};
// Enum for template types
#[derive(Clone)]
pub enum TemplateKind {
Full(String), // Full page template
Partial(String, String), // Template name, block name
Error(StatusCode), // Error template with status code
Redirect(axum::response::Redirect), // Redirect
}
#[derive(Clone)]
pub struct TemplateResponse {
template_kind: TemplateKind,
context: Value,
}
impl TemplateResponse {
pub fn new_template<T: Serialize>(name: impl Into<String>, context: T) -> Self {
Self {
template_kind: TemplateKind::Full(name.into()),
context: Value::from_serialize(&context),
}
}
pub fn new_partial<T: Serialize>(
template: impl Into<String>,
block: impl Into<String>,
context: T,
) -> Self {
Self {
template_kind: TemplateKind::Partial(template.into(), block.into()),
context: Value::from_serialize(&context),
}
}
pub fn error(status: StatusCode, title: &str, error: &str, description: &str) -> Self {
let ctx = context! {
status_code => status.as_u16(),
title => title,
error => error,
description => description
};
Self {
template_kind: TemplateKind::Error(status),
context: ctx,
}
}
// Convenience methods for common errors
pub fn not_found() -> Self {
Self::error(
StatusCode::NOT_FOUND,
"Page Not Found",
"Not Found",
"The page you're looking for doesn't exist or was removed.",
)
}
pub fn server_error() -> Self {
Self::error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal Server Error",
"Internal Server Error",
"Something went wrong on our end.",
)
}
pub fn unauthorized() -> Self {
Self::error(
StatusCode::UNAUTHORIZED,
"Unauthorized",
"Access Denied",
"You need to be logged in to access this page.",
)
}
pub fn bad_request(message: &str) -> Self {
Self::error(
StatusCode::BAD_REQUEST,
"Bad Request",
"Bad Request",
message,
)
}
pub fn redirect(path: impl AsRef<str>) -> Self {
let redirect_response = axum::response::Redirect::to(path.as_ref());
Self {
template_kind: TemplateKind::Redirect(redirect_response),
context: Value::from_serialize(&()),
}
}
}
impl IntoResponse for TemplateResponse {
fn into_response(self) -> Response {
Extension(self).into_response()
}
}
// Wrapper to avoid recursion
struct TemplateStateWrapper {
state: HtmlState,
auth: AuthSessionType,
template_response: TemplateResponse,
}
impl IntoResponse for TemplateStateWrapper {
fn into_response(self) -> Response {
let templates = self.state.templates;
match &self.template_response.template_kind {
TemplateKind::Full(name) => {
render_template(name, self.template_response.context, templates)
}
TemplateKind::Partial(name, block) => {
render_block(name, block, self.template_response.context, templates)
}
TemplateKind::Error(status) => {
let html = match try_render_template(
"errors/error.html",
self.template_response.context,
templates,
) {
Ok(html_string) => Html(html_string),
Err(_) => fallback_error(),
};
(*status, html).into_response()
}
TemplateKind::Redirect(redirect) => redirect.clone().into_response(),
}
}
}
// Helper functions for rendering with error handling
fn render_template(name: &str, context: Value, templates: Arc<AutoReloader>) -> Response {
match try_render_template(name, context, templates.clone()) {
Ok(html) => Html(html).into_response(),
Err(_) => fallback_error().into_response(),
}
}
fn render_block(name: &str, block: &str, context: Value, templates: Arc<AutoReloader>) -> Response {
match try_render_block(name, block, context, templates.clone()) {
Ok(html) => Html(html).into_response(),
Err(_) => fallback_error().into_response(),
}
}
fn try_render_template(
template_name: &str,
context: Value,
templates: Arc<AutoReloader>,
) -> Result<String, ()> {
let env = templates.acquire_env().map_err(|e| {
tracing::error!("Environment error: {:?}", e);
()
})?;
let tmpl = env.get_template(template_name).map_err(|e| {
tracing::error!("Template error: {:?}", e);
()
})?;
tmpl.render(context).map_err(|e| {
tracing::error!("Render error: {:?}", e);
()
})
}
fn try_render_block(
template_name: &str,
block: &str,
context: Value,
templates: Arc<AutoReloader>,
) -> Result<String, ()> {
let env = templates.acquire_env().map_err(|e| {
tracing::error!("Environment error: {:?}", e);
()
})?;
let tmpl = env.get_template(template_name).map_err(|e| {
tracing::error!("Template error: {:?}", e);
()
})?;
let mut state = tmpl.eval_to_state(context).map_err(|e| {
tracing::error!("Eval error: {:?}", e);
()
})?;
state.render_block(block).map_err(|e| {
tracing::error!("Block render error: {:?}", e);
()
})
}
fn fallback_error() -> Html<String> {
Html(
r#"
<html>
<body>
<div class="container mx-auto p-4">
<h1 class="text-4xl text-error">Error</h1>
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
</div>
</body>
</html>
"#
.to_string(),
)
}
pub async fn with_template_response(
State(state): State<HtmlState>,
auth: AuthSessionType,
response: Response,
) -> Response {
// Clone the TemplateResponse from extensions
let template_response = response.extensions().get::<TemplateResponse>().cloned();
if let Some(template_response) = template_response {
TemplateStateWrapper {
state,
auth,
template_response,
}
.into_response()
} else {
response
}
}
// Define HtmlError
pub enum HtmlError {
AppError(AppError),
TemplateError(String),
}
// Conversion from AppError to HtmlError
impl From<AppError> for HtmlError {
fn from(err: AppError) -> Self {
HtmlError::AppError(err)
}
}
// Conversion for database error to HtmlError
impl From<surrealdb::Error> for HtmlError {
fn from(err: surrealdb::Error) -> Self {
HtmlError::AppError(AppError::from(err))
}
}
// Now implement IntoResponse for HtmlError
impl IntoResponse for HtmlError {
fn into_response(self) -> Response {
match self {
HtmlError::AppError(err) => {
let template_response = match err {
AppError::NotFound(_) => TemplateResponse::not_found(),
AppError::Auth(_) => TemplateResponse::unauthorized(),
AppError::Validation(msg) => TemplateResponse::bad_request(&msg),
_ => {
tracing::error!("Internal error: {:?}", err);
TemplateResponse::server_error()
}
};
template_response.into_response()
}
HtmlError::TemplateError(_) => TemplateResponse::server_error().into_response(),
}
}
}

View File

@@ -10,12 +10,15 @@ serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
serde_json = { workspace = true }
futures = { workspace = true }
async-openai = { workspace = true }
async-openai = "0.24.1"
tiktoken-rs = "0.6.0"
reqwest = {version = "0.12.12", features = ["charset", "json"]}
scraper = "0.22.0"
chrono = { version = "0.4.39", features = ["serde"] }
text-splitter = "0.18.1"
uuid = { version = "1.10.0", features = ["v4", "serde"] }
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }

View File

@@ -1,9 +1,5 @@
use crate::{
error::AppError,
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
retrieval::combined_knowledge_entity_retrieval,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use std::sync::Arc;
use async_openai::{
error::OpenAIError,
types::{
@@ -12,20 +8,28 @@ use async_openai::{
ResponseFormatJsonSchema,
},
};
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use composite_retrieval::retrieve_entities;
use serde_json::json;
use tracing::debug;
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
use crate::{
types::llm_enrichment_result::LLMEnrichmentResult,
utils::llm_instructions::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
};
pub struct IngressAnalyzer<'a> {
db_client: &'a SurrealDbClient,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
pub struct IngestionEnricher {
db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
}
impl<'a> IngressAnalyzer<'a> {
impl IngestionEnricher {
pub fn new(
db_client: &'a SurrealDbClient,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Self {
Self {
db_client,
@@ -39,7 +43,7 @@ impl<'a> IngressAnalyzer<'a> {
instructions: &str,
text: &str,
user_id: &str,
) -> Result<LLMGraphAnalysisResult, AppError> {
) -> Result<LLMEnrichmentResult, AppError> {
let similar_entities = self
.find_similar_entities(category, instructions, text, user_id)
.await?;
@@ -60,13 +64,7 @@ impl<'a> IngressAnalyzer<'a> {
text, category, instructions
);
combined_knowledge_entity_retrieval(
self.db_client,
self.openai_client,
&input_text,
user_id,
)
.await
retrieve_entities(&self.db_client, &self.openai_client, &input_text, user_id).await
}
fn prepare_llm_request(
@@ -120,7 +118,7 @@ impl<'a> IngressAnalyzer<'a> {
async fn perform_analysis(
&self,
request: CreateChatCompletionRequest,
) -> Result<LLMGraphAnalysisResult, AppError> {
) -> Result<LLMEnrichmentResult, AppError> {
let response = self.openai_client.chat().create(request).await?;
debug!("Received LLM response: {:?}", response);

View File

@@ -1,2 +1,4 @@
pub mod enricher;
pub mod pipeline;
pub mod types;
pub mod utils;

View File

@@ -19,12 +19,11 @@ use common::{
utils::embedding::generate_embedding,
};
use common::ingress::analysis::{
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
use crate::{
enricher::IngestionEnricher,
types::{llm_enrichment_result::LLMEnrichmentResult, to_text_content},
};
use crate::types::to_text_content;
pub struct IngestionPipeline {
db: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
@@ -109,8 +108,8 @@ impl IngestionPipeline {
async fn perform_semantic_analysis(
&self,
content: &TextContent,
) -> Result<LLMGraphAnalysisResult, AppError> {
let analyser = IngressAnalyzer::new(&self.db, &self.openai_client);
) -> Result<LLMEnrichmentResult, AppError> {
let analyser = IngestionEnricher::new(self.db.clone(), self.openai_client.clone());
analyser
.analyze_content(
&content.category,

View File

@@ -4,7 +4,7 @@ use chrono::Utc;
use serde::{Deserialize, Serialize};
use tokio::task;
use crate::{
use common::{
error::AppError,
storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
@@ -14,7 +14,7 @@ use crate::{
};
use futures::future::try_join_all;
use super::graph_mapper::GraphMapper; // For future parallelization
use crate::utils::GraphMapper;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMKnowledgeEntity {
@@ -35,12 +35,12 @@ pub struct LLMRelationship {
/// Represents the entire graph analysis result from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMGraphAnalysisResult {
pub struct LLMEnrichmentResult {
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
pub relationships: Vec<LLMRelationship>,
}
impl LLMGraphAnalysisResult {
impl LLMEnrichmentResult {
/// Converts the LLM graph analysis result into database entities and relationships.
///
/// # Arguments

View File

@@ -1,3 +1,5 @@
pub mod llm_enrichment_result;
use std::{sync::Arc, time::Duration};
use async_openai::types::{

View File

@@ -1,3 +1,5 @@
pub mod llm_instructions;
use std::collections::HashMap;
use uuid::Uuid;

View File

@@ -11,40 +11,12 @@ thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
axum = { workspace = true }
surrealdb = { workspace = true }
futures = { workspace = true }
async-openai = { workspace = true }
async-openai = "0.24.1"
async-stream = "0.3.6"
axum-htmx = "0.6.0"
axum_session = "0.14.4"
axum_session_auth = "0.14.1"
axum_session_surreal = "0.2.1"
axum_typed_multipart = "0.12.1"
chrono = { version = "0.4.39", features = ["serde"] }
chrono-tz = "0.10.1"
config = "0.15.4"
futures = "0.3.31"
json-stream-parser = "0.1.4"
lettre = { version = "0.11.11", features = ["rustls-tls"] }
mime = "0.3.17"
mime_guess = "2.0.5"
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
minijinja-autoreload = "2.5.0"
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
mockall = "0.13.0"
plotly = "0.12.1"
reqwest = {version = "0.12.12", features = ["charset", "json"]}
scraper = "0.22.0"
sha2 = "0.10.8"
surrealdb = "2.0.4"
tempfile = "3.12.0"
text-splitter = "0.18.1"
tiktoken-rs = "0.6.0"
tower-http = { version = "0.6.2", features = ["fs"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
url = { version = "2.5.2", features = ["serde"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
# Reference to api-router
ingestion-pipeline = { path = "../ingestion-pipeline" }
api-router = { path = "../api-router" }
html-router = { path = "../html-router" }