in progress, routers and main split up

This commit is contained in:
Per Stark
2025-03-04 07:44:00 +01:00
parent 091270b458
commit cdb55ed8c1
80 changed files with 599 additions and 1577 deletions

View File

@@ -0,0 +1,18 @@
[package]
name = "api-router"
version = "0.1.0"
edition = "2021"
[dependencies]
# Workspace dependencies
tokio = { workspace = true }
serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
anyhow = { workspace = true }
tempfile = "3.12.0"
futures = "0.3.31"
axum_typed_multipart = "0.12.1"
common = { path = "../common" }

View File

@@ -0,0 +1,33 @@
use std::sync::Arc;
use common::{ingress::jobqueue::JobQueue, storage::db::SurrealDbClient, utils::config::AppConfig};
#[derive(Clone)]
pub struct ApiState {
pub surreal_db_client: Arc<SurrealDbClient>,
pub job_queue: Arc<JobQueue>,
}
impl ApiState {
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
let surreal_db_client = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
surreal_db_client.ensure_initialized().await?;
let app_state = ApiState {
surreal_db_client: surreal_db_client.clone(),
job_queue: Arc::new(JobQueue::new(surreal_db_client)),
};
Ok(app_state)
}
}

View File

@@ -0,0 +1,25 @@
use api_state::ApiState;
use axum::{
extract::{DefaultBodyLimit, FromRef},
middleware::from_fn_with_state,
routing::post,
Router,
};
use middleware_api_auth::api_auth;
use routes::ingress::ingress_data;
pub mod api_state;
mod middleware_api_auth;
mod routes;
/// Router for API functionality, version 1
pub fn api_routes_v1<S>(app_state: &ApiState) -> Router<S>
where
S: Clone + Send + Sync + 'static,
ApiState: FromRef<S>,
{
Router::new()
.route("/ingress", post(ingress_data))
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
.route_layer(from_fn_with_state(app_state.clone(), api_auth))
}

View File

@@ -0,0 +1,43 @@
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use common::{error::ApiError, storage::types::user::User};
use crate::api_state::ApiState;
pub async fn api_auth(
State(state): State<ApiState>,
mut request: Request,
next: Next,
) -> Result<Response, ApiError> {
let api_key = extract_api_key(&request).ok_or(ApiError::Unauthorized(
"You have to be authenticated".to_string(),
))?;
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
let user = user.ok_or(ApiError::Unauthorized(
"You have to be authenticated".to_string(),
))?;
request.extensions_mut().insert(user);
Ok(next.run(request).await)
}
fn extract_api_key(request: &Request) -> Option<String> {
request
.headers()
.get("X-API-Key")
.and_then(|v| v.to_str().ok())
.or_else(|| {
request
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer ").map(|s| s.trim()))
})
.map(String::from)
}

View File

@@ -0,0 +1,57 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use common::{
error::{ApiError, AppError},
ingress::ingress_input::{create_ingress_objects, IngressInput},
storage::types::{file_info::FileInfo, user::User},
};
use futures::{future::try_join_all, TryFutureExt};
use tempfile::NamedTempFile;
use tracing::{debug, info};
use crate::api_state::ApiState;
#[derive(Debug, TryFromMultipart)]
pub struct IngressParams {
pub content: Option<String>,
pub instructions: String,
pub category: String,
#[form_data(limit = "10000000")] // Adjust limit as needed
#[form_data(default)]
pub files: Vec<FieldData<NamedTempFile>>,
}
pub async fn ingress_data(
State(state): State<ApiState>,
Extension(user): Extension<User>,
TypedMultipart(input): TypedMultipart<IngressParams>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", input);
let file_infos = try_join_all(input.files.into_iter().map(|file| {
FileInfo::new(file, &state.surreal_db_client, &user.id).map_err(AppError::from)
}))
.await?;
debug!("Got file infos");
let ingress_objects = create_ingress_objects(
IngressInput {
content: input.content,
instructions: input.instructions,
category: input.category,
files: file_infos,
},
user.id.as_str(),
)?;
debug!("Got ingress objects");
let futures: Vec<_> = ingress_objects
.into_iter()
.map(|object| state.job_queue.enqueue(object.clone(), user.id.clone()))
.collect();
try_join_all(futures).await.map_err(AppError::from)?;
Ok(StatusCode::OK)
}

View File

@@ -0,0 +1 @@
pub mod ingress;

47
crates/common/Cargo.toml Normal file
View File

@@ -0,0 +1,47 @@
[package]
name = "common"
version = "0.1.0"
edition = "2021"
[dependencies]
# Workspace dependencies
tokio = { workspace = true }
serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde_json = { workspace = true }
async-openai = "0.24.1"
async-stream = "0.3.6"
axum-htmx = "0.6.0"
axum_session = "0.14.4"
axum_session_auth = "0.14.1"
axum_session_surreal = "0.2.1"
axum_typed_multipart = "0.12.1"
chrono = { version = "0.4.39", features = ["serde"] }
chrono-tz = "0.10.1"
config = "0.15.4"
futures = "0.3.31"
json-stream-parser = "0.1.4"
lettre = { version = "0.11.11", features = ["rustls-tls"] }
mime = "0.3.17"
mime_guess = "2.0.5"
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
minijinja-autoreload = "2.5.0"
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
mockall = "0.13.0"
plotly = "0.12.1"
reqwest = {version = "0.12.12", features = ["charset", "json"]}
scraper = "0.22.0"
sha2 = "0.10.8"
surrealdb = "2.0.4"
tempfile = "3.12.0"
text-splitter = "0.18.1"
tiktoken-rs = "0.6.0"
tower-http = { version = "0.6.2", features = ["fs"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
url = { version = "2.5.2", features = ["serde"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }

254
crates/common/src/error.rs Normal file
View File

@@ -0,0 +1,254 @@
use std::sync::Arc;
use async_openai::error::OpenAIError;
use axum::{
http::StatusCode,
response::{Html, IntoResponse, Response},
Json,
};
use minijinja::context;
use minijinja_autoreload::AutoReloader;
use serde::Serialize;
use serde_json::json;
use thiserror::Error;
use tokio::task::JoinError;
use crate::{storage::types::file_info::FileError, utils::mailer::EmailError};
// Core internal errors
#[derive(Error, Debug)]
pub enum AppError {
#[error("Database error: {0}")]
Database(#[from] surrealdb::Error),
#[error("OpenAI error: {0}")]
OpenAI(#[from] OpenAIError),
#[error("File error: {0}")]
File(#[from] FileError),
#[error("Email error: {0}")]
Email(#[from] EmailError),
#[error("Not found: {0}")]
NotFound(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("Authorization error: {0}")]
Auth(String),
#[error("LLM parsing error: {0}")]
LLMParsing(String),
#[error("Task join error: {0}")]
Join(#[from] JoinError),
#[error("Graph mapper error: {0}")]
GraphMapper(String),
#[error("IoError: {0}")]
Io(#[from] std::io::Error),
#[error("Minijina error: {0}")]
MiniJinja(#[from] minijinja::Error),
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Tiktoken error: {0}")]
Tiktoken(#[from] anyhow::Error),
#[error("Ingress Processing error: {0}")]
Processing(String),
}
// API-specific errors
#[derive(Debug, Serialize)]
pub enum ApiError {
InternalError(String),
ValidationError(String),
NotFound(String),
Unauthorized(String),
}
impl From<AppError> for ApiError {
fn from(err: AppError) -> Self {
match err {
AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => {
tracing::error!("Internal error: {:?}", err);
ApiError::InternalError("Internal server error".to_string())
}
AppError::NotFound(msg) => ApiError::NotFound(msg),
AppError::Validation(msg) => ApiError::ValidationError(msg),
AppError::Auth(msg) => ApiError::Unauthorized(msg),
_ => ApiError::InternalError("Internal server error".to_string()),
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let (status, body) = match self {
ApiError::InternalError(message) => (
StatusCode::INTERNAL_SERVER_ERROR,
json!({
"error": message,
"status": "error"
}),
),
ApiError::ValidationError(message) => (
StatusCode::BAD_REQUEST,
json!({
"error": message,
"status": "error"
}),
),
ApiError::NotFound(message) => (
StatusCode::NOT_FOUND,
json!({
"error": message,
"status": "error"
}),
),
ApiError::Unauthorized(message) => (
StatusCode::UNAUTHORIZED,
json!({
"error": message,
"status": "error"
}),
), // ... other matches
};
(status, Json(body)).into_response()
}
}
pub type TemplateResult<T> = Result<T, HtmlError>;
// Helper trait for converting to HtmlError with templates
pub trait IntoHtmlError {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError;
}
// // Implement for AppError
impl IntoHtmlError for AppError {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
HtmlError::new(self, templates)
}
}
// // Implement for minijinja::Error directly
impl IntoHtmlError for minijinja::Error {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
HtmlError::from_template_error(self, templates)
}
}
#[derive(Clone)]
pub struct ErrorContext {
#[allow(dead_code)]
templates: Arc<AutoReloader>,
}
impl ErrorContext {
pub fn new(templates: Arc<AutoReloader>) -> Self {
Self { templates }
}
}
pub enum HtmlError {
ServerError(Arc<AutoReloader>),
NotFound(Arc<AutoReloader>),
Unauthorized(Arc<AutoReloader>),
BadRequest(String, Arc<AutoReloader>),
Template(String, Arc<AutoReloader>),
}
impl HtmlError {
pub fn new(error: AppError, templates: Arc<AutoReloader>) -> Self {
match error {
AppError::NotFound(_msg) => HtmlError::NotFound(templates),
AppError::Auth(_msg) => HtmlError::Unauthorized(templates),
AppError::Validation(msg) => HtmlError::BadRequest(msg, templates),
_ => {
tracing::error!("Internal error: {:?}", error);
HtmlError::ServerError(templates)
}
}
}
pub fn from_template_error(error: minijinja::Error, templates: Arc<AutoReloader>) -> Self {
tracing::error!("Template error: {:?}", error);
HtmlError::Template(error.to_string(), templates)
}
}
impl IntoResponse for HtmlError {
fn into_response(self) -> Response {
let (status, context, templates) = match self {
HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => (
StatusCode::INTERNAL_SERVER_ERROR,
context! {
status_code => 500,
title => "Internal Server Error",
error => "Internal Server Error",
description => "Something went wrong on our end."
},
templates,
),
HtmlError::NotFound(templates) => (
StatusCode::NOT_FOUND,
context! {
status_code => 404,
title => "Page Not Found",
error => "Not Found",
description => "The page you're looking for doesn't exist or was removed."
},
templates,
),
HtmlError::Unauthorized(templates) => (
StatusCode::UNAUTHORIZED,
context! {
status_code => 401,
title => "Unauthorized",
error => "Access Denied",
description => "You need to be logged in to access this page."
},
templates,
),
HtmlError::BadRequest(msg, templates) => (
StatusCode::BAD_REQUEST,
context! {
status_code => 400,
title => "Bad Request",
error => "Bad Request",
description => msg
},
templates,
),
};
let html = match templates.acquire_env() {
Ok(env) => match env.get_template("errors/error.html") {
Ok(tmpl) => match tmpl.render(context) {
Ok(output) => output,
Err(e) => {
tracing::error!("Template render error: {:?}", e);
Self::fallback_html()
}
},
Err(e) => {
tracing::error!("Template get error: {:?}", e);
Self::fallback_html()
}
},
Err(e) => {
tracing::error!("Environment acquire error: {:?}", e);
Self::fallback_html()
}
};
(status, Html(html)).into_response()
}
}
impl HtmlError {
fn fallback_html() -> String {
r#"
<html>
<body>
<div class="container mx-auto p-4">
<h1 class="text-4xl text-error">Error</h1>
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
</div>
</body>
</html>
"#
.to_string()
}
}

View File

@@ -0,0 +1,145 @@
use crate::{
error::AppError,
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
retrieval::combined_knowledge_entity_retrieval,
storage::types::knowledge_entity::KnowledgeEntity,
};
use async_openai::{
error::OpenAIError,
types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
},
};
use serde_json::json;
use surrealdb::engine::any::Any;
use surrealdb::Surreal;
use tracing::debug;
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
pub struct IngressAnalyzer<'a> {
db_client: &'a Surreal<Any>,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl<'a> IngressAnalyzer<'a> {
pub fn new(
db_client: &'a Surreal<Any>,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Self {
Self {
db_client,
openai_client,
}
}
pub async fn analyze_content(
&self,
category: &str,
instructions: &str,
text: &str,
user_id: &str,
) -> Result<LLMGraphAnalysisResult, AppError> {
let similar_entities = self
.find_similar_entities(category, instructions, text, user_id)
.await?;
let llm_request =
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
self.perform_analysis(llm_request).await
}
async fn find_similar_entities(
&self,
category: &str,
instructions: &str,
text: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let input_text = format!(
"content: {}, category: {}, user_instructions: {}",
text, category, instructions
);
combined_knowledge_entity_retrieval(
self.db_client,
self.openai_client,
&input_text,
user_id,
)
.await
}
fn prepare_llm_request(
&self,
category: &str,
instructions: &str,
text: &str,
similar_entities: &[KnowledgeEntity],
) -> Result<CreateChatCompletionRequest, OpenAIError> {
let entities_json = json!(similar_entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>());
let user_message = format!(
"Category:\n{}\nInstructions:\n{}\nContent:\n{}\nExisting KnowledgeEntities in database:\n{}",
category, instructions, text, entities_json
);
debug!("Prepared LLM request message: {}", user_message);
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Structured analysis of the submitted content".into()),
name: "content_analysis".into(),
schema: Some(get_ingress_analysis_schema()),
strict: Some(true),
},
};
CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.2)
.max_tokens(3048u32)
.messages([
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
}
async fn perform_analysis(
&self,
request: CreateChatCompletionRequest,
) -> Result<LLMGraphAnalysisResult, AppError> {
let response = self.openai_client.chat().create(request).await?;
debug!("Received LLM response: {:?}", response);
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing(
"No content found in LLM response".to_string(),
))
.and_then(|content| {
serde_json::from_str(content).map_err(|e| {
AppError::LLMParsing(format!(
"Failed to parse LLM response into analysis: {}",
e
))
})
})
}
}

View File

@@ -0,0 +1,3 @@
pub mod ingress_analyser;
pub mod prompt;
pub mod types;

View File

@@ -0,0 +1,81 @@
use serde_json::{json, Value};
pub static INGRESS_ANALYSIS_SYSTEM_MESSAGE: &str = r#"
You are an AI assistant. You will receive a text content, along with user instructions and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users instructions and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
The JSON should have the following structure:
{
"knowledge_entities": [
{
"key": "unique-key-1",
"name": "Entity Name",
"description": "A detailed description of the entity.",
"entity_type": "TypeOfEntity"
},
// More entities...
],
"relationships": [
{
"type": "RelationshipType",
"source": "unique-key-1 or UUID from existing database",
"target": "unique-key-1 or UUID from existing database"
},
// More relationships...
]
}
Guidelines:
1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
7. Only create relationships between existing KnowledgeEntities.
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
9. A new relationship MUST include a newly created KnowledgeEntity.
"#;
pub fn get_ingress_analysis_schema() -> Value {
json!({
"type": "object",
"properties": {
"knowledge_entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"key": { "type": "string" },
"name": { "type": "string" },
"description": { "type": "string" },
"entity_type": {
"type": "string",
"enum": ["idea", "project", "document", "page", "textsnippet"]
}
},
"required": ["key", "name", "description", "entity_type"],
"additionalProperties": false
}
},
"relationships": {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["RelatedTo", "RelevantTo", "SimilarTo"]
},
"source": { "type": "string" },
"target": { "type": "string" }
},
"required": ["type", "source", "target"],
"additionalProperties": false
}
}
},
"required": ["knowledge_entities", "relationships"],
"additionalProperties": false
})
}

View File

