refactor: better separation of dependencies to crates

node stuff to html crate only
This commit is contained in:
Per Stark
2025-04-04 12:50:38 +02:00
parent 20fc43638b
commit 5bc48fb30b
160 changed files with 231 additions and 337 deletions

View File

@@ -0,0 +1,43 @@
use common::storage::db::SurrealDbClient;
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
use common::{create_template_engine, storage::db::ProvidesDb};
use std::sync::Arc;
use tracing::debug;
use crate::{OpenAIClientType, SessionStoreType};
#[derive(Clone)]
pub struct HtmlState {
pub db: Arc<SurrealDbClient>,
pub openai_client: Arc<OpenAIClientType>,
pub templates: Arc<TemplateEngine>,
pub session_store: Arc<SessionStoreType>,
}
impl HtmlState {
pub fn new_with_resources(
db: Arc<SurrealDbClient>,
openai_client: Arc<OpenAIClientType>,
session_store: Arc<SessionStoreType>,
) -> Result<Self, Box<dyn std::error::Error>> {
let template_engine = create_template_engine!("templates");
debug!("Template engine created for html_router.");
Ok(Self {
db,
openai_client,
session_store,
templates: Arc::new(template_engine),
})
}
}
impl ProvidesDb for HtmlState {
fn db(&self) -> &Arc<SurrealDbClient> {
&self.db
}
}
impl ProvidesTemplateEngine for HtmlState {
fn template_engine(&self) -> &Arc<TemplateEngine> {
&self.templates
}
}

39
html-router/src/lib.rs Normal file
View File

@@ -0,0 +1,39 @@
pub mod html_state;
pub mod middlewares;
pub mod router_factory;
pub mod routes;
use axum::{extract::FromRef, Router};
use axum_session::{Session, SessionStore};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use common::storage::types::user::User;
use html_state::HtmlState;
use router_factory::RouterFactory;
use surrealdb::{engine::any::Any, Surreal};
pub type AuthSessionType = AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>;
pub type SessionType = Session<SessionSurrealPool<Any>>;
pub type SessionStoreType = SessionStore<SessionSurrealPool<Any>>;
pub type OpenAIClientType = async_openai::Client<async_openai::config::OpenAIConfig>;
/// Html routes
pub fn html_routes<S>(app_state: &HtmlState) -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
RouterFactory::new(app_state)
.add_public_routes(routes::index::public_router())
.add_public_routes(routes::auth::router())
.with_public_assets("/assets", "assets/")
.add_protected_routes(routes::index::protected_router())
.add_protected_routes(routes::search::router())
.add_protected_routes(routes::account::router())
.add_protected_routes(routes::admin::router())
.add_protected_routes(routes::chat::router())
.add_protected_routes(routes::content::router())
.add_protected_routes(routes::knowledge::router())
.add_protected_routes(routes::ingestion::router())
.build()
}

View File

@@ -0,0 +1,30 @@
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use common::storage::{db::ProvidesDb, types::analytics::Analytics};
use crate::SessionType;
/// Middleware to count unique visitors and page loads
pub async fn analytics_middleware<S>(
State(state): State<S>,
session: SessionType,
request: Request,
next: Next,
) -> Response
where
S: ProvidesDb + Clone + Send + Sync + 'static,
{
let path = request.uri().path();
if !path.starts_with("/assets") && !path.contains('.') {
if !session.get::<bool>("counted_visitor").unwrap_or(false) {
let _ = Analytics::increment_visitors(state.db()).await;
session.set("counted_visitor", true);
}
let _ = Analytics::increment_page_loads(state.db()).await;
}
next.run(request).await
}

View File

@@ -0,0 +1,50 @@
use axum::{
async_trait,
extract::{FromRequestParts, Request},
http::request::Parts,
middleware::Next,
response::{IntoResponse, Response},
};
use common::storage::types::user::User;
use crate::AuthSessionType;
use super::response_middleware::TemplateResponse;
#[derive(Debug, Clone)]
pub struct RequireUser(pub User);
// Implement FromRequestParts for RequireUser
#[async_trait]
impl<S> FromRequestParts<S> for RequireUser
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<User>()
.cloned()
.map(RequireUser)
.ok_or_else(|| TemplateResponse::redirect("/signin").into_response())
}
}
// Auth middleware that adds the user to extensions
pub async fn require_auth(auth: AuthSessionType, mut request: Request, next: Next) -> Response {
// Check if user is authenticated
match auth.current_user {
Some(user) => {
// Add user to request extensions
request.extensions_mut().insert(user);
// Continue to the handler
next.run(request).await
}
None => {
// Redirect to login
TemplateResponse::redirect("/signin").into_response()
}
}
}

View File

@@ -0,0 +1,3 @@
pub mod analytics_middleware;
pub mod auth_middleware;
pub mod response_middleware;

View File

