mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-24 17:58:31 +02:00
in progress, routers and main split up
This commit is contained in:
18
crates/api-router/Cargo.toml
Normal file
18
crates/api-router/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "api-router"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
# Workspace dependencies
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
tempfile = "3.12.0"
|
||||
futures = "0.3.31"
|
||||
axum_typed_multipart = "0.12.1"
|
||||
|
||||
common = { path = "../common" }
|
||||
33
crates/api-router/src/api_state.rs
Normal file
33
crates/api-router/src/api_state.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{ingress::jobqueue::JobQueue, storage::db::SurrealDbClient, utils::config::AppConfig};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiState {
|
||||
pub surreal_db_client: Arc<SurrealDbClient>,
|
||||
pub job_queue: Arc<JobQueue>,
|
||||
}
|
||||
|
||||
impl ApiState {
|
||||
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let surreal_db_client = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
&config.surrealdb_namespace,
|
||||
&config.surrealdb_database,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
|
||||
surreal_db_client.ensure_initialized().await?;
|
||||
|
||||
let app_state = ApiState {
|
||||
surreal_db_client: surreal_db_client.clone(),
|
||||
job_queue: Arc::new(JobQueue::new(surreal_db_client)),
|
||||
};
|
||||
|
||||
Ok(app_state)
|
||||
}
|
||||
}
|
||||
25
crates/api-router/src/lib.rs
Normal file
25
crates/api-router/src/lib.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use api_state::ApiState;
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, FromRef},
|
||||
middleware::from_fn_with_state,
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use middleware_api_auth::api_auth;
|
||||
use routes::ingress::ingress_data;
|
||||
|
||||
pub mod api_state;
|
||||
mod middleware_api_auth;
|
||||
mod routes;
|
||||
|
||||
/// Router for API functionality, version 1
|
||||
pub fn api_routes_v1<S>(app_state: &ApiState) -> Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
ApiState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route("/ingress", post(ingress_data))
|
||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth))
|
||||
}
|
||||
43
crates/api-router/src/middleware_api_auth.rs
Normal file
43
crates/api-router/src/middleware_api_auth.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
|
||||
use common::{error::ApiError, storage::types::user::User};
|
||||
|
||||
use crate::api_state::ApiState;
|
||||
|
||||
pub async fn api_auth(
|
||||
State(state): State<ApiState>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ApiError> {
|
||||
let api_key = extract_api_key(&request).ok_or(ApiError::Unauthorized(
|
||||
"You have to be authenticated".to_string(),
|
||||
))?;
|
||||
|
||||
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
|
||||
let user = user.ok_or(ApiError::Unauthorized(
|
||||
"You have to be authenticated".to_string(),
|
||||
))?;
|
||||
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
fn extract_api_key(request: &Request) -> Option<String> {
|
||||
request
|
||||
.headers()
|
||||
.get("X-API-Key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.or_else(|| {
|
||||
request
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(|s| s.trim()))
|
||||
})
|
||||
.map(String::from)
|
||||
}
|
||||
57
crates/api-router/src/routes/ingress.rs
Normal file
57
crates/api-router/src/routes/ingress.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
use common::{
|
||||
error::{ApiError, AppError},
|
||||
ingress::ingress_input::{create_ingress_objects, IngressInput},
|
||||
storage::types::{file_info::FileInfo, user::User},
|
||||
};
|
||||
use futures::{future::try_join_all, TryFutureExt};
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::api_state::ApiState;
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
pub struct IngressParams {
|
||||
pub content: Option<String>,
|
||||
pub instructions: String,
|
||||
pub category: String,
|
||||
#[form_data(limit = "10000000")] // Adjust limit as needed
|
||||
#[form_data(default)]
|
||||
pub files: Vec<FieldData<NamedTempFile>>,
|
||||
}
|
||||
|
||||
pub async fn ingress_data(
|
||||
State(state): State<ApiState>,
|
||||
Extension(user): Extension<User>,
|
||||
TypedMultipart(input): TypedMultipart<IngressParams>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
info!("Received input: {:?}", input);
|
||||
|
||||
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
||||
FileInfo::new(file, &state.surreal_db_client, &user.id).map_err(AppError::from)
|
||||
}))
|
||||
.await?;
|
||||
|
||||
debug!("Got file infos");
|
||||
|
||||
let ingress_objects = create_ingress_objects(
|
||||
IngressInput {
|
||||
content: input.content,
|
||||
instructions: input.instructions,
|
||||
category: input.category,
|
||||
files: file_infos,
|
||||
},
|
||||
user.id.as_str(),
|
||||
)?;
|
||||
debug!("Got ingress objects");
|
||||
|
||||
let futures: Vec<_> = ingress_objects
|
||||
.into_iter()
|
||||
.map(|object| state.job_queue.enqueue(object.clone(), user.id.clone()))
|
||||
.collect();
|
||||
|
||||
try_join_all(futures).await.map_err(AppError::from)?;
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
1
crates/api-router/src/routes/mod.rs
Normal file
1
crates/api-router/src/routes/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod ingress;
|
||||
47
crates/common/Cargo.toml
Normal file
47
crates/common/Cargo.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[package]
|
||||
name = "common"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
# Workspace dependencies
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
async-openai = "0.24.1"
|
||||
async-stream = "0.3.6"
|
||||
axum-htmx = "0.6.0"
|
||||
axum_session = "0.14.4"
|
||||
axum_session_auth = "0.14.1"
|
||||
axum_session_surreal = "0.2.1"
|
||||
axum_typed_multipart = "0.12.1"
|
||||
chrono = { version = "0.4.39", features = ["serde"] }
|
||||
chrono-tz = "0.10.1"
|
||||
config = "0.15.4"
|
||||
futures = "0.3.31"
|
||||
json-stream-parser = "0.1.4"
|
||||
lettre = { version = "0.11.11", features = ["rustls-tls"] }
|
||||
mime = "0.3.17"
|
||||
mime_guess = "2.0.5"
|
||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
minijinja-autoreload = "2.5.0"
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
mockall = "0.13.0"
|
||||
plotly = "0.12.1"
|
||||
reqwest = {version = "0.12.12", features = ["charset", "json"]}
|
||||
scraper = "0.22.0"
|
||||
sha2 = "0.10.8"
|
||||
surrealdb = "2.0.4"
|
||||
tempfile = "3.12.0"
|
||||
text-splitter = "0.18.1"
|
||||
tiktoken-rs = "0.6.0"
|
||||
tower-http = { version = "0.6.2", features = ["fs"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
url = { version = "2.5.2", features = ["serde"] }
|
||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||
|
||||
254
crates/common/src/error.rs
Normal file
254
crates/common/src/error.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_openai::error::OpenAIError;
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{Html, IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use minijinja::context;
|
||||
use minijinja_autoreload::AutoReloader;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinError;
|
||||
|
||||
use crate::{storage::types::file_info::FileError, utils::mailer::EmailError};
|
||||
|
||||
// Core internal errors
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] surrealdb::Error),
|
||||
#[error("OpenAI error: {0}")]
|
||||
OpenAI(#[from] OpenAIError),
|
||||
#[error("File error: {0}")]
|
||||
File(#[from] FileError),
|
||||
#[error("Email error: {0}")]
|
||||
Email(#[from] EmailError),
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
#[error("Validation error: {0}")]
|
||||
Validation(String),
|
||||
#[error("Authorization error: {0}")]
|
||||
Auth(String),
|
||||
#[error("LLM parsing error: {0}")]
|
||||
LLMParsing(String),
|
||||
#[error("Task join error: {0}")]
|
||||
Join(#[from] JoinError),
|
||||
#[error("Graph mapper error: {0}")]
|
||||
GraphMapper(String),
|
||||
#[error("IoError: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Minijina error: {0}")]
|
||||
MiniJinja(#[from] minijinja::Error),
|
||||
#[error("Reqwest error: {0}")]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
#[error("Tiktoken error: {0}")]
|
||||
Tiktoken(#[from] anyhow::Error),
|
||||
#[error("Ingress Processing error: {0}")]
|
||||
Processing(String),
|
||||
}
|
||||
|
||||
// API-specific errors
|
||||
#[derive(Debug, Serialize)]
|
||||
pub enum ApiError {
|
||||
InternalError(String),
|
||||
ValidationError(String),
|
||||
NotFound(String),
|
||||
Unauthorized(String),
|
||||
}
|
||||
|
||||
impl From<AppError> for ApiError {
|
||||
fn from(err: AppError) -> Self {
|
||||
match err {
|
||||
AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => {
|
||||
tracing::error!("Internal error: {:?}", err);
|
||||
ApiError::InternalError("Internal server error".to_string())
|
||||
}
|
||||
AppError::NotFound(msg) => ApiError::NotFound(msg),
|
||||
AppError::Validation(msg) => ApiError::ValidationError(msg),
|
||||
AppError::Auth(msg) => ApiError::Unauthorized(msg),
|
||||
_ => ApiError::InternalError("Internal server error".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, body) = match self {
|
||||
ApiError::InternalError(message) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
json!({
|
||||
"error": message,
|
||||
"status": "error"
|
||||
}),
|
||||
),
|
||||
ApiError::ValidationError(message) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({
|
||||
"error": message,
|
||||
"status": "error"
|
||||
}),
|
||||
),
|
||||
ApiError::NotFound(message) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
json!({
|
||||
"error": message,
|
||||
"status": "error"
|
||||
}),
|
||||
),
|
||||
ApiError::Unauthorized(message) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({
|
||||
"error": message,
|
||||
"status": "error"
|
||||
}),
|
||||
), // ... other matches
|
||||
};
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub type TemplateResult<T> = Result<T, HtmlError>;
|
||||
|
||||
// Helper trait for converting to HtmlError with templates
|
||||
pub trait IntoHtmlError {
|
||||
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError;
|
||||
}
|
||||
// // Implement for AppError
|
||||
impl IntoHtmlError for AppError {
|
||||
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
|
||||
HtmlError::new(self, templates)
|
||||
}
|
||||
}
|
||||
// // Implement for minijinja::Error directly
|
||||
impl IntoHtmlError for minijinja::Error {
|
||||
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
|
||||
HtmlError::from_template_error(self, templates)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ErrorContext {
|
||||
#[allow(dead_code)]
|
||||
templates: Arc<AutoReloader>,
|
||||
}
|
||||
|
||||
impl ErrorContext {
|
||||
pub fn new(templates: Arc<AutoReloader>) -> Self {
|
||||
Self { templates }
|
||||
}
|
||||
}
|
||||
|
||||
pub enum HtmlError {
|
||||
ServerError(Arc<AutoReloader>),
|
||||
NotFound(Arc<AutoReloader>),
|
||||
Unauthorized(Arc<AutoReloader>),
|
||||
BadRequest(String, Arc<AutoReloader>),
|
||||
Template(String, Arc<AutoReloader>),
|
||||
}
|
||||
|
||||
impl HtmlError {
|
||||
pub fn new(error: AppError, templates: Arc<AutoReloader>) -> Self {
|
||||
match error {
|
||||
AppError::NotFound(_msg) => HtmlError::NotFound(templates),
|
||||
AppError::Auth(_msg) => HtmlError::Unauthorized(templates),
|
||||
AppError::Validation(msg) => HtmlError::BadRequest(msg, templates),
|
||||
_ => {
|
||||
tracing::error!("Internal error: {:?}", error);
|
||||
HtmlError::ServerError(templates)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_template_error(error: minijinja::Error, templates: Arc<AutoReloader>) -> Self {
|
||||
tracing::error!("Template error: {:?}", error);
|
||||
HtmlError::Template(error.to_string(), templates)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for HtmlError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, context, templates) = match self {
|
||||
HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
context! {
|
||||
status_code => 500,
|
||||
title => "Internal Server Error",
|
||||
error => "Internal Server Error",
|
||||
description => "Something went wrong on our end."
|
||||
},
|
||||
templates,
|
||||
),
|
||||
HtmlError::NotFound(templates) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
context! {
|
||||
status_code => 404,
|
||||
title => "Page Not Found",
|
||||
error => "Not Found",
|
||||
description => "The page you're looking for doesn't exist or was removed."
|
||||
},
|
||||
templates,
|
||||
),
|
||||
HtmlError::Unauthorized(templates) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
context! {
|
||||
status_code => 401,
|
||||
title => "Unauthorized",
|
||||
error => "Access Denied",
|
||||
description => "You need to be logged in to access this page."
|
||||
},
|
||||
templates,
|
||||
),
|
||||
HtmlError::BadRequest(msg, templates) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
context! {
|
||||
status_code => 400,
|
||||
title => "Bad Request",
|
||||
error => "Bad Request",
|
||||
description => msg
|
||||
},
|
||||
templates,
|
||||
),
|
||||
};
|
||||
|
||||
let html = match templates.acquire_env() {
|
||||
Ok(env) => match env.get_template("errors/error.html") {
|
||||
Ok(tmpl) => match tmpl.render(context) {
|
||||
Ok(output) => output,
|
||||
Err(e) => {
|
||||
tracing::error!("Template render error: {:?}", e);
|
||||
Self::fallback_html()
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Template get error: {:?}", e);
|
||||
Self::fallback_html()
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Environment acquire error: {:?}", e);
|
||||
Self::fallback_html()
|
||||
}
|
||||
};
|
||||
|
||||
(status, Html(html)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl HtmlError {
|
||||
fn fallback_html() -> String {
|
||||
r#"
|
||||
<html>
|
||||
<body>
|
||||
<div class="container mx-auto p-4">
|
||||
<h1 class="text-4xl text-error">Error</h1>
|
||||
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"#
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
145
crates/common/src/ingress/analysis/ingress_analyser.rs
Normal file
145
crates/common/src/ingress/analysis/ingress_analyser.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use crate::{
|
||||
error::AppError,
|
||||
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
||||
retrieval::combined_knowledge_entity_retrieval,
|
||||
storage::types::knowledge_entity::KnowledgeEntity,
|
||||
};
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
||||
ResponseFormatJsonSchema,
|
||||
},
|
||||
};
|
||||
use serde_json::json;
|
||||
use surrealdb::engine::any::Any;
|
||||
use surrealdb::Surreal;
|
||||
use tracing::debug;
|
||||
|
||||
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
|
||||
|
||||
pub struct IngressAnalyzer<'a> {
|
||||
db_client: &'a Surreal<Any>,
|
||||
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
}
|
||||
|
||||
impl<'a> IngressAnalyzer<'a> {
|
||||
pub fn new(
|
||||
db_client: &'a Surreal<Any>,
|
||||
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn analyze_content(
|
||||
&self,
|
||||
category: &str,
|
||||
instructions: &str,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<LLMGraphAnalysisResult, AppError> {
|
||||
let similar_entities = self
|
||||
.find_similar_entities(category, instructions, text, user_id)
|
||||
.await?;
|
||||
let llm_request =
|
||||
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
|
||||
self.perform_analysis(llm_request).await
|
||||
}
|
||||
|
||||
async fn find_similar_entities(
|
||||
&self,
|
||||
category: &str,
|
||||
instructions: &str,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let input_text = format!(
|
||||
"content: {}, category: {}, user_instructions: {}",
|
||||
text, category, instructions
|
||||
);
|
||||
|
||||
combined_knowledge_entity_retrieval(
|
||||
self.db_client,
|
||||
self.openai_client,
|
||||
&input_text,
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn prepare_llm_request(
|
||||
&self,
|
||||
category: &str,
|
||||
instructions: &str,
|
||||
text: &str,
|
||||
similar_entities: &[KnowledgeEntity],
|
||||
) -> Result<CreateChatCompletionRequest, OpenAIError> {
|
||||
let entities_json = json!(similar_entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entity.id,
|
||||
"name": entity.name,
|
||||
"description": entity.description
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>());
|
||||
|
||||
let user_message = format!(
|
||||
"Category:\n{}\nInstructions:\n{}\nContent:\n{}\nExisting KnowledgeEntities in database:\n{}",
|
||||
category, instructions, text, entities_json
|
||||
);
|
||||
|
||||
debug!("Prepared LLM request message: {}", user_message);
|
||||
|
||||
let response_format = ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: Some("Structured analysis of the submitted content".into()),
|
||||
name: "content_analysis".into(),
|
||||
schema: Some(get_ingress_analysis_schema()),
|
||||
strict: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.model("gpt-4o-mini")
|
||||
.temperature(0.2)
|
||||
.max_tokens(3048u32)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
])
|
||||
.response_format(response_format)
|
||||
.build()
|
||||
}
|
||||
|
||||
async fn perform_analysis(
|
||||
&self,
|
||||
request: CreateChatCompletionRequest,
|
||||
) -> Result<LLMGraphAnalysisResult, AppError> {
|
||||
let response = self.openai_client.chat().create(request).await?;
|
||||
debug!("Received LLM response: {:?}", response);
|
||||
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".to_string(),
|
||||
))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!(
|
||||
"Failed to parse LLM response into analysis: {}",
|
||||
e
|
||||
))
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
3
crates/common/src/ingress/analysis/mod.rs
Normal file
3
crates/common/src/ingress/analysis/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod ingress_analyser;
|
||||
pub mod prompt;
|
||||
pub mod types;
|
||||
81
crates/common/src/ingress/analysis/prompt.rs
Normal file
81
crates/common/src/ingress/analysis/prompt.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub static INGRESS_ANALYSIS_SYSTEM_MESSAGE: &str = r#"
|
||||
You are an AI assistant. You will receive a text content, along with user instructions and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users instructions and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
|
||||
|
||||
The JSON should have the following structure:
|
||||
|
||||
{
|
||||
"knowledge_entities": [
|
||||
{
|
||||
"key": "unique-key-1",
|
||||
"name": "Entity Name",
|
||||
"description": "A detailed description of the entity.",
|
||||
"entity_type": "TypeOfEntity"
|
||||
},
|
||||
// More entities...
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"type": "RelationshipType",
|
||||
"source": "unique-key-1 or UUID from existing database",
|
||||
"target": "unique-key-1 or UUID from existing database"
|
||||
},
|
||||
// More relationships...
|
||||
]
|
||||
}
|
||||
|
||||
Guidelines:
|
||||
1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.
|
||||
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
|
||||
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
|
||||
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
|
||||
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
|
||||
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
|
||||
7. Only create relationships between existing KnowledgeEntities.
|
||||
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
|
||||
9. A new relationship MUST include a newly created KnowledgeEntity.
|
||||
"#;
|
||||
|
||||
pub fn get_ingress_analysis_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"knowledge_entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": { "type": "string" },
|
||||
"name": { "type": "string" },
|
||||
"description": { "type": "string" },
|
||||
"entity_type": {
|
||||
"type": "string",
|
||||
"enum": ["idea", "project", "document", "page", "textsnippet"]
|
||||
}
|
||||
},
|
||||
"required": ["key", "name", "description", "entity_type"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"relationships": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["RelatedTo", "RelevantTo", "SimilarTo"]
|
||||
},
|
||||
"source": { "type": "string" },
|
||||
"target": { "type": "string" }
|
||||
},
|
||||
"required": ["type", "source", "target"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["knowledge_entities", "relationships"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
42
crates/common/src/ingress/analysis/types/graph_mapper.rs
Normal file
42
crates/common/src/ingress/analysis/types/graph_mapper.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
|
||||
#[derive(Clone)]
|
||||
pub struct GraphMapper {
|
||||
pub key_to_id: HashMap<String, Uuid>,
|
||||
}
|
||||
|
||||
impl Default for GraphMapper {
|
||||
fn default() -> Self {
|
||||
GraphMapper::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphMapper {
|
||||
pub fn new() -> Self {
|
||||
GraphMapper {
|
||||
key_to_id: HashMap::new(),
|
||||
}
|
||||
}
|
||||
/// Get ID, tries to parse UUID
|
||||
pub fn get_or_parse_id(&mut self, key: &str) -> Uuid {
|
||||
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
|
||||
parsed_uuid
|
||||
} else {
|
||||
*self.key_to_id.get(key).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a new UUID for a given key.
|
||||
pub fn assign_id(&mut self, key: &str) -> Uuid {
|
||||
let id = Uuid::new_v4();
|
||||
self.key_to_id.insert(key.to_string(), id);
|
||||
id
|
||||
}
|
||||
|
||||
/// Retrieves the UUID for a given key.
|
||||
pub fn get_id(&self, key: &str) -> Option<&Uuid> {
|
||||
self.key_to_id.get(key)
|
||||
}
|
||||
}
|
||||
182
crates/common/src/ingress/analysis/types/llm_analysis_result.rs
Normal file
182
crates/common/src/ingress/analysis/types/llm_analysis_result.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::task;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use futures::future::try_join_all;
|
||||
|
||||
use super::graph_mapper::GraphMapper; // For future parallelization
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMKnowledgeEntity {
|
||||
pub key: String, // Temporary identifier
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub entity_type: String, // Should match KnowledgeEntityType variants
|
||||
}
|
||||
|
||||
/// Represents a single relationship from the LLM.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMRelationship {
|
||||
#[serde(rename = "type")]
|
||||
pub type_: String, // e.g., RelatedTo, RelevantTo
|
||||
pub source: String, // Key of the source entity
|
||||
pub target: String, // Key of the target entity
|
||||
}
|
||||
|
||||
/// Represents the entire graph analysis result from the LLM.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMGraphAnalysisResult {
|
||||
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
|
||||
pub relationships: Vec<LLMRelationship>,
|
||||
}
|
||||
|
||||
impl LLMGraphAnalysisResult {
|
||||
/// Converts the LLM graph analysis result into database entities and relationships.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - A UUID representing the source identifier.
|
||||
/// * `openai_client` - OpenAI client for LLM calls.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
|
||||
pub async fn to_database_entities(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||
// Create mapper and pre-assign IDs
|
||||
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
|
||||
|
||||
// Process entities
|
||||
let entities = self
|
||||
.process_entities(source_id, user_id, Arc::clone(&mapper), openai_client)
|
||||
.await?;
|
||||
|
||||
// Process relationships
|
||||
let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?;
|
||||
|
||||
Ok((entities, relationships))
|
||||
}
|
||||
|
||||
fn create_mapper(&self) -> Result<GraphMapper, AppError> {
|
||||
let mut mapper = GraphMapper::new();
|
||||
|
||||
// Pre-assign all IDs
|
||||
for entity in &self.knowledge_entities {
|
||||
mapper.assign_id(&entity.key);
|
||||
}
|
||||
|
||||
Ok(mapper)
|
||||
}
|
||||
|
||||
async fn process_entities(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let futures: Vec<_> = self
|
||||
.knowledge_entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
let mapper = Arc::clone(&mapper);
|
||||
let openai_client = openai_client.clone();
|
||||
let source_id = source_id.to_string();
|
||||
let user_id = user_id.to_string();
|
||||
let entity = entity.clone();
|
||||
|
||||
task::spawn(async move {
|
||||
create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client)
|
||||
.await
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = try_join_all(futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn process_relationships(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
) -> Result<Vec<KnowledgeRelationship>, AppError> {
|
||||
let mut mapper_guard = mapper
|
||||
.lock()
|
||||
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
|
||||
self.relationships
|
||||
.iter()
|
||||
.map(|rel| {
|
||||
let source_db_id = mapper_guard.get_or_parse_id(&rel.source);
|
||||
let target_db_id = mapper_guard.get_or_parse_id(&rel.target);
|
||||
|
||||
Ok(KnowledgeRelationship::new(
|
||||
source_db_id.to_string(),
|
||||
target_db_id.to_string(),
|
||||
user_id.to_string(),
|
||||
source_id.to_string(),
|
||||
rel.type_.clone(),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
async fn create_single_entity(
|
||||
llm_entity: &LLMKnowledgeEntity,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<KnowledgeEntity, AppError> {
|
||||
let assigned_id = {
|
||||
let mapper = mapper
|
||||
.lock()
|
||||
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
|
||||
mapper
|
||||
.get_id(&llm_entity.key)
|
||||
.ok_or_else(|| {
|
||||
AppError::GraphMapper(format!("ID not found for key: {}", llm_entity.key))
|
||||
})?
|
||||
.to_string()
|
||||
};
|
||||
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {}",
|
||||
llm_entity.name, llm_entity.description, llm_entity.entity_type
|
||||
);
|
||||
|
||||
let embedding = generate_embedding(openai_client, &embedding_input).await?;
|
||||
|
||||
let now = Utc::now();
|
||||
Ok(KnowledgeEntity {
|
||||
id: assigned_id,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
name: llm_entity.name.to_string(),
|
||||
description: llm_entity.description.to_string(),
|
||||
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.to_string()),
|
||||
source_id: source_id.to_string(),
|
||||
metadata: None,
|
||||
embedding,
|
||||
user_id: user_id.into(),
|
||||
})
|
||||
}
|
||||
2
crates/common/src/ingress/analysis/types/mod.rs
Normal file
2
crates/common/src/ingress/analysis/types/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod graph_mapper;
|
||||
pub mod llm_analysis_result;
|
||||
124
crates/common/src/ingress/content_processor.rs
Normal file
124
crates/common/src/ingress/content_processor.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
use text_splitter::TextSplitter;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::{store_item, SurrealDbClient},
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk, text_content::TextContent,
|
||||
},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
|
||||
use super::analysis::{
|
||||
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
|
||||
};
|
||||
|
||||
pub struct ContentProcessor {
|
||||
db_client: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
}
|
||||
|
||||
impl ContentProcessor {
|
||||
pub async fn new(
|
||||
surreal_db_client: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
) -> Result<Self, AppError> {
|
||||
Ok(Self {
|
||||
db_client: surreal_db_client,
|
||||
openai_client,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
|
||||
let now = Instant::now();
|
||||
|
||||
// Perform analyis, this step also includes retrieval
|
||||
let analysis = self.perform_semantic_analysis(content).await?;
|
||||
|
||||
let end = now.elapsed();
|
||||
info!(
|
||||
"{:?} time elapsed during creation of entities and relationships",
|
||||
end
|
||||
);
|
||||
|
||||
// Convert analysis to objects
|
||||
let (entities, relationships) = analysis
|
||||
.to_database_entities(&content.id, &content.user_id, &self.openai_client)
|
||||
.await?;
|
||||
|
||||
// Store everything
|
||||
tokio::try_join!(
|
||||
self.store_graph_entities(entities, relationships),
|
||||
self.store_vector_chunks(content),
|
||||
)?;
|
||||
|
||||
// Store original content
|
||||
store_item(&self.db_client, content.to_owned()).await?;
|
||||
|
||||
self.db_client.rebuild_indexes().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn perform_semantic_analysis(
|
||||
&self,
|
||||
content: &TextContent,
|
||||
) -> Result<LLMGraphAnalysisResult, AppError> {
|
||||
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
|
||||
analyser
|
||||
.analyze_content(
|
||||
&content.category,
|
||||
&content.instructions,
|
||||
&content.text,
|
||||
&content.user_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_graph_entities(
|
||||
&self,
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
) -> Result<(), AppError> {
|
||||
for entity in &entities {
|
||||
debug!("Storing entity: {:?}", entity);
|
||||
store_item(&self.db_client, entity.clone()).await?;
|
||||
}
|
||||
|
||||
for relationship in &relationships {
|
||||
debug!("Storing relationship: {:?}", relationship);
|
||||
relationship.store_relationship(&self.db_client).await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Stored {} entities and {} relationships",
|
||||
entities.len(),
|
||||
relationships.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), AppError> {
|
||||
let splitter = TextSplitter::new(500..2000);
|
||||
let chunks = splitter.chunks(&content.text);
|
||||
|
||||
// Could potentially process chunks in parallel with a bounded concurrent limit
|
||||
for chunk in chunks {
|
||||
let embedding = generate_embedding(&self.openai_client, chunk).await?;
|
||||
let text_chunk = TextChunk::new(
|
||||
content.id.to_string(),
|
||||
chunk.to_string(),
|
||||
embedding,
|
||||
content.user_id.to_string(),
|
||||
);
|
||||
store_item(&self.db_client, text_chunk).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
74
crates/common/src/ingress/ingress_input.rs
Normal file
74
crates/common/src/ingress/ingress_input.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use super::ingress_object::IngressObject;
|
||||
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
|
||||
/// Struct defining the expected body when ingressing content.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct IngressInput {
|
||||
pub content: Option<String>,
|
||||
pub instructions: String,
|
||||
pub category: String,
|
||||
pub files: Vec<FileInfo>,
|
||||
}
|
||||
|
||||
/// Function to create ingress objects from input.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input` - IngressInput containing information needed to ingress content.
|
||||
/// * `user_id` - User id of the ingressing user
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Vec<IngressObject>` - An array containing the ingressed objects, one file/contenttype per object.
|
||||
pub fn create_ingress_objects(
|
||||
input: IngressInput,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<IngressObject>, AppError> {
|
||||
// Initialize list
|
||||
let mut object_list = Vec::new();
|
||||
|
||||
// Create a IngressObject from input.content if it exists, checking for URL or text
|
||||
if let Some(input_content) = input.content {
|
||||
match Url::parse(&input_content) {
|
||||
Ok(url) => {
|
||||
info!("Detected URL: {}", url);
|
||||
object_list.push(IngressObject::Url {
|
||||
url: url.to_string(),
|
||||
instructions: input.instructions.clone(),
|
||||
category: input.category.clone(),
|
||||
user_id: user_id.into(),
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
if input_content.len() > 2 {
|
||||
info!("Treating input as plain text");
|
||||
object_list.push(IngressObject::Text {
|
||||
text: input_content.to_string(),
|
||||
instructions: input.instructions.clone(),
|
||||
category: input.category.clone(),
|
||||
user_id: user_id.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for file in input.files {
|
||||
object_list.push(IngressObject::File {
|
||||
file_info: file,
|
||||
instructions: input.instructions.clone(),
|
||||
category: input.category.clone(),
|
||||
user_id: user_id.into(),
|
||||
})
|
||||
}
|
||||
|
||||
// If no objects are constructed, we return Err
|
||||
if object_list.is_empty() {
|
||||
return Err(AppError::NotFound(
|
||||
"No valid content or files provided".into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(object_list)
|
||||
}
|
||||
277
crates/common/src/ingress/ingress_object.rs
Normal file
277
crates/common/src/ingress/ingress_object.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::types::{file_info::FileInfo, text_content::TextContent},
|
||||
};
|
||||
use async_openai::types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequestArgs,
|
||||
};
|
||||
use reqwest;
|
||||
use scraper::{Html, Selector};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
use tiktoken_rs::{o200k_base, CoreBPE};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub enum IngressObject {
|
||||
Url {
|
||||
url: String,
|
||||
instructions: String,
|
||||
category: String,
|
||||
user_id: String,
|
||||
},
|
||||
Text {
|
||||
text: String,
|
||||
instructions: String,
|
||||
category: String,
|
||||
user_id: String,
|
||||
},
|
||||
File {
|
||||
file_info: FileInfo,
|
||||
instructions: String,
|
||||
category: String,
|
||||
user_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl IngressObject {
|
||||
/// Creates a new `TextContent` instance from a `IngressObject`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// `&self` - A reference to the `IngressObject`.
|
||||
///
|
||||
/// # Returns
|
||||
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
|
||||
pub async fn to_text_content(
|
||||
&self,
|
||||
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
) -> Result<TextContent, AppError> {
|
||||
match self {
|
||||
IngressObject::Url {
|
||||
url,
|
||||
instructions,
|
||||
category,
|
||||
user_id,
|
||||
} => {
|
||||
let text = Self::fetch_text_from_url(url, openai_client).await?;
|
||||
Ok(TextContent::new(
|
||||
text,
|
||||
instructions.into(),
|
||||
category.into(),
|
||||
None,
|
||||
Some(url.into()),
|
||||
user_id.into(),
|
||||
))
|
||||
}
|
||||
IngressObject::Text {
|
||||
text,
|
||||
instructions,
|
||||
category,
|
||||
user_id,
|
||||
} => Ok(TextContent::new(
|
||||
text.into(),
|
||||
instructions.into(),
|
||||
category.into(),
|
||||
None,
|
||||
None,
|
||||
user_id.into(),
|
||||
)),
|
||||
IngressObject::File {
|
||||
file_info,
|
||||
instructions,
|
||||
category,
|
||||
user_id,
|
||||
} => {
|
||||
let text = Self::extract_text_from_file(file_info).await?;
|
||||
Ok(TextContent::new(
|
||||
text,
|
||||
instructions.into(),
|
||||
category.into(),
|
||||
Some(file_info.to_owned()),
|
||||
None,
|
||||
user_id.into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get text from url, will return it as a markdown formatted string
|
||||
async fn fetch_text_from_url(
|
||||
url: &str,
|
||||
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
) -> Result<String, AppError> {
|
||||
// Use a client with timeouts and reuse
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()?;
|
||||
let response = client.get(url).send().await?.text().await?;
|
||||
|
||||
// Preallocate string with capacity
|
||||
let mut structured_content = String::with_capacity(response.len() / 2);
|
||||
|
||||
let document = Html::parse_document(&response);
|
||||
let main_selectors = Selector::parse(
|
||||
"article, main, .article-content, .post-content, .entry-content, [role='main']",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let content_element = document
|
||||
.select(&main_selectors)
|
||||
.next()
|
||||
.or_else(|| document.select(&Selector::parse("body").unwrap()).next())
|
||||
.ok_or(AppError::NotFound("No content found".into()))?;
|
||||
|
||||
// Compile selectors once
|
||||
let heading_selector = Selector::parse("h1, h2, h3").unwrap();
|
||||
let paragraph_selector = Selector::parse("p").unwrap();
|
||||
|
||||
// Process content in one pass
|
||||
for element in content_element.select(&heading_selector) {
|
||||
let _ = writeln!(
|
||||
structured_content,
|
||||
"<heading>{}</heading>",
|
||||
element.text().collect::<String>().trim()
|
||||
);
|
||||
}
|
||||
for element in content_element.select(¶graph_selector) {
|
||||
let _ = writeln!(
|
||||
structured_content,
|
||||
"<paragraph>{}</paragraph>",
|
||||
element.text().collect::<String>().trim()
|
||||
);
|
||||
}
|
||||
|
||||
let content = structured_content
|
||||
.replace(|c: char| c.is_control(), " ")
|
||||
.replace(" ", " ");
|
||||
Self::process_web_content(content, openai_client).await
|
||||
}
|
||||
|
||||
pub async fn process_web_content(
|
||||
content: String,
|
||||
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
) -> Result<String, AppError> {
|
||||
const MAX_TOKENS: usize = 122000;
|
||||
const SYSTEM_PROMPT: &str = r#"
|
||||
You are a precise content extractor for web pages. Your task:
|
||||
|
||||
1. Extract ONLY the main article/content from the provided text
|
||||
2. Maintain the original content - do not summarize or modify the core information
|
||||
3. Ignore peripheral content such as:
|
||||
- Navigation elements
|
||||
- Error messages (e.g., "JavaScript required")
|
||||
- Related articles sections
|
||||
- Comments
|
||||
- Social media links
|
||||
- Advertisement text
|
||||
|
||||
FORMAT:
|
||||
- Convert <heading> tags to markdown headings (#, ##, ###)
|
||||
- Convert <paragraph> tags to markdown paragraphs
|
||||
- Preserve quotes and important formatting
|
||||
- Remove duplicate content
|
||||
- Remove any metadata or technical artifacts
|
||||
|
||||
OUTPUT RULES:
|
||||
- Output ONLY the cleaned content in markdown
|
||||
- Do not add any explanations or meta-commentary
|
||||
- Do not add summaries or conclusions
|
||||
- Do not use any XML/HTML tags in the output
|
||||
"#;
|
||||
|
||||
let bpe = o200k_base()?;
|
||||
|
||||
// Process content in chunks if needed
|
||||
let truncated_content = if bpe.encode_with_special_tokens(&content).len() > MAX_TOKENS {
|
||||
Self::truncate_content(&content, MAX_TOKENS, &bpe)?
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
let request = CreateChatCompletionRequestArgs::default()
|
||||
.model("gpt-4o-mini")
|
||||
.temperature(0.0)
|
||||
.max_tokens(16200u32)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(SYSTEM_PROMPT).into(),
|
||||
ChatCompletionRequestUserMessage::from(truncated_content).into(),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
let response = openai_client.chat().create(request).await?;
|
||||
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.map(|content| content.to_owned())
|
||||
.ok_or(AppError::LLMParsing("No content in response".into()))
|
||||
}
|
||||
|
||||
fn truncate_content(
|
||||
content: &str,
|
||||
max_tokens: usize,
|
||||
tokenizer: &CoreBPE,
|
||||
) -> Result<String, AppError> {
|
||||
// Pre-allocate with estimated size
|
||||
let mut result = String::with_capacity(content.len() / 2);
|
||||
let mut current_tokens = 0;
|
||||
|
||||
// Process content by paragraph to maintain context
|
||||
for paragraph in content.split("\n\n") {
|
||||
let tokens = tokenizer.encode_with_special_tokens(paragraph).len();
|
||||
|
||||
// Check if adding paragraph exceeds limit
|
||||
if current_tokens + tokens > max_tokens {
|
||||
break;
|
||||
}
|
||||
|
||||
result.push_str(paragraph);
|
||||
result.push_str("\n\n");
|
||||
current_tokens += tokens;
|
||||
}
|
||||
|
||||
// Ensure we return valid content
|
||||
if result.is_empty() {
|
||||
return Err(AppError::Processing("Content exceeds token limit".into()));
|
||||
}
|
||||
|
||||
Ok(result.trim_end().to_string())
|
||||
}
|
||||
|
||||
/// Extracts text from a file based on its MIME type.
|
||||
async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, AppError> {
|
||||
match file_info.mime_type.as_str() {
|
||||
"text/plain" => {
|
||||
// Read the file and return its content
|
||||
let content = tokio::fs::read_to_string(&file_info.path).await?;
|
||||
Ok(content)
|
||||
}
|
||||
"text/markdown" => {
|
||||
// Read the file and return its content
|
||||
let content = tokio::fs::read_to_string(&file_info.path).await?;
|
||||
Ok(content)
|
||||
}
|
||||
"application/pdf" => {
|
||||
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
|
||||
Err(AppError::NotFound(file_info.mime_type.clone()))
|
||||
}
|
||||
"image/png" | "image/jpeg" => {
|
||||
// TODO: Implement OCR on image using a crate like `tesseract`
|
||||
Err(AppError::NotFound(file_info.mime_type.clone()))
|
||||
}
|
||||
"application/octet-stream" => {
|
||||
let content = tokio::fs::read_to_string(&file_info.path).await?;
|
||||
Ok(content)
|
||||
}
|
||||
"text/x-rust" => {
|
||||
let content = tokio::fs::read_to_string(&file_info.path).await?;
|
||||
Ok(content)
|
||||
}
|
||||
// Handle other MIME types as needed
|
||||
_ => Err(AppError::NotFound(file_info.mime_type.clone())),
|
||||
}
|
||||
}
|
||||
}
|
||||
182
crates/common/src/ingress/jobqueue.rs
Normal file
182
crates/common/src/ingress/jobqueue.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::sync::Arc;
|
||||
use surrealdb::{opt::PatchOp, Error, Notification};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::{delete_item, get_item, store_item, SurrealDbClient},
|
||||
types::{
|
||||
job::{Job, JobStatus},
|
||||
StoredObject,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use super::{content_processor::ContentProcessor, ingress_object::IngressObject};
|
||||
|
||||
pub struct JobQueue {
|
||||
pub db: Arc<SurrealDbClient>,
|
||||
}
|
||||
|
||||
pub const MAX_ATTEMPTS: u32 = 3;
|
||||
|
||||
impl JobQueue {
|
||||
pub fn new(db: Arc<SurrealDbClient>) -> Self {
|
||||
Self { db }
|
||||
}
|
||||
|
||||
/// Creates a new job and stores it in the database
|
||||
pub async fn enqueue(&self, content: IngressObject, user_id: String) -> Result<(), AppError> {
|
||||
let job = Job::new(content, user_id).await;
|
||||
|
||||
info!("{:?}", job);
|
||||
|
||||
store_item(&self.db, job).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Gets all jobs for a specific user
|
||||
pub async fn get_user_jobs(&self, user_id: &str) -> Result<Vec<Job>, AppError> {
|
||||
let jobs: Vec<Job> = self
|
||||
.db
|
||||
.query("SELECT * FROM job WHERE user_id = $user_id ORDER BY created_at DESC")
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
debug!("{:?}", jobs);
|
||||
|
||||
Ok(jobs)
|
||||
}
|
||||
|
||||
/// Gets all active jobs for a specific user
|
||||
pub async fn get_unfinished_user_jobs(&self, user_id: &str) -> Result<Vec<Job>, AppError> {
|
||||
let jobs: Vec<Job> = self
|
||||
.db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE user_id = $user_id
|
||||
AND (
|
||||
status = 'Created'
|
||||
OR (
|
||||
status.InProgress != NONE
|
||||
AND status.InProgress.attempts < $max_attempts
|
||||
)
|
||||
)
|
||||
ORDER BY created_at DESC",
|
||||
)
|
||||
.bind(("table", Job::table_name()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("max_attempts", MAX_ATTEMPTS))
|
||||
.await?
|
||||
.take(0)?;
|
||||
debug!("{:?}", jobs);
|
||||
Ok(jobs)
|
||||
}
|
||||
|
||||
pub async fn delete_job(&self, id: &str, user_id: &str) -> Result<(), AppError> {
|
||||
get_item::<Job>(&self.db.client, id)
|
||||
.await?
|
||||
.filter(|job| job.user_id == user_id)
|
||||
.ok_or_else(|| {
|
||||
error!("Unauthorized attempt to delete job {id} by user {user_id}");
|
||||
AppError::Auth("Not authorized to delete this job".into())
|
||||
})?;
|
||||
|
||||
info!("Deleting job {id} for user {user_id}");
|
||||
delete_item::<Job>(&self.db.client, id)
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_status(&self, id: &str, status: JobStatus) -> Result<(), AppError> {
|
||||
let _job: Option<Job> = self
|
||||
.db
|
||||
.update((Job::table_name(), id))
|
||||
.patch(PatchOp::replace("/status", status))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::sql::Datetime::default(),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Listen for new jobs
|
||||
pub async fn listen_for_jobs(
|
||||
&self,
|
||||
) -> Result<impl Stream<Item = Result<Notification<Job>, Error>>, Error> {
|
||||
self.db.select("job").live().await
|
||||
}
|
||||
|
||||
/// Get unfinished jobs, ie newly created and in progress up two times
|
||||
pub async fn get_unfinished_jobs(&self) -> Result<Vec<Job>, AppError> {
|
||||
let jobs: Vec<Job> = self
|
||||
.db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE
|
||||
status = 'Created'
|
||||
OR (
|
||||
status.InProgress != NONE
|
||||
AND status.InProgress.attempts < $max_attempts
|
||||
)
|
||||
ORDER BY created_at ASC",
|
||||
)
|
||||
.bind(("table", Job::table_name()))
|
||||
.bind(("max_attempts", MAX_ATTEMPTS))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(jobs)
|
||||
}
|
||||
|
||||
// Method to process a single job
|
||||
pub async fn process_job(
|
||||
&self,
|
||||
job: Job,
|
||||
processor: &ContentProcessor,
|
||||
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
) -> Result<(), AppError> {
|
||||
let current_attempts = match job.status {
|
||||
JobStatus::InProgress { attempts, .. } => attempts + 1,
|
||||
_ => 1,
|
||||
};
|
||||
|
||||
// Update status to InProgress with attempt count
|
||||
self.update_status(
|
||||
&job.id,
|
||||
JobStatus::InProgress {
|
||||
attempts: current_attempts,
|
||||
last_attempt: Utc::now(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let text_content = job.content.to_text_content(&openai_client).await?;
|
||||
|
||||
match processor.process(&text_content).await {
|
||||
Ok(_) => {
|
||||
self.update_status(&job.id, JobStatus::Completed).await?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
if current_attempts >= MAX_ATTEMPTS {
|
||||
self.update_status(
|
||||
&job.id,
|
||||
JobStatus::Error(format!("Max attempts reached: {}", e)),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Err(AppError::Processing(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
6
crates/common/src/ingress/mod.rs
Normal file
6
crates/common/src/ingress/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod analysis;
|
||||
pub mod content_processor;
|
||||
pub mod ingress_input;
|
||||
pub mod ingress_object;
|
||||
pub mod jobqueue;
|
||||
pub mod queue_task;
|
||||
13
crates/common/src/ingress/queue_task.rs
Normal file
13
crates/common/src/ingress/queue_task.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use crate::ingress::ingress_object::IngressObject;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct QueueTask {
|
||||
pub delivery_tag: u64,
|
||||
pub content: IngressObject,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct QueueTaskResponse {
|
||||
pub tasks: Vec<QueueTask>,
|
||||
}
|
||||
6
crates/common/src/lib.rs
Normal file
6
crates/common/src/lib.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod error;
|
||||
pub mod ingress;
|
||||
pub mod retrieval;
|
||||
pub mod server;
|
||||
pub mod storage;
|
||||
pub mod utils;
|
||||
74
crates/common/src/retrieval/graph.rs
Normal file
74
crates/common/src/retrieval/graph.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use surrealdb::{engine::any::Any, Error, Surreal};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, StoredObject};
|
||||
|
||||
/// Retrieves database entries that match a specific source identifier.
|
||||
///
|
||||
/// This function queries the database for all records in a specified table that have
|
||||
/// a matching `source_id` field. It's commonly used to find related entities or
|
||||
/// track the origin of database entries.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - The identifier to search for in the database
|
||||
/// * `table_name` - The name of the table to search in
|
||||
/// * `db_client` - The SurrealDB client instance for database operations
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
|
||||
/// * `Err(Error)` - An error if the database query fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function will return a `Error` if:
|
||||
/// * The database query fails to execute
|
||||
/// * The results cannot be deserialized into type `T`
|
||||
pub async fn find_entities_by_source_ids<T>(
|
||||
source_id: Vec<String>,
|
||||
table_name: String,
|
||||
db_client: &Surreal<Any>,
|
||||
) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
|
||||
|
||||
db_client
|
||||
.query(query)
|
||||
.bind(("table", table_name))
|
||||
.bind(("source_ids", source_id))
|
||||
.await?
|
||||
.take(0)
|
||||
}
|
||||
|
||||
/// Find entities by their relationship to the id
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db_client: &Surreal<Any>,
|
||||
entity_id: String,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let query = format!(
|
||||
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
|
||||
entity_id
|
||||
);
|
||||
|
||||
debug!("{}", query);
|
||||
|
||||
db_client.query(query).await?.take(0)
|
||||
}
|
||||
|
||||
/// Get a specific KnowledgeEntity by its id
|
||||
pub async fn get_entity_by_id(
|
||||
db_client: &Surreal<Any>,
|
||||
entity_id: &str,
|
||||
) -> Result<Option<KnowledgeEntity>, Error> {
|
||||
db_client
|
||||
.select((KnowledgeEntity::table_name(), entity_id))
|
||||
.await
|
||||
}
|
||||
92
crates/common/src/retrieval/mod.rs
Normal file
92
crates/common/src/retrieval/mod.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
pub mod graph;
|
||||
pub mod query_helper;
|
||||
pub mod query_helper_prompt;
|
||||
pub mod vector;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
retrieval::{
|
||||
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
||||
vector::find_items_by_vector_similarity,
|
||||
},
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
};
|
||||
use futures::future::{try_join, try_join_all};
|
||||
use std::collections::HashMap;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
|
||||
/// to find the most relevant entities for a given query.
|
||||
///
|
||||
/// # Strategy
|
||||
/// The function employs a three-pronged approach to knowledge retrieval:
|
||||
/// 1. Direct vector similarity search on knowledge entities
|
||||
/// 2. Text chunk similarity search with source entity lookup
|
||||
/// 3. Graph relationship traversal from related entities
|
||||
///
|
||||
/// This combined approach ensures both semantic similarity matches and structurally
|
||||
/// related content are included in the results.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - SurrealDB client for database operations
|
||||
/// * `openai_client` - OpenAI client for vector embeddings generation
|
||||
/// * `query` - The search query string to find relevant knowledge entities
|
||||
/// * 'user_id' - The user id of the current user
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
|
||||
/// knowledge entities, or an error if the retrieval process fails
|
||||
pub async fn combined_knowledge_entity_retrieval(
|
||||
db_client: &Surreal<Any>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
// info!("Received input: {:?}", query);
|
||||
|
||||
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
||||
find_items_by_vector_similarity(
|
||||
10,
|
||||
query,
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
openai_client,
|
||||
user_id,
|
||||
),
|
||||
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let source_ids = closest_chunks
|
||||
.iter()
|
||||
.map(|chunk: &TextChunk| chunk.source_id.clone())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
|
||||
|
||||
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
|
||||
.collect();
|
||||
|
||||
let items_from_relationships = try_join_all(items_from_relationships_futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<KnowledgeEntity>>();
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
|
||||
.into_iter()
|
||||
.chain(items_from_text_chunk_similarity.into_iter())
|
||||
.chain(items_from_relationships.into_iter())
|
||||
.fold(HashMap::new(), |mut map, entity| {
|
||||
map.insert(entity.id.clone(), entity);
|
||||
map
|
||||
})
|
||||
.into_values()
|
||||
.collect();
|
||||
|
||||
Ok(entities)
|
||||
}
|
||||
193
crates/common/src/retrieval/query_helper.rs
Normal file
193
crates/common/src/retrieval/query_helper.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
||||
ResponseFormat, ResponseFormatJsonSchema,
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
retrieval::combined_knowledge_entity_retrieval,
|
||||
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
|
||||
};
|
||||
|
||||
use super::query_helper_prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Reference {
|
||||
#[allow(dead_code)]
|
||||
pub reference: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LLMResponseFormat {
|
||||
pub answer: String,
|
||||
#[allow(dead_code)]
|
||||
pub references: Vec<Reference>,
|
||||
}
|
||||
|
||||
// /// Orchestrator function that takes a query and clients and returns a answer with references
|
||||
// ///
|
||||
// /// # Arguments
|
||||
// /// * `surreal_db_client` - Client for interacting with SurrealDn
|
||||
// /// * `openai_client` - Client for interacting with openai
|
||||
// /// * `query` - The query
|
||||
// ///
|
||||
// /// # Returns
|
||||
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
|
||||
// pub async fn get_answer_with_references(
|
||||
// surreal_db_client: &SurrealDbClient,
|
||||
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
// query: &str,
|
||||
// ) -> Result<(String, Vec<String>), ApiError> {
|
||||
// let entities =
|
||||
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
|
||||
|
||||
// // Format entities and create message
|
||||
// let entities_json = format_entities_json(&entities);
|
||||
// let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
// // Create and send request
|
||||
// let request = create_chat_request(user_message)?;
|
||||
// let response = openai_client
|
||||
// .chat()
|
||||
// .create(request)
|
||||
// .await
|
||||
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
|
||||
|
||||
// // Process response
|
||||
// let answer = process_llm_response(response).await?;
|
||||
|
||||
// let references: Vec<String> = answer
|
||||
// .references
|
||||
// .into_iter()
|
||||
// .map(|reference| reference.reference)
|
||||
// .collect();
|
||||
|
||||
// Ok((answer.answer, references))
|
||||
// }
|
||||
|
||||
/// Orchestrates query processing and returns an answer with references
|
||||
///
|
||||
/// Takes a query and uses the provided clients to generate an answer with supporting references.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `surreal_db_client` - Client for SurrealDB interactions
|
||||
/// * `openai_client` - Client for OpenAI API calls
|
||||
/// * `query` - The user's query string
|
||||
/// * `user_id` - The user's id
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a tuple of the answer and its references, or an API error
|
||||
#[derive(Debug)]
|
||||
pub struct Answer {
|
||||
pub content: String,
|
||||
pub references: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn get_answer_with_references(
|
||||
surreal_db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Answer, AppError> {
|
||||
let entities =
|
||||
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
|
||||
.await?;
|
||||
|
||||
let entities_json = format_entities_json(&entities);
|
||||
debug!("{:?}", entities_json);
|
||||
let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
let request = create_chat_request(user_message)?;
|
||||
let response = openai_client.chat().create(request).await?;
|
||||
|
||||
let llm_response = process_llm_response(response).await?;
|
||||
|
||||
Ok(Answer {
|
||||
content: llm_response.answer,
|
||||
references: llm_response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
|
||||
json!(entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entity.id,
|
||||
"name": entity.name,
|
||||
"description": entity.description
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r#"
|
||||
Context Information:
|
||||
==================
|
||||
{}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{}
|
||||
"#,
|
||||
entities_json, query
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_chat_request(
|
||||
user_message: String,
|
||||
) -> Result<CreateChatCompletionRequest, OpenAIError> {
|
||||
let response_format = ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: Some("Query answering AI".into()),
|
||||
name: "query_answering_with_uuids".into(),
|
||||
schema: Some(get_query_response_schema()),
|
||||
strict: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.model("gpt-4o-mini")
|
||||
.temperature(0.2)
|
||||
.max_tokens(3048u32)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(QUERY_SYSTEM_PROMPT).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
])
|
||||
.response_format(response_format)
|
||||
.build()
|
||||
}
|
||||
|
||||
pub async fn process_llm_response(
|
||||
response: CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, AppError> {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".into(),
|
||||
))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {}", e))
|
||||
})
|
||||
})
|
||||
}
|
||||
48
crates/common/src/retrieval/query_helper_prompt.rs
Normal file
48
crates/common/src/retrieval/query_helper_prompt.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub static QUERY_SYSTEM_PROMPT: &str = r#"
|
||||
You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
|
||||
|
||||
Your task is to:
|
||||
1. Carefully analyze the provided knowledge entities in the context
|
||||
2. Answer user questions based on this information
|
||||
3. Provide clear, concise, and accurate responses
|
||||
4. When referencing information, briefly mention which knowledge entity it came from
|
||||
5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this
|
||||
6. If only partial information is available, explain what you can answer and what information is missing
|
||||
7. Avoid making assumptions or providing information not supported by the context
|
||||
8. Output the references to the documents. Use the UUIDs and make sure they are correct!
|
||||
|
||||
Remember:
|
||||
- Be direct and honest about the limitations of your knowledge
|
||||
- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array
|
||||
- If you need to combine information from multiple entities, explain how they connect
|
||||
- Don't speculate beyond what's provided in the context
|
||||
|
||||
Example response formats:
|
||||
"Based on [Entity Name], [answer...]"
|
||||
"I found relevant information in multiple entries: [explanation...]"
|
||||
"I apologize, but the provided context doesn't contain information about [topic]"
|
||||
"#;
|
||||
|
||||
pub fn get_query_response_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": { "type": "string" },
|
||||
"references": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reference": { "type": "string" },
|
||||
},
|
||||
"required": ["reference"],
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["answer", "references"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
47
crates/common/src/retrieval/vector.rs
Normal file
47
crates/common/src/retrieval/vector.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use crate::{error::AppError, utils::embedding::generate_embedding};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
|
||||
/// and then deserializes the results into the specified type `T`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `take` - The number of items to retrieve from the database.
|
||||
/// * `input_text` - The text to generate embeddings for.
|
||||
/// * `db_client` - The SurrealDB client to use for querying the database.
|
||||
/// * `table` - The table to query in the database.
|
||||
/// * `openai_client` - The OpenAI client to use for generating embeddings.
|
||||
/// * 'user_id`- The user id of the current user.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: u8,
|
||||
input_text: &str,
|
||||
db_client: &Surreal<Any>,
|
||||
table: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<T>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
Ok(closest_entities)
|
||||
}
|
||||
180
crates/common/src/storage/db.rs
Normal file
180
crates/common/src/storage/db.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use crate::error::AppError;
|
||||
|
||||
use super::types::{analytics::Analytics, system_settings::SystemSettings, StoredObject};
|
||||
use axum_session::{SessionConfig, SessionError, SessionStore};
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use std::ops::Deref;
|
||||
use surrealdb::{
|
||||
engine::any::{connect, Any},
|
||||
opt::auth::Root,
|
||||
Error, Surreal,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SurrealDbClient {
|
||||
pub client: Surreal<Any>,
|
||||
}
|
||||
|
||||
impl SurrealDbClient {
|
||||
/// # Initialize a new datbase client
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// # Returns
|
||||
/// * `SurrealDbClient` initialized
|
||||
pub async fn new(
|
||||
address: &str,
|
||||
username: &str,
|
||||
password: &str,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
) -> Result<Self, Error> {
|
||||
let db = connect(address).await?;
|
||||
|
||||
// Sign in to database
|
||||
db.signin(Root { username, password }).await?;
|
||||
|
||||
// Set namespace
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
pub async fn create_session_store(
|
||||
&self,
|
||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||
SessionStore::new(
|
||||
Some(self.client.clone().into()),
|
||||
SessionConfig::default()
|
||||
.with_table_name("test_session_table")
|
||||
.with_secure(true),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn ensure_initialized(&self) -> Result<(), AppError> {
|
||||
Self::build_indexes(&self).await?;
|
||||
Self::setup_auth(&self).await?;
|
||||
|
||||
Analytics::ensure_initialized(self).await?;
|
||||
SystemSettings::ensure_initialized(self).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn setup_auth(&self) -> Result<(), Error> {
|
||||
self.client.query(
|
||||
"DEFINE TABLE user SCHEMALESS;
|
||||
DEFINE INDEX unique_name ON TABLE user FIELDS email UNIQUE;
|
||||
DEFINE ACCESS account ON DATABASE TYPE RECORD
|
||||
SIGNUP ( CREATE user SET email = $email, password = crypto::argon2::generate($password), anonymous = false, user_id = $user_id)
|
||||
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(password, $password) );",
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn build_indexes(&self) -> Result<(), Error> {
|
||||
self.client.query("DEFINE INDEX idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536").await?;
|
||||
self.client.query("DEFINE INDEX idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536").await?;
|
||||
|
||||
self.client
|
||||
.query("DEFINE INDEX idx_job_status ON job FIELDS status")
|
||||
.await?;
|
||||
self.client
|
||||
.query("DEFINE INDEX idx_job_user ON job FIELDS user_id")
|
||||
.await?;
|
||||
self.client
|
||||
.query("DEFINE INDEX idx_job_created ON job FIELDS created_at")
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
|
||||
.await?;
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embeddings_entities ON knowledge_entity")
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn drop_table<T>(&self) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
{
|
||||
self.client.delete(T::table_name()).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SurrealDbClient {
|
||||
type Target = Surreal<Any>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.client
|
||||
}
|
||||
}
|
||||
|
||||
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - A initialized database client
|
||||
/// * `item` - The item to be stored
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result` - Item or Error
|
||||
pub async fn store_item<T>(db_client: &Surreal<Any>, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
{
|
||||
db_client
|
||||
.create((T::table_name(), item.get_id()))
|
||||
.content(item)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - A initialized database client
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result` - Vec<T> or Error
|
||||
pub async fn get_all_stored_items<T>(db_client: &Surreal<Any>) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
{
|
||||
db_client.select(T::table_name()).await
|
||||
}
|
||||
|
||||
/// Operation to retrieve a single object by its ID, requires the struct to implement StoredObject
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - An initialized database client
|
||||
/// * `id` - The ID of the item to retrieve
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Option<T>, Error>` - The found item or Error
|
||||
pub async fn get_item<T>(db_client: &Surreal<Any>, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
{
|
||||
db_client.select((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Operation to delete a single object by its ID, requires the struct to implement StoredObject
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - An initialized database client
|
||||
/// * `id` - The ID of the item to delete
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Option<T>, Error>` - The deleted item or Error
|
||||
pub async fn delete_item<T>(db_client: &Surreal<Any>, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
{
|
||||
db_client.delete((T::table_name(), id)).await
|
||||
}
|
||||
2
crates/common/src/storage/mod.rs
Normal file
2
crates/common/src/storage/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod db;
|
||||
pub mod types;
|
||||
78
crates/common/src/storage/types/analytics.rs
Normal file
78
crates/common/src/storage/types/analytics.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Analytics {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
pub page_loads: i64,
|
||||
pub visitors: i64,
|
||||
}
|
||||
|
||||
impl Analytics {
|
||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let analytics = db.select(("analytics", "current")).await?;
|
||||
|
||||
if analytics.is_none() {
|
||||
let created: Option<Analytics> = db
|
||||
.create(("analytics", "current"))
|
||||
.content(Analytics {
|
||||
id: "current".to_string(),
|
||||
visitors: 0,
|
||||
page_loads: 0,
|
||||
})
|
||||
.await?;
|
||||
|
||||
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
};
|
||||
|
||||
analytics.ok_or(AppError::Validation("Failed to initialize settings".into()))
|
||||
}
|
||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let analytics: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('analytics', 'current')")
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
analytics.ok_or(AppError::NotFound("Analytics not found".into()))
|
||||
}
|
||||
|
||||
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let updated: Option<Self> = db
|
||||
.client
|
||||
.query("UPDATE type::thing('analytics', 'current') SET visitors += 1 RETURN AFTER")
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
|
||||
}
|
||||
|
||||
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let updated: Option<Self> = db
|
||||
.client
|
||||
.query("UPDATE type::thing('analytics', 'current') SET page_loads += 1 RETURN AFTER")
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
|
||||
}
|
||||
|
||||
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CountResult {
|
||||
count: i64,
|
||||
}
|
||||
|
||||
let result: Option<CountResult> = db
|
||||
.client
|
||||
.query("SELECT count() as count FROM type::table($table) GROUP ALL")
|
||||
.bind(("table", User::table_name()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(result.map(|r| r.count).unwrap_or(0))
|
||||
}
|
||||
}
|
||||
52
crates/common/src/storage/types/conversation.rs
Normal file
52
crates/common/src/storage/types/conversation.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::db::{get_item, SurrealDbClient},
|
||||
stored_object,
|
||||
};
|
||||
|
||||
use super::message::Message;
|
||||
|
||||
stored_object!(Conversation, "conversation", {
|
||||
user_id: String,
|
||||
title: String
|
||||
});
|
||||
|
||||
impl Conversation {
|
||||
pub fn new(user_id: String, title: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id,
|
||||
title,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_complete_conversation(
|
||||
conversation_id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(Self, Vec<Message>), AppError> {
|
||||
let conversation: Conversation = get_item(&db, conversation_id)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
|
||||
|
||||
if conversation.user_id != user_id {
|
||||
return Err(AppError::Auth(
|
||||
"You don't have access to this conversation".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let messages:Vec<Message> = db.client.
|
||||
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
|
||||
bind(("table_name", Message::table_name())).
|
||||
bind(("conversation_id", conversation_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok((conversation, messages))
|
||||
}
|
||||
}
|
||||
248
crates/common/src/storage/types/file_info.rs
Normal file
248
crates/common/src/storage/types/file_info.rs
Normal file
@@ -0,0 +1,248 @@
|
||||
use axum_typed_multipart::FieldData;
|
||||
use mime_guess::from_path;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
io::{BufReader, Read},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
use thiserror::Error;
|
||||
use tokio::fs::remove_dir_all;
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
storage::db::{delete_item, get_item, store_item, SurrealDbClient},
|
||||
stored_object,
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum FileError {
|
||||
#[error("File not found for UUID: {0}")]
|
||||
FileNotFound(String),
|
||||
|
||||
#[error("IO error occurred: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Duplicate file detected with SHA256: {0}")]
|
||||
DuplicateFile(String),
|
||||
|
||||
#[error("SurrealDB error: {0}")]
|
||||
SurrealError(#[from] surrealdb::Error),
|
||||
|
||||
#[error("Failed to persist file: {0}")]
|
||||
PersistError(#[from] tempfile::PersistError),
|
||||
|
||||
#[error("File name missing in metadata")]
|
||||
MissingFileName,
|
||||
}
|
||||
|
||||
stored_object!(FileInfo, "file", {
|
||||
sha256: String,
|
||||
path: String,
|
||||
file_name: String,
|
||||
mime_type: String
|
||||
});
|
||||
|
||||
impl FileInfo {
|
||||
pub async fn new(
|
||||
field_data: FieldData<NamedTempFile>,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Self, FileError> {
|
||||
let file = field_data.contents;
|
||||
let file_name = field_data
|
||||
.metadata
|
||||
.file_name
|
||||
.ok_or(FileError::MissingFileName)?;
|
||||
|
||||
// Calculate SHA256
|
||||
let sha256 = Self::get_sha(&file).await?;
|
||||
|
||||
// Early return if file already exists
|
||||
match Self::get_by_sha(&sha256, db_client).await {
|
||||
Ok(existing_file) => {
|
||||
info!("File already exists with SHA256: {}", sha256);
|
||||
return Ok(existing_file);
|
||||
}
|
||||
Err(FileError::FileNotFound(_)) => (), // Expected case for new files
|
||||
Err(e) => return Err(e), // Propagate unexpected errors
|
||||
}
|
||||
|
||||
// Generate UUID and prepare paths
|
||||
let uuid = Uuid::new_v4();
|
||||
let sanitized_file_name = Self::sanitize_file_name(&file_name);
|
||||
|
||||
let now = Utc::now();
|
||||
// Create new FileInfo instance
|
||||
let file_info = Self {
|
||||
id: uuid.to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
file_name,
|
||||
sha256,
|
||||
path: Self::persist_file(&uuid, file, &sanitized_file_name, user_id)
|
||||
.await?
|
||||
.to_string_lossy()
|
||||
.into(),
|
||||
mime_type: Self::guess_mime_type(Path::new(&sanitized_file_name)),
|
||||
};
|
||||
|
||||
// Store in database
|
||||
store_item(&db_client.client, file_info.clone()).await?;
|
||||
|
||||
Ok(file_info)
|
||||
}
|
||||
|
||||
/// Guesses the MIME type based on the file extension.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - The path to the file.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `String` - The guessed MIME type as a string.
|
||||
fn guess_mime_type(path: &Path) -> String {
|
||||
from_path(path)
|
||||
.first_or(mime::APPLICATION_OCTET_STREAM)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Calculates the SHA256 hash of the given file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `file` - The file to hash.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
|
||||
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
|
||||
let mut reader = BufReader::new(file.as_file());
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buffer = [0u8; 8192]; // 8KB buffer
|
||||
|
||||
loop {
|
||||
let n = reader.read(&mut buffer)?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buffer[..n]);
|
||||
}
|
||||
|
||||
let digest = hasher.finalize();
|
||||
Ok(format!("{:x}", digest))
|
||||
}
|
||||
|
||||
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
|
||||
fn sanitize_file_name(file_name: &str) -> String {
|
||||
file_name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Persists the file to the filesystem under `./data/{user_id}/{uuid}/{file_name}`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `uuid` - The UUID of the file.
|
||||
/// * `file` - The temporary file to persist.
|
||||
/// * `file_name` - The sanitized file name.
|
||||
/// * `user-id` - User id
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<PathBuf, FileError>` - The persisted file path or an error.
|
||||
async fn persist_file(
|
||||
uuid: &Uuid,
|
||||
file: NamedTempFile,
|
||||
file_name: &str,
|
||||
user_id: &str,
|
||||
) -> Result<PathBuf, FileError> {
|
||||
let base_dir = Path::new("./data");
|
||||
let user_dir = base_dir.join(user_id); // Create the user directory
|
||||
let uuid_dir = user_dir.join(uuid.to_string()); // Create the UUID directory under the user directory
|
||||
|
||||
// Create the user and UUID directories if they don't exist
|
||||
tokio::fs::create_dir_all(&uuid_dir)
|
||||
.await
|
||||
.map_err(FileError::Io)?;
|
||||
|
||||
// Define the final file path
|
||||
let final_path = uuid_dir.join(file_name);
|
||||
info!("Final path: {:?}", final_path);
|
||||
|
||||
// Persist the temporary file to the final path
|
||||
file.persist(&final_path)?;
|
||||
info!("Persisted file to {:?}", final_path);
|
||||
|
||||
Ok(final_path)
|
||||
}
|
||||
|
||||
/// Retrieves a `FileInfo` by SHA256.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `sha256` - The SHA256 hash string.
|
||||
/// * `db_client` - Reference to the SurrealDbClient.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
|
||||
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
||||
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256);
|
||||
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
|
||||
|
||||
response
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or(FileError::FileNotFound(sha256.to_string()))
|
||||
}
|
||||
|
||||
/// Removes FileInfo from database and file from disk
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `id` - Id of the FileInfo
|
||||
/// * `db_client` - Reference to SurrealDbClient
|
||||
///
|
||||
/// # Returns
|
||||
/// `Result<(), FileError>`
|
||||
pub async fn delete_by_id(id: &str, db_client: &SurrealDbClient) -> Result<(), FileError> {
|
||||
// Get the FileInfo from the database
|
||||
let file_info = match get_item::<FileInfo>(db_client, id).await? {
|
||||
Some(info) => info,
|
||||
None => {
|
||||
return Err(FileError::FileNotFound(format!(
|
||||
"File with id {} was not found",
|
||||
id
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Remove the file and its parent directory
|
||||
let file_path = Path::new(&file_info.path);
|
||||
if file_path.exists() {
|
||||
// Get the parent directory of the file
|
||||
if let Some(parent_dir) = file_path.parent() {
|
||||
// Remove the entire directory containing the file
|
||||
remove_dir_all(parent_dir).await?;
|
||||
info!("Removed directory {:?} and its contents", parent_dir);
|
||||
} else {
|
||||
return Err(FileError::FileNotFound(
|
||||
"File has no parent directory".to_string(),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return Err(FileError::FileNotFound(format!(
|
||||
"File at path {:?} was not found",
|
||||
file_path
|
||||
)));
|
||||
}
|
||||
|
||||
// Delete the FileInfo from the database
|
||||
delete_item::<FileInfo>(db_client, id).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
36
crates/common/src/storage/types/job.rs
Normal file
36
crates/common/src/storage/types/job.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{ingress::ingress_object::IngressObject, stored_object};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum JobStatus {
|
||||
Created,
|
||||
InProgress {
|
||||
attempts: u32,
|
||||
last_attempt: DateTime<Utc>,
|
||||
},
|
||||
Completed,
|
||||
Error(String),
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
stored_object!(Job, "job", {
|
||||
content: IngressObject,
|
||||
status: JobStatus,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl Job {
|
||||
pub async fn new(content: IngressObject, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
content,
|
||||
status: JobStatus::Created,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
121
crates/common/src/storage/types/knowledge_entity.rs
Normal file
121
crates/common/src/storage/types/knowledge_entity.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use crate::{
|
||||
error::AppError, storage::db::SurrealDbClient, stored_object,
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub enum KnowledgeEntityType {
|
||||
Idea,
|
||||
Project,
|
||||
Document,
|
||||
Page,
|
||||
TextSnippet,
|
||||
// Add more types as needed
|
||||
}
|
||||
impl KnowledgeEntityType {
|
||||
pub fn variants() -> &'static [&'static str] {
|
||||
&["Idea", "Project", "Document", "Page", "TextSnippet"]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for KnowledgeEntityType {
|
||||
fn from(s: String) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"idea" => KnowledgeEntityType::Idea,
|
||||
"project" => KnowledgeEntityType::Project,
|
||||
"document" => KnowledgeEntityType::Document,
|
||||
"page" => KnowledgeEntityType::Page,
|
||||
"textsnippet" => KnowledgeEntityType::TextSnippet,
|
||||
_ => KnowledgeEntityType::Document, // Default case
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stored_object!(KnowledgeEntity, "knowledge_entity", {
|
||||
source_id: String,
|
||||
name: String,
|
||||
description: String,
|
||||
entity_type: KnowledgeEntityType,
|
||||
metadata: Option<serde_json::Value>,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl KnowledgeEntity {
|
||||
pub fn new(
|
||||
source_id: String,
|
||||
name: String,
|
||||
description: String,
|
||||
entity_type: KnowledgeEntityType,
|
||||
metadata: Option<serde_json::Value>,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
source_id,
|
||||
name,
|
||||
description,
|
||||
entity_type,
|
||||
metadata,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE {} WHERE source_id = '{}'",
|
||||
Self::table_name(),
|
||||
source_id
|
||||
);
|
||||
db_client.query(query).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
id: &str,
|
||||
name: &str,
|
||||
description: &str,
|
||||
entity_type: &KnowledgeEntityType,
|
||||
db_client: &SurrealDbClient,
|
||||
ai_client: &Client<OpenAIConfig>,
|
||||
) -> Result<(), AppError> {
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
name, description, entity_type
|
||||
);
|
||||
let embedding = generate_embedding(ai_client, &embedding_input).await?;
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing($table, $id)
|
||||
SET name = $name,
|
||||
description = $description,
|
||||
updated_at = $updated_at,
|
||||
entity_type = $entity_type,
|
||||
embedding = $embedding
|
||||
RETURN AFTER",
|
||||
)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", id.to_string()))
|
||||
.bind(("name", name.to_string()))
|
||||
.bind(("updated_at", Utc::now()))
|
||||
.bind(("entity_type", entity_type.to_owned()))
|
||||
.bind(("embedding", embedding))
|
||||
.bind(("description", description.to_string()))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
86
crates/common/src/storage/types/knowledge_relationship.rs
Normal file
86
crates/common/src/storage/types/knowledge_relationship.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use crate::storage::types::file_info::deserialize_flexible_id;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct RelationshipMetadata {
|
||||
pub user_id: String,
|
||||
pub source_id: String,
|
||||
pub relationship_type: String,
|
||||
}
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct KnowledgeRelationship {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
#[serde(rename = "in", deserialize_with = "deserialize_flexible_id")]
|
||||
pub in_: String,
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub out: String,
|
||||
pub metadata: RelationshipMetadata,
|
||||
}
|
||||
|
||||
impl KnowledgeRelationship {
|
||||
pub fn new(
|
||||
in_: String,
|
||||
out: String,
|
||||
user_id: String,
|
||||
source_id: String,
|
||||
relationship_type: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
in_,
|
||||
out,
|
||||
metadata: RelationshipMetadata {
|
||||
user_id,
|
||||
source_id,
|
||||
relationship_type,
|
||||
},
|
||||
}
|
||||
}
|
||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
r#"RELATE knowledge_entity:`{}`->relates_to:`{}`->knowledge_entity:`{}`
|
||||
SET
|
||||
metadata.user_id = '{}',
|
||||
metadata.source_id = '{}',
|
||||
metadata.relationship_type = '{}'"#,
|
||||
self.in_,
|
||||
self.id,
|
||||
self.out,
|
||||
self.metadata.user_id,
|
||||
self.metadata.source_id,
|
||||
self.metadata.relationship_type
|
||||
);
|
||||
|
||||
db_client.query(query).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_relationships_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'",
|
||||
source_id
|
||||
);
|
||||
|
||||
db_client.query(query).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_relationship_by_id(
|
||||
id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!("DELETE relates_to:`{}`", id);
|
||||
|
||||
db_client.query(query).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
37
crates/common/src/storage/types/message.rs
Normal file
37
crates/common/src/storage/types/message.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::stored_object;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Serialize)]
|
||||
pub enum MessageRole {
|
||||
User,
|
||||
AI,
|
||||
System,
|
||||
}
|
||||
|
||||
stored_object!(Message, "message", {
|
||||
conversation_id: String,
|
||||
role: MessageRole,
|
||||
content: String,
|
||||
references: Option<Vec<String>>
|
||||
});
|
||||
|
||||
impl Message {
|
||||
pub fn new(
|
||||
conversation_id: String,
|
||||
role: MessageRole,
|
||||
content: String,
|
||||
references: Option<Vec<String>>,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
conversation_id,
|
||||
role,
|
||||
content,
|
||||
references,
|
||||
}
|
||||
}
|
||||
}
|
||||
110
crates/common/src/storage/types/mod.rs
Normal file
110
crates/common/src/storage/types/mod.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use axum::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub mod analytics;
|
||||
pub mod conversation;
|
||||
pub mod file_info;
|
||||
pub mod job;
|
||||
pub mod knowledge_entity;
|
||||
pub mod knowledge_relationship;
|
||||
pub mod message;
|
||||
pub mod system_settings;
|
||||
pub mod text_chunk;
|
||||
pub mod text_content;
|
||||
pub mod user;
|
||||
|
||||
#[async_trait]
|
||||
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
||||
fn table_name() -> &'static str;
|
||||
fn get_id(&self) -> &str;
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! stored_object {
|
||||
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||
use axum::async_trait;
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use surrealdb::sql::Thing;
|
||||
use $crate::storage::types::StoredObject;
|
||||
use serde::de::{self, Visitor};
|
||||
use std::fmt;
|
||||
use chrono::{DateTime, Utc };
|
||||
|
||||
struct FlexibleIdVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for FlexibleIdVisitor {
|
||||
type Value = String;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string or a Thing")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(value.to_string())
|
||||
}
|
||||
|
||||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::MapAccess<'de>,
|
||||
{
|
||||
// Try to deserialize as Thing
|
||||
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
|
||||
Ok(thing.id.to_raw())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_any(FlexibleIdVisitor)
|
||||
}
|
||||
|
||||
fn serialize_datetime<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
Into::<surrealdb::sql::Datetime>::into(*date).serialize(serializer)
|
||||
}
|
||||
|
||||
fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let dt = surrealdb::sql::Datetime::deserialize(deserializer)?;
|
||||
Ok(DateTime::<Utc>::from(dt))
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct $name {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
|
||||
pub created_at: DateTime<Utc>,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
|
||||
pub updated_at: DateTime<Utc>,
|
||||
$(pub $field: $ty),*
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StoredObject for $name {
|
||||
fn table_name() -> &'static str {
|
||||
$table
|
||||
}
|
||||
|
||||
fn get_id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
56
crates/common/src/storage/types/system_settings.rs
Normal file
56
crates/common/src/storage/types/system_settings.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use crate::storage::types::file_info::deserialize_flexible_id;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SystemSettings {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
pub registrations_enabled: bool,
|
||||
pub require_email_verification: bool,
|
||||
}
|
||||
|
||||
impl SystemSettings {
|
||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let settings = db.select(("system_settings", "current")).await?;
|
||||
|
||||
if settings.is_none() {
|
||||
let created: Option<SystemSettings> = db
|
||||
.create(("system_settings", "current"))
|
||||
.content(SystemSettings {
|
||||
id: "current".to_string(),
|
||||
registrations_enabled: true,
|
||||
require_email_verification: false,
|
||||
})
|
||||
.await?;
|
||||
|
||||
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
};
|
||||
|
||||
settings.ok_or(AppError::Validation("Failed to initialize settings".into()))
|
||||
}
|
||||
|
||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let settings: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('system_settings', 'current')")
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
settings.ok_or(AppError::NotFound("System settings not found".into()))
|
||||
}
|
||||
|
||||
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
|
||||
let updated: Option<Self> = db
|
||||
.client
|
||||
.query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER")
|
||||
.bind(("changes", changes))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
updated.ok_or(AppError::Validation(
|
||||
"Something went wrong updating the settings".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
38
crates/common/src/storage/types/text_chunk.rs
Normal file
38
crates/common/src/storage/types/text_chunk.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TextChunk, "text_chunk", {
|
||||
source_id: String,
|
||||
chunk: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextChunk {
|
||||
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
source_id,
|
||||
chunk,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE {} WHERE source_id = '{}'",
|
||||
Self::table_name(),
|
||||
source_id
|
||||
);
|
||||
db_client.query(query).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
38
crates/common/src/storage/types/text_content.rs
Normal file
38
crates/common/src/storage/types/text_content.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::stored_object;
|
||||
|
||||
use super::file_info::FileInfo;
|
||||
|
||||
stored_object!(TextContent, "text_content", {
|
||||
text: String,
|
||||
file_info: Option<FileInfo>,
|
||||
url: Option<String>,
|
||||
instructions: String,
|
||||
category: String,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextContent {
|
||||
pub fn new(
|
||||
text: String,
|
||||
instructions: String,
|
||||
category: String,
|
||||
file_info: Option<FileInfo>,
|
||||
url: Option<String>,
|
||||
user_id: String,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
text,
|
||||
file_info,
|
||||
url,
|
||||
instructions,
|
||||
category,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
352
crates/common/src/storage/types/user.rs
Normal file
352
crates/common/src/storage/types/user.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::db::{get_item, SurrealDbClient},
|
||||
stored_object,
|
||||
};
|
||||
use axum_session_auth::Authentication;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
conversation::Conversation, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
|
||||
text_content::TextContent,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CategoryResponse {
|
||||
category: String,
|
||||
}
|
||||
|
||||
stored_object!(User, "user", {
|
||||
email: String,
|
||||
password: String,
|
||||
anonymous: bool,
|
||||
api_key: Option<String>,
|
||||
admin: bool,
|
||||
#[serde(default)]
|
||||
timezone: String
|
||||
});
|
||||
|
||||
#[async_trait]
|
||||
impl Authentication<User, String, Surreal<Any>> for User {
|
||||
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
|
||||
let db = db.unwrap();
|
||||
Ok(get_item::<Self>(db, userid.as_str()).await?.unwrap())
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
!self.anonymous
|
||||
}
|
||||
|
||||
fn is_active(&self) -> bool {
|
||||
!self.anonymous
|
||||
}
|
||||
|
||||
fn is_anonymous(&self) -> bool {
|
||||
self.anonymous
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_timezone(input: &str) -> String {
|
||||
use chrono_tz::Tz;
|
||||
|
||||
// Check if it's a valid IANA timezone identifier
|
||||
match input.parse::<Tz>() {
|
||||
Ok(_) => input.to_owned(),
|
||||
Err(_) => {
|
||||
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
|
||||
"UTC".to_owned()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub async fn create_new(
|
||||
email: String,
|
||||
password: String,
|
||||
db: &SurrealDbClient,
|
||||
timezone: String,
|
||||
) -> Result<Self, AppError> {
|
||||
// verify that the application allows new creations
|
||||
let systemsettings = SystemSettings::get_current(db).await?;
|
||||
if !systemsettings.registrations_enabled {
|
||||
return Err(AppError::Auth("Registration is not allowed".into()));
|
||||
}
|
||||
|
||||
let validated_tz = validate_timezone(&timezone);
|
||||
let now = Utc::now();
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
let user: Option<User> = db
|
||||
.client
|
||||
.query(
|
||||
"LET $count = (SELECT count() FROM type::table($table))[0].count;
|
||||
CREATE type::thing('user', $id) SET
|
||||
email = $email,
|
||||
password = crypto::argon2::generate($password),
|
||||
admin = $count < 1, // Changed from == 0 to < 1
|
||||
anonymous = false,
|
||||
created_at = $created_at,
|
||||
updated_at = $updated_at,
|
||||
timezone = $timezone",
|
||||
)
|
||||
.bind(("table", "user"))
|
||||
.bind(("id", id))
|
||||
.bind(("email", email))
|
||||
.bind(("password", password))
|
||||
.bind(("created_at", now))
|
||||
.bind(("updated_at", now))
|
||||
.bind(("timezone", validated_tz))
|
||||
.await?
|
||||
.take(1)?;
|
||||
|
||||
user.ok_or(AppError::Auth("User failed to create".into()))
|
||||
}
|
||||
|
||||
pub async fn authenticate(
|
||||
email: String,
|
||||
password: String,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Self, AppError> {
|
||||
let user: Option<User> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM user
|
||||
WHERE email = $email
|
||||
AND crypto::argon2::compare(password, $password)",
|
||||
)
|
||||
.bind(("email", email))
|
||||
.bind(("password", password))
|
||||
.await?
|
||||
.take(0)?;
|
||||
user.ok_or(AppError::Auth("User failed to authenticate".into()))
|
||||
}
|
||||
|
||||
pub async fn find_by_email(
|
||||
email: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let user: Option<User> = db
|
||||
.client
|
||||
.query("SELECT * FROM user WHERE email = $email LIMIT 1")
|
||||
.bind(("email", email.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub async fn find_by_api_key(
|
||||
api_key: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let user: Option<User> = db
|
||||
.client
|
||||
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
|
||||
.bind(("api_key", api_key.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
|
||||
// Generate a secure random API key
|
||||
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
|
||||
|
||||
// Update the user record with the new API key
|
||||
let user: Option<User> = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('user', $id)
|
||||
SET api_key = $api_key
|
||||
RETURN AFTER",
|
||||
)
|
||||
.bind(("id", id.to_owned()))
|
||||
.bind(("api_key", api_key.clone()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
// If the user was found and updated, return the API key
|
||||
if user.is_some() {
|
||||
Ok(api_key)
|
||||
} else {
|
||||
Err(AppError::Auth("User not found".into()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
let user: Option<User> = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('user', $id)
|
||||
SET api_key = NULL
|
||||
RETURN AFTER",
|
||||
)
|
||||
.bind(("id", id.to_owned()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
if user.is_some() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(AppError::Auth("User was not found".into()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_knowledge_entities(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let entities: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table) WHERE user_id = $user_id")
|
||||
.bind(("table", KnowledgeEntity::table_name()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(entities)
|
||||
}
|
||||
|
||||
pub async fn get_knowledge_relationships(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<KnowledgeRelationship>, AppError> {
|
||||
let relationships: Vec<KnowledgeRelationship> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table) WHERE metadata.user_id = $user_id")
|
||||
.bind(("table", "relates_to"))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(relationships)
|
||||
}
|
||||
|
||||
pub async fn get_latest_text_contents(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<TextContent>, AppError> {
|
||||
let items: Vec<TextContent> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC LIMIT 5")
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(items)
|
||||
}
|
||||
|
||||
pub async fn get_text_contents(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<TextContent>, AppError> {
|
||||
let items: Vec<TextContent> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC")
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(items)
|
||||
}
|
||||
|
||||
pub async fn get_latest_knowledge_entities(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let items: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC LIMIT 5",
|
||||
)
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("table_name", KnowledgeEntity::table_name()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(items)
|
||||
}
|
||||
pub async fn update_timezone(
|
||||
user_id: &str,
|
||||
timezone: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
db.query("UPDATE type::thing('user', $user_id) SET timezone = $timezone")
|
||||
.bind(("table_name", User::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.bind(("timezone", timezone.to_string()))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_user_categories(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<String>, AppError> {
|
||||
// Query to select distinct categories for the user
|
||||
let response: Vec<CategoryResponse> = db
|
||||
.client
|
||||
.query("SELECT category FROM type::table($table_name) WHERE user_id = $user_id GROUP BY category")
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
// Extract the categories from the response
|
||||
let categories: Vec<String> = response.into_iter().map(|item| item.category).collect();
|
||||
|
||||
Ok(categories)
|
||||
}
|
||||
|
||||
pub async fn get_and_validate_knowledge_entity(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<KnowledgeEntity, AppError> {
|
||||
let entity: KnowledgeEntity = get_item(db, &id)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("Entity not found".into()))?;
|
||||
|
||||
if entity.user_id != user_id {
|
||||
return Err(AppError::Auth("Access denied".into()));
|
||||
}
|
||||
|
||||
Ok(entity)
|
||||
}
|
||||
|
||||
pub async fn get_and_validate_text_content(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<TextContent, AppError> {
|
||||
let text_content: TextContent = get_item(db, &id)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("Content not found".into()))?;
|
||||
|
||||
if text_content.user_id != user_id {
|
||||
return Err(AppError::Auth("Access denied".into()));
|
||||
}
|
||||
|
||||
Ok(text_content)
|
||||
}
|
||||
|
||||
pub async fn get_user_conversations(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<Conversation>, AppError> {
|
||||
let conversations: Vec<Conversation> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id")
|
||||
.bind(("table_name", Conversation::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(conversations)
|
||||
}
|
||||
}
|
||||
30
crates/common/src/utils/config.rs
Normal file
30
crates/common/src/utils/config.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use config::{Config, ConfigError, File};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppConfig {
|
||||
pub smtp_username: String,
|
||||
pub smtp_password: String,
|
||||
pub smtp_relayer: String,
|
||||
pub surrealdb_address: String,
|
||||
pub surrealdb_username: String,
|
||||
pub surrealdb_password: String,
|
||||
pub surrealdb_namespace: String,
|
||||
pub surrealdb_database: String,
|
||||
}
|
||||
|
||||
pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||
let config = Config::builder()
|
||||
.add_source(File::with_name("config"))
|
||||
.build()?;
|
||||
|
||||
Ok(AppConfig {
|
||||
smtp_username: config.get_string("SMTP_USERNAME")?,
|
||||
smtp_password: config.get_string("SMTP_PASSWORD")?,
|
||||
smtp_relayer: config.get_string("SMTP_RELAYER")?,
|
||||
surrealdb_address: config.get_string("SURREALDB_ADDRESS")?,
|
||||
surrealdb_username: config.get_string("SURREALDB_USERNAME")?,
|
||||
surrealdb_password: config.get_string("SURREALDB_PASSWORD")?,
|
||||
surrealdb_namespace: config.get_string("SURREALDB_NAMESPACE")?,
|
||||
surrealdb_database: config.get_string("SURREALDB_DATABASE")?,
|
||||
})
|
||||
}
|
||||
48
crates/common/src/utils/embedding.rs
Normal file
48
crates/common/src/utils/embedding.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use async_openai::types::CreateEmbeddingRequestArgs;
|
||||
|
||||
use crate::error::AppError;
|
||||
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
|
||||
///
|
||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
||||
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity
|
||||
/// comparisons, vector search, and other natural language processing tasks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `client`: The OpenAI client instance used to make API requests.
|
||||
/// * `input`: The text string to generate embeddings for.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
|
||||
/// * `Err(ProcessingError)`: An error if the embedding generation fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function can return a `AppError` in the following cases:
|
||||
/// * If the OpenAI API request fails
|
||||
/// * If the request building fails
|
||||
/// * If no embedding data is received in the response
|
||||
pub async fn generate_embedding(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model("text-embedding-3-small")
|
||||
.input([input])
|
||||
.build()?;
|
||||
|
||||
// Send the request to OpenAI
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
// Extract the embedding vector
|
||||
let embedding: Vec<f32> = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
81
crates/common/src/utils/mailer.rs
Normal file
81
crates/common/src/utils/mailer.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use lettre::address::AddressError;
|
||||
use lettre::message::MultiPart;
|
||||
use lettre::{transport::smtp::authentication::Credentials, SmtpTransport};
|
||||
use lettre::{Message, Transport};
|
||||
use minijinja::context;
|
||||
use minijinja_autoreload::AutoReloader;
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
|
||||
pub struct Mailer {
|
||||
pub mailer: SmtpTransport,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum EmailError {
|
||||
#[error("Email construction error: {0}")]
|
||||
EmailParsingError(#[from] AddressError),
|
||||
|
||||
#[error("Email sending error: {0}")]
|
||||
SendingError(#[from] lettre::transport::smtp::Error),
|
||||
|
||||
#[error("Body constructing error: {0}")]
|
||||
BodyError(#[from] lettre::error::Error),
|
||||
|
||||
#[error("Templating error: {0}")]
|
||||
TemplatingError(#[from] minijinja::Error),
|
||||
}
|
||||
|
||||
impl Mailer {
|
||||
pub fn new(
|
||||
username: &str,
|
||||
relayer: &str,
|
||||
password: &str,
|
||||
) -> Result<Self, lettre::transport::smtp::Error> {
|
||||
let creds = Credentials::new(username.to_owned(), password.to_owned());
|
||||
|
||||
let mailer = SmtpTransport::relay(&relayer)?.credentials(creds).build();
|
||||
|
||||
Ok(Mailer { mailer })
|
||||
}
|
||||
|
||||
pub fn send_email_verification(
|
||||
&self,
|
||||
email_to: &str,
|
||||
verification_code: &str,
|
||||
templates: &AutoReloader,
|
||||
) -> Result<(), EmailError> {
|
||||
let name = email_to
|
||||
.split('@')
|
||||
.next()
|
||||
.unwrap_or("User")
|
||||
.chars()
|
||||
.enumerate()
|
||||
.map(|(i, c)| if i == 0 { c.to_ascii_uppercase() } else { c })
|
||||
.collect::<String>();
|
||||
|
||||
let context = context! {
|
||||
name => name,
|
||||
verification_code => verification_code
|
||||
};
|
||||
|
||||
let env = templates.acquire_env()?;
|
||||
let html = env
|
||||
.get_template("email/email_verification.html")?
|
||||
.render(&context)?;
|
||||
let plain = env
|
||||
.get_template("email/email_verification.txt")?
|
||||
.render(&context)?;
|
||||
|
||||
let email = Message::builder()
|
||||
.from("Admin <minne@starks.cloud>".parse()?)
|
||||
.reply_to("Admin <minne@starks.cloud>".parse()?)
|
||||
.to(format!("{} <{}>", name, email_to).parse()?)
|
||||
.subject("Verify Your Email Address")
|
||||
.multipart(MultiPart::alternative_plain_html(plain, html))?;
|
||||
|
||||
info!("Sending email to: {}", email_to);
|
||||
self.mailer.send(&email)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
3
crates/common/src/utils/mod.rs
Normal file
3
crates/common/src/utils/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod config;
|
||||
pub mod embedding;
|
||||
pub mod mailer;
|
||||
35
crates/html-router/Cargo.toml
Normal file
35
crates/html-router/Cargo.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
[package]
|
||||
name = "html-router"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
# Workspace dependencies
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
|
||||
axum-htmx = "0.6.0"
|
||||
axum_session = "0.14.4"
|
||||
axum_session_auth = "0.14.1"
|
||||
axum_session_surreal = "0.2.1"
|
||||
axum_typed_multipart = "0.12.1"
|
||||
futures = "0.3.31"
|
||||
tempfile = "3.12.0"
|
||||
async-stream = "0.3.6"
|
||||
json-stream-parser = "0.1.4"
|
||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
minijinja-autoreload = "2.5.0"
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
plotly = "0.12.1"
|
||||
surrealdb = "2.0.4"
|
||||
tower-http = { version = "0.6.2", features = ["fs"] }
|
||||
chrono-tz = "0.10.1"
|
||||
async-openai = "0.24.1"
|
||||
|
||||
|
||||
|
||||
common = { path = "../common" }
|
||||
88
crates/html-router/src/html_state.rs
Normal file
88
crates/html-router/src/html_state.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use axum_session::SessionStore;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use common::ingress::jobqueue::JobQueue;
|
||||
use common::storage::db::SurrealDbClient;
|
||||
use common::utils::config::AppConfig;
|
||||
use common::utils::mailer::Mailer;
|
||||
use minijinja::{path_loader, Environment};
|
||||
use minijinja_autoreload::AutoReloader;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use surrealdb::engine::any::Any;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct HtmlState {
|
||||
pub surreal_db_client: Arc<SurrealDbClient>,
|
||||
pub openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||
pub templates: Arc<AutoReloader>,
|
||||
pub mailer: Arc<Mailer>,
|
||||
pub job_queue: Arc<JobQueue>,
|
||||
pub session_store: Arc<SessionStore<SessionSurrealPool<Any>>>,
|
||||
}
|
||||
|
||||
impl HtmlState {
|
||||
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let reloader = AutoReloader::new(move |notifier| {
|
||||
let template_path = get_templates_dir();
|
||||
let mut env = Environment::new();
|
||||
env.set_loader(path_loader(&template_path));
|
||||
|
||||
notifier.set_fast_reload(true);
|
||||
notifier.watch_path(&template_path, true);
|
||||
minijinja_contrib::add_to_environment(&mut env);
|
||||
Ok(env)
|
||||
});
|
||||
|
||||
let surreal_db_client = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
&config.surrealdb_namespace,
|
||||
&config.surrealdb_database,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
|
||||
surreal_db_client.ensure_initialized().await?;
|
||||
|
||||
let openai_client = Arc::new(async_openai::Client::new());
|
||||
|
||||
let session_store = Arc::new(surreal_db_client.create_session_store().await?);
|
||||
|
||||
let app_state = HtmlState {
|
||||
surreal_db_client: surreal_db_client.clone(),
|
||||
templates: Arc::new(reloader),
|
||||
openai_client: openai_client.clone(),
|
||||
mailer: Arc::new(Mailer::new(
|
||||
&config.smtp_username,
|
||||
&config.smtp_relayer,
|
||||
&config.smtp_password,
|
||||
)?),
|
||||
job_queue: Arc::new(JobQueue::new(surreal_db_client)),
|
||||
session_store,
|
||||
};
|
||||
|
||||
Ok(app_state)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_workspace_root() -> PathBuf {
|
||||
// Starts from CARGO_MANIFEST_DIR (e.g., /project/crates/html-router/)
|
||||
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
|
||||
// Navigate up to /path/to/project/crates
|
||||
let crates_dir = manifest_dir
|
||||
.parent()
|
||||
.expect("Failed to find parent of manifest directory");
|
||||
|
||||
// Navigate up to workspace root
|
||||
crates_dir
|
||||
.parent()
|
||||
.expect("Failed to find workspace root")
|
||||
.to_path_buf()
|
||||
}
|
||||
|
||||
pub fn get_templates_dir() -> PathBuf {
|
||||
get_workspace_root().join("templates")
|
||||
}
|
||||
110
crates/html-router/src/lib.rs
Normal file
110
crates/html-router/src/lib.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
pub mod html_state;
|
||||
mod middleware_analytics;
|
||||
mod routes;
|
||||
|
||||
use axum::{
|
||||
extract::FromRef,
|
||||
middleware::from_fn_with_state,
|
||||
routing::{delete, get, patch, post},
|
||||
Router,
|
||||
};
|
||||
use axum_session::SessionLayer;
|
||||
use axum_session_auth::{AuthConfig, AuthSessionLayer};
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use common::storage::types::user::User;
|
||||
use html_state::HtmlState;
|
||||
use middleware_analytics::analytics_middleware;
|
||||
use routes::{
|
||||
account::{delete_account, set_api_key, show_account_page, update_timezone},
|
||||
admin_panel::{show_admin_panel, toggle_registration_status},
|
||||
chat::{
|
||||
message_response_stream::get_response_stream, new_chat_user_message, new_user_message,
|
||||
references::show_reference_tooltip, show_chat_base, show_existing_chat,
|
||||
show_initialized_chat,
|
||||
},
|
||||
content::{patch_text_content, show_content_page, show_text_content_edit_form},
|
||||
documentation::{
|
||||
show_documentation_index, show_get_started, show_mobile_friendly, show_privacy_policy,
|
||||
},
|
||||
gdpr::{accept_gdpr, deny_gdpr},
|
||||
index::{delete_job, delete_text_content, index_handler, show_active_jobs},
|
||||
ingress_form::{hide_ingress_form, process_ingress_form, show_ingress_form},
|
||||
knowledge::{
|
||||
delete_knowledge_entity, delete_knowledge_relationship, patch_knowledge_entity,
|
||||
save_knowledge_relationship, show_edit_knowledge_entity_form, show_knowledge_page,
|
||||
},
|
||||
search_result::search_result_handler,
|
||||
signin::{authenticate_user, show_signin_form},
|
||||
signout::sign_out_user,
|
||||
signup::{process_signup_and_show_verification, show_signup_form},
|
||||
};
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tower_http::services::ServeDir;
|
||||
|
||||
/// Router for HTML endpoints
|
||||
pub fn html_routes<S>(app_state: &HtmlState) -> Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route("/", get(index_handler))
|
||||
.route("/gdpr/accept", post(accept_gdpr))
|
||||
.route("/gdpr/deny", post(deny_gdpr))
|
||||
.route("/search", get(search_result_handler))
|
||||
.route("/chat", get(show_chat_base).post(new_chat_user_message))
|
||||
.route("/initialized-chat", post(show_initialized_chat))
|
||||
.route("/chat/:id", get(show_existing_chat).post(new_user_message))
|
||||
.route("/chat/response-stream", get(get_response_stream))
|
||||
.route("/knowledge/:id", get(show_reference_tooltip))
|
||||
.route("/signout", get(sign_out_user))
|
||||
.route("/signin", get(show_signin_form).post(authenticate_user))
|
||||
.route(
|
||||
"/ingress-form",
|
||||
get(show_ingress_form).post(process_ingress_form),
|
||||
)
|
||||
.route("/hide-ingress-form", get(hide_ingress_form))
|
||||
.route("/text-content/:id", delete(delete_text_content))
|
||||
.route("/jobs/:job_id", delete(delete_job))
|
||||
.route("/active-jobs", get(show_active_jobs))
|
||||
.route("/content", get(show_content_page))
|
||||
.route(
|
||||
"/content/:id",
|
||||
get(show_text_content_edit_form).patch(patch_text_content),
|
||||
)
|
||||
.route("/knowledge", get(show_knowledge_page))
|
||||
.route(
|
||||
"/knowledge-entity/:id",
|
||||
get(show_edit_knowledge_entity_form)
|
||||
.delete(delete_knowledge_entity)
|
||||
.patch(patch_knowledge_entity),
|
||||
)
|
||||
.route("/knowledge-relationship", post(save_knowledge_relationship))
|
||||
.route(
|
||||
"/knowledge-relationship/:id",
|
||||
delete(delete_knowledge_relationship),
|
||||
)
|
||||
.route("/account", get(show_account_page))
|
||||
.route("/admin", get(show_admin_panel))
|
||||
.route("/toggle-registrations", patch(toggle_registration_status))
|
||||
.route("/set-api-key", post(set_api_key))
|
||||
.route("/update-timezone", patch(update_timezone))
|
||||
.route("/delete-account", delete(delete_account))
|
||||
.route(
|
||||
"/signup",
|
||||
get(show_signup_form).post(process_signup_and_show_verification),
|
||||
)
|
||||
.route("/documentation", get(show_documentation_index))
|
||||
.route("/documentation/privacy-policy", get(show_privacy_policy))
|
||||
.route("/documentation/get-started", get(show_get_started))
|
||||
.route("/documentation/mobile-friendly", get(show_mobile_friendly))
|
||||
.nest_service("/assets", ServeDir::new("assets/"))
|
||||
.layer(from_fn_with_state(app_state.clone(), analytics_middleware))
|
||||
.layer(
|
||||
AuthSessionLayer::<User, String, SessionSurrealPool<Any>, Surreal<Any>>::new(Some(
|
||||
app_state.surreal_db_client.client.clone(),
|
||||
))
|
||||
.with_config(AuthConfig::<String>::default()),
|
||||
)
|
||||
.layer(SessionLayer::new((*app_state.session_store).clone()))
|
||||
}
|
||||
33
crates/html-router/src/middleware_analytics.rs
Normal file
33
crates/html-router/src/middleware_analytics.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::engine::any::Any;
|
||||
|
||||
use common::storage::types::analytics::Analytics;
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
pub async fn analytics_middleware(
|
||||
State(state): State<HtmlState>,
|
||||
session: axum_session::Session<SessionSurrealPool<Any>>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
// Get the path from the request
|
||||
let path = request.uri().path();
|
||||
|
||||
// Only count if it's a main page request (not assets or other resources)
|
||||
if !path.starts_with("/assets") && !path.starts_with("/_next") && !path.contains('.') {
|
||||
if !session.get::<bool>("counted_visitor").unwrap_or(false) {
|
||||
let _ = Analytics::increment_visitors(&state.surreal_db_client).await;
|
||||
session.set("counted_visitor", true);
|
||||
}
|
||||
|
||||
let _ = Analytics::increment_page_loads(&state.surreal_db_client).await;
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
147
crates/html-router/src/routes/account.rs
Normal file
147
crates/html-router/src/routes/account.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{StatusCode, Uri},
|
||||
response::{IntoResponse, Redirect},
|
||||
Form,
|
||||
};
|
||||
use axum_htmx::HxRedirect;
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use chrono_tz::TZ_VARIANTS;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{
|
||||
error::{AppError, HtmlError},
|
||||
storage::{db::delete_item, types::user::User},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, page_data};
|
||||
|
||||
use super::{render_block, render_template};
|
||||
|
||||
page_data!(AccountData, "auth/account_settings.html", {
|
||||
user: User,
|
||||
timezones: Vec<String>
|
||||
});
|
||||
|
||||
pub async fn show_account_page(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
|
||||
|
||||
let output = render_template(
|
||||
AccountData::template_name(),
|
||||
AccountData { user, timezones },
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn set_api_key(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match &auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
// Generate and set the API key
|
||||
let api_key = User::set_api_key(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
|
||||
// Update the user's API key
|
||||
let updated_user = User {
|
||||
api_key: Some(api_key),
|
||||
..user.clone()
|
||||
};
|
||||
|
||||
// Render the API key section block
|
||||
let output = render_block(
|
||||
AccountData::template_name(),
|
||||
"api_key_section",
|
||||
AccountData {
|
||||
user: updated_user,
|
||||
timezones: vec![],
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn delete_account(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match &auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
delete_item::<User>(&state.surreal_db_client, &user.id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
|
||||
auth.logout_user();
|
||||
|
||||
auth.session.destroy();
|
||||
|
||||
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UpdateTimezoneForm {
|
||||
timezone: String,
|
||||
}
|
||||
|
||||
pub async fn update_timezone(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<UpdateTimezoneForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match &auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
User::update_timezone(&user.id, &form.timezone, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
|
||||
// Update the user's API key
|
||||
let updated_user = User {
|
||||
timezone: form.timezone,
|
||||
..user.clone()
|
||||
};
|
||||
|
||||
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
|
||||
|
||||
// Render the API key section block
|
||||
let output = render_block(
|
||||
AccountData::template_name(),
|
||||
"timezone_section",
|
||||
AccountData {
|
||||
user: updated_user,
|
||||
timezones,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
118
crates/html-router/src/routes/admin_panel.rs
Normal file
118
crates/html-router/src/routes/admin_panel.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
response::{IntoResponse, Redirect},
|
||||
Form,
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{
|
||||
error::HtmlError,
|
||||
storage::types::{analytics::Analytics, system_settings::SystemSettings, user::User},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, page_data};
|
||||
|
||||
use super::{render_block, render_template};
|
||||
|
||||
page_data!(AdminPanelData, "auth/admin_panel.html", {
|
||||
user: User,
|
||||
settings: SystemSettings,
|
||||
analytics: Analytics,
|
||||
users: i64,
|
||||
});
|
||||
|
||||
pub async fn show_admin_panel(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated and admin
|
||||
let user = match auth.current_user {
|
||||
Some(user) if user.admin => user,
|
||||
_ => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let settings = SystemSettings::get_current(&state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let analytics = Analytics::get_current(&state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let users_count = Analytics::get_users_amount(&state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
AdminPanelData::template_name(),
|
||||
AdminPanelData {
|
||||
user,
|
||||
settings,
|
||||
analytics,
|
||||
users: users_count,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
fn checkbox_to_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
match String::deserialize(deserializer) {
|
||||
Ok(string) => Ok(string == "on"),
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct RegistrationToggleInput {
|
||||
#[serde(default)]
|
||||
#[serde(deserialize_with = "checkbox_to_bool")]
|
||||
registration_open: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct RegistrationToggleData {
|
||||
settings: SystemSettings,
|
||||
}
|
||||
|
||||
pub async fn toggle_registration_status(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(input): Form<RegistrationToggleInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated and admin
|
||||
let _user = match auth.current_user {
|
||||
Some(user) if user.admin => user,
|
||||
_ => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let new_settings = SystemSettings {
|
||||
registrations_enabled: input.registration_open,
|
||||
..current_settings.clone()
|
||||
};
|
||||
|
||||
SystemSettings::update(&state.surreal_db_client, new_settings.clone())
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_block(
|
||||
AdminPanelData::template_name(),
|
||||
"registration_status_input",
|
||||
RegistrationToggleData {
|
||||
settings: new_settings,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
336
crates/html-router/src/routes/chat/message_response_stream.rs
Normal file
336
crates/html-router/src/routes/chat/message_response_stream.rs
Normal file
@@ -0,0 +1,336 @@
|
||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use async_stream::stream;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::{
|
||||
sse::{Event, KeepAlive},
|
||||
Sse,
|
||||
},
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::{
|
||||
stream::{self, once},
|
||||
Stream, StreamExt, TryStreamExt,
|
||||
};
|
||||
use json_stream_parser::JsonStreamParser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tokio::sync::{mpsc::channel, Mutex};
|
||||
use tracing::{error, info};
|
||||
|
||||
use common::{
|
||||
retrieval::{
|
||||
combined_knowledge_entity_retrieval,
|
||||
query_helper::{
|
||||
create_chat_request, create_user_message, format_entities_json, LLMResponseFormat,
|
||||
},
|
||||
},
|
||||
storage::{
|
||||
db::{get_item, store_item, SurrealDbClient},
|
||||
types::{
|
||||
message::{Message, MessageRole},
|
||||
user::User,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, routes::render_template};
|
||||
|
||||
// Error handling function
|
||||
fn create_error_stream(
|
||||
message: impl Into<String>,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
|
||||
let message = message.into();
|
||||
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
|
||||
}
|
||||
|
||||
// Helper function to get message and user
|
||||
async fn get_message_and_user(
|
||||
db: &SurrealDbClient,
|
||||
current_user: Option<User>,
|
||||
message_id: &str,
|
||||
) -> Result<(Message, User), Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>> {
|
||||
// Check authentication
|
||||
let user = match current_user {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Retrieve message
|
||||
let message = match get_item::<Message>(db, message_id).await {
|
||||
Ok(Some(message)) => message,
|
||||
Ok(None) => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Message not found: the specified message does not exist",
|
||||
)))
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error retrieving message {}: {:?}", message_id, e);
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"Failed to retrieve message: database error",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok((message, user))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct QueryParams {
|
||||
message_id: String,
|
||||
}
|
||||
|
||||
pub async fn get_response_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
|
||||
// 1. Authentication and initial data validation
|
||||
let (user_message, user) = match get_message_and_user(
|
||||
&state.surreal_db_client,
|
||||
auth.current_user,
|
||||
¶ms.message_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok((user_message, user)) => (user_message, user),
|
||||
Err(error_stream) => return error_stream,
|
||||
};
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let entities = match combined_knowledge_entity_retrieval(
|
||||
&state.surreal_db_client,
|
||||
&state.openai_client,
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(entities) => entities,
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
|
||||
}
|
||||
};
|
||||
|
||||
// 3. Create the OpenAI request
|
||||
let entities_json = format_entities_json(&entities);
|
||||
let formatted_user_message = create_user_message(&entities_json, &user_message.content);
|
||||
let request = match create_chat_request(formatted_user_message) {
|
||||
Ok(req) => req,
|
||||
Err(..) => {
|
||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Set up the OpenAI stream
|
||||
let openai_stream = match state.openai_client.chat().create_stream(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to create OpenAI stream"));
|
||||
}
|
||||
};
|
||||
|
||||
// 5. Create channel for collecting complete response
|
||||
let (tx, mut rx) = channel::<String>(1000);
|
||||
let tx_clone = tx.clone();
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = state.surreal_db_client.clone();
|
||||
tokio::spawn(async move {
|
||||
drop(tx); // Close sender when no longer needed
|
||||
|
||||
// Collect full response
|
||||
let mut full_json = String::new();
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
full_json.push_str(&chunk);
|
||||
}
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||
let references: Vec<String> = response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect();
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
response.answer,
|
||||
Some(references),
|
||||
);
|
||||
|
||||
let _ = tx_final.send(ai_message.clone()).await;
|
||||
|
||||
match store_item(&db_client, ai_message).await {
|
||||
Ok(_) => info!("Successfully stored AI message with references"),
|
||||
Err(e) => error!("Failed to store AI message: {:?}", e),
|
||||
}
|
||||
} else {
|
||||
error!("Failed to parse LLM response as structured format");
|
||||
|
||||
// Fallback - store raw response
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
full_json,
|
||||
None,
|
||||
);
|
||||
|
||||
let _ = store_item(&db_client, ai_message).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Create a shared state for tracking the JSON parsing
|
||||
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
|
||||
|
||||
// 7. Create the response event stream
|
||||
let event_stream = openai_stream
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = json_state.clone();
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !content.is_empty() {
|
||||
// Always send raw content to storage
|
||||
let _ = tx_storage.send(content.clone()).await;
|
||||
|
||||
// Process through JSON parser
|
||||
let mut state = json_state.lock().await;
|
||||
let display_content = state.process_chunk(&content);
|
||||
drop(state);
|
||||
if !display_content.is_empty() {
|
||||
yield Ok(Event::default()
|
||||
.event("chat_message")
|
||||
.data(display_content));
|
||||
}
|
||||
// If display_content is empty, don't yield anything
|
||||
}
|
||||
// If content is empty, don't yield anything
|
||||
}
|
||||
Err(e) => {
|
||||
yield Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.chain(stream::once(async move {
|
||||
if let Some(message) = rx_final.recv().await {
|
||||
// Don't send any event if references is empty
|
||||
if message.references.as_ref().is_some_and(|x| x.is_empty()) {
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
// Prepare data for template
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
// Render template with references
|
||||
match render_template(
|
||||
"chat/reference_list.html",
|
||||
ReferenceData { message },
|
||||
state.templates.clone(),
|
||||
) {
|
||||
Ok(html) => {
|
||||
// Extract the String from Html<String>
|
||||
let html_string = html.0;
|
||||
|
||||
// Return the rendered HTML
|
||||
Ok(Event::default().event("references").data(html_string))
|
||||
}
|
||||
Err(_) => {
|
||||
// Handle template rendering error
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to render references"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle case where no references were received
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to retrieve references"))
|
||||
}
|
||||
}))
|
||||
.chain(once(async {
|
||||
Ok(Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"))
|
||||
}));
|
||||
|
||||
info!("OpenAI streaming started");
|
||||
Sse::new(event_stream.boxed()).keep_alive(
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
}
|
||||
|
||||
// Replace JsonParseState with StreamParserState
|
||||
struct StreamParserState {
|
||||
parser: JsonStreamParser,
|
||||
last_answer_content: String,
|
||||
in_answer_field: bool,
|
||||
}
|
||||
|
||||
impl StreamParserState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
parser: JsonStreamParser::new(),
|
||||
last_answer_content: String::new(),
|
||||
in_answer_field: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn process_chunk(&mut self, chunk: &str) -> String {
|
||||
// Feed all characters into the parser
|
||||
for c in chunk.chars() {
|
||||
let _ = self.parser.add_char(c);
|
||||
}
|
||||
|
||||
// Get the current state of the JSON
|
||||
let json = self.parser.get_result();
|
||||
|
||||
// Check if we're in the answer field
|
||||
if let Some(obj) = json.as_object() {
|
||||
if let Some(answer) = obj.get("answer") {
|
||||
self.in_answer_field = true;
|
||||
|
||||
// Get current answer content
|
||||
let current_content = answer.as_str().unwrap_or_default().to_string();
|
||||
|
||||
// Calculate difference to send only new content
|
||||
if current_content.len() > self.last_answer_content.len() {
|
||||
let new_content = current_content[self.last_answer_content.len()..].to_string();
|
||||
self.last_answer_content = current_content;
|
||||
return new_content;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No new content to return
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
289
crates/html-router/src/routes/chat/mod.rs
Normal file
289
crates/html-router/src/routes/chat/mod.rs
Normal file
@@ -0,0 +1,289 @@
|
||||
pub mod message_response_stream;
|
||||
pub mod references;
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::HeaderValue,
|
||||
response::{IntoResponse, Redirect},
|
||||
Form,
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tracing::info;
|
||||
|
||||
use common::{
|
||||
error::{AppError, HtmlError},
|
||||
storage::{
|
||||
db::{get_item, store_item},
|
||||
types::{
|
||||
conversation::Conversation,
|
||||
message::{Message, MessageRole},
|
||||
user::User,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, page_data, routes::render_template};
|
||||
|
||||
// Update your ChatStartParams struct to properly deserialize the references
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatStartParams {
|
||||
user_query: String,
|
||||
llm_response: String,
|
||||
#[serde(deserialize_with = "deserialize_references")]
|
||||
references: Vec<String>,
|
||||
}
|
||||
|
||||
// Custom deserializer function
|
||||
fn deserialize_references<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
serde_json::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
page_data!(ChatData, "chat/base.html", {
|
||||
user: User,
|
||||
history: Vec<Message>,
|
||||
conversation: Option<Conversation>,
|
||||
conversation_archive: Vec<Conversation>
|
||||
});
|
||||
|
||||
pub async fn show_initialized_chat(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<ChatStartParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
info!("Displaying chat start");
|
||||
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
|
||||
|
||||
let user_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
MessageRole::User,
|
||||
form.user_query,
|
||||
None,
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
MessageRole::AI,
|
||||
form.llm_response,
|
||||
Some(form.references),
|
||||
);
|
||||
|
||||
let (conversation_result, ai_message_result, user_message_result) = futures::join!(
|
||||
store_item(&state.surreal_db_client, conversation.clone()),
|
||||
store_item(&state.surreal_db_client, ai_message.clone()),
|
||||
store_item(&state.surreal_db_client, user_message.clone())
|
||||
);
|
||||
|
||||
// Check each result individually
|
||||
conversation_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
user_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
ai_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let messages = vec![user_message, ai_message];
|
||||
|
||||
let output = render_template(
|
||||
ChatData::template_name(),
|
||||
ChatData {
|
||||
history: messages,
|
||||
user,
|
||||
conversation_archive,
|
||||
conversation: Some(conversation.clone()),
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
let mut response = output.into_response();
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn show_chat_base(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
info!("Displaying empty chat start");
|
||||
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
ChatData::template_name(),
|
||||
ChatData {
|
||||
history: vec![],
|
||||
user,
|
||||
conversation_archive,
|
||||
conversation: None,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct NewMessageForm {
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub async fn show_existing_chat(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
info!("Displaying initialized chat with id: {}", conversation_id);
|
||||
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let (conversation, messages) = Conversation::get_complete_conversation(
|
||||
conversation_id.as_str(),
|
||||
&user.id,
|
||||
&state.surreal_db_client,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
ChatData::template_name(),
|
||||
ChatData {
|
||||
history: messages,
|
||||
user,
|
||||
conversation: Some(conversation.clone()),
|
||||
conversation_archive,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn new_user_message(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let conversation: Conversation = get_item(&state.surreal_db_client, &conversation_id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?
|
||||
.ok_or_else(|| {
|
||||
HtmlError::new(
|
||||
AppError::NotFound("Conversation was not found".to_string()),
|
||||
state.templates.clone(),
|
||||
)
|
||||
})?;
|
||||
|
||||
if conversation.user_id != user.id {
|
||||
return Err(HtmlError::new(
|
||||
AppError::Auth("The user does not have permission for this conversation".to_string()),
|
||||
state.templates.clone(),
|
||||
));
|
||||
};
|
||||
|
||||
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
|
||||
|
||||
store_item(&state.surreal_db_client, user_message.clone())
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
}
|
||||
|
||||
let output = render_template(
|
||||
"chat/streaming_response.html",
|
||||
SSEResponseInitData { user_message },
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
let mut response = output.into_response();
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn new_chat_user_message(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let conversation = Conversation::new(user.id, "New chat".to_string());
|
||||
let user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
form.content,
|
||||
None,
|
||||
);
|
||||
|
||||
store_item(&state.surreal_db_client, conversation.clone())
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
store_item(&state.surreal_db_client, user_message.clone())
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
conversation: Conversation,
|
||||
}
|
||||
|
||||
let output = render_template(
|
||||
"chat/new_chat_first_response.html",
|
||||
SSEResponseInitData {
|
||||
user_message,
|
||||
conversation: conversation.clone(),
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
let mut response = output.into_response();
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
Ok(response)
|
||||
}
|
||||
63
crates/html-router/src/routes/chat/references.rs
Normal file
63
crates/html-router/src/routes/chat/references.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use serde::Serialize;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tracing::info;
|
||||
|
||||
use common::{
|
||||
error::{AppError, HtmlError},
|
||||
storage::{
|
||||
db::get_item,
|
||||
types::{knowledge_entity::KnowledgeEntity, user::User},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, routes::render_template};
|
||||
|
||||
pub async fn show_reference_tooltip(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Path(reference_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
info!("Showing reference");
|
||||
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let entity: KnowledgeEntity = get_item(&state.surreal_db_client, &reference_id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?
|
||||
.ok_or_else(|| {
|
||||
HtmlError::new(
|
||||
AppError::NotFound("Item was not found".to_string()),
|
||||
state.templates.clone(),
|
||||
)
|
||||
})?;
|
||||
|
||||
if entity.user_id != user.id {
|
||||
return Err(HtmlError::new(
|
||||
AppError::Auth("You dont have access to this entity".to_string()),
|
||||
state.templates.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceTooltipData {
|
||||
entity: KnowledgeEntity,
|
||||
user: User,
|
||||
}
|
||||
|
||||
let output = render_template(
|
||||
"chat/reference_tooltip.html",
|
||||
ReferenceTooltipData { entity, user },
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
108
crates/html-router/src/routes/content/mod.rs
Normal file
108
crates/html-router/src/routes/content/mod.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{
|
||||
error::HtmlError,
|
||||
storage::types::{text_content::TextContent, user::User},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, page_data};
|
||||
|
||||
use super::render_template;
|
||||
|
||||
page_data!(ContentPageData, "content/base.html", {
|
||||
user: User,
|
||||
text_contents: Vec<TextContent>
|
||||
});
|
||||
|
||||
pub async fn show_content_page(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
let text_contents = User::get_text_contents(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
ContentPageData::template_name(),
|
||||
ContentPageData {
|
||||
user,
|
||||
text_contents,
|
||||
},
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TextContentEditModal {
|
||||
pub user: User,
|
||||
pub text_content: TextContent,
|
||||
}
|
||||
|
||||
pub async fn show_text_content_edit_form(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
"content/edit_text_content_modal.html",
|
||||
TextContentEditModal { user, text_content },
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn patch_text_content(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let text_contents = User::get_text_contents(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
"content/content_list.html",
|
||||
ContentPageData {
|
||||
user,
|
||||
text_contents,
|
||||
},
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
78
crates/html-router/src/routes/documentation/mod.rs
Normal file
78
crates/html-router/src/routes/documentation/mod.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use axum::{extract::State, response::IntoResponse};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{error::HtmlError, storage::types::user::User};
|
||||
|
||||
use crate::{html_state::HtmlState, page_data};
|
||||
|
||||
use super::render_template;
|
||||
|
||||
page_data!(DocumentationData, "do_not_use_this", {
|
||||
user: Option<User>,
|
||||
current_path: String
|
||||
});
|
||||
|
||||
pub async fn show_privacy_policy(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let output = render_template(
|
||||
"documentation/privacy.html",
|
||||
DocumentationData {
|
||||
user: auth.current_user,
|
||||
current_path: "/privacy_policy".to_string(),
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn show_get_started(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let output = render_template(
|
||||
"documentation/get_started.html",
|
||||
DocumentationData {
|
||||
user: auth.current_user,
|
||||
current_path: "/get-started".to_string(),
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
pub async fn show_mobile_friendly(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let output = render_template(
|
||||
"documentation/mobile_friendly.html",
|
||||
DocumentationData {
|
||||
user: auth.current_user,
|
||||
current_path: "/mobile-friendly".to_string(),
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn show_documentation_index(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let output = render_template(
|
||||
"documentation/index.html",
|
||||
DocumentationData {
|
||||
user: auth.current_user,
|
||||
current_path: "/index".to_string(),
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
22
crates/html-router/src/routes/gdpr.rs
Normal file
22
crates/html-router/src/routes/gdpr.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use axum::response::{Html, IntoResponse};
|
||||
use axum_session::Session;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::engine::any::Any;
|
||||
|
||||
use common::error::HtmlError;
|
||||
|
||||
pub async fn accept_gdpr(
|
||||
session: Session<SessionSurrealPool<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
session.set("gdpr_accepted", true);
|
||||
|
||||
Ok(Html("").into_response())
|
||||
}
|
||||
|
||||
pub async fn deny_gdpr(
|
||||
session: Session<SessionSurrealPool<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
session.set("gdpr_accepted", true);
|
||||
|
||||
Ok(Html("").into_response())
|
||||
}
|
||||
246
crates/html-router/src/routes/index.rs
Normal file
246
crates/html-router/src/routes/index.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_session::Session;
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tokio::join;
|
||||
use tracing::info;
|
||||
|
||||
use common::{
|
||||
error::{AppError, HtmlError},
|
||||
storage::{
|
||||
db::{delete_item, get_item},
|
||||
types::{
|
||||
file_info::FileInfo, job::Job, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
||||
text_content::TextContent, user::User,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, page_data, routes::render_template};
|
||||
|
||||
use super::render_block;
|
||||
|
||||
page_data!(IndexData, "index/index.html", {
|
||||
gdpr_accepted: bool,
|
||||
user: Option<User>,
|
||||
latest_text_contents: Vec<TextContent>,
|
||||
active_jobs: Vec<Job>
|
||||
});
|
||||
|
||||
pub async fn index_handler(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
session: Session<SessionSurrealPool<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
info!("Displaying index page");
|
||||
|
||||
let gdpr_accepted = auth.current_user.is_some() | session.get("gdpr_accepted").unwrap_or(false);
|
||||
|
||||
let active_jobs = match auth.current_user.is_some() {
|
||||
true => state
|
||||
.job_queue
|
||||
.get_unfinished_user_jobs(&auth.current_user.clone().unwrap().id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?,
|
||||
false => vec![],
|
||||
};
|
||||
|
||||
let latest_text_contents = match auth.current_user.clone().is_some() {
|
||||
true => User::get_latest_text_contents(
|
||||
auth.current_user.clone().unwrap().id.as_str(),
|
||||
&state.surreal_db_client,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?,
|
||||
false => vec![],
|
||||
};
|
||||
|
||||
// let latest_knowledge_entities = match auth.current_user.is_some() {
|
||||
// true => User::get_latest_knowledge_entities(
|
||||
// auth.current_user.clone().unwrap().id.as_str(),
|
||||
// &state.surreal_db_client,
|
||||
// )
|
||||
// .await
|
||||
// .map_err(|e| HtmlError::new(e, state.templates.clone()))?,
|
||||
// false => vec![],
|
||||
// };
|
||||
|
||||
let output = render_template(
|
||||
IndexData::template_name(),
|
||||
IndexData {
|
||||
gdpr_accepted,
|
||||
user: auth.current_user,
|
||||
latest_text_contents,
|
||||
active_jobs,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LatestTextContentData {
|
||||
latest_text_contents: Vec<TextContent>,
|
||||
user: User,
|
||||
}
|
||||
|
||||
pub async fn delete_text_content(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match &auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
// Get and validate TextContent
|
||||
let text_content = get_and_validate_text_content(&state, &id, user).await?;
|
||||
|
||||
// Perform concurrent deletions
|
||||
let deletion_tasks = join!(
|
||||
async {
|
||||
if let Some(file_info) = text_content.file_info {
|
||||
FileInfo::delete_by_id(&file_info.id, &state.surreal_db_client).await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
delete_item::<TextContent>(&state.surreal_db_client, &text_content.id),
|
||||
TextChunk::delete_by_source_id(&text_content.id, &state.surreal_db_client),
|
||||
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.surreal_db_client),
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(
|
||||
&text_content.id,
|
||||
&state.surreal_db_client
|
||||
)
|
||||
);
|
||||
|
||||
// Handle potential errors from concurrent operations
|
||||
match deletion_tasks {
|
||||
(Ok(_), Ok(_), Ok(_), Ok(_), Ok(_)) => (),
|
||||
_ => {
|
||||
return Err(HtmlError::new(
|
||||
AppError::Processing("Failed to delete one or more items".to_string()),
|
||||
state.templates.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// Render updated content
|
||||
let latest_text_contents = User::get_latest_text_contents(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_block(
|
||||
"index/signed_in/recent_content.html",
|
||||
"latest_content_section",
|
||||
LatestTextContentData {
|
||||
user: user.clone(),
|
||||
latest_text_contents,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
// Helper function to get and validate text content
|
||||
async fn get_and_validate_text_content(
|
||||
state: &HtmlState,
|
||||
id: &str,
|
||||
user: &User,
|
||||
) -> Result<TextContent, HtmlError> {
|
||||
let text_content = get_item::<TextContent>(&state.surreal_db_client, id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?
|
||||
.ok_or_else(|| {
|
||||
HtmlError::new(
|
||||
AppError::NotFound("No item found".to_string()),
|
||||
state.templates.clone(),
|
||||
)
|
||||
})?;
|
||||
|
||||
if text_content.user_id != user.id {
|
||||
return Err(HtmlError::new(
|
||||
AppError::Auth("You are not the owner of that content".to_string()),
|
||||
state.templates.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(text_content)
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ActiveJobsData {
|
||||
pub active_jobs: Vec<Job>,
|
||||
pub user: User,
|
||||
}
|
||||
|
||||
pub async fn delete_job(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
state
|
||||
.job_queue
|
||||
.delete_job(&id, &user.id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let active_jobs = state
|
||||
.job_queue
|
||||
.get_unfinished_user_jobs(&user.id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_block(
|
||||
"index/signed_in/active_jobs.html",
|
||||
"active_jobs_section",
|
||||
ActiveJobsData {
|
||||
user: user.clone(),
|
||||
active_jobs,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn show_active_jobs(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
let active_jobs = state
|
||||
.job_queue
|
||||
.get_unfinished_user_jobs(&user.id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_block(
|
||||
"index/signed_in/active_jobs.html",
|
||||
"active_jobs_section",
|
||||
ActiveJobsData {
|
||||
user: user.clone(),
|
||||
active_jobs,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
154
crates/html-router/src/routes/ingress_form.rs
Normal file
154
crates/html-router/src/routes/ingress_form.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
response::{Html, IntoResponse, Redirect},
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
use futures::{future::try_join_all, TryFutureExt};
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::info;
|
||||
|
||||
use common::{
|
||||
error::{AppError, HtmlError, IntoHtmlError},
|
||||
ingress::ingress_input::{create_ingress_objects, IngressInput},
|
||||
storage::types::{file_info::FileInfo, user::User},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
page_data,
|
||||
routes::{index::ActiveJobsData, render_block},
|
||||
};
|
||||
|
||||
use super::render_template;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ShowIngressFormData {
|
||||
user_categories: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn show_ingress_form(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
if !auth.is_authenticated() {
|
||||
return Ok(Redirect::to("/").into_response());
|
||||
}
|
||||
|
||||
let user_categories = User::get_user_categories(&auth.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
"index/signed_in/ingress_modal.html",
|
||||
ShowIngressFormData { user_categories },
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn hide_ingress_form(
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
if !auth.is_authenticated() {
|
||||
return Ok(Redirect::to("/").into_response());
|
||||
}
|
||||
|
||||
Ok(Html(
|
||||
"<a class='btn btn-primary' hx-get='/ingress-form' hx-swap='outerHTML'>Add Content</a>",
|
||||
)
|
||||
.into_response())
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
pub struct IngressParams {
|
||||
pub content: Option<String>,
|
||||
pub instructions: String,
|
||||
pub category: String,
|
||||
#[form_data(limit = "10000000")] // Adjust limit as needed
|
||||
#[form_data(default)]
|
||||
pub files: Vec<FieldData<NamedTempFile>>,
|
||||
}
|
||||
|
||||
page_data!(IngressFormData, "ingress_form.html", {
|
||||
instructions: String,
|
||||
content: String,
|
||||
category: String,
|
||||
error: String,
|
||||
});
|
||||
|
||||
pub async fn process_ingress_form(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
TypedMultipart(input): TypedMultipart<IngressParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = auth.current_user.ok_or_else(|| {
|
||||
AppError::Auth("You must be signed in".to_string()).with_template(state.templates.clone())
|
||||
})?;
|
||||
|
||||
if input.content.clone().is_some_and(|c| c.len() < 2) && input.files.is_empty() {
|
||||
let output = render_template(
|
||||
IngressFormData::template_name(),
|
||||
IngressFormData {
|
||||
instructions: input.instructions.clone(),
|
||||
content: input.content.clone().unwrap(),
|
||||
category: input.category.clone(),
|
||||
error: "You need to either add files or content".to_string(),
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
return Ok(output.into_response());
|
||||
}
|
||||
|
||||
info!("{:?}", input);
|
||||
|
||||
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
||||
FileInfo::new(file, &state.surreal_db_client, &user.id)
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let ingress_objects = create_ingress_objects(
|
||||
IngressInput {
|
||||
content: input.content,
|
||||
instructions: input.instructions,
|
||||
category: input.category,
|
||||
files: file_infos,
|
||||
},
|
||||
user.id.as_str(),
|
||||
)
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let futures: Vec<_> = ingress_objects
|
||||
.into_iter()
|
||||
.map(|object| state.job_queue.enqueue(object.clone(), user.id.clone()))
|
||||
.collect();
|
||||
|
||||
try_join_all(futures)
|
||||
.await
|
||||
.map_err(AppError::from)
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
// Update the active jobs page with the newly created job
|
||||
let active_jobs = state
|
||||
.job_queue
|
||||
.get_unfinished_user_jobs(&user.id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_block(
|
||||
"index/signed_in/active_jobs.html",
|
||||
"active_jobs_section",
|
||||
ActiveJobsData {
|
||||
user: user.clone(),
|
||||
active_jobs,
|
||||
},
|
||||
state.templates.clone(),
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
379
crates/html-router/src/routes/knowledge/mod.rs
Normal file
379
crates/html-router/src/routes/knowledge/mod.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
Form,
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use plotly::{
|
||||
common::{Line, Marker, Mode},
|
||||
layout::{Axis, Camera, LayoutScene, ProjectionType},
|
||||
Layout, Plot, Scatter3D,
|
||||
};
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tracing::info;
|
||||
|
||||
use common::{
|
||||
error::{AppError, HtmlError},
|
||||
storage::{
|
||||
db::delete_item,
|
||||
types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
user::User,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{html_state::HtmlState, page_data, routes::render_template};
|
||||
|
||||
page_data!(KnowledgeBaseData, "knowledge/base.html", {
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
user: User,
|
||||
plot_html: String
|
||||
});
|
||||
|
||||
pub async fn show_knowledge_page(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
info!("Got entities ok");
|
||||
|
||||
let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let mut plot = Plot::new();
|
||||
|
||||
// Fibonacci sphere distribution
|
||||
let node_count = entities.len();
|
||||
let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
|
||||
let node_positions: Vec<(f64, f64, f64)> = (0..node_count)
|
||||
.map(|i| {
|
||||
let i = i as f64;
|
||||
let theta = 2.0 * std::f64::consts::PI * i / golden_ratio;
|
||||
let phi = (1.0 - 2.0 * (i + 0.5) / node_count as f64).acos();
|
||||
let x = phi.sin() * theta.cos();
|
||||
let y = phi.sin() * theta.sin();
|
||||
let z = phi.cos();
|
||||
(x, y, z)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let node_x: Vec<f64> = node_positions.iter().map(|(x, _, _)| *x).collect();
|
||||
let node_y: Vec<f64> = node_positions.iter().map(|(_, y, _)| *y).collect();
|
||||
let node_z: Vec<f64> = node_positions.iter().map(|(_, _, z)| *z).collect();
|
||||
|
||||
// Nodes trace
|
||||
let nodes = Scatter3D::new(node_x.clone(), node_y.clone(), node_z.clone())
|
||||
.mode(Mode::Markers)
|
||||
.marker(Marker::new().size(8).color("#1f77b4"))
|
||||
.text_array(
|
||||
entities
|
||||
.iter()
|
||||
.map(|e| e.description.clone())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.hover_template("Entity: %{text}<br>");
|
||||
|
||||
// Edges traces
|
||||
for rel in &relationships {
|
||||
let from_idx = entities.iter().position(|e| e.id == rel.out).unwrap_or(0);
|
||||
let to_idx = entities.iter().position(|e| e.id == rel.in_).unwrap_or(0);
|
||||
|
||||
let edge_x = vec![node_x[from_idx], node_x[to_idx]];
|
||||
let edge_y = vec![node_y[from_idx], node_y[to_idx]];
|
||||
let edge_z = vec![node_z[from_idx], node_z[to_idx]];
|
||||
|
||||
let edge_trace = Scatter3D::new(edge_x, edge_y, edge_z)
|
||||
.mode(Mode::Lines)
|
||||
.line(Line::new().color("#888").width(2.0))
|
||||
.hover_template(&format!(
|
||||
"Relationship: {}<br>",
|
||||
rel.metadata.relationship_type
|
||||
))
|
||||
.show_legend(false);
|
||||
|
||||
plot.add_trace(edge_trace);
|
||||
}
|
||||
plot.add_trace(nodes);
|
||||
|
||||
// Layout
|
||||
let layout = Layout::new()
|
||||
.scene(
|
||||
LayoutScene::new()
|
||||
.x_axis(Axis::new().visible(false))
|
||||
.y_axis(Axis::new().visible(false))
|
||||
.z_axis(Axis::new().visible(false))
|
||||
.camera(
|
||||
Camera::new()
|
||||
.projection(ProjectionType::Perspective.into())
|
||||
.eye((1.5, 1.5, 1.5).into()),
|
||||
),
|
||||
)
|
||||
.show_legend(false)
|
||||
.paper_background_color("rbga(250,100,0,0)")
|
||||
.plot_background_color("rbga(0,0,0,0)");
|
||||
|
||||
plot.set_layout(layout);
|
||||
// Convert to HTML
|
||||
let html = plot.to_html();
|
||||
|
||||
let output = render_template(
|
||||
KnowledgeBaseData::template_name(),
|
||||
KnowledgeBaseData {
|
||||
entities,
|
||||
relationships,
|
||||
user,
|
||||
plot_html: html,
|
||||
},
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct EntityData {
|
||||
entity: KnowledgeEntity,
|
||||
entity_types: Vec<String>,
|
||||
user: User,
|
||||
}
|
||||
|
||||
pub async fn show_edit_knowledge_entity_form(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
// Get entity types
|
||||
let entity_types: Vec<String> = KnowledgeEntityType::variants()
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
|
||||
// Get the entity and validate ownership
|
||||
let entity = User::get_and_validate_knowledge_entity(&id, &user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let output = render_template(
|
||||
"knowledge/edit_knowledge_entity_modal.html",
|
||||
EntityData {
|
||||
entity,
|
||||
user,
|
||||
entity_types,
|
||||
},
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct EntityListData {
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
user: User,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PatchKnowledgeEntityParams {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub entity_type: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
pub async fn patch_knowledge_entity(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<PatchKnowledgeEntityParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
// Get the existing entity and validate that the user is allowed
|
||||
User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let entity_type: KnowledgeEntityType = KnowledgeEntityType::from(form.entity_type);
|
||||
|
||||
// Update the entity
|
||||
KnowledgeEntity::patch(
|
||||
&form.id,
|
||||
&form.name,
|
||||
&form.description,
|
||||
&entity_type,
|
||||
&state.surreal_db_client,
|
||||
&state.openai_client,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
|
||||
// Get updated list of entities
|
||||
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
// Render updated list
|
||||
let output = render_template(
|
||||
"knowledge/entity_list.html",
|
||||
EntityListData { entities, user },
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn delete_knowledge_entity(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
// Get the existing entity and validate that the user is allowed
|
||||
User::get_and_validate_knowledge_entity(&id, &user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
// Delete the entity
|
||||
delete_item::<KnowledgeEntity>(&state.surreal_db_client, &id)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
|
||||
|
||||
// Get updated list of entities
|
||||
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
// Render updated list
|
||||
let output = render_template(
|
||||
"knowledge/entity_list.html",
|
||||
EntityListData { entities, user },
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct RelationshipTableData {
|
||||
entities: Vec<KnowledgeEntity>,
|
||||
relationships: Vec<KnowledgeRelationship>,
|
||||
}
|
||||
|
||||
pub async fn delete_knowledge_relationship(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
// GOTTA ADD AUTH VALIDATION
|
||||
|
||||
KnowledgeRelationship::delete_relationship_by_id(&id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
// Render updated list
|
||||
let output = render_template(
|
||||
"knowledge/relationship_table.html",
|
||||
RelationshipTableData {
|
||||
entities,
|
||||
relationships,
|
||||
},
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SaveKnowledgeRelationshipInput {
|
||||
pub in_: String,
|
||||
pub out: String,
|
||||
pub relationship_type: String,
|
||||
}
|
||||
|
||||
pub async fn save_knowledge_relationship(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<SaveKnowledgeRelationshipInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not authenticated
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
// Construct relationship
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
form.in_,
|
||||
form.out,
|
||||
user.id.clone(),
|
||||
"manual".into(),
|
||||
form.relationship_type,
|
||||
);
|
||||
|
||||
relationship
|
||||
.store_relationship(&state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client)
|
||||
.await
|
||||
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
// Render updated list
|
||||
let output = render_template(
|
||||
"knowledge/relationship_table.html",
|
||||
RelationshipTableData {
|
||||
entities,
|
||||
relationships,
|
||||
},
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
91
crates/html-router/src/routes/mod.rs
Normal file
91
crates/html-router/src/routes/mod.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::response::Html;
|
||||
use minijinja_autoreload::AutoReloader;
|
||||
|
||||
use common::error::{HtmlError, IntoHtmlError};
|
||||
|
||||
pub mod account;
|
||||
pub mod admin_panel;
|
||||
pub mod chat;
|
||||
pub mod content;
|
||||
pub mod documentation;
|
||||
pub mod gdpr;
|
||||
pub mod index;
|
||||
pub mod ingress_form;
|
||||
pub mod knowledge;
|
||||
pub mod search_result;
|
||||
pub mod signin;
|
||||
pub mod signout;
|
||||
pub mod signup;
|
||||
|
||||
pub trait PageData {
|
||||
fn template_name() -> &'static str;
|
||||
}
|
||||
|
||||
// Helper function for render_template
|
||||
pub fn render_template<T>(
|
||||
template_name: &str,
|
||||
context: T,
|
||||
templates: Arc<AutoReloader>,
|
||||
) -> Result<Html<String>, HtmlError>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
{
|
||||
let env = templates
|
||||
.acquire_env()
|
||||
.map_err(|e| e.with_template(templates.clone()))?;
|
||||
let tmpl = env
|
||||
.get_template(template_name)
|
||||
.map_err(|e| e.with_template(templates.clone()))?;
|
||||
let context = minijinja::Value::from_serialize(&context);
|
||||
let output = tmpl
|
||||
.render(context)
|
||||
.map_err(|e| e.with_template(templates.clone()))?;
|
||||
Ok(Html(output))
|
||||
}
|
||||
|
||||
pub fn render_block<T>(
|
||||
template_name: &str,
|
||||
block: &str,
|
||||
context: T,
|
||||
templates: Arc<AutoReloader>,
|
||||
) -> Result<Html<String>, HtmlError>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
{
|
||||
let env = templates
|
||||
.acquire_env()
|
||||
.map_err(|e| e.with_template(templates.clone()))?;
|
||||
let tmpl = env
|
||||
.get_template(template_name)
|
||||
.map_err(|e| e.with_template(templates.clone()))?;
|
||||
|
||||
let context = minijinja::Value::from_serialize(&context);
|
||||
let output = tmpl
|
||||
.eval_to_state(context)
|
||||
.map_err(|e| e.with_template(templates.clone()))?
|
||||
.render_block(block)
|
||||
.map_err(|e| e.with_template(templates.clone()))?;
|
||||
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! page_data {
|
||||
($name:ident, $template_name:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*$(,)?}) => {
|
||||
use serde::{Serialize, Deserialize};
|
||||
use $crate::routes::PageData;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct $name {
|
||||
$($(#[$attr])* pub $field: $ty),*
|
||||
}
|
||||
|
||||
impl PageData for $name {
|
||||
fn template_name() -> &'static str {
|
||||
$template_name
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
68
crates/html-router/src/routes/search_result.rs
Normal file
68
crates/html-router/src/routes/search_result.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tracing::info;
|
||||
|
||||
use common::{error::HtmlError, storage::types::user::User};
|
||||
|
||||
use crate::{html_state::HtmlState, routes::render_template};
|
||||
#[derive(Deserialize)]
|
||||
pub struct SearchParams {
|
||||
query: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AnswerData {
|
||||
user_query: String,
|
||||
answer_content: String,
|
||||
answer_references: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn search_result_handler(
|
||||
State(state): State<HtmlState>,
|
||||
Query(query): Query<SearchParams>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
info!("Displaying search results");
|
||||
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/signin").into_response()),
|
||||
};
|
||||
|
||||
// let answer = get_answer_with_references(
|
||||
// &state.surreal_db_client,
|
||||
// &state.openai_client,
|
||||
// &query.query,
|
||||
// &user.id,
|
||||
// )
|
||||
// .await
|
||||
// .map_err(|e| HtmlError::new(e, state.templates.clone()))?;
|
||||
|
||||
let answer = "The Minne project is focused on simplifying knowledge management through features such as easy capture, smart analysis, and visualization of connections between ideas. It includes various functionalities like the Smart Analysis Feature, which provides content analysis and organization, and the Easy Capture Feature, which allows users to effortlessly capture and retrieve knowledge in various formats. Additionally, it offers tools like Knowledge Graph Visualization to enhance understanding and organization of knowledge. The project also emphasizes a user-friendly onboarding experience and mobile-friendly options for accessing its services.".to_string();
|
||||
|
||||
let references = vec![
|
||||
"i81cd5be8-557c-4b2b-ba3a-4b8d28e74b9b".to_string(),
|
||||
"5f72a724-d7a3-467d-8783-7cca6053ddc7".to_string(),
|
||||
"ad106a1f-ccda-415e-9e87-c3a34e202624".to_string(),
|
||||
"8797b57d-094d-4ee9-a3a7-c3195b246254".to_string(),
|
||||
"69763f43-82e6-4cb5-ba3e-f6da13777dab".to_string(),
|
||||
];
|
||||
|
||||
let output = render_template(
|
||||
"index/signed_in/search_response.html",
|
||||
AnswerData {
|
||||
user_query: query.query,
|
||||
answer_content: answer,
|
||||
answer_references: references,
|
||||
},
|
||||
state.templates,
|
||||
)?;
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
71
crates/html-router/src/routes/signin.rs
Normal file
71
crates/html-router/src/routes/signin.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{StatusCode, Uri},
|
||||
response::{Html, IntoResponse, Redirect},
|
||||
Form,
|
||||
};
|
||||
use axum_htmx::{HxBoosted, HxRedirect};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{error::HtmlError, storage::types::user::User};
|
||||
|
||||
use crate::{html_state::HtmlState, page_data};
|
||||
|
||||
use super::{render_block, render_template};
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct SignupParams {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
pub remember_me: Option<String>,
|
||||
}
|
||||
|
||||
page_data!(ShowSignInForm, "auth/signin_form.html", {});
|
||||
|
||||
pub async fn show_signin_form(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
HxBoosted(boosted): HxBoosted,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
if auth.is_authenticated() {
|
||||
return Ok(Redirect::to("/").into_response());
|
||||
}
|
||||
let output = match boosted {
|
||||
true => render_block(
|
||||
ShowSignInForm::template_name(),
|
||||
"body",
|
||||
ShowSignInForm {},
|
||||
state.templates.clone(),
|
||||
)?,
|
||||
false => render_template(
|
||||
ShowSignInForm::template_name(),
|
||||
ShowSignInForm {},
|
||||
state.templates.clone(),
|
||||
)?,
|
||||
};
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn authenticate_user(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<SignupParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match User::authenticate(form.email, form.password, &state.surreal_db_client).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => {
|
||||
return Ok(Html("<p>Incorrect email or password </p>").into_response());
|
||||
}
|
||||
};
|
||||
|
||||
auth.login_user(user.id);
|
||||
|
||||
if form.remember_me.is_some_and(|string| string == *"on") {
|
||||
auth.remember_user(true);
|
||||
}
|
||||
|
||||
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
|
||||
}
|
||||
18
crates/html-router/src/routes/signout.rs
Normal file
18
crates/html-router/src/routes/signout.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use axum::response::{IntoResponse, Redirect};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{error::ApiError, storage::types::user::User};
|
||||
|
||||
pub async fn sign_out_user(
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
if !auth.is_authenticated() {
|
||||
return Ok(Redirect::to("/").into_response());
|
||||
}
|
||||
|
||||
auth.logout_user();
|
||||
|
||||
Ok(Redirect::to("/").into_response())
|
||||
}
|
||||
65
crates/html-router/src/routes/signup.rs
Normal file
65
crates/html-router/src/routes/signup.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{StatusCode, Uri},
|
||||
response::{Html, IntoResponse, Redirect},
|
||||
Form,
|
||||
};
|
||||
use axum_htmx::{HxBoosted, HxRedirect};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{error::HtmlError, storage::types::user::User};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
use super::{render_block, render_template};
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct SignupParams {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
pub timezone: String,
|
||||
}
|
||||
|
||||
pub async fn show_signup_form(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
HxBoosted(boosted): HxBoosted,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
if auth.is_authenticated() {
|
||||
return Ok(Redirect::to("/").into_response());
|
||||
}
|
||||
let output = match boosted {
|
||||
true => render_block("auth/signup_form.html", "body", {}, state.templates.clone())?,
|
||||
false => render_template("auth/signup_form.html", {}, state.templates.clone())?,
|
||||
};
|
||||
|
||||
Ok(output.into_response())
|
||||
}
|
||||
|
||||
pub async fn process_signup_and_show_verification(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<SignupParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match User::create_new(
|
||||
form.email,
|
||||
form.password,
|
||||
&state.surreal_db_client,
|
||||
form.timezone,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(user) => user,
|
||||
Err(e) => {
|
||||
tracing::error!("{:?}", e);
|
||||
return Ok(Html(format!("<p>{}</p>", e)).into_response());
|
||||
}
|
||||
};
|
||||
|
||||
auth.login_user(user.id);
|
||||
|
||||
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
|
||||
}
|
||||
58
crates/main/Cargo.toml
Normal file
58
crates/main/Cargo.toml
Normal file
@@ -0,0 +1,58 @@
|
||||
[package]
|
||||
name = "main"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
|
||||
async-openai = "0.24.1"
|
||||
async-stream = "0.3.6"
|
||||
axum-htmx = "0.6.0"
|
||||
axum_session = "0.14.4"
|
||||
axum_session_auth = "0.14.1"
|
||||
axum_session_surreal = "0.2.1"
|
||||
axum_typed_multipart = "0.12.1"
|
||||
chrono = { version = "0.4.39", features = ["serde"] }
|
||||
chrono-tz = "0.10.1"
|
||||
config = "0.15.4"
|
||||
futures = "0.3.31"
|
||||
json-stream-parser = "0.1.4"
|
||||
lettre = { version = "0.11.11", features = ["rustls-tls"] }
|
||||
mime = "0.3.17"
|
||||
mime_guess = "2.0.5"
|
||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
minijinja-autoreload = "2.5.0"
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
mockall = "0.13.0"
|
||||
plotly = "0.12.1"
|
||||
reqwest = {version = "0.12.12", features = ["charset", "json"]}
|
||||
scraper = "0.22.0"
|
||||
sha2 = "0.10.8"
|
||||
surrealdb = "2.0.4"
|
||||
tempfile = "3.12.0"
|
||||
text-splitter = "0.18.1"
|
||||
tiktoken-rs = "0.6.0"
|
||||
tower-http = { version = "0.6.2", features = ["fs"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
url = { version = "2.5.2", features = ["serde"] }
|
||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||
|
||||
# Reference to api-router
|
||||
api-router = { path = "../api-router" }
|
||||
html-router = { path = "../html-router" }
|
||||
common = { path = "../common" }
|
||||
|
||||
[[bin]]
|
||||
name = "server"
|
||||
path = "src/server.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "worker"
|
||||
path = "src/worker.rs"
|
||||
47
crates/main/src/server.rs
Normal file
47
crates/main/src/server.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use api_router::{api_routes_v1, api_state::ApiState};
|
||||
use axum::{extract::FromRef, Router};
|
||||
use common::utils::config::get_config;
|
||||
use html_router::{html_routes, html_state::HtmlState};
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Set up tracing
|
||||
tracing_subscriber::registry()
|
||||
.with(fmt::layer())
|
||||
.with(EnvFilter::from_default_env())
|
||||
.try_init()
|
||||
.ok();
|
||||
|
||||
// Get config
|
||||
let config = get_config()?;
|
||||
|
||||
// Set up router states
|
||||
let html_state = HtmlState::new(&config).await?;
|
||||
let api_state = ApiState {
|
||||
surreal_db_client: html_state.surreal_db_client.clone(),
|
||||
job_queue: html_state.job_queue.clone(),
|
||||
};
|
||||
|
||||
// Create Axum router
|
||||
let app = Router::new()
|
||||
.nest("/api/v1", api_routes_v1(&api_state))
|
||||
.nest("/", html_routes(&html_state))
|
||||
.with_state(AppState {
|
||||
api_state,
|
||||
html_state,
|
||||
});
|
||||
|
||||
info!("Listening on 0.0.0.0:3000");
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, FromRef)]
|
||||
struct AppState {
|
||||
api_state: ApiState,
|
||||
html_state: HtmlState,
|
||||
}
|
||||
150
crates/main/src/worker.rs
Normal file
150
crates/main/src/worker.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{
|
||||
ingress::{
|
||||
content_processor::ContentProcessor,
|
||||
jobqueue::{JobQueue, MAX_ATTEMPTS},
|
||||
},
|
||||
storage::{
|
||||
db::{get_item, SurrealDbClient},
|
||||
types::job::{Job, JobStatus},
|
||||
},
|
||||
utils::config::get_config,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use surrealdb::Action;
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Set up tracing
|
||||
tracing_subscriber::registry()
|
||||
.with(fmt::layer())
|
||||
.with(EnvFilter::from_default_env())
|
||||
.try_init()
|
||||
.ok();
|
||||
|
||||
let config = get_config()?;
|
||||
|
||||
let surreal_db_client = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
&config.surrealdb_namespace,
|
||||
&config.surrealdb_database,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
|
||||
let openai_client = Arc::new(async_openai::Client::new());
|
||||
|
||||
let job_queue = JobQueue::new(surreal_db_client.clone());
|
||||
|
||||
let content_processor = ContentProcessor::new(surreal_db_client, openai_client.clone()).await?;
|
||||
|
||||
loop {
|
||||
// First, check for any unfinished jobs
|
||||
let unfinished_jobs = job_queue.get_unfinished_jobs().await?;
|
||||
|
||||
if !unfinished_jobs.is_empty() {
|
||||
info!("Found {} unfinished jobs", unfinished_jobs.len());
|
||||
|
||||
for job in unfinished_jobs {
|
||||
job_queue
|
||||
.process_job(job, &content_processor, openai_client.clone())
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
// If no unfinished jobs, start listening for new ones
|
||||
info!("Listening for new jobs...");
|
||||
let mut job_stream = job_queue.listen_for_jobs().await?;
|
||||
|
||||
while let Some(notification) = job_stream.next().await {
|
||||
match notification {
|
||||
Ok(notification) => {
|
||||
info!("Received notification: {:?}", notification);
|
||||
|
||||
match notification.action {
|
||||
Action::Create => {
|
||||
if let Err(e) = job_queue
|
||||
.process_job(
|
||||
notification.data,
|
||||
&content_processor,
|
||||
openai_client.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Error processing job: {}", e);
|
||||
}
|
||||
}
|
||||
Action::Update => {
|
||||
match notification.data.status {
|
||||
JobStatus::Completed
|
||||
| JobStatus::Error(_)
|
||||
| JobStatus::Cancelled => {
|
||||
info!(
|
||||
"Skipping already completed/error/cancelled job: {}",
|
||||
notification.data.id
|
||||
);
|
||||
continue;
|
||||
}
|
||||
JobStatus::InProgress { attempts, .. } => {
|
||||
// Only process if this is a retry after an error, not our own update
|
||||
if let Ok(Some(current_job)) =
|
||||
get_item::<Job>(&job_queue.db.client, ¬ification.data.id)
|
||||
.await
|
||||
{
|
||||
match current_job.status {
|
||||
JobStatus::Error(_) if attempts < MAX_ATTEMPTS => {
|
||||
// This is a retry after an error
|
||||
if let Err(e) = job_queue
|
||||
.process_job(
|
||||
current_job,
|
||||
&content_processor,
|
||||
openai_client.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Error processing job retry: {}", e);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
info!(
|
||||
"Skipping in-progress update for job: {}",
|
||||
notification.data.id
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
JobStatus::Created => {
|
||||
// Shouldn't happen with Update action, but process if it does
|
||||
if let Err(e) = job_queue
|
||||
.process_job(
|
||||
notification.data,
|
||||
&content_processor,
|
||||
openai_client.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Error processing job: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {} // Ignore other actions
|
||||
}
|
||||
}
|
||||
Err(e) => error!("Error in job notification: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, the stream has ended (connection lost?)
|
||||
error!("Job stream ended unexpectedly, reconnecting...");
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user