@@ -0,0 +1,42 @@
use std::collections::HashMap;
use uuid::Uuid;
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
#[derive(Clone)]
pub struct GraphMapper {
pub key_to_id: HashMap<String, Uuid>,
}
impl Default for GraphMapper {
fn default() -> Self {
GraphMapper::new()
}
}
impl GraphMapper {
pub fn new() -> Self {
GraphMapper {
key_to_id: HashMap::new(),
}
}
/// Get ID, tries to parse UUID
pub fn get_or_parse_id(&mut self, key: &str) -> Uuid {
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
parsed_uuid
} else {
*self.key_to_id.get(key).unwrap()
}
}
/// Assigns a new UUID for a given key.
pub fn assign_id(&mut self, key: &str) -> Uuid {
let id = Uuid::new_v4();
self.key_to_id.insert(key.to_string(), id);
id
}
/// Retrieves the UUID for a given key.
pub fn get_id(&self, key: &str) -> Option<&Uuid> {
self.key_to_id.get(key)
}
}

View File

@@ -0,0 +1,182 @@
use std::sync::{Arc, Mutex};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use tokio::task;
use crate::{
error::AppError,
storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
},
utils::embedding::generate_embedding,
};
use futures::future::try_join_all;
use super::graph_mapper::GraphMapper; // For future parallelization
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMKnowledgeEntity {
pub key: String, // Temporary identifier
pub name: String,
pub description: String,
pub entity_type: String, // Should match KnowledgeEntityType variants
}
/// Represents a single relationship from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMRelationship {
#[serde(rename = "type")]
pub type_: String, // e.g., RelatedTo, RelevantTo
pub source: String, // Key of the source entity
pub target: String, // Key of the target entity
}
/// Represents the entire graph analysis result from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMGraphAnalysisResult {
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
pub relationships: Vec<LLMRelationship>,
}
impl LLMGraphAnalysisResult {
/// Converts the LLM graph analysis result into database entities and relationships.
///
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
/// * `openai_client` - OpenAI client for LLM calls.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities(
&self,
source_id: &str,
user_id: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
// Create mapper and pre-assign IDs
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
// Process entities
let entities = self
.process_entities(source_id, user_id, Arc::clone(&mapper), openai_client)
.await?;
// Process relationships
let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?;
Ok((entities, relationships))
}
fn create_mapper(&self) -> Result<GraphMapper, AppError> {
let mut mapper = GraphMapper::new();
// Pre-assign all IDs
for entity in &self.knowledge_entities {
mapper.assign_id(&entity.key);
}
Ok(mapper)
}
async fn process_entities(
&self,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let futures: Vec<_> = self
.knowledge_entities
.iter()
.map(|entity| {
let mapper = Arc::clone(&mapper);
let openai_client = openai_client.clone();
let source_id = source_id.to_string();
let user_id = user_id.to_string();
let entity = entity.clone();
task::spawn(async move {
create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client)
.await
})
})
.collect();
let results = try_join_all(futures)
.await?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(results)
}
fn process_relationships(
&self,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
) -> Result<Vec<KnowledgeRelationship>, AppError> {
let mut mapper_guard = mapper
.lock()
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
self.relationships
.iter()
.map(|rel| {
let source_db_id = mapper_guard.get_or_parse_id(&rel.source);
let target_db_id = mapper_guard.get_or_parse_id(&rel.target);
Ok(KnowledgeRelationship::new(
source_db_id.to_string(),
target_db_id.to_string(),
user_id.to_string(),
source_id.to_string(),
rel.type_.clone(),
))
})
.collect()
}
}
async fn create_single_entity(
llm_entity: &LLMKnowledgeEntity,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<KnowledgeEntity, AppError> {
let assigned_id = {
let mapper = mapper
.lock()
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
AppError::GraphMapper(format!("ID not found for key: {}", llm_entity.key))
})?
.to_string()
};
let embedding_input = format!(
"name: {}, description: {}, type: {}",
llm_entity.name, llm_entity.description, llm_entity.entity_type
);
let embedding = generate_embedding(openai_client, &embedding_input).await?;
let now = Utc::now();
Ok(KnowledgeEntity {
id: assigned_id,
created_at: now,
updated_at: now,
name: llm_entity.name.to_string(),
description: llm_entity.description.to_string(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.to_string()),
source_id: source_id.to_string(),
metadata: None,
embedding,
user_id: user_id.into(),
})
}

View File

@@ -0,0 +1,2 @@
pub mod graph_mapper;
pub mod llm_analysis_result;

View File

@@ -0,0 +1,124 @@
use std::{sync::Arc, time::Instant};
use text_splitter::TextSplitter;
use tracing::{debug, info};
use crate::{
error::AppError,
storage::{
db::{store_item, SurrealDbClient},
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk, text_content::TextContent,
},
},
utils::embedding::generate_embedding,
};
use super::analysis::{
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
};
pub struct ContentProcessor {
db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
}
impl ContentProcessor {
pub async fn new(
surreal_db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<Self, AppError> {
Ok(Self {
db_client: surreal_db_client,
openai_client,
})
}
pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
let now = Instant::now();
// Perform analyis, this step also includes retrieval
let analysis = self.perform_semantic_analysis(content).await?;
let end = now.elapsed();
info!(
"{:?} time elapsed during creation of entities and relationships",
end
);
// Convert analysis to objects
let (entities, relationships) = analysis
.to_database_entities(&content.id, &content.user_id, &self.openai_client)
.await?;
// Store everything
tokio::try_join!(
self.store_graph_entities(entities, relationships),
self.store_vector_chunks(content),
)?;
// Store original content
store_item(&self.db_client, content.to_owned()).await?;
self.db_client.rebuild_indexes().await?;
Ok(())
}
async fn perform_semantic_analysis(
&self,
content: &TextContent,
) -> Result<LLMGraphAnalysisResult, AppError> {
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
analyser
.analyze_content(
&content.category,
&content.instructions,
&content.text,
&content.user_id,
)
.await
}
async fn store_graph_entities(
&self,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), AppError> {
for entity in &entities {
debug!("Storing entity: {:?}", entity);
store_item(&self.db_client, entity.clone()).await?;
}
for relationship in &relationships {
debug!("Storing relationship: {:?}", relationship);
relationship.store_relationship(&self.db_client).await?;
}
info!(
"Stored {} entities and {} relationships",
entities.len(),
relationships.len()
);
Ok(())
}
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), AppError> {
let splitter = TextSplitter::new(500..2000);
let chunks = splitter.chunks(&content.text);
// Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk).await?;
let text_chunk = TextChunk::new(
content.id.to_string(),
chunk.to_string(),
embedding,
content.user_id.to_string(),
);
store_item(&self.db_client, text_chunk).await?;
}
Ok(())
}
}

View File