@@ -0,0 +1,208 @@
use axum::{
extract::State,
http::StatusCode,
response::{Html, IntoResponse, Response},
Extension,
};
use common::{error::AppError, utils::template_engine::ProvidesTemplateEngine};
use minijinja::{context, Value};
use serde::Serialize;
use tracing::error;
#[derive(Clone)]
pub enum TemplateKind {
Full(String),
Partial(String, String),
Error(StatusCode),
Redirect(String),
}
#[derive(Clone)]
pub struct TemplateResponse {
template_kind: TemplateKind,
context: Value,
}
impl TemplateResponse {
pub fn new_template<T: Serialize>(name: impl Into<String>, context: T) -> Self {
Self {
template_kind: TemplateKind::Full(name.into()),
context: Value::from_serialize(&context),
}
}
pub fn new_partial<T: Serialize>(
template: impl Into<String>,
block: impl Into<String>,
context: T,
) -> Self {
Self {
template_kind: TemplateKind::Partial(template.into(), block.into()),
context: Value::from_serialize(&context),
}
}
pub fn error(status: StatusCode, title: &str, error: &str, description: &str) -> Self {
let ctx = context! {
status_code => status.as_u16(),
title => title,
error => error,
description => description
};
Self {
template_kind: TemplateKind::Error(status),
context: ctx,
}
}
pub fn not_found() -> Self {
Self::error(
StatusCode::NOT_FOUND,
"Page Not Found",
"Not Found",
"The page you're looking for doesn't exist or was removed.",
)
}
pub fn server_error() -> Self {
Self::error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal Server Error",
"Internal Server Error",
"Something went wrong on our end.",
)
}
pub fn unauthorized() -> Self {
Self::error(
StatusCode::UNAUTHORIZED,
"Unauthorized",
"Access Denied",
"You need to be logged in to access this page.",
)
}
pub fn bad_request(message: &str) -> Self {
Self::error(
StatusCode::BAD_REQUEST,
"Bad Request",
"Bad Request",
message,
)
}
pub fn redirect(path: impl Into<String>) -> Self {
Self {
template_kind: TemplateKind::Redirect(path.into()),
context: Value::from_serialize(()),
}
}
}
impl IntoResponse for TemplateResponse {
fn into_response(self) -> Response {
Extension(self).into_response()
}
}
pub async fn with_template_response<S>(State(state): State<S>, response: Response) -> Response
where
S: ProvidesTemplateEngine + Clone + Send + Sync + 'static,
{
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
let template_engine = state.template_engine();
match &template_response.template_kind {
TemplateKind::Full(name) => {
match template_engine.render(name, &template_response.context) {
Ok(html) => Html(html).into_response(),
Err(e) => {
error!("Failed to render template '{}': {:?}", name, e);
(StatusCode::INTERNAL_SERVER_ERROR, fallback_error()).into_response()
}
}
}
TemplateKind::Partial(template, block) => {
match template_engine.render_block(template, block, &template_response.context) {
Ok(html) => Html(html).into_response(),
Err(e) => {
error!("Failed to render block '{}/{}': {:?}", template, block, e);
(StatusCode::INTERNAL_SERVER_ERROR, fallback_error()).into_response()
}
}
}
TemplateKind::Error(status) => {
match template_engine.render("errors/error.html", &template_response.context) {
Ok(html) => (*status, Html(html)).into_response(),
Err(e) => {
error!("Failed to render error template: {:?}", e);
(*status, fallback_error()).into_response()
}
}
}
TemplateKind::Redirect(path) => {
(StatusCode::OK, [(axum_htmx::HX_REDIRECT, path.clone())], "").into_response()
}
}
} else {
response
}
}
#[derive(Debug)]
pub enum HtmlError {
AppError(AppError),
TemplateError(String),
}
impl From<AppError> for HtmlError {
fn from(err: AppError) -> Self {
HtmlError::AppError(err)
}
}
impl From<surrealdb::Error> for HtmlError {
fn from(err: surrealdb::Error) -> Self {
HtmlError::AppError(AppError::from(err))
}
}
impl From<minijinja::Error> for HtmlError {
fn from(err: minijinja::Error) -> Self {
HtmlError::TemplateError(err.to_string())
}
}
impl IntoResponse for HtmlError {
fn into_response(self) -> Response {
match self {
HtmlError::AppError(err) => match err {
AppError::NotFound(_) => TemplateResponse::not_found().into_response(),
AppError::Auth(_) => TemplateResponse::unauthorized().into_response(),
AppError::Validation(msg) => TemplateResponse::bad_request(&msg).into_response(),
_ => {
error!("Internal error: {:?}", err);
TemplateResponse::server_error().into_response()
}
},
HtmlError::TemplateError(err) => {
error!("Template error: {}", err);
TemplateResponse::server_error().into_response()
}
}
}
}
fn fallback_error() -> String {
r#"
<html>
<body>
<div class="container mx-auto p-4">
<h1 class="text-4xl text-error">Error</h1>
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
</div>
</body>
</html>
"#
.to_string()
}

View File

@@ -0,0 +1,189 @@
use axum::{
extract::FromRef,
middleware::{from_fn_with_state, map_response_with_state},
Router,
};
use axum_session::SessionLayer;
use axum_session_auth::{AuthConfig, AuthSessionLayer};
use axum_session_surreal::SessionSurrealPool;
use common::storage::types::user::User;
use surrealdb::{engine::any::Any, Surreal};
use crate::{
html_state::HtmlState,
middlewares::{
analytics_middleware::analytics_middleware, auth_middleware::require_auth,
response_middleware::with_template_response,
},
};
#[macro_export]
macro_rules! create_asset_service {
// Takes the relative path to the asset directory
($relative_path:expr) => {{
#[cfg(debug_assertions)]
{
let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let assets_path = crate_dir.join($relative_path);
tracing::debug!("Assets: Serving from filesystem: {:?}", assets_path);
tower_http::services::ServeDir::new(assets_path)
}
#[cfg(not(debug_assertions))]
{
tracing::debug!("Assets: Serving embedded directory");
static ASSETS_DIR: include_dir::Dir<'static> =
include_dir::include_dir!("$CARGO_MANIFEST_DIR/assets");
tower_serve_static::ServeDir::new(&ASSETS_DIR)
}
}};
}
pub type MiddleWareVecType<S> = Vec<Box<dyn FnOnce(Router<S>) -> Router<S> + Send>>;
pub struct RouterFactory<S> {
app_state: HtmlState,
public_routers: Vec<Router<S>>,
protected_routers: Vec<Router<S>>,
nested_routes: Vec<(String, Router<S>)>,
nested_protected_routes: Vec<(String, Router<S>)>,
custom_middleware: MiddleWareVecType<S>,
public_assets_config: Option<AssetsConfig>,
}
struct AssetsConfig {
path: String, // URL path for assets
directory: String, // Directory on disk
}
impl<S> RouterFactory<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
pub fn new(app_state: &HtmlState) -> Self {
Self {
app_state: app_state.to_owned(),
public_routers: Vec::new(),
protected_routers: Vec::new(),
nested_routes: Vec::new(),
nested_protected_routes: Vec::new(),
custom_middleware: Vec::new(),
public_assets_config: None,
}
}
// Add a serving of assets
pub fn with_public_assets(mut self, path: &str, directory: &str) -> Self {
self.public_assets_config = Some(AssetsConfig {
path: path.to_string(),
directory: directory.to_string(),
});
self
}
// Add a public router that will be merged at the root level
pub fn add_public_routes(mut self, routes: Router<S>) -> Self {
self.public_routers.push(routes);
self
}
// Add a protected router that will be merged at the root level
pub fn add_protected_routes(mut self, routes: Router<S>) -> Self {
self.protected_routers.push(routes);
self
}
// Nest a public router under a path prefix
pub fn nest_public_routes(mut self, path: &str, routes: Router<S>) -> Self {
self.nested_routes.push((path.to_string(), routes));
self
}
// Nest a protected router under a path prefix
pub fn nest_protected_routes(mut self, path: &str, routes: Router<S>) -> Self {
self.nested_protected_routes
.push((path.to_string(), routes));
self
}
// Add custom middleware to be applied before the standard ones
pub fn with_middleware<F>(mut self, middleware_fn: F) -> Self
where
F: FnOnce(Router<S>) -> Router<S> + Send + 'static,
{
self.custom_middleware.push(Box::new(middleware_fn));
self
}
pub fn build(self) -> Router<S> {
// Start with an empty router
let mut public_router = Router::new();
// Merge all public routers
for router in self.public_routers {
public_router = public_router.merge(router);
}
// Add nested public routes
for (path, router) in self.nested_routes {
public_router = public_router.nest(&path, router);
}
// Add public assets to public router
if let Some(assets_config) = self.public_assets_config {
// Call the macro using the stored relative directory path
let asset_service = create_asset_service!(&assets_config.directory);
// Nest the resulting service under the stored URL path
public_router = public_router.nest_service(&assets_config.path, asset_service);
}
// Start with an empty protected router
let mut protected_router = Router::new();
// Check if there are any protected routers
let has_protected_routes =
!self.protected_routers.is_empty() || !self.nested_protected_routes.is_empty();
// Merge root-level protected routers
for router in self.protected_routers {
protected_router = protected_router.merge(router);
}
// Nest protected routers
for (path, router) in self.nested_protected_routes {
protected_router = protected_router.nest(&path, router);
}
// Apply auth middleware
if has_protected_routes {
protected_router = protected_router
.route_layer(from_fn_with_state(self.app_state.clone(), require_auth));
}
// Combine public and protected routes
let mut router = Router::new().merge(public_router).merge(protected_router);
// Apply custom middleware in order they were added
for middleware_fn in self.custom_middleware {
router = middleware_fn(router);
}
// Apply common middleware
router
.layer(from_fn_with_state(
self.app_state.clone(),
analytics_middleware::<HtmlState>,
))
.layer(map_response_with_state(
self.app_state.clone(),
with_template_response::<HtmlState>,
))
.layer(
AuthSessionLayer::<User, String, SessionSurrealPool<Any>, Surreal<Any>>::new(Some(
self.app_state.db.client.clone(),
))
.with_config(AuthConfig::<String>::default()),
)
.layer(SessionLayer::new((*self.app_state.session_store).clone()))
}
}

View File

@@ -0,0 +1,143 @@
use axum::{extract::State, response::IntoResponse, Form};
use chrono_tz::TZ_VARIANTS;
use serde::{Deserialize, Serialize};
use crate::{
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
AuthSessionType,
};
use common::storage::types::user::User;
use crate::html_state::HtmlState;
#[derive(Serialize)]
pub struct AccountPageData {
user: User,
timezones: Vec<String>,
}
pub async fn show_account_page(
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
Ok(TemplateResponse::new_template(
"auth/account_settings.html",
AccountPageData { user, timezones },
))
}
pub async fn set_api_key(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
// Generate and set the API key
let api_key = User::set_api_key(&user.id, &state.db).await?;
// Clear the cache so new requests have access to the user with api key
auth.cache_clear_user(user.id.to_string());
// Update the user's API key
let updated_user = User {
api_key: Some(api_key),
..user.clone()
};
// Render the API key section block
Ok(TemplateResponse::new_partial(
"auth/account_settings.html",
"api_key_section",
AccountPageData {
user: updated_user,
timezones: vec![],
},
))
}
pub async fn delete_account(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
state.db.delete_item::<User>(&user.id).await?;
auth.logout_user();
auth.session.destroy();
Ok(TemplateResponse::redirect("/"))
}
#[derive(Deserialize)]
pub struct UpdateTimezoneForm {
timezone: String,
}
pub async fn update_timezone(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
Form(form): Form<UpdateTimezoneForm>,
) -> Result<impl IntoResponse, HtmlError> {
User::update_timezone(&user.id, &form.timezone, &state.db).await?;
// Clear the cache
auth.cache_clear_user(user.id.to_string());
// Update the user's API key
let updated_user = User {
timezone: form.timezone,
..user.clone()
};
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
// Render the API key section block
Ok(TemplateResponse::new_partial(
"auth/account_settings.html",
"timezone_section",
AccountPageData {
user: updated_user,
timezones,
},
))
}
pub async fn show_change_password(
RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
Ok(TemplateResponse::new_template(
"auth/change_password_form.html",
(),
))
}
#[derive(Deserialize)]
pub struct NewPasswordForm {
old_password: String,
new_password: String,
}
pub async fn change_password(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
Form(form): Form<NewPasswordForm>,
) -> Result<impl IntoResponse, HtmlError> {
// Authenticate to make sure the password matches
let authenticated_user = User::authenticate(&user.email, &form.old_password, &state.db).await?;
User::patch_password(&authenticated_user.email, &form.new_password, &state.db).await?;
auth.cache_clear_user(user.id);
Ok(TemplateResponse::new_partial(
"auth/account_settings.html",
"change_password_section",
(),
))
}

View File

@@ -0,0 +1,24 @@
mod handlers;
use axum::{
extract::FromRef,
routing::{delete, get, patch, post},
Router,
};
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/account", get(handlers::show_account_page))
.route("/set-api-key", post(handlers::set_api_key))
.route("/update-timezone", patch(handlers::update_timezone))
.route(
"/change-password",
get(handlers::show_change_password).patch(handlers::change_password),
)
.route("/delete-account", delete(handlers::delete_account))
}

View File