@@ -0,0 +1,74 @@
use super::ingress_object::IngressObject;
use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use tracing::info;
use url::Url;
/// Struct defining the expected body when ingressing content.
#[derive(Serialize, Deserialize, Debug)]
pub struct IngressInput {
pub content: Option<String>,
pub instructions: String,
pub category: String,
pub files: Vec<FileInfo>,
}
/// Function to create ingress objects from input.
///
/// # Arguments
/// * `input` - IngressInput containing information needed to ingress content.
/// * `user_id` - User id of the ingressing user
///
/// # Returns
/// * `Vec<IngressObject>` - An array containing the ingressed objects, one file/contenttype per object.
pub fn create_ingress_objects(
input: IngressInput,
user_id: &str,
) -> Result<Vec<IngressObject>, AppError> {
// Initialize list
let mut object_list = Vec::new();
// Create a IngressObject from input.content if it exists, checking for URL or text
if let Some(input_content) = input.content {
match Url::parse(&input_content) {
Ok(url) => {
info!("Detected URL: {}", url);
object_list.push(IngressObject::Url {
url: url.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
}
Err(_) => {
if input_content.len() > 2 {
info!("Treating input as plain text");
object_list.push(IngressObject::Text {
text: input_content.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
}
}
}
}
for file in input.files {
object_list.push(IngressObject::File {
file_info: file,
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
})
}
// If no objects are constructed, we return Err
if object_list.is_empty() {
return Err(AppError::NotFound(
"No valid content or files provided".into(),
));
}
Ok(object_list)
}

View File

@@ -0,0 +1,277 @@
use std::{sync::Arc, time::Duration};
use crate::{
error::AppError,
storage::types::{file_info::FileInfo, text_content::TextContent},
};
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs,
};
use reqwest;
use scraper::{Html, Selector};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use tiktoken_rs::{o200k_base, CoreBPE};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngressObject {
Url {
url: String,
instructions: String,
category: String,
user_id: String,
},
Text {
text: String,
instructions: String,
category: String,
user_id: String,
},
File {
file_info: FileInfo,
instructions: String,
category: String,
user_id: String,
},
}
impl IngressObject {
/// Creates a new `TextContent` instance from a `IngressObject`.
///
/// # Arguments
/// `&self` - A reference to the `IngressObject`.
///
/// # Returns
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
pub async fn to_text_content(
&self,
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<TextContent, AppError> {
match self {
IngressObject::Url {
url,
instructions,
category,
user_id,
} => {
let text = Self::fetch_text_from_url(url, openai_client).await?;
Ok(TextContent::new(
text,
instructions.into(),
category.into(),
None,
Some(url.into()),
user_id.into(),
))
}
IngressObject::Text {
text,
instructions,
category,
user_id,
} => Ok(TextContent::new(
text.into(),
instructions.into(),
category.into(),
None,
None,
user_id.into(),
)),
IngressObject::File {
file_info,
instructions,
category,
user_id,
} => {
let text = Self::extract_text_from_file(file_info).await?;
Ok(TextContent::new(
text,
instructions.into(),
category.into(),
Some(file_info.to_owned()),
None,
user_id.into(),
))
}
}
}
/// Get text from url, will return it as a markdown formatted string
async fn fetch_text_from_url(
url: &str,
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<String, AppError> {
// Use a client with timeouts and reuse
let client = reqwest::ClientBuilder::new()
.timeout(Duration::from_secs(30))
.build()?;
let response = client.get(url).send().await?.text().await?;
// Preallocate string with capacity
let mut structured_content = String::with_capacity(response.len() / 2);
let document = Html::parse_document(&response);
let main_selectors = Selector::parse(
"article, main, .article-content, .post-content, .entry-content, [role='main']",
)
.unwrap();
let content_element = document
.select(&main_selectors)
.next()
.or_else(|| document.select(&Selector::parse("body").unwrap()).next())
.ok_or(AppError::NotFound("No content found".into()))?;
// Compile selectors once
let heading_selector = Selector::parse("h1, h2, h3").unwrap();
let paragraph_selector = Selector::parse("p").unwrap();
// Process content in one pass
for element in content_element.select(&heading_selector) {
let _ = writeln!(
structured_content,
"<heading>{}</heading>",
element.text().collect::<String>().trim()
);
}
for element in content_element.select(&paragraph_selector) {
let _ = writeln!(
structured_content,
"<paragraph>{}</paragraph>",
element.text().collect::<String>().trim()
);
}
let content = structured_content
.replace(|c: char| c.is_control(), " ")
.replace(" ", " ");
Self::process_web_content(content, openai_client).await
}
pub async fn process_web_content(
content: String,
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<String, AppError> {
const MAX_TOKENS: usize = 122000;
const SYSTEM_PROMPT: &str = r#"
You are a precise content extractor for web pages. Your task:
1. Extract ONLY the main article/content from the provided text
2. Maintain the original content - do not summarize or modify the core information
3. Ignore peripheral content such as:
- Navigation elements
- Error messages (e.g., "JavaScript required")
- Related articles sections
- Comments
- Social media links
- Advertisement text
FORMAT:
- Convert <heading> tags to markdown headings (#, ##, ###)
- Convert <paragraph> tags to markdown paragraphs
- Preserve quotes and important formatting
- Remove duplicate content
- Remove any metadata or technical artifacts
OUTPUT RULES:
- Output ONLY the cleaned content in markdown
- Do not add any explanations or meta-commentary
- Do not add summaries or conclusions
- Do not use any XML/HTML tags in the output
"#;
let bpe = o200k_base()?;
// Process content in chunks if needed
let truncated_content = if bpe.encode_with_special_tokens(&content).len() > MAX_TOKENS {
Self::truncate_content(&content, MAX_TOKENS, &bpe)?
} else {
content
};
let request = CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.0)
.max_tokens(16200u32)
.messages([
ChatCompletionRequestSystemMessage::from(SYSTEM_PROMPT).into(),
ChatCompletionRequestUserMessage::from(truncated_content).into(),
])
.build()?;
let response = openai_client.chat().create(request).await?;
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.map(|content| content.to_owned())
.ok_or(AppError::LLMParsing("No content in response".into()))
}
fn truncate_content(
content: &str,
max_tokens: usize,
tokenizer: &CoreBPE,
) -> Result<String, AppError> {
// Pre-allocate with estimated size
let mut result = String::with_capacity(content.len() / 2);
let mut current_tokens = 0;
// Process content by paragraph to maintain context
for paragraph in content.split("\n\n") {
let tokens = tokenizer.encode_with_special_tokens(paragraph).len();
// Check if adding paragraph exceeds limit
if current_tokens + tokens > max_tokens {
break;
}
result.push_str(paragraph);
result.push_str("\n\n");
current_tokens += tokens;
}
// Ensure we return valid content
if result.is_empty() {
return Err(AppError::Processing("Content exceeds token limit".into()));
}
Ok(result.trim_end().to_string())
}
/// Extracts text from a file based on its MIME type.
async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, AppError> {
match file_info.mime_type.as_str() {
"text/plain" => {
// Read the file and return its content
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"text/markdown" => {
// Read the file and return its content
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"application/pdf" => {
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
Err(AppError::NotFound(file_info.mime_type.clone()))
}
"image/png" | "image/jpeg" => {
// TODO: Implement OCR on image using a crate like `tesseract`
Err(AppError::NotFound(file_info.mime_type.clone()))
}
"application/octet-stream" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"text/x-rust" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
// Handle other MIME types as needed
_ => Err(AppError::NotFound(file_info.mime_type.clone())),
}
}
}

View File

@@ -0,0 +1,182 @@
use chrono::Utc;
use futures::Stream;
use std::sync::Arc;
use surrealdb::{opt::PatchOp, Error, Notification};
use tracing::{debug, error, info};
use crate::{
error::AppError,
storage::{
db::{delete_item, get_item, store_item, SurrealDbClient},
types::{
job::{Job, JobStatus},
StoredObject,
},
},
};
use super::{content_processor::ContentProcessor, ingress_object::IngressObject};
pub struct JobQueue {
pub db: Arc<SurrealDbClient>,
}
pub const MAX_ATTEMPTS: u32 = 3;
impl JobQueue {
pub fn new(db: Arc<SurrealDbClient>) -> Self {
Self { db }
}
/// Creates a new job and stores it in the database
pub async fn enqueue(&self, content: IngressObject, user_id: String) -> Result<(), AppError> {
let job = Job::new(content, user_id).await;
info!("{:?}", job);
store_item(&self.db, job).await?;
Ok(())
}
/// Gets all jobs for a specific user
pub async fn get_user_jobs(&self, user_id: &str) -> Result<Vec<Job>, AppError> {
let jobs: Vec<Job> = self
.db
.query("SELECT * FROM job WHERE user_id = $user_id ORDER BY created_at DESC")
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)?;
debug!("{:?}", jobs);
Ok(jobs)
}
/// Gets all active jobs for a specific user
pub async fn get_unfinished_user_jobs(&self, user_id: &str) -> Result<Vec<Job>, AppError> {
let jobs: Vec<Job> = self
.db
.query(
"SELECT * FROM type::table($table)
WHERE user_id = $user_id
AND (
status = 'Created'
OR (
status.InProgress != NONE
AND status.InProgress.attempts < $max_attempts
)
)
ORDER BY created_at DESC",
)
.bind(("table", Job::table_name()))
.bind(("user_id", user_id.to_owned()))
.bind(("max_attempts", MAX_ATTEMPTS))
.await?
.take(0)?;
debug!("{:?}", jobs);
Ok(jobs)
}
pub async fn delete_job(&self, id: &str, user_id: &str) -> Result<(), AppError> {
get_item::<Job>(&self.db.client, id)
.await?
.filter(|job| job.user_id == user_id)
.ok_or_else(|| {
error!("Unauthorized attempt to delete job {id} by user {user_id}");
AppError::Auth("Not authorized to delete this job".into())
})?;
info!("Deleting job {id} for user {user_id}");
delete_item::<Job>(&self.db.client, id)
.await
.map_err(AppError::Database)?;
Ok(())
}
pub async fn update_status(&self, id: &str, status: JobStatus) -> Result<(), AppError> {
let _job: Option<Job> = self
.db
.update((Job::table_name(), id))
.patch(PatchOp::replace("/status", status))
.patch(PatchOp::replace(
"/updated_at",
surrealdb::sql::Datetime::default(),
))
.await?;
Ok(())
}
/// Listen for new jobs
pub async fn listen_for_jobs(
&self,
) -> Result<impl Stream<Item = Result<Notification<Job>, Error>>, Error> {
self.db.select("job").live().await
}
/// Get unfinished jobs, ie newly created and in progress up two times
pub async fn get_unfinished_jobs(&self) -> Result<Vec<Job>, AppError> {
let jobs: Vec<Job> = self
.db
.query(
"SELECT * FROM type::table($table)
WHERE
status = 'Created'
OR (
status.InProgress != NONE
AND status.InProgress.attempts < $max_attempts
)
ORDER BY created_at ASC",
)
.bind(("table", Job::table_name()))
.bind(("max_attempts", MAX_ATTEMPTS))
.await?
.take(0)?;
Ok(jobs)
}
// Method to process a single job
pub async fn process_job(
&self,
job: Job,
processor: &ContentProcessor,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<(), AppError> {
let current_attempts = match job.status {
JobStatus::InProgress { attempts, .. } => attempts + 1,
_ => 1,
};
// Update status to InProgress with attempt count
self.update_status(
&job.id,
JobStatus::InProgress {
attempts: current_attempts,
last_attempt: Utc::now(),
},
)
.await?;
let text_content = job.content.to_text_content(&openai_client).await?;
match processor.process(&text_content).await {
Ok(_) => {
self.update_status(&job.id, JobStatus::Completed).await?;
Ok(())
}
Err(e) => {
if current_attempts >= MAX_ATTEMPTS {
self.update_status(
&job.id,
JobStatus::Error(format!("Max attempts reached: {}", e)),
)
.await?;
}
Err(AppError::Processing(e.to_string()))
}
}
}
}

View File

@@ -0,0 +1,6 @@
pub mod analysis;
pub mod content_processor;
pub mod ingress_input;
pub mod ingress_object;
pub mod jobqueue;
pub mod queue_task;

View File

@@ -0,0 +1,13 @@
use crate::ingress::ingress_object::IngressObject;
use serde::Serialize;
#[derive(Serialize)]
pub struct QueueTask {
pub delivery_tag: u64,
pub content: IngressObject,
}
#[derive(Serialize)]
pub struct QueueTaskResponse {
pub tasks: Vec<QueueTask>,
}

6
crates/common/src/lib.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod error;
pub mod ingress;
pub mod retrieval;
pub mod server;
pub mod storage;
pub mod utils;

View File

@@ -0,0 +1,74 @@
use surrealdb::{engine::any::Any, Error, Surreal};
use tracing::debug;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, StoredObject};
/// Retrieves database entries that match a specific source identifier.
///
/// This function queries the database for all records in a specified table that have
/// a matching `source_id` field. It's commonly used to find related entities or
/// track the origin of database entries.
///
/// # Arguments
///
/// * `source_id` - The identifier to search for in the database
/// * `table_name` - The name of the table to search in
/// * `db_client` - The SurrealDB client instance for database operations
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
/// * `Err(Error)` - An error if the database query fails
///
/// # Errors
///
/// This function will return a `Error` if:
/// * The database query fails to execute
/// * The results cannot be deserialized into type `T`
pub async fn find_entities_by_source_ids<T>(
source_id: Vec<String>,
table_name: String,
db_client: &Surreal<Any>,
) -> Result<Vec<T>, Error>
where
T: for<'de> serde::Deserialize<'de>,
{
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
db_client
.query(query)
.bind(("table", table_name))
.bind(("source_ids", source_id))
.await?
.take(0)
}
/// Find entities by their relationship to the id
pub async fn find_entities_by_relationship_by_id(
db_client: &Surreal<Any>,
entity_id: String,
) -> Result<Vec<KnowledgeEntity>, Error> {
let query = format!(
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
entity_id
);
debug!("{}", query);
db_client.query(query).await?.take(0)
}
/// Get a specific KnowledgeEntity by its id
pub async fn get_entity_by_id(
db_client: &Surreal<Any>,
entity_id: &str,
) -> Result<Option<KnowledgeEntity>, Error> {
db_client
.select((KnowledgeEntity::table_name(), entity_id))
.await
}

View File

@@ -0,0 +1,92 @@
pub mod graph;
pub mod query_helper;
pub mod query_helper_prompt;
pub mod vector;
use crate::{
error::AppError,
retrieval::{
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
vector::find_items_by_vector_similarity,
},
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
};
use futures::future::{try_join, try_join_all};
use std::collections::HashMap;
use surrealdb::{engine::any::Any, Surreal};
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
/// to find the most relevant entities for a given query.
///
/// # Strategy
/// The function employs a three-pronged approach to knowledge retrieval:
/// 1. Direct vector similarity search on knowledge entities
/// 2. Text chunk similarity search with source entity lookup
/// 3. Graph relationship traversal from related entities
///
/// This combined approach ensures both semantic similarity matches and structurally
/// related content are included in the results.
///
/// # Arguments
/// * `db_client` - SurrealDB client for database operations
/// * `openai_client` - OpenAI client for vector embeddings generation
/// * `query` - The search query string to find relevant knowledge entities
/// * 'user_id' - The user id of the current user
///
/// # Returns
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
/// knowledge entities, or an error if the retrieval process fails
pub async fn combined_knowledge_entity_retrieval(
db_client: &Surreal<Any>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
// info!("Received input: {:?}", query);
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
find_items_by_vector_similarity(
10,
query,
db_client,
"knowledge_entity",
openai_client,
user_id,
),
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
)
.await?;
let source_ids = closest_chunks
.iter()
.map(|chunk: &TextChunk| chunk.source_id.clone())
.collect::<Vec<String>>();
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
.clone()
.into_iter()
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
.collect();
let items_from_relationships = try_join_all(items_from_relationships_futures)
.await?
.into_iter()
.flatten()
.collect::<Vec<KnowledgeEntity>>();
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
.into_iter()
.chain(items_from_text_chunk_similarity.into_iter())
.chain(items_from_relationships.into_iter())
.fold(HashMap::new(), |mut map, entity| {
map.insert(entity.id.clone(), entity);
map
})
.into_values()
.collect();
Ok(entities)
}

View File

@@ -0,0 +1,193 @@
use async_openai::{
error::OpenAIError,
types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
},
};
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::debug;
use crate::{
error::AppError,
retrieval::combined_knowledge_entity_retrieval,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use super::query_helper_prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT};
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
// /// Orchestrator function that takes a query and clients and returns a answer with references
// ///
// /// # Arguments
// /// * `surreal_db_client` - Client for interacting with SurrealDn
// /// * `openai_client` - Client for interacting with openai
// /// * `query` - The query
// ///
// /// # Returns
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
// pub async fn get_answer_with_references(
// surreal_db_client: &SurrealDbClient,
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
// query: &str,
// ) -> Result<(String, Vec<String>), ApiError> {
// let entities =
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
// // Format entities and create message
// let entities_json = format_entities_json(&entities);
// let user_message = create_user_message(&entities_json, query);
// // Create and send request
// let request = create_chat_request(user_message)?;
// let response = openai_client
// .chat()
// .create(request)
// .await
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
// // Process response
// let answer = process_llm_response(response).await?;
// let references: Vec<String> = answer
// .references
// .into_iter()
// .map(|reference| reference.reference)
// .collect();
// Ok((answer.answer, references))
// }
/// Orchestrates query processing and returns an answer with references
///
/// Takes a query and uses the provided clients to generate an answer with supporting references.
///
/// # Arguments
///
/// * `surreal_db_client` - Client for SurrealDB interactions
/// * `openai_client` - Client for OpenAI API calls
/// * `query` - The user's query string
/// * `user_id` - The user's id
///
/// # Returns
///
/// Returns a tuple of the answer and its references, or an API error
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub async fn get_answer_with_references(
surreal_db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Answer, AppError> {
let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
.await?;
let entities_json = format_entities_json(&entities);
debug!("{:?}", entities_json);
let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message)?;
let response = openai_client.chat().create(request).await?;
let llm_response = process_llm_response(response).await?;
Ok(Answer {
content: llm_response.answer,
references: llm_response
.references
.into_iter()
.map(|r| r.reference)
.collect(),
})
}
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
json!(entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>())
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r#"
Context Information:
==================
{}
User Question:
==================
{}
"#,
entities_json, query
)
}
pub fn create_chat_request(
user_message: String,
) -> Result<CreateChatCompletionRequest, OpenAIError> {
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Query answering AI".into()),
name: "query_answering_with_uuids".into(),
schema: Some(get_query_response_schema()),
strict: Some(true),
},
};
CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.2)
.max_tokens(3048u32)
.messages([
ChatCompletionRequestSystemMessage::from(QUERY_SYSTEM_PROMPT).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
}
pub async fn process_llm_response(
response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, AppError> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing(
"No content found in LLM response".into(),
))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {}", e))
})
})
}

View File

@@ -0,0 +1,48 @@
use serde_json::{json, Value};
pub static QUERY_SYSTEM_PROMPT: &str = r#"
You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
Your task is to:
1. Carefully analyze the provided knowledge entities in the context
2. Answer user questions based on this information
3. Provide clear, concise, and accurate responses
4. When referencing information, briefly mention which knowledge entity it came from
5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this
6. If only partial information is available, explain what you can answer and what information is missing
7. Avoid making assumptions or providing information not supported by the context
8. Output the references to the documents. Use the UUIDs and make sure they are correct!
Remember:
- Be direct and honest about the limitations of your knowledge
- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array
- If you need to combine information from multiple entities, explain how they connect
- Don't speculate beyond what's provided in the context
Example response formats:
"Based on [Entity Name], [answer...]"
"I found relevant information in multiple entries: [explanation...]"
"I apologize, but the provided context doesn't contain information about [topic]"
"#;
pub fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}

View File

@@ -0,0 +1,47 @@
use surrealdb::{engine::any::Any, Surreal};
use crate::{error::AppError, utils::embedding::generate_embedding};
/// Compares vectors and retrieves a number of items from the specified table.
///
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
/// and then deserializes the results into the specified type `T`.
///
/// # Arguments
///
/// * `take` - The number of items to retrieve from the database.
/// * `input_text` - The text to generate embeddings for.
/// * `db_client` - The SurrealDB client to use for querying the database.
/// * `table` - The table to query in the database.
/// * `openai_client` - The OpenAI client to use for generating embeddings.
/// * 'user_id`- The user id of the current user.
///
/// # Returns
///
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: u8,
input_text: &str,
db_client: &Surreal<Any>,
table: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<T>, AppError>
where
T: for<'de> serde::Deserialize<'de>,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
// Perform query and deserialize to struct
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
Ok(closest_entities)
}

View File

@@ -0,0 +1,180 @@
use crate::error::AppError;
use super::types::{analytics::Analytics, system_settings::SystemSettings, StoredObject};
use axum_session::{SessionConfig, SessionError, SessionStore};
use axum_session_surreal::SessionSurrealPool;
use std::ops::Deref;
use surrealdb::{
engine::any::{connect, Any},
opt::auth::Root,
Error, Surreal,
};
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Any>,
}
impl SurrealDbClient {
/// # Initialize a new datbase client
///
/// # Arguments
///
/// # Returns
/// * `SurrealDbClient` initialized
pub async fn new(
address: &str,
username: &str,
password: &str,
namespace: &str,
database: &str,
) -> Result<Self, Error> {
let db = connect(address).await?;
// Sign in to database
db.signin(Root { username, password }).await?;
// Set namespace
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
SessionStore::new(
Some(self.client.clone().into()),
SessionConfig::default()
.with_table_name("test_session_table")
.with_secure(true),
)
.await
}
pub async fn ensure_initialized(&self) -> Result<(), AppError> {
Self::build_indexes(&self).await?;
Self::setup_auth(&self).await?;
Analytics::ensure_initialized(self).await?;
SystemSettings::ensure_initialized(self).await?;
Ok(())
}
pub async fn setup_auth(&self) -> Result<(), Error> {
self.client.query(
"DEFINE TABLE user SCHEMALESS;
DEFINE INDEX unique_name ON TABLE user FIELDS email UNIQUE;
DEFINE ACCESS account ON DATABASE TYPE RECORD
SIGNUP ( CREATE user SET email = $email, password = crypto::argon2::generate($password), anonymous = false, user_id = $user_id)
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(password, $password) );",
)
.await?;
Ok(())
}
pub async fn build_indexes(&self) -> Result<(), Error> {
self.client.query("DEFINE INDEX idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536").await?;
self.client.query("DEFINE INDEX idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536").await?;
self.client
.query("DEFINE INDEX idx_job_status ON job FIELDS status")
.await?;
self.client
.query("DEFINE INDEX idx_job_user ON job FIELDS user_id")
.await?;
self.client
.query("DEFINE INDEX idx_job_created ON job FIELDS created_at")
.await?;
Ok(())
}
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS idx_embeddings_entities ON knowledge_entity")
.await?;
Ok(())
}
pub async fn drop_table<T>(&self) -> Result<Vec<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
{
self.client.delete(T::table_name()).await
}
}
impl Deref for SurrealDbClient {
type Target = Surreal<Any>;
fn deref(&self) -> &Self::Target {
&self.client
}
}
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - A initialized database client
/// * `item` - The item to be stored
///
/// # Returns
/// * `Result` - Item or Error
pub async fn store_item<T>(db_client: &Surreal<Any>, item: T) -> Result<Option<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
{
db_client
.create((T::table_name(), item.get_id()))
.content(item)
.await
}
/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - A initialized database client
///
/// # Returns
/// * `Result` - Vec<T> or Error
pub async fn get_all_stored_items<T>(db_client: &Surreal<Any>) -> Result<Vec<T>, Error>
where
T: for<'de> StoredObject,
{
db_client.select(T::table_name()).await
}
/// Operation to retrieve a single object by its ID, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - An initialized database client
/// * `id` - The ID of the item to retrieve
///
/// # Returns
/// * `Result<Option<T>, Error>` - The found item or Error
pub async fn get_item<T>(db_client: &Surreal<Any>, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
{
db_client.select((T::table_name(), id)).await
}
/// Operation to delete a single object by its ID, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - An initialized database client
/// * `id` - The ID of the item to delete
///
/// # Returns
/// * `Result<Option<T>, Error>` - The deleted item or Error
pub async fn delete_item<T>(db_client: &Surreal<Any>, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
{
db_client.delete((T::table_name(), id)).await
}

View File

@@ -0,0 +1,2 @@
pub mod db;
pub mod types;

View File

@@ -0,0 +1,78 @@
use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject};
use serde::{Deserialize, Serialize};
use crate::{error::AppError, storage::db::SurrealDbClient};
#[derive(Debug, Serialize, Deserialize)]
pub struct Analytics {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub page_loads: i64,
pub visitors: i64,
}
impl Analytics {
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
let analytics = db.select(("analytics", "current")).await?;
if analytics.is_none() {
let created: Option<Analytics> = db
.create(("analytics", "current"))
.content(Analytics {
id: "current".to_string(),
visitors: 0,
page_loads: 0,
})
.await?;
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
};
analytics.ok_or(AppError::Validation("Failed to initialize settings".into()))
}
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
let analytics: Option<Self> = db
.client
.query("SELECT * FROM type::thing('analytics', 'current')")
.await?
.take(0)?;
analytics.ok_or(AppError::NotFound("Analytics not found".into()))
}
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('analytics', 'current') SET visitors += 1 RETURN AFTER")
.await?
.take(0)?;
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
}
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('analytics', 'current') SET page_loads += 1 RETURN AFTER")
.await?
.take(0)?;
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
}
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
#[derive(Debug, Deserialize)]
struct CountResult {
count: i64,
}
let result: Option<CountResult> = db
.client
.query("SELECT count() as count FROM type::table($table) GROUP ALL")
.bind(("table", User::table_name()))
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
}
}