@@ -0,0 +1,259 @@
use axum::{extract::State, response::IntoResponse, Form};
use serde::{Deserialize, Serialize};
use common::storage::types::{
analytics::Analytics,
system_prompts::{DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT, DEFAULT_QUERY_SYSTEM_PROMPT},
system_settings::SystemSettings,
user::User,
};
use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
};
#[derive(Serialize)]
pub struct AdminPanelData {
user: User,
settings: SystemSettings,
analytics: Analytics,
users: i64,
default_query_prompt: String,
}
pub async fn show_admin_panel(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let settings = SystemSettings::get_current(&state.db).await?;
let analytics = Analytics::get_current(&state.db).await?;
let users_count = Analytics::get_users_amount(&state.db).await?;
Ok(TemplateResponse::new_template(
"auth/admin_panel.html",
AdminPanelData {
user,
settings,
analytics,
users: users_count,
default_query_prompt: DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
},
))
}
fn checkbox_to_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
D: serde::Deserializer<'de>,
{
match String::deserialize(deserializer) {
Ok(string) => Ok(string == "on"),
Err(_) => Ok(false),
}
}
#[derive(Deserialize)]
pub struct RegistrationToggleInput {
#[serde(default)]
#[serde(deserialize_with = "checkbox_to_bool")]
registration_open: bool,
}
#[derive(Serialize)]
pub struct RegistrationToggleData {
settings: SystemSettings,
}
pub async fn toggle_registration_status(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<RegistrationToggleInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
};
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {
registrations_enabled: input.registration_open,
..current_settings.clone()
};
SystemSettings::update(&state.db, new_settings.clone()).await?;
Ok(TemplateResponse::new_partial(
"auth/admin_panel.html",
"registration_status_input",
RegistrationToggleData {
settings: new_settings,
},
))
}
#[derive(Deserialize)]
pub struct ModelSettingsInput {
query_model: String,
processing_model: String,
}
#[derive(Serialize)]
pub struct ModelSettingsData {
settings: SystemSettings,
}
pub async fn update_model_settings(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<ModelSettingsInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
};
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {
query_model: input.query_model,
processing_model: input.processing_model,
..current_settings
};
SystemSettings::update(&state.db, new_settings.clone()).await?;
Ok(TemplateResponse::new_partial(
"auth/admin_panel.html",
"model_settings_form",
ModelSettingsData {
settings: new_settings,
},
))
}
#[derive(Serialize)]
pub struct SystemPromptEditData {
settings: SystemSettings,
default_query_prompt: String,
}
pub async fn show_edit_system_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
};
let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template(
"admin/edit_query_prompt_modal.html",
SystemPromptEditData {
settings,
default_query_prompt: DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
},
))
}
#[derive(Deserialize)]
pub struct SystemPromptUpdateInput {
query_system_prompt: String,
}
#[derive(Serialize)]
pub struct SystemPromptSectionData {
settings: SystemSettings,
}
pub async fn patch_query_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<SystemPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
};
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {
query_system_prompt: input.query_system_prompt,
..current_settings.clone()
};
SystemSettings::update(&state.db, new_settings.clone()).await?;
Ok(TemplateResponse::new_partial(
"auth/admin_panel.html",
"system_prompt_section",
SystemPromptSectionData {
settings: new_settings,
},
))
}
#[derive(Serialize)]
pub struct IngestionPromptEditData {
settings: SystemSettings,
default_ingestion_prompt: String,
}
pub async fn show_edit_ingestion_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
};
let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template(
"admin/edit_ingestion_prompt_modal.html",
IngestionPromptEditData {
settings,
default_ingestion_prompt: DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(),
},
))
}
#[derive(Deserialize)]
pub struct IngestionPromptUpdateInput {
ingestion_system_prompt: String,
}
pub async fn patch_ingestion_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<IngestionPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
};
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {
ingestion_system_prompt: input.ingestion_system_prompt,
..current_settings.clone()
};
SystemSettings::update(&state.db, new_settings.clone()).await?;
Ok(TemplateResponse::new_partial(
"auth/admin_panel.html",
"system_prompt_section",
SystemPromptSectionData {
settings: new_settings,
},
))
}

View File

@@ -0,0 +1,27 @@
mod handlers;
use axum::{
extract::FromRef,
routing::{get, patch},
Router,
};
use handlers::{
patch_ingestion_prompt, patch_query_prompt, show_admin_panel, show_edit_ingestion_prompt,
show_edit_system_prompt, toggle_registration_status, update_model_settings,
};
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/admin", get(show_admin_panel))
.route("/toggle-registrations", patch(toggle_registration_status))
.route("/update-model-settings", patch(update_model_settings))
.route("/edit-query-prompt", get(show_edit_system_prompt))
.route("/update-query-prompt", patch(patch_query_prompt))
.route("/edit-ingestion-prompt", get(show_edit_ingestion_prompt))
.route("/update-ingestion-prompt", patch(patch_ingestion_prompt))
}

View File

@@ -0,0 +1,24 @@
pub mod signin;
pub mod signout;
pub mod signup;
use axum::{extract::FromRef, routing::get, Router};
use signin::{authenticate_user, show_signin_form};
use signout::sign_out_user;
use signup::{process_signup_and_show_verification, show_signup_form};
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/signout", get(sign_out_user))
.route("/signin", get(show_signin_form).post(authenticate_user))
.route(
"/signup",
get(show_signup_form).post(process_signup_and_show_verification),
)
}

View File

@@ -0,0 +1,59 @@
use axum::{
extract::State,
response::{Html, IntoResponse},
Form,
};
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
use crate::{
html_state::HtmlState,
middlewares::response_middleware::{HtmlError, TemplateResponse},
AuthSessionType,
};
use common::storage::types::user::User;
#[derive(Deserialize, Serialize)]
pub struct SignupParams {
pub email: String,
pub password: String,
pub remember_me: Option<String>,
}
pub async fn show_signin_form(
auth: AuthSessionType,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() {
return Ok(TemplateResponse::redirect("/"));
}
match boosted {
true => Ok(TemplateResponse::new_partial(
"auth/signin_base.html",
"body",
(),
)),
false => Ok(TemplateResponse::new_template("auth/signin_base.html", ())),
}
}
pub async fn authenticate_user(
State(state): State<HtmlState>,
auth: AuthSessionType,
Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match User::authenticate(&form.email, &form.password, &state.db).await {
Ok(user) => user,
Err(_) => {
return Ok(Html("<p>Incorrect email or password </p>").into_response());
}
};
auth.login_user(user.id);
if form.remember_me.is_some_and(|string| string == *"on") {
auth.remember_user(true);
}
Ok(TemplateResponse::redirect("/").into_response())
}

View File

@@ -0,0 +1,16 @@
use axum::response::IntoResponse;
use crate::{
middlewares::response_middleware::{HtmlError, TemplateResponse},
AuthSessionType,
};
pub async fn sign_out_user(auth: AuthSessionType) -> Result<impl IntoResponse, HtmlError> {
if !auth.is_authenticated() {
return Ok(TemplateResponse::redirect("/"));
}
auth.logout_user();
Ok(TemplateResponse::redirect("/"))
}

View File

@@ -0,0 +1,58 @@
use axum::{
extract::State,
response::{Html, IntoResponse},
Form,
};
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
use common::storage::types::user::User;
use crate::{
html_state::HtmlState,
middlewares::response_middleware::{HtmlError, TemplateResponse},
AuthSessionType,
};
#[derive(Deserialize, Serialize)]
pub struct SignupParams {
pub email: String,
pub password: String,
pub timezone: String,
}
pub async fn show_signup_form(
auth: AuthSessionType,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() {
return Ok(TemplateResponse::redirect("/"));
}
match boosted {
true => Ok(TemplateResponse::new_partial(
"auth/signup_form.html",
"body",
(),
)),
false => Ok(TemplateResponse::new_template("auth/signup_form.html", ())),
}
}
pub async fn process_signup_and_show_verification(
State(state): State<HtmlState>,
auth: AuthSessionType,
Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match User::create_new(form.email, form.password, &state.db, form.timezone).await {
Ok(user) => user,
Err(e) => {
tracing::error!("{:?}", e);
return Ok(Html(format!("<p>{}</p>", e)).into_response());
}
};
auth.login_user(user.id);
Ok(TemplateResponse::redirect("/").into_response())
}

View File