View File

@@ -0,0 +1,52 @@
use uuid::Uuid;
use crate::{
error::AppError,
storage::db::{get_item, SurrealDbClient},
stored_object,
};
use super::message::Message;
stored_object!(Conversation, "conversation", {
user_id: String,
title: String
});
impl Conversation {
pub fn new(user_id: String, title: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
user_id,
title,
}
}
pub async fn get_complete_conversation(
conversation_id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<(Self, Vec<Message>), AppError> {
let conversation: Conversation = get_item(&db, conversation_id)
.await?
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
if conversation.user_id != user_id {
return Err(AppError::Auth(
"You don't have access to this conversation".to_string(),
));
}
let messages:Vec<Message> = db.client.
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
bind(("table_name", Message::table_name())).
bind(("conversation_id", conversation_id.to_string()))
.await?
.take(0)?;
Ok((conversation, messages))
}
}

View File

@@ -0,0 +1,248 @@
use axum_typed_multipart::FieldData;
use mime_guess::from_path;
use sha2::{Digest, Sha256};
use std::{
io::{BufReader, Read},
path::{Path, PathBuf},
};
use tempfile::NamedTempFile;
use thiserror::Error;
use tokio::fs::remove_dir_all;
use tracing::info;
use uuid::Uuid;
use crate::{
storage::db::{delete_item, get_item, store_item, SurrealDbClient},
stored_object,
};
#[derive(Error, Debug)]
pub enum FileError {
#[error("File not found for UUID: {0}")]
FileNotFound(String),
#[error("IO error occurred: {0}")]
Io(#[from] std::io::Error),
#[error("Duplicate file detected with SHA256: {0}")]
DuplicateFile(String),
#[error("SurrealDB error: {0}")]
SurrealError(#[from] surrealdb::Error),
#[error("Failed to persist file: {0}")]
PersistError(#[from] tempfile::PersistError),
#[error("File name missing in metadata")]
MissingFileName,
}
stored_object!(FileInfo, "file", {
sha256: String,
path: String,
file_name: String,
mime_type: String
});
impl FileInfo {
pub async fn new(
field_data: FieldData<NamedTempFile>,
db_client: &SurrealDbClient,
user_id: &str,
) -> Result<Self, FileError> {
let file = field_data.contents;
let file_name = field_data
.metadata
.file_name
.ok_or(FileError::MissingFileName)?;
// Calculate SHA256
let sha256 = Self::get_sha(&file).await?;
// Early return if file already exists
match Self::get_by_sha(&sha256, db_client).await {
Ok(existing_file) => {
info!("File already exists with SHA256: {}", sha256);
return Ok(existing_file);
}
Err(FileError::FileNotFound(_)) => (), // Expected case for new files
Err(e) => return Err(e), // Propagate unexpected errors
}
// Generate UUID and prepare paths
let uuid = Uuid::new_v4();
let sanitized_file_name = Self::sanitize_file_name(&file_name);
let now = Utc::now();
// Create new FileInfo instance
let file_info = Self {
id: uuid.to_string(),
created_at: now,
updated_at: now,
file_name,
sha256,
path: Self::persist_file(&uuid, file, &sanitized_file_name, user_id)
.await?
.to_string_lossy()
.into(),
mime_type: Self::guess_mime_type(Path::new(&sanitized_file_name)),
};
// Store in database
store_item(&db_client.client, file_info.clone()).await?;
Ok(file_info)
}
/// Guesses the MIME type based on the file extension.
///
/// # Arguments
/// * `path` - The path to the file.
///
/// # Returns
/// * `String` - The guessed MIME type as a string.
fn guess_mime_type(path: &Path) -> String {
from_path(path)
.first_or(mime::APPLICATION_OCTET_STREAM)
.to_string()
}
/// Calculates the SHA256 hash of the given file.
///
/// # Arguments
/// * `file` - The file to hash.
///
/// # Returns
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
let mut reader = BufReader::new(file.as_file());
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192]; // 8KB buffer
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
hasher.update(&buffer[..n]);
}
let digest = hasher.finalize();
Ok(format!("{:x}", digest))
}
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
fn sanitize_file_name(file_name: &str) -> String {
file_name
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
/// Persists the file to the filesystem under `./data/{user_id}/{uuid}/{file_name}`.
///
/// # Arguments
/// * `uuid` - The UUID of the file.
/// * `file` - The temporary file to persist.
/// * `file_name` - The sanitized file name.
/// * `user-id` - User id
///
/// # Returns
/// * `Result<PathBuf, FileError>` - The persisted file path or an error.
async fn persist_file(
uuid: &Uuid,
file: NamedTempFile,
file_name: &str,
user_id: &str,
) -> Result<PathBuf, FileError> {
let base_dir = Path::new("./data");
let user_dir = base_dir.join(user_id); // Create the user directory
let uuid_dir = user_dir.join(uuid.to_string()); // Create the UUID directory under the user directory
// Create the user and UUID directories if they don't exist
tokio::fs::create_dir_all(&uuid_dir)
.await
.map_err(FileError::Io)?;
// Define the final file path
let final_path = uuid_dir.join(file_name);
info!("Final path: {:?}", final_path);
// Persist the temporary file to the final path
file.persist(&final_path)?;
info!("Persisted file to {:?}", final_path);
Ok(final_path)
}
/// Retrieves a `FileInfo` by SHA256.
///
/// # Arguments
/// * `sha256` - The SHA256 hash string.
/// * `db_client` - Reference to the SurrealDbClient.
///
/// # Returns
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256);
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
response
.into_iter()
.next()
.ok_or(FileError::FileNotFound(sha256.to_string()))
}
/// Removes FileInfo from database and file from disk
///
/// # Arguments
/// * `id` - Id of the FileInfo
/// * `db_client` - Reference to SurrealDbClient
///
/// # Returns
/// `Result<(), FileError>`
pub async fn delete_by_id(id: &str, db_client: &SurrealDbClient) -> Result<(), FileError> {
// Get the FileInfo from the database
let file_info = match get_item::<FileInfo>(db_client, id).await? {
Some(info) => info,
None => {
return Err(FileError::FileNotFound(format!(
"File with id {} was not found",
id
)))
}
};
// Remove the file and its parent directory
let file_path = Path::new(&file_info.path);
if file_path.exists() {
// Get the parent directory of the file
if let Some(parent_dir) = file_path.parent() {
// Remove the entire directory containing the file
remove_dir_all(parent_dir).await?;
info!("Removed directory {:?} and its contents", parent_dir);
} else {
return Err(FileError::FileNotFound(
"File has no parent directory".to_string(),
));
}
} else {
return Err(FileError::FileNotFound(format!(
"File at path {:?} was not found",
file_path
)));
}
// Delete the FileInfo from the database
delete_item::<FileInfo>(db_client, id).await?;
Ok(())
}
}

View File

@@ -0,0 +1,36 @@
use uuid::Uuid;
use crate::{ingress::ingress_object::IngressObject, stored_object};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum JobStatus {
Created,
InProgress {
attempts: u32,
last_attempt: DateTime<Utc>,
},
Completed,
Error(String),
Cancelled,
}
stored_object!(Job, "job", {
content: IngressObject,
status: JobStatus,
user_id: String
});
impl Job {
pub async fn new(content: IngressObject, user_id: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
content,
status: JobStatus::Created,
created_at: now,
updated_at: now,
user_id,
}
}
}

View File

@@ -0,0 +1,121 @@
use crate::{
error::AppError, storage::db::SurrealDbClient, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum KnowledgeEntityType {
Idea,
Project,
Document,
Page,
TextSnippet,
// Add more types as needed
}
impl KnowledgeEntityType {
pub fn variants() -> &'static [&'static str] {
&["Idea", "Project", "Document", "Page", "TextSnippet"]
}
}
impl From<String> for KnowledgeEntityType {
fn from(s: String) -> Self {
match s.to_lowercase().as_str() {
"idea" => KnowledgeEntityType::Idea,
"project" => KnowledgeEntityType::Project,
"document" => KnowledgeEntityType::Document,
"page" => KnowledgeEntityType::Page,
"textsnippet" => KnowledgeEntityType::TextSnippet,
_ => KnowledgeEntityType::Document, // Default case
}
}
}
stored_object!(KnowledgeEntity, "knowledge_entity", {
source_id: String,
name: String,
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String
});
impl KnowledgeEntity {
pub fn new(
source_id: String,
name: String,
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id,
name,
description,
entity_type,
metadata,
embedding,
user_id,
}
}
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
Ok(())
}
pub async fn patch(
id: &str,
name: &str,
description: &str,
entity_type: &KnowledgeEntityType,
db_client: &SurrealDbClient,
ai_client: &Client<OpenAIConfig>,
) -> Result<(), AppError> {
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
name, description, entity_type
);
let embedding = generate_embedding(ai_client, &embedding_input).await?;
db_client
.client
.query(
"UPDATE type::thing($table, $id)
SET name = $name,
description = $description,
updated_at = $updated_at,
entity_type = $entity_type,
embedding = $embedding
RETURN AFTER",
)
.bind(("table", Self::table_name()))
.bind(("id", id.to_string()))
.bind(("name", name.to_string()))
.bind(("updated_at", Utc::now()))
.bind(("entity_type", entity_type.to_owned()))
.bind(("embedding", embedding))
.bind(("description", description.to_string()))
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,86 @@
use crate::storage::types::file_info::deserialize_flexible_id;
use crate::{error::AppError, storage::db::SurrealDbClient};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RelationshipMetadata {
pub user_id: String,
pub source_id: String,
pub relationship_type: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct KnowledgeRelationship {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
#[serde(rename = "in", deserialize_with = "deserialize_flexible_id")]
pub in_: String,
#[serde(deserialize_with = "deserialize_flexible_id")]
pub out: String,
pub metadata: RelationshipMetadata,
}
impl KnowledgeRelationship {
pub fn new(
in_: String,
out: String,
user_id: String,
source_id: String,
relationship_type: String,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
in_,
out,
metadata: RelationshipMetadata {
user_id,
source_id,
relationship_type,
},
}
}
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
let query = format!(
r#"RELATE knowledge_entity:`{}`->relates_to:`{}`->knowledge_entity:`{}`
SET
metadata.user_id = '{}',
metadata.source_id = '{}',
metadata.relationship_type = '{}'"#,
self.in_,
self.id,
self.out,
self.metadata.user_id,
self.metadata.source_id,
self.metadata.relationship_type
);
db_client.query(query).await?;
Ok(())
}
pub async fn delete_relationships_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'",
source_id
);
db_client.query(query).await?;
Ok(())
}
pub async fn delete_relationship_by_id(
id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!("DELETE relates_to:`{}`", id);
db_client.query(query).await?;
Ok(())
}
}

View File

@@ -0,0 +1,37 @@
use uuid::Uuid;
use crate::stored_object;
#[derive(Deserialize, Debug, Clone, Serialize)]
pub enum MessageRole {
User,
AI,
System,
}
stored_object!(Message, "message", {
conversation_id: String,
role: MessageRole,
content: String,
references: Option<Vec<String>>
});
impl Message {
pub fn new(
conversation_id: String,
role: MessageRole,
content: String,
references: Option<Vec<String>>,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
conversation_id,
role,
content,
references,
}
}
}

View File

@@ -0,0 +1,110 @@
use axum::async_trait;
use serde::{Deserialize, Serialize};
pub mod analytics;
pub mod conversation;
pub mod file_info;
pub mod job;
pub mod knowledge_entity;
pub mod knowledge_relationship;
pub mod message;
pub mod system_settings;
pub mod text_chunk;
pub mod text_content;
pub mod user;
#[async_trait]
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
fn table_name() -> &'static str;
fn get_id(&self) -> &str;
}
#[macro_export]
macro_rules! stored_object {
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
use axum::async_trait;
use serde::{Deserialize, Deserializer, Serialize};
use surrealdb::sql::Thing;
use $crate::storage::types::StoredObject;
use serde::de::{self, Visitor};
use std::fmt;
use chrono::{DateTime, Utc };
struct FlexibleIdVisitor;
impl<'de> Visitor<'de> for FlexibleIdVisitor {
type Value = String;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or a Thing")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_string())
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value)
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
// Try to deserialize as Thing
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(thing.id.to_raw())
}
}
pub fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(FlexibleIdVisitor)
}
fn serialize_datetime<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Into::<surrealdb::sql::Datetime>::into(*date).serialize(serializer)
}
fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
where
D: serde::Deserializer<'de>,
{
let dt = surrealdb::sql::Datetime::deserialize(deserializer)?;
Ok(DateTime::<Utc>::from(dt))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct $name {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub updated_at: DateTime<Utc>,
$(pub $field: $ty),*
}
#[async_trait]
impl StoredObject for $name {
fn table_name() -> &'static str {
$table
}
fn get_id(&self) -> &str {
&self.id
}
}
};
}

View File

@@ -0,0 +1,56 @@
use crate::storage::types::file_info::deserialize_flexible_id;
use serde::{Deserialize, Serialize};
use crate::{error::AppError, storage::db::SurrealDbClient};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SystemSettings {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub registrations_enabled: bool,
pub require_email_verification: bool,
}
impl SystemSettings {
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
let settings = db.select(("system_settings", "current")).await?;
if settings.is_none() {
let created: Option<SystemSettings> = db
.create(("system_settings", "current"))
.content(SystemSettings {
id: "current".to_string(),
registrations_enabled: true,
require_email_verification: false,
})
.await?;
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
};
settings.ok_or(AppError::Validation("Failed to initialize settings".into()))
}
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
let settings: Option<Self> = db
.client
.query("SELECT * FROM type::thing('system_settings', 'current')")
.await?
.take(0)?;
settings.ok_or(AppError::NotFound("System settings not found".into()))
}
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER")
.bind(("changes", changes))
.await?
.take(0)?;
updated.ok_or(AppError::Validation(
"Something went wrong updating the settings".into(),
))
}
}

View File

@@ -0,0 +1,38 @@
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
source_id: String,
chunk: String,
embedding: Vec<f32>,
user_id: String
});
impl TextChunk {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id,
chunk,
embedding,
user_id,
}
}
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
Ok(())
}
}

View File

@@ -0,0 +1,38 @@
use uuid::Uuid;
use crate::stored_object;
use super::file_info::FileInfo;
stored_object!(TextContent, "text_content", {
text: String,
file_info: Option<FileInfo>,
url: Option<String>,
instructions: String,
category: String,
user_id: String
});
impl TextContent {
pub fn new(
text: String,
instructions: String,
category: String,
file_info: Option<FileInfo>,
url: Option<String>,
user_id: String,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
text,
file_info,
url,
instructions,
category,
user_id,
}
}
}

View File