@@ -0,0 +1,226 @@
use axum::{
extract::{Path, State},
http::HeaderValue,
response::{IntoResponse, Redirect},
Form,
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::{Deserialize, Serialize};
use surrealdb::{engine::any::Any, Surreal};
use common::{
error::AppError,
storage::types::{
conversation::Conversation,
message::{Message, MessageRole},
user::User,
},
};
use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
};
#[derive(Debug, Deserialize)]
pub struct ChatStartParams {
user_query: String,
llm_response: String,
#[serde(deserialize_with = "deserialize_references")]
references: Vec<String>,
}
// Custom deserializer function
fn deserialize_references<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(serde::de::Error::custom)
}
#[derive(Serialize)]
pub struct ChatPageData {
user: User,
history: Vec<Message>,
conversation: Option<Conversation>,
conversation_archive: Vec<Conversation>,
}
pub async fn show_initialized_chat(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<ChatStartParams>,
) -> Result<impl IntoResponse, HtmlError> {
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
let user_message = Message::new(
conversation.id.to_string(),
MessageRole::User,
form.user_query,
None,
);
let ai_message = Message::new(
conversation.id.to_string(),
MessageRole::AI,
form.llm_response,
Some(form.references),
);
state.db.store_item(conversation.clone()).await?;
state.db.store_item(ai_message.clone()).await?;
state.db.store_item(user_message.clone()).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let messages = vec![user_message, ai_message];
let mut response = TemplateResponse::new_template(
"chat/base.html",
ChatPageData {
history: messages,
user,
conversation_archive,
conversation: Some(conversation.clone()),
},
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
Ok(response)
}
pub async fn show_chat_base(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
"chat/base.html",
ChatPageData {
history: vec![],
user,
conversation_archive,
conversation: None,
},
))
}
#[derive(Deserialize)]
pub struct NewMessageForm {
content: String,
}
pub async fn show_existing_chat(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let (conversation, messages) =
Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db)
.await?;
Ok(TemplateResponse::new_template(
"chat/base.html",
ChatPageData {
history: messages,
user,
conversation: Some(conversation.clone()),
conversation_archive,
},
))
}
pub async fn new_user_message(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
let conversation: Conversation = state
.db
.get_item(&conversation_id)
.await?
.ok_or_else(|| AppError::NotFound("Conversation was not found".into()))?;
if conversation.user_id != user.id {
return Ok(TemplateResponse::unauthorized().into_response());
};
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
state.db.store_item(user_message.clone()).await?;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let mut response = TemplateResponse::new_template(
"chat/streaming_response.html",
SSEResponseInitData { user_message },
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
Ok(response)
}
pub async fn new_chat_user_message(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let conversation = Conversation::new(user.id, "New chat".to_string());
let user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
form.content,
None,
);
state.db.store_item(conversation.clone()).await?;
state.db.store_item(user_message.clone()).await?;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
conversation: Conversation,
}
let mut response = TemplateResponse::new_template(
"chat/new_chat_first_response.html",
SSEResponseInitData {
user_message,
conversation: conversation.clone(),
},
)
.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
Ok(response)
}

View File

@@ -0,0 +1,355 @@
use std::{pin::Pin, sync::Arc, time::Duration};
use async_stream::stream;
use axum::{
extract::{Query, State},
response::{
sse::{Event, KeepAlive},
Sse,
},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use composite_retrieval::{
answer_retrieval::{
create_chat_request, create_user_message_with_history, format_entities_json,
LLMResponseFormat,
},
retrieve_entities,
};
use futures::{
stream::{self, once},
Stream, StreamExt, TryStreamExt,
};
use json_stream_parser::JsonStreamParser;
use minijinja::Value;
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use surrealdb::{engine::any::Any, Surreal};
use tokio::sync::{mpsc::channel, Mutex};
use tracing::{debug, error};
use common::storage::{
db::SurrealDbClient,
types::{
conversation::Conversation,
message::{Message, MessageRole},
system_settings::SystemSettings,
user::User,
},
};
use crate::html_state::HtmlState;
// Error handling function
fn create_error_stream(
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
}
// Helper function to get message and user
async fn get_message_and_user(
db: &SurrealDbClient,
current_user: Option<User>,
message_id: &str,
) -> Result<
(Message, User, Conversation, Vec<Message>),
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
> {
// Check authentication
let user = match current_user {
Some(user) => user,
None => {
return Err(Sse::new(create_error_stream(
"You must be signed in to use this feature",
)))
}
};
// Retrieve message
let message = match db.get_item::<Message>(message_id).await {
Ok(Some(message)) => message,
Ok(None) => {
return Err(Sse::new(create_error_stream(
"Message not found: the specified message does not exist",
)))
}
Err(e) => {
error!("Database error retrieving message {}: {:?}", message_id, e);
return Err(Sse::new(create_error_stream(
"Failed to retrieve message: database error",
)));
}
};
// Get conversation history
let (conversation, mut history) =
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
{
Err(e) => {
error!("Database error retrieving message {}: {:?}", message_id, e);
return Err(Sse::new(create_error_stream(
"Failed to retrieve message: database error",
)));
}
Ok((conversation, history)) => (conversation, history),
};
// Remove the last message, its the same as the message
history.pop();
Ok((message, user, conversation, history))
}
#[derive(Deserialize)]
pub struct QueryParams {
message_id: String,
}
pub async fn get_response_stream(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
// 1. Authentication and initial data validation
let (user_message, user, _conversation, history) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
Ok((user_message, user, conversation, history)) => {
(user_message, user, conversation, history)
}
Err(error_stream) => return error_stream,
};
// 2. Retrieve knowledge entities
let entities = match retrieve_entities(
&state.db,
&state.openai_client,
&user_message.content,
&user.id,
)
.await
{
Ok(entities) => entities,
Err(_e) => {
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
}
};
// 3. Create the OpenAI request
let entities_json = format_entities_json(&entities);
let formatted_user_message =
create_user_message_with_history(&entities_json, &history, &user_message.content);
let settings = match SystemSettings::get_current(&state.db).await {
Ok(s) => s,
Err(_) => {
return Sse::new(create_error_stream("Failed to retrieve system settings"));
}
};
let request = match create_chat_request(formatted_user_message, &settings) {
Ok(req) => req,
Err(..) => {
return Sse::new(create_error_stream("Failed to create chat request"));
}
};
// 4. Set up the OpenAI stream
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return Sse::new(create_error_stream("Failed to create OpenAI stream"));
}
};
// 5. Create channel for collecting complete response
let (tx, mut rx) = channel::<String>(1000);
let tx_clone = tx.clone();
let (tx_final, mut rx_final) = channel::<Message>(1);
// 6. Set up the collection task for DB storage
let db_client = state.db.clone();
tokio::spawn(async move {
drop(tx); // Close sender when no longer needed
// Collect full response
let mut full_json = String::new();
while let Some(chunk) = rx.recv().await {
full_json.push_str(&chunk);
}
// Try to extract structured data
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let references: Vec<String> = response
.references
.into_iter()
.map(|r| r.reference)
.collect();
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
response.answer,
Some(references),
);
let _ = tx_final.send(ai_message.clone()).await;
match db_client.store_item(ai_message).await {
Ok(_) => debug!("Successfully stored AI message with references"),
Err(e) => error!("Failed to store AI message: {:?}", e),
}
} else {
error!("Failed to parse LLM response as structured format");
// Fallback - store raw response
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
full_json,
None,
);
let _ = db_client.store_item(ai_message).await;
}
});
// Create a shared state for tracking the JSON parsing
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
// 7. Create the response event stream
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx_clone.clone();
let json_state = json_state.clone();
stream! {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
// Always send raw content to storage
let _ = tx_storage.send(content.clone()).await;
// Process through JSON parser
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
// If display_content is empty, don't yield anything
}
// If content is empty, don't yield anything
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {}", e)));
}
}
}
})
.flatten()
.chain(stream::once(async move {
if let Some(message) = rx_final.recv().await {
// Don't send any event if references is empty
if message.references.as_ref().is_some_and(|x| x.is_empty()) {
return Ok(Event::default().event("empty")); // This event won't be sent
}
// Prepare data for template
#[derive(Serialize)]
struct ReferenceData {
message: Message,
}
// Render template with references
match state.templates.render(
"chat/reference_list.html",
&Value::from_serialize(ReferenceData { message }),
) {
Ok(html) => {
// Return the rendered HTML
Ok(Event::default().event("references").data(html))
}
Err(_) => {
// Handle template rendering error
Ok(Event::default()
.event("error")
.data("Failed to render references"))
}
}
} else {
// Handle case where no references were received
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}));
Sse::new(event_stream.boxed()).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
struct StreamParserState {
parser: JsonStreamParser,
last_answer_content: String,
in_answer_field: bool,
}
impl StreamParserState {
fn new() -> Self {
Self {
parser: JsonStreamParser::new(),
last_answer_content: String::new(),
in_answer_field: false,
}
}
fn process_chunk(&mut self, chunk: &str) -> String {
// Feed all characters into the parser
for c in chunk.chars() {
let _ = self.parser.add_char(c);
}
// Get the current state of the JSON
let json = self.parser.get_result();
// Check if we're in the answer field
if let Some(obj) = json.as_object() {
if let Some(answer) = obj.get("answer") {
self.in_answer_field = true;
// Get current answer content
let current_content = answer.as_str().unwrap_or_default().to_string();
// Calculate difference to send only new content
if current_content.len() > self.last_answer_content.len() {
let new_content = current_content[self.last_answer_content.len()..].to_string();
self.last_answer_content = current_content;
return new_content;
}
}
}
// No new content to return
String::new()
}
}

View File

@@ -0,0 +1,30 @@
mod chat_handlers;
mod message_response_stream;
mod references;
use axum::{
extract::FromRef,
routing::{get, post},
Router,
};
use chat_handlers::{
new_chat_user_message, new_user_message, show_chat_base, show_existing_chat,
show_initialized_chat,
};
use message_response_stream::get_response_stream;
use references::show_reference_tooltip;
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/chat", get(show_chat_base).post(new_chat_user_message))
.route("/chat/:id", get(show_existing_chat).post(new_user_message))
.route("/initialized-chat", post(show_initialized_chat))
.route("/chat/response-stream", get(get_response_stream))
.route("/chat/reference/:id", get(show_reference_tooltip))
}

View File

@@ -0,0 +1,45 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
};
use serde::Serialize;
use common::{
error::AppError,
storage::types::{knowledge_entity::KnowledgeEntity, user::User},
};
use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
};
pub async fn show_reference_tooltip(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(reference_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let entity: KnowledgeEntity = state
.db
.get_item(&reference_id)
.await?
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
if entity.user_id != user.id {
return Ok(TemplateResponse::unauthorized());
}
#[derive(Serialize)]
struct ReferenceTooltipData {
entity: KnowledgeEntity,
user: User,
}
Ok(TemplateResponse::new_template(
"chat/reference_tooltip.html",
ReferenceTooltipData { entity, user },
))
}

View File

@@ -0,0 +1,90 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
Form,
};
use serde::{Deserialize, Serialize};
use common::storage::types::{text_content::TextContent, user::User};
use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
};
#[derive(Serialize)]
pub struct ContentPageData {
user: User,
text_contents: Vec<TextContent>,
}
pub async fn show_content_page(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let text_contents = User::get_text_contents(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
"content/base.html",
ContentPageData {
user,
text_contents,
},
))
}
pub async fn show_text_content_edit_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
#[derive(Serialize)]
pub struct TextContentEditModal {
pub user: User,
pub text_content: TextContent,
}
Ok(TemplateResponse::new_template(
"content/edit_text_content_modal.html",
TextContentEditModal { user, text_content },
))
}
#[derive(Deserialize)]
pub struct PatchTextContentParams {
instructions: String,
category: String,
text: String,
}
pub async fn patch_text_content(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
Form(form): Form<PatchTextContentParams>,
) -> Result<impl IntoResponse, HtmlError> {
User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
TextContent::patch(
&id,
&form.instructions,
&form.category,
&form.text,
&state.db,
)
.await?;
let text_contents = User::get_text_contents(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
"content/content_list.html",
ContentPageData {
user,
text_contents,
},
))
}

View File

@@ -0,0 +1,19 @@
mod handlers;
use axum::{extract::FromRef, routing::get, Router};
use handlers::{patch_text_content, show_content_page, show_text_content_edit_form};
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/content", get(show_content_page))
.route(
"/content/:id",
get(show_text_content_edit_form).patch(patch_text_content),
)
}

View File