@@ -0,0 +1,352 @@
use crate::{
error::AppError,
storage::db::{get_item, SurrealDbClient},
stored_object,
};
use axum_session_auth::Authentication;
use surrealdb::{engine::any::Any, Surreal};
use uuid::Uuid;
use super::{
conversation::Conversation, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
text_content::TextContent,
};
#[derive(Deserialize)]
pub struct CategoryResponse {
category: String,
}
stored_object!(User, "user", {
email: String,
password: String,
anonymous: bool,
api_key: Option<String>,
admin: bool,
#[serde(default)]
timezone: String
});
#[async_trait]
impl Authentication<User, String, Surreal<Any>> for User {
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
let db = db.unwrap();
Ok(get_item::<Self>(db, userid.as_str()).await?.unwrap())
}
fn is_authenticated(&self) -> bool {
!self.anonymous
}
fn is_active(&self) -> bool {
!self.anonymous
}
fn is_anonymous(&self) -> bool {
self.anonymous
}
}
fn validate_timezone(input: &str) -> String {
use chrono_tz::Tz;
// Check if it's a valid IANA timezone identifier
match input.parse::<Tz>() {
Ok(_) => input.to_owned(),
Err(_) => {
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
}
}
}
impl User {
pub async fn create_new(
email: String,
password: String,
db: &SurrealDbClient,
timezone: String,
) -> Result<Self, AppError> {
// verify that the application allows new creations
let systemsettings = SystemSettings::get_current(db).await?;
if !systemsettings.registrations_enabled {
return Err(AppError::Auth("Registration is not allowed".into()));
}
let validated_tz = validate_timezone(&timezone);
let now = Utc::now();
let id = Uuid::new_v4().to_string();
let user: Option<User> = db
.client
.query(
"LET $count = (SELECT count() FROM type::table($table))[0].count;
CREATE type::thing('user', $id) SET
email = $email,
password = crypto::argon2::generate($password),
admin = $count < 1, // Changed from == 0 to < 1
anonymous = false,
created_at = $created_at,
updated_at = $updated_at,
timezone = $timezone",
)
.bind(("table", "user"))
.bind(("id", id))
.bind(("email", email))
.bind(("password", password))
.bind(("created_at", now))
.bind(("updated_at", now))
.bind(("timezone", validated_tz))
.await?
.take(1)?;
user.ok_or(AppError::Auth("User failed to create".into()))
}
pub async fn authenticate(
email: String,
password: String,
db: &SurrealDbClient,
) -> Result<Self, AppError> {
let user: Option<User> = db
.client
.query(
"SELECT * FROM user
WHERE email = $email
AND crypto::argon2::compare(password, $password)",
)
.bind(("email", email))
.bind(("password", password))
.await?
.take(0)?;
user.ok_or(AppError::Auth("User failed to authenticate".into()))
}
pub async fn find_by_email(
email: &str,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let user: Option<User> = db
.client
.query("SELECT * FROM user WHERE email = $email LIMIT 1")
.bind(("email", email.to_string()))
.await?
.take(0)?;
Ok(user)
}
pub async fn find_by_api_key(
api_key: &str,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let user: Option<User> = db
.client
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
.bind(("api_key", api_key.to_string()))
.await?
.take(0)?;
Ok(user)
}
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
// Generate a secure random API key
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
// Update the user record with the new API key
let user: Option<User> = db
.client
.query(
"UPDATE type::thing('user', $id)
SET api_key = $api_key
RETURN AFTER",
)
.bind(("id", id.to_owned()))
.bind(("api_key", api_key.clone()))
.await?
.take(0)?;
// If the user was found and updated, return the API key
if user.is_some() {
Ok(api_key)
} else {
Err(AppError::Auth("User not found".into()))
}
}
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
let user: Option<User> = db
.client
.query(
"UPDATE type::thing('user', $id)
SET api_key = NULL
RETURN AFTER",
)
.bind(("id", id.to_owned()))
.await?
.take(0)?;
if user.is_some() {
Ok(())
} else {
Err(AppError::Auth("User was not found".into()))
}
}
pub async fn get_knowledge_entities(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let entities: Vec<KnowledgeEntity> = db
.client
.query("SELECT * FROM type::table($table) WHERE user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name()))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)?;
Ok(entities)
}
pub async fn get_knowledge_relationships(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeRelationship>, AppError> {
let relationships: Vec<KnowledgeRelationship> = db
.client
.query("SELECT * FROM type::table($table) WHERE metadata.user_id = $user_id")
.bind(("table", "relates_to"))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)?;
Ok(relationships)
}
pub async fn get_latest_text_contents(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<TextContent>, AppError> {
let items: Vec<TextContent> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC LIMIT 5")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn get_text_contents(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<TextContent>, AppError> {
let items: Vec<TextContent> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn get_latest_knowledge_entities(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let items: Vec<KnowledgeEntity> = db
.client
.query(
"SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC LIMIT 5",
)
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", KnowledgeEntity::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn update_timezone(
user_id: &str,
timezone: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
db.query("UPDATE type::thing('user', $user_id) SET timezone = $timezone")
.bind(("table_name", User::table_name()))
.bind(("user_id", user_id.to_string()))
.bind(("timezone", timezone.to_string()))
.await?;
Ok(())
}
pub async fn get_user_categories(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<String>, AppError> {
// Query to select distinct categories for the user
let response: Vec<CategoryResponse> = db
.client
.query("SELECT category FROM type::table($table_name) WHERE user_id = $user_id GROUP BY category")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
// Extract the categories from the response
let categories: Vec<String> = response.into_iter().map(|item| item.category).collect();
Ok(categories)
}
pub async fn get_and_validate_knowledge_entity(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<KnowledgeEntity, AppError> {
let entity: KnowledgeEntity = get_item(db, &id)
.await?
.ok_or_else(|| AppError::NotFound("Entity not found".into()))?;
if entity.user_id != user_id {
return Err(AppError::Auth("Access denied".into()));
}
Ok(entity)
}
pub async fn get_and_validate_text_content(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<TextContent, AppError> {
let text_content: TextContent = get_item(db, &id)
.await?
.ok_or_else(|| AppError::NotFound("Content not found".into()))?;
if text_content.user_id != user_id {
return Err(AppError::Auth("Access denied".into()));
}
Ok(text_content)
}
pub async fn get_user_conversations(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<Conversation>, AppError> {
let conversations: Vec<Conversation> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id")
.bind(("table_name", Conversation::table_name()))
.bind(("user_id", user_id.to_string()))
.await?
.take(0)?;
Ok(conversations)
}
}

View File

@@ -0,0 +1,30 @@
use config::{Config, ConfigError, File};
#[derive(Clone, Debug)]
pub struct AppConfig {
pub smtp_username: String,
pub smtp_password: String,
pub smtp_relayer: String,
pub surrealdb_address: String,
pub surrealdb_username: String,
pub surrealdb_password: String,
pub surrealdb_namespace: String,
pub surrealdb_database: String,
}
pub fn get_config() -> Result<AppConfig, ConfigError> {
let config = Config::builder()
.add_source(File::with_name("config"))
.build()?;
Ok(AppConfig {
smtp_username: config.get_string("SMTP_USERNAME")?,
smtp_password: config.get_string("SMTP_PASSWORD")?,
smtp_relayer: config.get_string("SMTP_RELAYER")?,
surrealdb_address: config.get_string("SURREALDB_ADDRESS")?,
surrealdb_username: config.get_string("SURREALDB_USERNAME")?,
surrealdb_password: config.get_string("SURREALDB_PASSWORD")?,
surrealdb_namespace: config.get_string("SURREALDB_NAMESPACE")?,
surrealdb_database: config.get_string("SURREALDB_DATABASE")?,
})
}

View File

@@ -0,0 +1,48 @@
use async_openai::types::CreateEmbeddingRequestArgs;
use crate::error::AppError;
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity
/// comparisons, vector search, and other natural language processing tasks.
///
/// # Arguments
///
/// * `client`: The OpenAI client instance used to make API requests.
/// * `input`: The text string to generate embeddings for.
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
/// * `Err(ProcessingError)`: An error if the embedding generation fails
///
/// # Errors
///
/// This function can return a `AppError` in the following cases:
/// * If the OpenAI API request fails
/// * If the request building fails
/// * If no embedding data is received in the response
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
) -> Result<Vec<f32>, AppError> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input([input])
.build()?;
// Send the request to OpenAI
let response = client.embeddings().create(request).await?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
.embedding
.clone();
Ok(embedding)
}

View File

@@ -0,0 +1,81 @@
use lettre::address::AddressError;
use lettre::message::MultiPart;
use lettre::{transport::smtp::authentication::Credentials, SmtpTransport};
use lettre::{Message, Transport};
use minijinja::context;
use minijinja_autoreload::AutoReloader;
use thiserror::Error;
use tracing::info;
pub struct Mailer {
pub mailer: SmtpTransport,
}
#[derive(Error, Debug)]
pub enum EmailError {
#[error("Email construction error: {0}")]
EmailParsingError(#[from] AddressError),
#[error("Email sending error: {0}")]
SendingError(#[from] lettre::transport::smtp::Error),
#[error("Body constructing error: {0}")]
BodyError(#[from] lettre::error::Error),
#[error("Templating error: {0}")]
TemplatingError(#[from] minijinja::Error),
}
impl Mailer {
pub fn new(
username: &str,
relayer: &str,
password: &str,
) -> Result<Self, lettre::transport::smtp::Error> {
let creds = Credentials::new(username.to_owned(), password.to_owned());
let mailer = SmtpTransport::relay(&relayer)?.credentials(creds).build();
Ok(Mailer { mailer })
}
pub fn send_email_verification(
&self,
email_to: &str,
verification_code: &str,
templates: &AutoReloader,
) -> Result<(), EmailError> {
let name = email_to
.split('@')
.next()
.unwrap_or("User")
.chars()
.enumerate()
.map(|(i, c)| if i == 0 { c.to_ascii_uppercase() } else { c })
.collect::<String>();
let context = context! {
name => name,
verification_code => verification_code
};
let env = templates.acquire_env()?;
let html = env
.get_template("email/email_verification.html")?
.render(&context)?;
let plain = env
.get_template("email/email_verification.txt")?
.render(&context)?;
let email = Message::builder()
.from("Admin <minne@starks.cloud>".parse()?)
.reply_to("Admin <minne@starks.cloud>".parse()?)
.to(format!("{} <{}>", name, email_to).parse()?)
.subject("Verify Your Email Address")
.multipart(MultiPart::alternative_plain_html(plain, html))?;
info!("Sending email to: {}", email_to);
self.mailer.send(&email)?;
Ok(())
}
}

View File

@@ -0,0 +1,3 @@
pub mod config;
pub mod embedding;
pub mod mailer;

View File

@@ -0,0 +1,35 @@
[package]
name = "html-router"
version = "0.1.0"
edition = "2021"
[dependencies]
# Workspace dependencies
tokio = { workspace = true }
serde = { workspace = true }
axum = { workspace = true }
tracing = { workspace = true }
serde_json = { workspace = true }
axum-htmx = "0.6.0"
axum_session = "0.14.4"
axum_session_auth = "0.14.1"
axum_session_surreal = "0.2.1"
axum_typed_multipart = "0.12.1"
futures = "0.3.31"
tempfile = "3.12.0"
async-stream = "0.3.6"
json-stream-parser = "0.1.4"
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
minijinja-autoreload = "2.5.0"
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
plotly = "0.12.1"
surrealdb = "2.0.4"
tower-http = { version = "0.6.2", features = ["fs"] }
chrono-tz = "0.10.1"
async-openai = "0.24.1"
common = { path = "../common" }

View File

@@ -0,0 +1,88 @@
use axum_session::SessionStore;
use axum_session_surreal::SessionSurrealPool;
use common::ingress::jobqueue::JobQueue;
use common::storage::db::SurrealDbClient;
use common::utils::config::AppConfig;
use common::utils::mailer::Mailer;
use minijinja::{path_loader, Environment};
use minijinja_autoreload::AutoReloader;
use std::path::PathBuf;
use std::sync::Arc;
use surrealdb::engine::any::Any;
#[derive(Clone)]
pub struct HtmlState {
pub surreal_db_client: Arc<SurrealDbClient>,
pub openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
pub templates: Arc<AutoReloader>,
pub mailer: Arc<Mailer>,
pub job_queue: Arc<JobQueue>,
pub session_store: Arc<SessionStore<SessionSurrealPool<Any>>>,
}
impl HtmlState {
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
let reloader = AutoReloader::new(move |notifier| {
let template_path = get_templates_dir();
let mut env = Environment::new();
env.set_loader(path_loader(&template_path));
notifier.set_fast_reload(true);
notifier.watch_path(&template_path, true);
minijinja_contrib::add_to_environment(&mut env);
Ok(env)
});
let surreal_db_client = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
surreal_db_client.ensure_initialized().await?;
let openai_client = Arc::new(async_openai::Client::new());
let session_store = Arc::new(surreal_db_client.create_session_store().await?);
let app_state = HtmlState {
surreal_db_client: surreal_db_client.clone(),
templates: Arc::new(reloader),
openai_client: openai_client.clone(),
mailer: Arc::new(Mailer::new(
&config.smtp_username,
&config.smtp_relayer,
&config.smtp_password,
)?),
job_queue: Arc::new(JobQueue::new(surreal_db_client)),
session_store,
};
Ok(app_state)
}
}
pub fn get_workspace_root() -> PathBuf {
// Starts from CARGO_MANIFEST_DIR (e.g., /project/crates/html-router/)
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
// Navigate up to /path/to/project/crates
let crates_dir = manifest_dir
.parent()
.expect("Failed to find parent of manifest directory");
// Navigate up to workspace root
crates_dir
.parent()
.expect("Failed to find workspace root")
.to_path_buf()
}
pub fn get_templates_dir() -> PathBuf {
get_workspace_root().join("templates")
}

View File

@@ -0,0 +1,110 @@
pub mod html_state;
mod middleware_analytics;
mod routes;
use axum::{
extract::FromRef,
middleware::from_fn_with_state,
routing::{delete, get, patch, post},
Router,
};
use axum_session::SessionLayer;
use axum_session_auth::{AuthConfig, AuthSessionLayer};
use axum_session_surreal::SessionSurrealPool;
use common::storage::types::user::User;
use html_state::HtmlState;
use middleware_analytics::analytics_middleware;
use routes::{
account::{delete_account, set_api_key, show_account_page, update_timezone},
admin_panel::{show_admin_panel, toggle_registration_status},
chat::{
message_response_stream::get_response_stream, new_chat_user_message, new_user_message,
references::show_reference_tooltip, show_chat_base, show_existing_chat,
show_initialized_chat,
},
content::{patch_text_content, show_content_page, show_text_content_edit_form},
documentation::{
show_documentation_index, show_get_started, show_mobile_friendly, show_privacy_policy,
},
gdpr::{accept_gdpr, deny_gdpr},
index::{delete_job, delete_text_content, index_handler, show_active_jobs},
ingress_form::{hide_ingress_form, process_ingress_form, show_ingress_form},
knowledge::{
delete_knowledge_entity, delete_knowledge_relationship, patch_knowledge_entity,
save_knowledge_relationship, show_edit_knowledge_entity_form, show_knowledge_page,
},
search_result::search_result_handler,
signin::{authenticate_user, show_signin_form},
signout::sign_out_user,
signup::{process_signup_and_show_verification, show_signup_form},
};
use surrealdb::{engine::any::Any, Surreal};
use tower_http::services::ServeDir;
/// Router for HTML endpoints
pub fn html_routes<S>(app_state: &HtmlState) -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route("/", get(index_handler))
.route("/gdpr/accept", post(accept_gdpr))
.route("/gdpr/deny", post(deny_gdpr))
.route("/search", get(search_result_handler))
.route("/chat", get(show_chat_base).post(new_chat_user_message))
.route("/initialized-chat", post(show_initialized_chat))
.route("/chat/:id", get(show_existing_chat).post(new_user_message))
.route("/chat/response-stream", get(get_response_stream))
.route("/knowledge/:id", get(show_reference_tooltip))
.route("/signout", get(sign_out_user))
.route("/signin", get(show_signin_form).post(authenticate_user))
.route(
"/ingress-form",
get(show_ingress_form).post(process_ingress_form),
)
.route("/hide-ingress-form", get(hide_ingress_form))
.route("/text-content/:id", delete(delete_text_content))
.route("/jobs/:job_id", delete(delete_job))
.route("/active-jobs", get(show_active_jobs))
.route("/content", get(show_content_page))
.route(
"/content/:id",
get(show_text_content_edit_form).patch(patch_text_content),
)
.route("/knowledge", get(show_knowledge_page))
.route(
"/knowledge-entity/:id",
get(show_edit_knowledge_entity_form)
.delete(delete_knowledge_entity)
.patch(patch_knowledge_entity),
)
.route("/knowledge-relationship", post(save_knowledge_relationship))
.route(
"/knowledge-relationship/:id",
delete(delete_knowledge_relationship),
)
.route("/account", get(show_account_page))
.route("/admin", get(show_admin_panel))
.route("/toggle-registrations", patch(toggle_registration_status))
.route("/set-api-key", post(set_api_key))
.route("/update-timezone", patch(update_timezone))
.route("/delete-account", delete(delete_account))
.route(
"/signup",
get(show_signup_form).post(process_signup_and_show_verification),
)
.route("/documentation", get(show_documentation_index))
.route("/documentation/privacy-policy", get(show_privacy_policy))
.route("/documentation/get-started", get(show_get_started))
.route("/documentation/mobile-friendly", get(show_mobile_friendly))
.nest_service("/assets", ServeDir::new("assets/"))
.layer(from_fn_with_state(app_state.clone(), analytics_middleware))
.layer(
AuthSessionLayer::<User, String, SessionSurrealPool<Any>, Surreal<Any>>::new(Some(
app_state.surreal_db_client.client.clone(),
))
.with_config(AuthConfig::<String>::default()),
)
.layer(SessionLayer::new((*app_state.session_store).clone()))
}

View File

@@ -0,0 +1,33 @@
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use axum_session_surreal::SessionSurrealPool;
use surrealdb::engine::any::Any;
use common::storage::types::analytics::Analytics;
use crate::html_state::HtmlState;
pub async fn analytics_middleware(
State(state): State<HtmlState>,
session: axum_session::Session<SessionSurrealPool<Any>>,
request: Request,
next: Next,
) -> Response {
// Get the path from the request
let path = request.uri().path();
// Only count if it's a main page request (not assets or other resources)
if !path.starts_with("/assets") && !path.starts_with("/_next") && !path.contains('.') {
if !session.get::<bool>("counted_visitor").unwrap_or(false) {
let _ = Analytics::increment_visitors(&state.surreal_db_client).await;
session.set("counted_visitor", true);
}
let _ = Analytics::increment_page_loads(&state.surreal_db_client).await;
}
next.run(request).await
}

View File

@@ -0,0 +1,147 @@
use axum::{
extract::State,
http::{StatusCode, Uri},
response::{IntoResponse, Redirect},
Form,
};
use axum_htmx::HxRedirect;
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use chrono_tz::TZ_VARIANTS;
use surrealdb::{engine::any::Any, Surreal};
use common::{
error::{AppError, HtmlError},
storage::{db::delete_item, types::user::User},
};
use crate::{html_state::HtmlState, page_data};
use super::{render_block, render_template};
page_data!(AccountData, "auth/account_settings.html", {
user: User,
timezones: Vec<String>
});
pub async fn show_account_page(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
let output = render_template(
AccountData::template_name(),
AccountData { user, timezones },
state.templates.clone(),
)?;
Ok(output.into_response())
}
pub async fn set_api_key(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
// Generate and set the API key
let api_key = User::set_api_key(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
auth.cache_clear_user(user.id.to_string());
// Update the user's API key
let updated_user = User {
api_key: Some(api_key),
..user.clone()
};
// Render the API key section block
let output = render_block(
AccountData::template_name(),
"api_key_section",
AccountData {
user: updated_user,
timezones: vec![],
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
pub async fn delete_account(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
delete_item::<User>(&state.surreal_db_client, &user.id)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
auth.logout_user();
auth.session.destroy();
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
}
#[derive(Deserialize)]
pub struct UpdateTimezoneForm {
timezone: String,
}
pub async fn update_timezone(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<UpdateTimezoneForm>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
User::update_timezone(&user.id, &form.timezone, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
auth.cache_clear_user(user.id.to_string());
// Update the user's API key
let updated_user = User {
timezone: form.timezone,
..user.clone()
};
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
// Render the API key section block
let output = render_block(
AccountData::template_name(),
"timezone_section",
AccountData {
user: updated_user,
timezones,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,118 @@
use axum::{
extract::State,
response::{IntoResponse, Redirect},
Form,
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use common::{
error::HtmlError,
storage::types::{analytics::Analytics, system_settings::SystemSettings, user::User},
};
use crate::{html_state::HtmlState, page_data};
use super::{render_block, render_template};
page_data!(AdminPanelData, "auth/admin_panel.html", {
user: User,
settings: SystemSettings,
analytics: Analytics,
users: i64,
});
pub async fn show_admin_panel(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated and admin
let user = match auth.current_user {
Some(user) if user.admin => user,
_ => return Ok(Redirect::to("/").into_response()),
};
let settings = SystemSettings::get_current(&state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let analytics = Analytics::get_current(&state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let users_count = Analytics::get_users_amount(&state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
AdminPanelData::template_name(),
AdminPanelData {
user,
settings,
analytics,
users: users_count,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
fn checkbox_to_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
D: serde::Deserializer<'de>,
{
match String::deserialize(deserializer) {
Ok(string) => Ok(string == "on"),
Err(_) => Ok(false),
}
}
#[derive(Deserialize)]
pub struct RegistrationToggleInput {
#[serde(default)]
#[serde(deserialize_with = "checkbox_to_bool")]
registration_open: bool,
}
#[derive(Serialize)]
pub struct RegistrationToggleData {
settings: SystemSettings,
}
pub async fn toggle_registration_status(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(input): Form<RegistrationToggleInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated and admin
let _user = match auth.current_user {
Some(user) if user.admin => user,
_ => return Ok(Redirect::to("/").into_response()),
};
let current_settings = SystemSettings::get_current(&state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let new_settings = SystemSettings {
registrations_enabled: input.registration_open,
..current_settings.clone()
};
SystemSettings::update(&state.surreal_db_client, new_settings.clone())
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_block(
AdminPanelData::template_name(),
"registration_status_input",
RegistrationToggleData {
settings: new_settings,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,336 @@
use std::{pin::Pin, sync::Arc, time::Duration};
use async_stream::stream;
use axum::{
extract::{Query, State},
response::{
sse::{Event, KeepAlive},
Sse,
},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use futures::{
stream::{self, once},
Stream, StreamExt, TryStreamExt,
};
use json_stream_parser::JsonStreamParser;
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use surrealdb::{engine::any::Any, Surreal};
use tokio::sync::{mpsc::channel, Mutex};
use tracing::{error, info};
use common::{
retrieval::{
combined_knowledge_entity_retrieval,
query_helper::{
create_chat_request, create_user_message, format_entities_json, LLMResponseFormat,
},
},
storage::{
db::{get_item, store_item, SurrealDbClient},
types::{
message::{Message, MessageRole},
user::User,
},
},
};
use crate::{html_state::HtmlState, routes::render_template};
// Error handling function
fn create_error_stream(
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
}
// Helper function to get message and user
async fn get_message_and_user(
db: &SurrealDbClient,
current_user: Option<User>,
message_id: &str,
) -> Result<(Message, User), Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>> {
// Check authentication
let user = match current_user {
Some(user) => user,
None => {
return Err(Sse::new(create_error_stream(
"You must be signed in to use this feature",
)))
}
};
// Retrieve message
let message = match get_item::<Message>(db, message_id).await {
Ok(Some(message)) => message,
Ok(None) => {
return Err(Sse::new(create_error_stream(
"Message not found: the specified message does not exist",
)))
}
Err(e) => {
error!("Database error retrieving message {}: {:?}", message_id, e);
return Err(Sse::new(create_error_stream(
"Failed to retrieve message: database error",
)));
}
};
Ok((message, user))
}
#[derive(Deserialize)]
pub struct QueryParams {
message_id: String,
}
pub async fn get_response_stream(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
// 1. Authentication and initial data validation
let (user_message, user) = match get_message_and_user(
&state.surreal_db_client,
auth.current_user,
&params.message_id,
)
.await
{
Ok((user_message, user)) => (user_message, user),
Err(error_stream) => return error_stream,
};
// 2. Retrieve knowledge entities
let entities = match combined_knowledge_entity_retrieval(
&state.surreal_db_client,
&state.openai_client,
&user_message.content,
&user.id,
)
.await
{
Ok(entities) => entities,
Err(_e) => {
return Sse::new(create_error_stream("Failed to retrieve knowledge entities"));
}
};
// 3. Create the OpenAI request
let entities_json = format_entities_json(&entities);
let formatted_user_message = create_user_message(&entities_json, &user_message.content);
let request = match create_chat_request(formatted_user_message) {
Ok(req) => req,
Err(..) => {
return Sse::new(create_error_stream("Failed to create chat request"));
}
};
// 4. Set up the OpenAI stream
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return Sse::new(create_error_stream("Failed to create OpenAI stream"));
}
};
// 5. Create channel for collecting complete response
let (tx, mut rx) = channel::<String>(1000);
let tx_clone = tx.clone();
let (tx_final, mut rx_final) = channel::<Message>(1);
// 6. Set up the collection task for DB storage
let db_client = state.surreal_db_client.clone();
tokio::spawn(async move {
drop(tx); // Close sender when no longer needed
// Collect full response
let mut full_json = String::new();
while let Some(chunk) = rx.recv().await {
full_json.push_str(&chunk);
}
// Try to extract structured data
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let references: Vec<String> = response
.references
.into_iter()
.map(|r| r.reference)
.collect();
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
response.answer,
Some(references),
);
let _ = tx_final.send(ai_message.clone()).await;
match store_item(&db_client, ai_message).await {
Ok(_) => info!("Successfully stored AI message with references"),
Err(e) => error!("Failed to store AI message: {:?}", e),
}
} else {
error!("Failed to parse LLM response as structured format");
// Fallback - store raw response
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
full_json,
None,
);
let _ = store_item(&db_client, ai_message).await;
}
});
// Create a shared state for tracking the JSON parsing
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
// 7. Create the response event stream
let event_stream = openai_stream
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx_clone.clone();
let json_state = json_state.clone();
stream! {
match result {
Ok(response) => {
let content = response
.choices
.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
if !content.is_empty() {
// Always send raw content to storage
let _ = tx_storage.send(content.clone()).await;
// Process through JSON parser
let mut state = json_state.lock().await;
let display_content = state.process_chunk(&content);
drop(state);
if !display_content.is_empty() {
yield Ok(Event::default()
.event("chat_message")
.data(display_content));
}
// If display_content is empty, don't yield anything
}
// If content is empty, don't yield anything
}
Err(e) => {
yield Ok(Event::default()
.event("error")
.data(format!("Stream error: {}", e)));
}
}
}
})
.flatten()
.chain(stream::once(async move {
if let Some(message) = rx_final.recv().await {
// Don't send any event if references is empty
if message.references.as_ref().is_some_and(|x| x.is_empty()) {
return Ok(Event::default().event("empty")); // This event won't be sent
}
// Prepare data for template
#[derive(Serialize)]
struct ReferenceData {
message: Message,
}
// Render template with references
match render_template(
"chat/reference_list.html",
ReferenceData { message },
state.templates.clone(),
) {
Ok(html) => {
// Extract the String from Html<String>
let html_string = html.0;
// Return the rendered HTML
Ok(Event::default().event("references").data(html_string))
}
Err(_) => {
// Handle template rendering error
Ok(Event::default()
.event("error")
.data("Failed to render references"))
}
}
} else {
// Handle case where no references were received
Ok(Event::default()
.event("error")
.data("Failed to retrieve references"))
}
}))
.chain(once(async {
Ok(Event::default()
.event("close_stream")
.data("Stream complete"))
}));
info!("OpenAI streaming started");
Sse::new(event_stream.boxed()).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
// Replace JsonParseState with StreamParserState
struct StreamParserState {
parser: JsonStreamParser,
last_answer_content: String,
in_answer_field: bool,
}
impl StreamParserState {
fn new() -> Self {
Self {
parser: JsonStreamParser::new(),
last_answer_content: String::new(),
in_answer_field: false,
}
}
fn process_chunk(&mut self, chunk: &str) -> String {
// Feed all characters into the parser
for c in chunk.chars() {
let _ = self.parser.add_char(c);
}
// Get the current state of the JSON
let json = self.parser.get_result();
// Check if we're in the answer field
if let Some(obj) = json.as_object() {
if let Some(answer) = obj.get("answer") {
self.in_answer_field = true;
// Get current answer content
let current_content = answer.as_str().unwrap_or_default().to_string();
// Calculate difference to send only new content
if current_content.len() > self.last_answer_content.len() {
let new_content = current_content[self.last_answer_content.len()..].to_string();
self.last_answer_content = current_content;
return new_content;
}
}
}
// No new content to return
String::new()
}
}

View File

@@ -0,0 +1,289 @@
pub mod message_response_stream;
pub mod references;
use axum::{
extract::{Path, State},
http::HeaderValue,
response::{IntoResponse, Redirect},
Form,
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use common::{
error::{AppError, HtmlError},
storage::{
db::{get_item, store_item},
types::{
conversation::Conversation,
message::{Message, MessageRole},
user::User,
},
},
};
use crate::{html_state::HtmlState, page_data, routes::render_template};
// Update your ChatStartParams struct to properly deserialize the references
#[derive(Debug, Deserialize)]
pub struct ChatStartParams {
user_query: String,
llm_response: String,
#[serde(deserialize_with = "deserialize_references")]
references: Vec<String>,
}
// Custom deserializer function
fn deserialize_references<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(serde::de::Error::custom)
}
page_data!(ChatData, "chat/base.html", {
user: User,
history: Vec<Message>,
conversation: Option<Conversation>,
conversation_archive: Vec<Conversation>
});
pub async fn show_initialized_chat(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<ChatStartParams>,
) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying chat start");
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
let user_message = Message::new(
conversation.id.to_string(),
MessageRole::User,
form.user_query,
None,
);
let ai_message = Message::new(
conversation.id.to_string(),
MessageRole::AI,
form.llm_response,
Some(form.references),
);
let (conversation_result, ai_message_result, user_message_result) = futures::join!(
store_item(&state.surreal_db_client, conversation.clone()),
store_item(&state.surreal_db_client, ai_message.clone()),
store_item(&state.surreal_db_client, user_message.clone())
);
// Check each result individually
conversation_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
user_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
ai_message_result.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let messages = vec![user_message, ai_message];
let output = render_template(
ChatData::template_name(),
ChatData {
history: messages,
user,
conversation_archive,
conversation: Some(conversation.clone()),
},
state.templates.clone(),
)?;
let mut response = output.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
Ok(response)
}
pub async fn show_chat_base(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying empty chat start");
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
ChatData::template_name(),
ChatData {
history: vec![],
user,
conversation_archive,
conversation: None,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
#[derive(Deserialize)]
pub struct NewMessageForm {
content: String,
}
pub async fn show_existing_chat(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying initialized chat with id: {}", conversation_id);
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let conversation_archive = User::get_user_conversations(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let (conversation, messages) = Conversation::get_complete_conversation(
conversation_id.as_str(),
&user.id,
&state.surreal_db_client,
)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
ChatData::template_name(),
ChatData {
history: messages,
user,
conversation: Some(conversation.clone()),
conversation_archive,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
pub async fn new_user_message(
Path(conversation_id): Path<String>,
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let conversation: Conversation = get_item(&state.surreal_db_client, &conversation_id)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?
.ok_or_else(|| {
HtmlError::new(
AppError::NotFound("Conversation was not found".to_string()),
state.templates.clone(),
)
})?;
if conversation.user_id != user.id {
return Err(HtmlError::new(
AppError::Auth("The user does not have permission for this conversation".to_string()),
state.templates.clone(),
));
};
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
store_item(&state.surreal_db_client, user_message.clone())
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
}
let output = render_template(
"chat/streaming_response.html",
SSEResponseInitData { user_message },
state.templates.clone(),
)?;
let mut response = output.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
Ok(response)
}
pub async fn new_chat_user_message(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<NewMessageForm>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let conversation = Conversation::new(user.id, "New chat".to_string());
let user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
form.content,
None,
);
store_item(&state.surreal_db_client, conversation.clone())
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
store_item(&state.surreal_db_client, user_message.clone())
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
#[derive(Serialize)]
struct SSEResponseInitData {
user_message: Message,
conversation: Conversation,
}
let output = render_template(
"chat/new_chat_first_response.html",
SSEResponseInitData {
user_message,
conversation: conversation.clone(),
},
state.templates.clone(),
)?;
let mut response = output.into_response();
response.headers_mut().insert(
"HX-Push",
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
);
Ok(response)
}

View File

@@ -0,0 +1,63 @@
use axum::{
extract::{Path, State},
response::{IntoResponse, Redirect},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::Serialize;
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use common::{
error::{AppError, HtmlError},
storage::{
db::get_item,
types::{knowledge_entity::KnowledgeEntity, user::User},
},
};
use crate::{html_state::HtmlState, routes::render_template};
pub async fn show_reference_tooltip(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Path(reference_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
info!("Showing reference");
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
let entity: KnowledgeEntity = get_item(&state.surreal_db_client, &reference_id)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?
.ok_or_else(|| {
HtmlError::new(
AppError::NotFound("Item was not found".to_string()),
state.templates.clone(),
)
})?;
if entity.user_id != user.id {
return Err(HtmlError::new(
AppError::Auth("You dont have access to this entity".to_string()),
state.templates.clone(),
));
}
#[derive(Serialize)]
struct ReferenceTooltipData {
entity: KnowledgeEntity,
user: User,
}
let output = render_template(
"chat/reference_tooltip.html",
ReferenceTooltipData { entity, user },
state.templates.clone(),
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,108 @@
use axum::{
extract::{Path, State},
response::{IntoResponse, Redirect},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use common::{
error::HtmlError,
storage::types::{text_content::TextContent, user::User},
};
use crate::{html_state::HtmlState, page_data};
use super::render_template;
page_data!(ContentPageData, "content/base.html", {
user: User,
text_contents: Vec<TextContent>
});
pub async fn show_content_page(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
let text_contents = User::get_text_contents(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
ContentPageData::template_name(),
ContentPageData {
user,
text_contents,
},
state.templates,
)?;
Ok(output.into_response())
}
#[derive(Serialize)]
pub struct TextContentEditModal {
pub user: User,
pub text_content: TextContent,
}
pub async fn show_text_content_edit_form(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
"content/edit_text_content_modal.html",
TextContentEditModal { user, text_content },
state.templates,
)?;
Ok(output.into_response())
}
pub async fn patch_text_content(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let text_contents = User::get_text_contents(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
"content/content_list.html",
ContentPageData {
user,
text_contents,
},
state.templates,
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,78 @@
use axum::{extract::State, response::IntoResponse};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use common::{error::HtmlError, storage::types::user::User};
use crate::{html_state::HtmlState, page_data};
use super::render_template;
page_data!(DocumentationData, "do_not_use_this", {
user: Option<User>,
current_path: String
});
pub async fn show_privacy_policy(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
let output = render_template(
"documentation/privacy.html",
DocumentationData {
user: auth.current_user,
current_path: "/privacy_policy".to_string(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
pub async fn show_get_started(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
let output = render_template(
"documentation/get_started.html",
DocumentationData {
user: auth.current_user,
current_path: "/get-started".to_string(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
pub async fn show_mobile_friendly(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
let output = render_template(
"documentation/mobile_friendly.html",
DocumentationData {
user: auth.current_user,
current_path: "/mobile-friendly".to_string(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
pub async fn show_documentation_index(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
let output = render_template(
"documentation/index.html",
DocumentationData {
user: auth.current_user,
current_path: "/index".to_string(),
},
state.templates.clone(),
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,22 @@
use axum::response::{Html, IntoResponse};
use axum_session::Session;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::engine::any::Any;
use common::error::HtmlError;
pub async fn accept_gdpr(
session: Session<SessionSurrealPool<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
session.set("gdpr_accepted", true);
Ok(Html("").into_response())
}
pub async fn deny_gdpr(
session: Session<SessionSurrealPool<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
session.set("gdpr_accepted", true);
Ok(Html("").into_response())
}

View File

@@ -0,0 +1,246 @@
use axum::{
extract::{Path, State},
response::{IntoResponse, Redirect},
};
use axum_session::Session;
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use tokio::join;
use tracing::info;
use common::{
error::{AppError, HtmlError},
storage::{
db::{delete_item, get_item},
types::{
file_info::FileInfo, job::Job, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
text_content::TextContent, user::User,
},
},
};
use crate::{html_state::HtmlState, page_data, routes::render_template};
use super::render_block;
page_data!(IndexData, "index/index.html", {
gdpr_accepted: bool,
user: Option<User>,
latest_text_contents: Vec<TextContent>,
active_jobs: Vec<Job>
});
pub async fn index_handler(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
session: Session<SessionSurrealPool<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying index page");
let gdpr_accepted = auth.current_user.is_some() | session.get("gdpr_accepted").unwrap_or(false);
let active_jobs = match auth.current_user.is_some() {
true => state
.job_queue
.get_unfinished_user_jobs(&auth.current_user.clone().unwrap().id)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?,
false => vec![],
};
let latest_text_contents = match auth.current_user.clone().is_some() {
true => User::get_latest_text_contents(
auth.current_user.clone().unwrap().id.as_str(),
&state.surreal_db_client,
)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?,
false => vec![],
};
// let latest_knowledge_entities = match auth.current_user.is_some() {
// true => User::get_latest_knowledge_entities(
// auth.current_user.clone().unwrap().id.as_str(),
// &state.surreal_db_client,
// )
// .await
// .map_err(|e| HtmlError::new(e, state.templates.clone()))?,
// false => vec![],
// };
let output = render_template(
IndexData::template_name(),
IndexData {
gdpr_accepted,
user: auth.current_user,
latest_text_contents,
active_jobs,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
#[derive(Serialize)]
pub struct LatestTextContentData {
latest_text_contents: Vec<TextContent>,
user: User,
}
pub async fn delete_text_content(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match &auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/").into_response()),
};
// Get and validate TextContent
let text_content = get_and_validate_text_content(&state, &id, user).await?;
// Perform concurrent deletions
let deletion_tasks = join!(
async {
if let Some(file_info) = text_content.file_info {
FileInfo::delete_by_id(&file_info.id, &state.surreal_db_client).await
} else {
Ok(())
}
},
delete_item::<TextContent>(&state.surreal_db_client, &text_content.id),
TextChunk::delete_by_source_id(&text_content.id, &state.surreal_db_client),
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.surreal_db_client),
KnowledgeRelationship::delete_relationships_by_source_id(
&text_content.id,
&state.surreal_db_client
)
);
// Handle potential errors from concurrent operations
match deletion_tasks {
(Ok(_), Ok(_), Ok(_), Ok(_), Ok(_)) => (),
_ => {
return Err(HtmlError::new(
AppError::Processing("Failed to delete one or more items".to_string()),
state.templates.clone(),
))
}
}
// Render updated content
let latest_text_contents = User::get_latest_text_contents(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_block(
"index/signed_in/recent_content.html",
"latest_content_section",
LatestTextContentData {
user: user.clone(),
latest_text_contents,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
// Helper function to get and validate text content
async fn get_and_validate_text_content(
state: &HtmlState,
id: &str,
user: &User,
) -> Result<TextContent, HtmlError> {
let text_content = get_item::<TextContent>(&state.surreal_db_client, id)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?
.ok_or_else(|| {
HtmlError::new(
AppError::NotFound("No item found".to_string()),
state.templates.clone(),
)
})?;
if text_content.user_id != user.id {
return Err(HtmlError::new(
AppError::Auth("You are not the owner of that content".to_string()),
state.templates.clone(),
));
}
Ok(text_content)
}
#[derive(Serialize)]
pub struct ActiveJobsData {
pub active_jobs: Vec<Job>,
pub user: User,
}
pub async fn delete_job(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
state
.job_queue
.delete_job(&id, &user.id)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let active_jobs = state
.job_queue
.get_unfinished_user_jobs(&user.id)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_block(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}
pub async fn show_active_jobs(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
let active_jobs = state
.job_queue
.get_unfinished_user_jobs(&user.id)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_block(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,154 @@
use axum::{
extract::State,
response::{Html, IntoResponse, Redirect},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use futures::{future::try_join_all, TryFutureExt};
use surrealdb::{engine::any::Any, Surreal};
use tempfile::NamedTempFile;
use tracing::info;
use common::{
error::{AppError, HtmlError, IntoHtmlError},
ingress::ingress_input::{create_ingress_objects, IngressInput},
storage::types::{file_info::FileInfo, user::User},
};
use crate::{
html_state::HtmlState,
page_data,
routes::{index::ActiveJobsData, render_block},
};
use super::render_template;
#[derive(Serialize)]
pub struct ShowIngressFormData {
user_categories: Vec<String>,
}
pub async fn show_ingress_form(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
if !auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
}
let user_categories = User::get_user_categories(&auth.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
"index/signed_in/ingress_modal.html",
ShowIngressFormData { user_categories },
state.templates.clone(),
)?;
Ok(output.into_response())
}
pub async fn hide_ingress_form(
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
if !auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
}
Ok(Html(
"<a class='btn btn-primary' hx-get='/ingress-form' hx-swap='outerHTML'>Add Content</a>",
)
.into_response())
}
#[derive(Debug, TryFromMultipart)]
pub struct IngressParams {
pub content: Option<String>,
pub instructions: String,
pub category: String,
#[form_data(limit = "10000000")] // Adjust limit as needed
#[form_data(default)]
pub files: Vec<FieldData<NamedTempFile>>,
}
page_data!(IngressFormData, "ingress_form.html", {
instructions: String,
content: String,
category: String,
error: String,
});
pub async fn process_ingress_form(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
TypedMultipart(input): TypedMultipart<IngressParams>,
) -> Result<impl IntoResponse, HtmlError> {
let user = auth.current_user.ok_or_else(|| {
AppError::Auth("You must be signed in".to_string()).with_template(state.templates.clone())
})?;
if input.content.clone().is_some_and(|c| c.len() < 2) && input.files.is_empty() {
let output = render_template(
IngressFormData::template_name(),
IngressFormData {
instructions: input.instructions.clone(),
content: input.content.clone().unwrap(),
category: input.category.clone(),
error: "You need to either add files or content".to_string(),
},
state.templates.clone(),
)?;
return Ok(output.into_response());
}
info!("{:?}", input);
let file_infos = try_join_all(input.files.into_iter().map(|file| {
FileInfo::new(file, &state.surreal_db_client, &user.id)
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))
}))
.await?;
let ingress_objects = create_ingress_objects(
IngressInput {
content: input.content,
instructions: input.instructions,
category: input.category,
files: file_infos,
},
user.id.as_str(),
)
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let futures: Vec<_> = ingress_objects
.into_iter()
.map(|object| state.job_queue.enqueue(object.clone(), user.id.clone()))
.collect();
try_join_all(futures)
.await
.map_err(AppError::from)
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
// Update the active jobs page with the newly created job
let active_jobs = state
.job_queue
.get_unfinished_user_jobs(&user.id)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_block(
"index/signed_in/active_jobs.html",
"active_jobs_section",
ActiveJobsData {
user: user.clone(),
active_jobs,
},
state.templates.clone(),
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,379 @@
use axum::{
extract::{Path, State},
response::{IntoResponse, Redirect},
Form,
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use plotly::{
common::{Line, Marker, Mode},
layout::{Axis, Camera, LayoutScene, ProjectionType},
Layout, Plot, Scatter3D,
};
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use common::{
error::{AppError, HtmlError},
storage::{
db::delete_item,
types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
user::User,
},
},
};
use crate::{html_state::HtmlState, page_data, routes::render_template};
page_data!(KnowledgeBaseData, "knowledge/base.html", {
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
user: User,
plot_html: String
});
pub async fn show_knowledge_page(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
info!("Got entities ok");
let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let mut plot = Plot::new();
// Fibonacci sphere distribution
let node_count = entities.len();
let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
let node_positions: Vec<(f64, f64, f64)> = (0..node_count)
.map(|i| {
let i = i as f64;
let theta = 2.0 * std::f64::consts::PI * i / golden_ratio;
let phi = (1.0 - 2.0 * (i + 0.5) / node_count as f64).acos();
let x = phi.sin() * theta.cos();
let y = phi.sin() * theta.sin();
let z = phi.cos();
(x, y, z)
})
.collect();
let node_x: Vec<f64> = node_positions.iter().map(|(x, _, _)| *x).collect();
let node_y: Vec<f64> = node_positions.iter().map(|(_, y, _)| *y).collect();
let node_z: Vec<f64> = node_positions.iter().map(|(_, _, z)| *z).collect();
// Nodes trace
let nodes = Scatter3D::new(node_x.clone(), node_y.clone(), node_z.clone())
.mode(Mode::Markers)
.marker(Marker::new().size(8).color("#1f77b4"))
.text_array(
entities
.iter()
.map(|e| e.description.clone())
.collect::<Vec<_>>(),
)
.hover_template("Entity: %{text}<br>");
// Edges traces
for rel in &relationships {
let from_idx = entities.iter().position(|e| e.id == rel.out).unwrap_or(0);
let to_idx = entities.iter().position(|e| e.id == rel.in_).unwrap_or(0);
let edge_x = vec![node_x[from_idx], node_x[to_idx]];
let edge_y = vec![node_y[from_idx], node_y[to_idx]];
let edge_z = vec![node_z[from_idx], node_z[to_idx]];
let edge_trace = Scatter3D::new(edge_x, edge_y, edge_z)
.mode(Mode::Lines)
.line(Line::new().color("#888").width(2.0))
.hover_template(&format!(
"Relationship: {}<br>",
rel.metadata.relationship_type
))
.show_legend(false);
plot.add_trace(edge_trace);
}
plot.add_trace(nodes);
// Layout
let layout = Layout::new()
.scene(
LayoutScene::new()
.x_axis(Axis::new().visible(false))
.y_axis(Axis::new().visible(false))
.z_axis(Axis::new().visible(false))
.camera(
Camera::new()
.projection(ProjectionType::Perspective.into())
.eye((1.5, 1.5, 1.5).into()),
),
)
.show_legend(false)
.paper_background_color("rbga(250,100,0,0)")
.plot_background_color("rbga(0,0,0,0)");
plot.set_layout(layout);
// Convert to HTML
let html = plot.to_html();
let output = render_template(
KnowledgeBaseData::template_name(),
KnowledgeBaseData {
entities,
relationships,
user,
plot_html: html,
},
state.templates,
)?;
Ok(output.into_response())
}
#[derive(Serialize)]
pub struct EntityData {
entity: KnowledgeEntity,
entity_types: Vec<String>,
user: User,
}
pub async fn show_edit_knowledge_entity_form(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
// Get entity types
let entity_types: Vec<String> = KnowledgeEntityType::variants()
.iter()
.map(|s| s.to_string())
.collect();
// Get the entity and validate ownership
let entity = User::get_and_validate_knowledge_entity(&id, &user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let output = render_template(
"knowledge/edit_knowledge_entity_modal.html",
EntityData {
entity,
user,
entity_types,
},
state.templates,
)?;
Ok(output.into_response())
}
#[derive(Serialize)]
pub struct EntityListData {
entities: Vec<KnowledgeEntity>,
user: User,
}
#[derive(Debug, Deserialize)]
pub struct PatchKnowledgeEntityParams {
pub id: String,
pub name: String,
pub entity_type: String,
pub description: String,
}
pub async fn patch_knowledge_entity(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<PatchKnowledgeEntityParams>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
// Get the existing entity and validate that the user is allowed
User::get_and_validate_knowledge_entity(&form.id, &user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let entity_type: KnowledgeEntityType = KnowledgeEntityType::from(form.entity_type);
// Update the entity
KnowledgeEntity::patch(
&form.id,
&form.name,
&form.description,
&entity_type,
&state.surreal_db_client,
&state.openai_client,
)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
// Get updated list of entities
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
// Render updated list
let output = render_template(
"knowledge/entity_list.html",
EntityListData { entities, user },
state.templates,
)?;
Ok(output.into_response())
}
pub async fn delete_knowledge_entity(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
// Get the existing entity and validate that the user is allowed
User::get_and_validate_knowledge_entity(&id, &user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
// Delete the entity
delete_item::<KnowledgeEntity>(&state.surreal_db_client, &id)
.await
.map_err(|e| HtmlError::new(AppError::from(e), state.templates.clone()))?;
// Get updated list of entities
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
// Render updated list
let output = render_template(
"knowledge/entity_list.html",
EntityListData { entities, user },
state.templates,
)?;
Ok(output.into_response())
}
#[derive(Serialize)]
pub struct RelationshipTableData {
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
}
pub async fn delete_knowledge_relationship(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
// GOTTA ADD AUTH VALIDATION
KnowledgeRelationship::delete_relationship_by_id(&id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
// Render updated list
let output = render_template(
"knowledge/relationship_table.html",
RelationshipTableData {
entities,
relationships,
},
state.templates,
)?;
Ok(output.into_response())
}
#[derive(Deserialize)]
pub struct SaveKnowledgeRelationshipInput {
pub in_: String,
pub out: String,
pub relationship_type: String,
}
pub async fn save_knowledge_relationship(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<SaveKnowledgeRelationshipInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not authenticated
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
// Construct relationship
let relationship = KnowledgeRelationship::new(
form.in_,
form.out,
user.id.clone(),
"manual".into(),
form.relationship_type,
);
relationship
.store_relationship(&state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let entities = User::get_knowledge_entities(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let relationships = User::get_knowledge_relationships(&user.id, &state.surreal_db_client)
.await
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;
// Render updated list
let output = render_template(
"knowledge/relationship_table.html",
RelationshipTableData {
entities,
relationships,
},
state.templates,
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,91 @@
use std::sync::Arc;
use axum::response::Html;
use minijinja_autoreload::AutoReloader;
use common::error::{HtmlError, IntoHtmlError};
pub mod account;
pub mod admin_panel;
pub mod chat;
pub mod content;
pub mod documentation;
pub mod gdpr;
pub mod index;
pub mod ingress_form;
pub mod knowledge;
pub mod search_result;
pub mod signin;
pub mod signout;
pub mod signup;
pub trait PageData {
fn template_name() -> &'static str;
}
// Helper function for render_template
pub fn render_template<T>(
template_name: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, HtmlError>
where
T: serde::Serialize,
{
let env = templates
.acquire_env()
.map_err(|e| e.with_template(templates.clone()))?;
let tmpl = env
.get_template(template_name)
.map_err(|e| e.with_template(templates.clone()))?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl
.render(context)
.map_err(|e| e.with_template(templates.clone()))?;
Ok(Html(output))
}
pub fn render_block<T>(
template_name: &str,
block: &str,
context: T,
templates: Arc<AutoReloader>,
) -> Result<Html<String>, HtmlError>
where
T: serde::Serialize,
{
let env = templates
.acquire_env()
.map_err(|e| e.with_template(templates.clone()))?;
let tmpl = env
.get_template(template_name)
.map_err(|e| e.with_template(templates.clone()))?;
let context = minijinja::Value::from_serialize(&context);
let output = tmpl
.eval_to_state(context)
.map_err(|e| e.with_template(templates.clone()))?
.render_block(block)
.map_err(|e| e.with_template(templates.clone()))?;
Ok(output.into())
}
#[macro_export]
macro_rules! page_data {
($name:ident, $template_name:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*$(,)?}) => {
use serde::{Serialize, Deserialize};
use $crate::routes::PageData;
#[derive(Debug, Deserialize, Serialize)]
pub struct $name {
$($(#[$attr])* pub $field: $ty),*
}
impl PageData for $name {
fn template_name() -> &'static str {
$template_name
}
}
};
}

View File

@@ -0,0 +1,68 @@
use axum::{
extract::{Query, State},
response::{IntoResponse, Redirect},
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::{Deserialize, Serialize};
use surrealdb::{engine::any::Any, Surreal};
use tracing::info;
use common::{error::HtmlError, storage::types::user::User};
use crate::{html_state::HtmlState, routes::render_template};
#[derive(Deserialize)]
pub struct SearchParams {
query: String,
}
#[derive(Serialize)]
pub struct AnswerData {
user_query: String,
answer_content: String,
answer_references: Vec<String>,
}
pub async fn search_result_handler(
State(state): State<HtmlState>,
Query(query): Query<SearchParams>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, HtmlError> {
info!("Displaying search results");
let user = match auth.current_user {
Some(user) => user,
None => return Ok(Redirect::to("/signin").into_response()),
};
// let answer = get_answer_with_references(
// &state.surreal_db_client,
// &state.openai_client,
// &query.query,
// &user.id,
// )
// .await
// .map_err(|e| HtmlError::new(e, state.templates.clone()))?;
let answer = "The Minne project is focused on simplifying knowledge management through features such as easy capture, smart analysis, and visualization of connections between ideas. It includes various functionalities like the Smart Analysis Feature, which provides content analysis and organization, and the Easy Capture Feature, which allows users to effortlessly capture and retrieve knowledge in various formats. Additionally, it offers tools like Knowledge Graph Visualization to enhance understanding and organization of knowledge. The project also emphasizes a user-friendly onboarding experience and mobile-friendly options for accessing its services.".to_string();
let references = vec![
"i81cd5be8-557c-4b2b-ba3a-4b8d28e74b9b".to_string(),
"5f72a724-d7a3-467d-8783-7cca6053ddc7".to_string(),
"ad106a1f-ccda-415e-9e87-c3a34e202624".to_string(),
"8797b57d-094d-4ee9-a3a7-c3195b246254".to_string(),
"69763f43-82e6-4cb5-ba3e-f6da13777dab".to_string(),
];
let output = render_template(
"index/signed_in/search_response.html",
AnswerData {
user_query: query.query,
answer_content: answer,
answer_references: references,
},
state.templates,
)?;
Ok(output.into_response())
}

View File

@@ -0,0 +1,71 @@
use axum::{
extract::State,
http::{StatusCode, Uri},
response::{Html, IntoResponse, Redirect},
Form,
};
use axum_htmx::{HxBoosted, HxRedirect};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use common::{error::HtmlError, storage::types::user::User};
use crate::{html_state::HtmlState, page_data};
use super::{render_block, render_template};
#[derive(Deserialize, Serialize)]
pub struct SignupParams {
pub email: String,
pub password: String,
pub remember_me: Option<String>,
}
page_data!(ShowSignInForm, "auth/signin_form.html", {});
pub async fn show_signin_form(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
}
let output = match boosted {
true => render_block(
ShowSignInForm::template_name(),
"body",
ShowSignInForm {},
state.templates.clone(),
)?,
false => render_template(
ShowSignInForm::template_name(),
ShowSignInForm {},
state.templates.clone(),
)?,
};
Ok(output.into_response())
}
pub async fn authenticate_user(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match User::authenticate(form.email, form.password, &state.surreal_db_client).await {
Ok(user) => user,
Err(_) => {
return Ok(Html("<p>Incorrect email or password </p>").into_response());
}
};
auth.login_user(user.id);
if form.remember_me.is_some_and(|string| string == *"on") {
auth.remember_user(true);
}
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
}

View File

@@ -0,0 +1,18 @@
use axum::response::{IntoResponse, Redirect};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use surrealdb::{engine::any::Any, Surreal};
use common::{error::ApiError, storage::types::user::User};
pub async fn sign_out_user(
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<impl IntoResponse, ApiError> {
if !auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
}
auth.logout_user();
Ok(Redirect::to("/").into_response())
}

View File

@@ -0,0 +1,65 @@
use axum::{
extract::State,
http::{StatusCode, Uri},
response::{Html, IntoResponse, Redirect},
Form,
};
use axum_htmx::{HxBoosted, HxRedirect};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::{Deserialize, Serialize};
use surrealdb::{engine::any::Any, Surreal};
use common::{error::HtmlError, storage::types::user::User};
use crate::html_state::HtmlState;
use super::{render_block, render_template};
#[derive(Deserialize, Serialize)]
pub struct SignupParams {
pub email: String,
pub password: String,
pub timezone: String,
}
pub async fn show_signup_form(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
HxBoosted(boosted): HxBoosted,
) -> Result<impl IntoResponse, HtmlError> {
if auth.is_authenticated() {
return Ok(Redirect::to("/").into_response());
}
let output = match boosted {
true => render_block("auth/signup_form.html", "body", {}, state.templates.clone())?,
false => render_template("auth/signup_form.html", {}, state.templates.clone())?,
};
Ok(output.into_response())
}
pub async fn process_signup_and_show_verification(
State(state): State<HtmlState>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, HtmlError> {
let user = match User::create_new(
form.email,
form.password,
&state.surreal_db_client,
form.timezone,
)
.await
{
Ok(user) => user,
Err(e) => {
tracing::error!("{:?}", e);
return Ok(Html(format!("<p>{}</p>", e)).into_response());
}
};
auth.login_user(user.id);
Ok((HxRedirect::from(Uri::from_static("/")), StatusCode::OK).into_response())
}

58
crates/main/Cargo.toml Normal file
View File

@@ -0,0 +1,58 @@
[package]
name = "main"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
axum = { workspace = true }
async-openai = "0.24.1"
async-stream = "0.3.6"
axum-htmx = "0.6.0"
axum_session = "0.14.4"
axum_session_auth = "0.14.1"
axum_session_surreal = "0.2.1"
axum_typed_multipart = "0.12.1"
chrono = { version = "0.4.39", features = ["serde"] }
chrono-tz = "0.10.1"
config = "0.15.4"
futures = "0.3.31"
json-stream-parser = "0.1.4"
lettre = { version = "0.11.11", features = ["rustls-tls"] }
mime = "0.3.17"
mime_guess = "2.0.5"
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
minijinja-autoreload = "2.5.0"
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
mockall = "0.13.0"
plotly = "0.12.1"
reqwest = {version = "0.12.12", features = ["charset", "json"]}
scraper = "0.22.0"
sha2 = "0.10.8"
surrealdb = "2.0.4"
tempfile = "3.12.0"
text-splitter = "0.18.1"
tiktoken-rs = "0.6.0"
tower-http = { version = "0.6.2", features = ["fs"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
url = { version = "2.5.2", features = ["serde"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
# Reference to api-router
api-router = { path = "../api-router" }
html-router = { path = "../html-router" }
common = { path = "../common" }
[[bin]]
name = "server"
path = "src/server.rs"
[[bin]]
name = "worker"
path = "src/worker.rs"

47
crates/main/src/server.rs Normal file
View File

@@ -0,0 +1,47 @@
use api_router::{api_routes_v1, api_state::ApiState};
use axum::{extract::FromRef, Router};
use common::utils::config::get_config;
use html_router::{html_routes, html_state::HtmlState};
use tracing::info;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Set up tracing
tracing_subscriber::registry()
.with(fmt::layer())
.with(EnvFilter::from_default_env())
.try_init()
.ok();
// Get config
let config = get_config()?;
// Set up router states
let html_state = HtmlState::new(&config).await?;
let api_state = ApiState {
surreal_db_client: html_state.surreal_db_client.clone(),
job_queue: html_state.job_queue.clone(),
};
// Create Axum router
let app = Router::new()
.nest("/api/v1", api_routes_v1(&api_state))
.nest("/", html_routes(&html_state))
.with_state(AppState {
api_state,
html_state,
});
info!("Listening on 0.0.0.0:3000");
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
axum::serve(listener, app).await?;
Ok(())
}
#[derive(Clone, FromRef)]
struct AppState {
api_state: ApiState,
html_state: HtmlState,
}

150
crates/main/src/worker.rs Normal file
View File

@@ -0,0 +1,150 @@
use std::sync::Arc;
use common::{
ingress::{
content_processor::ContentProcessor,
jobqueue::{JobQueue, MAX_ATTEMPTS},
},
storage::{
db::{get_item, SurrealDbClient},
types::job::{Job, JobStatus},
},
utils::config::get_config,
};
use futures::StreamExt;
use surrealdb::Action;
use tracing::{error, info};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Set up tracing
tracing_subscriber::registry()
.with(fmt::layer())
.with(EnvFilter::from_default_env())
.try_init()
.ok();
let config = get_config()?;
let surreal_db_client = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
let openai_client = Arc::new(async_openai::Client::new());
let job_queue = JobQueue::new(surreal_db_client.clone());
let content_processor = ContentProcessor::new(surreal_db_client, openai_client.clone()).await?;
loop {
// First, check for any unfinished jobs
let unfinished_jobs = job_queue.get_unfinished_jobs().await?;
if !unfinished_jobs.is_empty() {
info!("Found {} unfinished jobs", unfinished_jobs.len());
for job in unfinished_jobs {
job_queue
.process_job(job, &content_processor, openai_client.clone())
.await?;
}
}
// If no unfinished jobs, start listening for new ones
info!("Listening for new jobs...");
let mut job_stream = job_queue.listen_for_jobs().await?;
while let Some(notification) = job_stream.next().await {
match notification {
Ok(notification) => {
info!("Received notification: {:?}", notification);
match notification.action {
Action::Create => {
if let Err(e) = job_queue
.process_job(
notification.data,
&content_processor,
openai_client.clone(),
)
.await
{
error!("Error processing job: {}", e);
}
}
Action::Update => {
match notification.data.status {
JobStatus::Completed
| JobStatus::Error(_)
| JobStatus::Cancelled => {
info!(
"Skipping already completed/error/cancelled job: {}",
notification.data.id
);
continue;
}
JobStatus::InProgress { attempts, .. } => {
// Only process if this is a retry after an error, not our own update
if let Ok(Some(current_job)) =
get_item::<Job>(&job_queue.db.client, &notification.data.id)
.await
{
match current_job.status {
JobStatus::Error(_) if attempts < MAX_ATTEMPTS => {
// This is a retry after an error
if let Err(e) = job_queue
.process_job(
current_job,
&content_processor,
openai_client.clone(),
)
.await
{
error!("Error processing job retry: {}", e);
}
}
_ => {
info!(
"Skipping in-progress update for job: {}",
notification.data.id
);
continue;
}
}
}
}
JobStatus::Created => {
// Shouldn't happen with Update action, but process if it does
if let Err(e) = job_queue
.process_job(
notification.data,
&content_processor,
openai_client.clone(),
)
.await
{
error!("Error processing job: {}", e);
}
}
}
}
_ => {} // Ignore other actions
}
}
Err(e) => error!("Error in job notification: {}", e),
}
}
// If we reach here, the stream has ended (connection lost?)
error!("Job stream ended unexpectedly, reconnecting...");
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
}