@@ -0,0 +1,164 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
};
use serde::Serialize;
use tokio::join;
use crate::{
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
AuthSessionType,
};
use common::{
error::AppError,
storage::types::{
file_info::FileInfo, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
text_content::TextContent, user::User,
},
};
use crate::html_state::HtmlState;
#[derive(Serialize)]
pub struct IndexPageData {
user: Option<User>,
latest_text_contents: Vec<TextContent>,
active_jobs: Vec<IngestionTask>,
}
pub async fn index_handler(
State(state): State<HtmlState>,
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
let Some(user) = auth.current_user else {
return Ok(TemplateResponse::new_template(
"index/index.html",
IndexPageData {
user: None,
latest_text_contents: vec![],
active_jobs: vec![],
},
));
};
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
let latest_text_contents = User::get_latest_text_contents(user.id.as_str(), &state.db).await?;
Ok(TemplateResponse::new_template(
"index/index.html",
IndexPageData {
user: Some(user),
latest_text_contents,
active_jobs,
},
))
}
#[derive(Serialize)]
pub struct LatestTextContentData {
latest_text_contents: Vec<TextContent>,
user: User,
}
pub async fn delete_text_content(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// Get and validate TextContent
let text_content = get_and_validate_text_content(&state, &id, &user).await?;
// Perform concurrent deletions
let (_res1, _res2, _res3, _res4, _res5) = join!(
async {
if let Some(file_info) = text_content.file_info {
FileInfo::delete_by_id(&file_info.id, &state.db).await
} else {
Ok(())
}
},
state.db.delete_item::<TextContent>(&text_content.id),
TextChunk::delete_by_source_id(&text_content.id, &state.db),
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db),
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db)
);
// Render updated content
let latest_text_contents = User::get_latest_text_contents(&user.id, &state.db).await?;
Ok(TemplateResponse::new_partial(
"index/signed_in/recent_content.html",
"latest_content_section",
LatestTextContentData {
user: user.to_owned(),
latest_text_contents,
},
))
}
// Helper function to get and validate text content
async fn get_and_validate_text_content(
state: &HtmlState,
id: &str,
user: &User,
) -> Result<TextContent, AppError> {
let text_content = state
.db
.get_item::<TextContent>(id)
.await?
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
if text_content.user_id != user.id {
return Err(AppError::Auth(
"You are not the owner of that content".to_string(),
));
}
Ok(text_content)
}
#[derive(Serialize)]
pub struct ActiveJobsData {
pub active_jobs: Vec<IngestionTask>,
pub user: User,
}
pub async fn delete_job(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
User::validate_and_delete_job(&id, &user.id, &state.db).await?;
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
Ok(TemplateResponse::new_partial(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
))
}
pub async fn show_active_jobs(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
Ok(TemplateResponse::new_partial(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
))
}

View File

@@ -0,0 +1,29 @@
pub mod handlers;
use axum::{
extract::FromRef,
routing::{delete, get},
Router,
};
use handlers::{delete_job, delete_text_content, index_handler, show_active_jobs};
use crate::html_state::HtmlState;
pub fn public_router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new().route("/", get(index_handler))
}
pub fn protected_router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/jobs/:job_id", delete(delete_job))
.route("/active-jobs", get(show_active_jobs))
.route("/text-content/:id", delete(delete_text_content))
}

View File

@@ -0,0 +1,127 @@
use axum::{
extract::State,
response::{Html, IntoResponse},
};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use futures::{future::try_join_all, TryFutureExt};
use serde::Serialize;
use tempfile::NamedTempFile;
use tracing::info;
use common::{
error::AppError,
storage::types::{
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
user::User,
},
};
use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
routes::index::handlers::ActiveJobsData,
};
pub async fn show_ingress_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
#[derive(Serialize)]
pub struct ShowIngressFormData {
user_categories: Vec<String>,
}
Ok(TemplateResponse::new_template(
"index/signed_in/ingress_modal.html",
ShowIngressFormData { user_categories },
))
}
pub async fn hide_ingress_form(
RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
Ok(Html(
"<a class='btn btn-primary' hx-get='/ingress-form' hx-swap='outerHTML'>Add Content</a>",
)
.into_response())
}
#[derive(Debug, TryFromMultipart)]
pub struct IngressParams {
pub content: Option<String>,
pub instructions: String,
pub category: String,
#[form_data(limit = "10000000")] // Adjust limit as needed
#[form_data(default)]
pub files: Vec<FieldData<NamedTempFile>>,
}
pub async fn process_ingress_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
TypedMultipart(input): TypedMultipart<IngressParams>,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
pub struct IngressFormData {
instructions: String,
content: String,
category: String,
error: String,
}
if input.content.as_ref().map_or(true, |c| c.len() < 2) && input.files.is_empty() {
return Ok(TemplateResponse::new_template(
"index/signed_in/ingress_form.html",
IngressFormData {
instructions: input.instructions.clone(),
content: input.content.clone().unwrap_or_default(),
category: input.category.clone(),
error: "You need to either add files or content".to_string(),
},
));
}
info!("{:?}", input);
let file_infos = try_join_all(
input
.files
.into_iter()
.map(|file| FileInfo::new(file, &state.db, &user.id).map_err(AppError::from)),
)
.await?;
let payloads = IngestionPayload::create_ingestion_payload(
input.content,
input.instructions,
input.category,
file_infos,
user.id.as_str(),
)?;
let futures: Vec<_> = payloads
.into_iter()
.map(|object| {
IngestionTask::create_and_add_to_db(object.clone(), user.id.clone(), &state.db)
})
.collect();
try_join_all(futures).await?;
// Update the active jobs page with the newly created job
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
Ok(TemplateResponse::new_partial(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
))
}

View File

@@ -0,0 +1,19 @@
mod handlers;
use axum::{extract::FromRef, routing::get, Router};
use handlers::{hide_ingress_form, process_ingress_form, show_ingress_form};
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route(
"/ingress-form",
get(show_ingress_form).post(process_ingress_form),
)
.route("/hide-ingress-form", get(hide_ingress_form))
}

View File

@@ -0,0 +1,291 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
Form,
};
use plotly::{
common::{Line, Marker, Mode},
layout::{Axis, Camera, LayoutScene, ProjectionType},
Layout, Plot, Scatter3D,
};
use serde::{Deserialize, Serialize};
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
user::User,
};
use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
};
pub async fn show_knowledge_page(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
pub struct KnowledgeBaseData {
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
user: User,
plot_html: String,
}
let entities = User::get_knowledge_entities(&user.id, &state.db).await?;
let relationships = User::get_knowledge_relationships(&user.id, &state.db).await?;
let mut plot = Plot::new();
// Fibonacci sphere distribution
let node_count = entities.len();
let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
let node_positions: Vec<(f64, f64, f64)> = (0..node_count)
.map(|i| {
let i = i as f64;
let theta = 2.0 * std::f64::consts::PI * i / golden_ratio;
let phi = (1.0 - 2.0 * (i + 0.5) / node_count as f64).acos();
let x = phi.sin() * theta.cos();
let y = phi.sin() * theta.sin();
let z = phi.cos();
(x, y, z)
})
.collect();
let node_x: Vec<f64> = node_positions.iter().map(|(x, _, _)| *x).collect();
let node_y: Vec<f64> = node_positions.iter().map(|(_, y, _)| *y).collect();
let node_z: Vec<f64> = node_positions.iter().map(|(_, _, z)| *z).collect();
// Nodes trace
let nodes = Scatter3D::new(node_x.clone(), node_y.clone(), node_z.clone())
.mode(Mode::Markers)
.marker(Marker::new().size(8).color("#1f77b4"))
.text_array(
entities
.iter()
.map(|e| e.description.clone())
.collect::<Vec<_>>(),
)
.hover_template("Entity: %{text}<br>");
// Edges traces
for rel in &relationships {
let from_idx = entities.iter().position(|e| e.id == rel.out).unwrap_or(0);
let to_idx = entities.iter().position(|e| e.id == rel.in_).unwrap_or(0);
let edge_x = vec![node_x[from_idx], node_x[to_idx]];
let edge_y = vec![node_y[from_idx], node_y[to_idx]];
let edge_z = vec![node_z[from_idx], node_z[to_idx]];
let edge_trace = Scatter3D::new(edge_x, edge_y, edge_z)
.mode(Mode::Lines)
.line(Line::new().color("#888").width(2.0))
.hover_template(format!(
"Relationship: {}<br>",
rel.metadata.relationship_type
))
.show_legend(false);
plot.add_trace(edge_trace);
}
plot.add_trace(nodes);
// Layout
let layout = Layout::new()
.scene(
LayoutScene::new()
.x_axis(Axis::new().visible(false))
.y_axis(Axis::new().visible(false))
.z_axis(Axis::new().visible(false))
.camera(
Camera::new()
.projection(ProjectionType::Perspective.into())
.eye((1.5, 1.5, 1.5).into()),
),
)
.show_legend(false)
.paper_background_color("rbga(250,100,0,0)")
.plot_background_color("rbga(0,0,0,0)");
plot.set_layout(layout);
// Convert to HTML
let html = plot.to_html();
Ok(TemplateResponse::new_template(
"knowledge/base.html",
KnowledgeBaseData {
entities,
relationships,
user,
plot_html: html,
},
))
}
pub async fn show_edit_knowledge_entity_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
pub struct EntityData {
entity: KnowledgeEntity,
entity_types: Vec<String>,
user: User,
}
// Get entity types
let entity_types: Vec<String> = KnowledgeEntityType::variants()
.iter()
.map(|s| s.to_string())
.collect();
// Get the entity and validate ownership
let entity = User::get_and_validate_knowledge_entity(&id, &user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
"knowledge/edit_knowledge_entity_modal.html",
EntityData {
entity,
user,
entity_types,
},
))
}
#[derive(Debug, Deserialize)]
pub struct PatchKnowledgeEntityParams {
pub id: String,
pub name: String,
pub entity_type: String,
pub description: String,
}
#[derive(Serialize)]
pub struct EntityListData {
entities: Vec<KnowledgeEntity>,
user: User,
}
pub async fn patch_knowledge_entity(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<PatchKnowledgeEntityParams>,
) -> Result<impl IntoResponse, HtmlError> {
// Get the existing entity and validate that the user is allowed
User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.db).await?;
let entity_type: KnowledgeEntityType = KnowledgeEntityType::from(form.entity_type);
// Update the entity
KnowledgeEntity::patch(
&form.id,
&form.name,
&form.description,
&entity_type,
&state.db,
&state.openai_client,
)
.await?;
// Get updated list of entities
let entities = User::get_knowledge_entities(&user.id, &state.db).await?;
// Render updated list
Ok(TemplateResponse::new_template(
"knowledge/entity_list.html",
EntityListData { entities, user },
))
}
pub async fn delete_knowledge_entity(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// Get the existing entity and validate that the user is allowed
User::get_and_validate_knowledge_entity(&id, &user.id, &state.db).await?;
// Delete the entity
state.db.delete_item::<KnowledgeEntity>(&id).await?;
// Get updated list of entities
let entities = User::get_knowledge_entities(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template(
"knowledge/entity_list.html",
EntityListData { entities, user },
))
}
#[derive(Serialize)]
pub struct RelationshipTableData {
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
}
pub async fn delete_knowledge_relationship(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// GOTTA ADD AUTH VALIDATION
KnowledgeRelationship::delete_relationship_by_id(&id, &state.db).await?;
let entities = User::get_knowledge_entities(&user.id, &state.db).await?;
let relationships = User::get_knowledge_relationships(&user.id, &state.db).await?;
// Render updated list
Ok(TemplateResponse::new_template(
"knowledge/relationship_table.html",
RelationshipTableData {
entities,
relationships,
},
))
}
#[derive(Deserialize)]
pub struct SaveKnowledgeRelationshipInput {
pub in_: String,
pub out: String,
pub relationship_type: String,
}
pub async fn save_knowledge_relationship(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(form): Form<SaveKnowledgeRelationshipInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Construct relationship
let relationship = KnowledgeRelationship::new(
form.in_,
form.out,
user.id.clone(),
"manual".into(),
form.relationship_type,
);
relationship.store_relationship(&state.db).await?;
let entities = User::get_knowledge_entities(&user.id, &state.db).await?;
let relationships = User::get_knowledge_relationships(&user.id, &state.db).await?;
// Render updated list
Ok(TemplateResponse::new_template(
"knowledge/relationship_table.html",
RelationshipTableData {
entities,
relationships,
},
))
}

View File

@@ -0,0 +1,33 @@
mod handlers;
use axum::{
extract::FromRef,
routing::{delete, get, post},
Router,
};
use handlers::{
delete_knowledge_entity, delete_knowledge_relationship, patch_knowledge_entity,
save_knowledge_relationship, show_edit_knowledge_entity_form, show_knowledge_page,
};
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/knowledge", get(show_knowledge_page))
.route(
"/knowledge-entity/:id",
get(show_edit_knowledge_entity_form)
.delete(delete_knowledge_entity)
.patch(patch_knowledge_entity),
)
.route("/knowledge-relationship", post(save_knowledge_relationship))
.route(
"/knowledge-relationship/:id",
delete(delete_knowledge_relationship),
)
}

View File

@@ -0,0 +1,9 @@
pub mod account;
pub mod admin;
pub mod auth;
pub mod chat;
pub mod content;
pub mod index;
pub mod ingestion;
pub mod knowledge;
pub mod search;

View File

@@ -0,0 +1,44 @@
use axum::{
extract::{Query, State},
response::IntoResponse,
};
use composite_retrieval::answer_retrieval::get_answer_with_references;
use serde::{Deserialize, Serialize};
use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
};
#[derive(Deserialize)]
pub struct SearchParams {
query: String,
}
pub async fn search_result_handler(
State(state): State<HtmlState>,
Query(query): Query<SearchParams>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
pub struct AnswerData {
user_query: String,
answer_content: String,
answer_references: Vec<String>,
}
let answer =
get_answer_with_references(&state.db, &state.openai_client, &query.query, &user.id).await?;
Ok(TemplateResponse::new_template(
"index/signed_in/search_response.html",
AnswerData {
user_query: query.query,
answer_content: answer.content,
answer_references: answer.references,
},
))
}

View File

@@ -0,0 +1,14 @@
mod handlers;
use axum::{extract::FromRef, routing::get, Router};
use handlers::search_result_handler;
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new().route("/search", get(search_result_handler))
}