in progress, routers and main split up

This commit is contained in:
Per Stark
2025-03-04 07:44:00 +01:00
parent 037bc52a64
commit 847571729b
80 changed files with 599 additions and 1577 deletions

254
crates/common/src/error.rs Normal file
View File

@@ -0,0 +1,254 @@
use std::sync::Arc;
use async_openai::error::OpenAIError;
use axum::{
http::StatusCode,
response::{Html, IntoResponse, Response},
Json,
};
use minijinja::context;
use minijinja_autoreload::AutoReloader;
use serde::Serialize;
use serde_json::json;
use thiserror::Error;
use tokio::task::JoinError;
use crate::{storage::types::file_info::FileError, utils::mailer::EmailError};
// Core internal errors
#[derive(Error, Debug)]
pub enum AppError {
#[error("Database error: {0}")]
Database(#[from] surrealdb::Error),
#[error("OpenAI error: {0}")]
OpenAI(#[from] OpenAIError),
#[error("File error: {0}")]
File(#[from] FileError),
#[error("Email error: {0}")]
Email(#[from] EmailError),
#[error("Not found: {0}")]
NotFound(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("Authorization error: {0}")]
Auth(String),
#[error("LLM parsing error: {0}")]
LLMParsing(String),
#[error("Task join error: {0}")]
Join(#[from] JoinError),
#[error("Graph mapper error: {0}")]
GraphMapper(String),
#[error("IoError: {0}")]
Io(#[from] std::io::Error),
#[error("Minijina error: {0}")]
MiniJinja(#[from] minijinja::Error),
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Tiktoken error: {0}")]
Tiktoken(#[from] anyhow::Error),
#[error("Ingress Processing error: {0}")]
Processing(String),
}
// API-specific errors
#[derive(Debug, Serialize)]
pub enum ApiError {
InternalError(String),
ValidationError(String),
NotFound(String),
Unauthorized(String),
}
impl From<AppError> for ApiError {
fn from(err: AppError) -> Self {
match err {
AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => {
tracing::error!("Internal error: {:?}", err);
ApiError::InternalError("Internal server error".to_string())
}
AppError::NotFound(msg) => ApiError::NotFound(msg),
AppError::Validation(msg) => ApiError::ValidationError(msg),
AppError::Auth(msg) => ApiError::Unauthorized(msg),
_ => ApiError::InternalError("Internal server error".to_string()),
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let (status, body) = match self {
ApiError::InternalError(message) => (
StatusCode::INTERNAL_SERVER_ERROR,
json!({
"error": message,
"status": "error"
}),
),
ApiError::ValidationError(message) => (
StatusCode::BAD_REQUEST,
json!({
"error": message,
"status": "error"
}),
),
ApiError::NotFound(message) => (
StatusCode::NOT_FOUND,
json!({
"error": message,
"status": "error"
}),
),
ApiError::Unauthorized(message) => (
StatusCode::UNAUTHORIZED,
json!({
"error": message,
"status": "error"
}),
), // ... other matches
};
(status, Json(body)).into_response()
}
}
pub type TemplateResult<T> = Result<T, HtmlError>;
// Helper trait for converting to HtmlError with templates
pub trait IntoHtmlError {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError;
}
// // Implement for AppError
impl IntoHtmlError for AppError {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
HtmlError::new(self, templates)
}
}
// // Implement for minijinja::Error directly
impl IntoHtmlError for minijinja::Error {
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
HtmlError::from_template_error(self, templates)
}
}
#[derive(Clone)]
pub struct ErrorContext {
#[allow(dead_code)]
templates: Arc<AutoReloader>,
}
impl ErrorContext {
pub fn new(templates: Arc<AutoReloader>) -> Self {
Self { templates }
}
}
pub enum HtmlError {
ServerError(Arc<AutoReloader>),
NotFound(Arc<AutoReloader>),
Unauthorized(Arc<AutoReloader>),
BadRequest(String, Arc<AutoReloader>),
Template(String, Arc<AutoReloader>),
}
impl HtmlError {
pub fn new(error: AppError, templates: Arc<AutoReloader>) -> Self {
match error {
AppError::NotFound(_msg) => HtmlError::NotFound(templates),
AppError::Auth(_msg) => HtmlError::Unauthorized(templates),
AppError::Validation(msg) => HtmlError::BadRequest(msg, templates),
_ => {
tracing::error!("Internal error: {:?}", error);
HtmlError::ServerError(templates)
}
}
}
pub fn from_template_error(error: minijinja::Error, templates: Arc<AutoReloader>) -> Self {
tracing::error!("Template error: {:?}", error);
HtmlError::Template(error.to_string(), templates)
}
}
impl IntoResponse for HtmlError {
fn into_response(self) -> Response {
let (status, context, templates) = match self {
HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => (
StatusCode::INTERNAL_SERVER_ERROR,
context! {
status_code => 500,
title => "Internal Server Error",
error => "Internal Server Error",
description => "Something went wrong on our end."
},
templates,
),
HtmlError::NotFound(templates) => (
StatusCode::NOT_FOUND,
context! {
status_code => 404,
title => "Page Not Found",
error => "Not Found",
description => "The page you're looking for doesn't exist or was removed."
},
templates,
),
HtmlError::Unauthorized(templates) => (
StatusCode::UNAUTHORIZED,
context! {
status_code => 401,
title => "Unauthorized",
error => "Access Denied",
description => "You need to be logged in to access this page."
},
templates,
),
HtmlError::BadRequest(msg, templates) => (
StatusCode::BAD_REQUEST,
context! {
status_code => 400,
title => "Bad Request",
error => "Bad Request",
description => msg
},
templates,
),
};
let html = match templates.acquire_env() {
Ok(env) => match env.get_template("errors/error.html") {
Ok(tmpl) => match tmpl.render(context) {
Ok(output) => output,
Err(e) => {
tracing::error!("Template render error: {:?}", e);
Self::fallback_html()
}
},
Err(e) => {
tracing::error!("Template get error: {:?}", e);
Self::fallback_html()
}
},
Err(e) => {
tracing::error!("Environment acquire error: {:?}", e);
Self::fallback_html()
}
};
(status, Html(html)).into_response()
}
}
impl HtmlError {
fn fallback_html() -> String {
r#"
<html>
<body>
<div class="container mx-auto p-4">
<h1 class="text-4xl text-error">Error</h1>
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
</div>
</body>
</html>
"#
.to_string()
}
}

View File

@@ -0,0 +1,145 @@
use crate::{
error::AppError,
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
retrieval::combined_knowledge_entity_retrieval,
storage::types::knowledge_entity::KnowledgeEntity,
};
use async_openai::{
error::OpenAIError,
types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
},
};
use serde_json::json;
use surrealdb::engine::any::Any;
use surrealdb::Surreal;
use tracing::debug;
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
pub struct IngressAnalyzer<'a> {
db_client: &'a Surreal<Any>,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl<'a> IngressAnalyzer<'a> {
pub fn new(
db_client: &'a Surreal<Any>,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Self {
Self {
db_client,
openai_client,
}
}
pub async fn analyze_content(
&self,
category: &str,
instructions: &str,
text: &str,
user_id: &str,
) -> Result<LLMGraphAnalysisResult, AppError> {
let similar_entities = self
.find_similar_entities(category, instructions, text, user_id)
.await?;
let llm_request =
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
self.perform_analysis(llm_request).await
}
async fn find_similar_entities(
&self,
category: &str,
instructions: &str,
text: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let input_text = format!(
"content: {}, category: {}, user_instructions: {}",
text, category, instructions
);
combined_knowledge_entity_retrieval(
self.db_client,
self.openai_client,
&input_text,
user_id,
)
.await
}
fn prepare_llm_request(
&self,
category: &str,
instructions: &str,
text: &str,
similar_entities: &[KnowledgeEntity],
) -> Result<CreateChatCompletionRequest, OpenAIError> {
let entities_json = json!(similar_entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>());
let user_message = format!(
"Category:\n{}\nInstructions:\n{}\nContent:\n{}\nExisting KnowledgeEntities in database:\n{}",
category, instructions, text, entities_json
);
debug!("Prepared LLM request message: {}", user_message);
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Structured analysis of the submitted content".into()),
name: "content_analysis".into(),
schema: Some(get_ingress_analysis_schema()),
strict: Some(true),
},
};
CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.2)
.max_tokens(3048u32)
.messages([
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
}
async fn perform_analysis(
&self,
request: CreateChatCompletionRequest,
) -> Result<LLMGraphAnalysisResult, AppError> {
let response = self.openai_client.chat().create(request).await?;
debug!("Received LLM response: {:?}", response);
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing(
"No content found in LLM response".to_string(),
))
.and_then(|content| {
serde_json::from_str(content).map_err(|e| {
AppError::LLMParsing(format!(
"Failed to parse LLM response into analysis: {}",
e
))
})
})
}
}

View File

@@ -0,0 +1,3 @@
pub mod ingress_analyser;
pub mod prompt;
pub mod types;

View File

@@ -0,0 +1,81 @@
use serde_json::{json, Value};
pub static INGRESS_ANALYSIS_SYSTEM_MESSAGE: &str = r#"
You are an AI assistant. You will receive a text content, along with user instructions and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users instructions and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
The JSON should have the following structure:
{
"knowledge_entities": [
{
"key": "unique-key-1",
"name": "Entity Name",
"description": "A detailed description of the entity.",
"entity_type": "TypeOfEntity"
},
// More entities...
],
"relationships": [
{
"type": "RelationshipType",
"source": "unique-key-1 or UUID from existing database",
"target": "unique-key-1 or UUID from existing database"
},
// More relationships...
]
}
Guidelines:
1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
7. Only create relationships between existing KnowledgeEntities.
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
9. A new relationship MUST include a newly created KnowledgeEntity.
"#;
pub fn get_ingress_analysis_schema() -> Value {
json!({
"type": "object",
"properties": {
"knowledge_entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"key": { "type": "string" },
"name": { "type": "string" },
"description": { "type": "string" },
"entity_type": {
"type": "string",
"enum": ["idea", "project", "document", "page", "textsnippet"]
}
},
"required": ["key", "name", "description", "entity_type"],
"additionalProperties": false
}
},
"relationships": {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["RelatedTo", "RelevantTo", "SimilarTo"]
},
"source": { "type": "string" },
"target": { "type": "string" }
},
"required": ["type", "source", "target"],
"additionalProperties": false
}
}
},
"required": ["knowledge_entities", "relationships"],
"additionalProperties": false
})
}

View File

@@ -0,0 +1,42 @@
use std::collections::HashMap;
use uuid::Uuid;
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
#[derive(Clone)]
pub struct GraphMapper {
pub key_to_id: HashMap<String, Uuid>,
}
impl Default for GraphMapper {
fn default() -> Self {
GraphMapper::new()
}
}
impl GraphMapper {
pub fn new() -> Self {
GraphMapper {
key_to_id: HashMap::new(),
}
}
/// Get ID, tries to parse UUID
pub fn get_or_parse_id(&mut self, key: &str) -> Uuid {
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
parsed_uuid
} else {
*self.key_to_id.get(key).unwrap()
}
}
/// Assigns a new UUID for a given key.
pub fn assign_id(&mut self, key: &str) -> Uuid {
let id = Uuid::new_v4();
self.key_to_id.insert(key.to_string(), id);
id
}
/// Retrieves the UUID for a given key.
pub fn get_id(&self, key: &str) -> Option<&Uuid> {
self.key_to_id.get(key)
}
}

View File

@@ -0,0 +1,182 @@
use std::sync::{Arc, Mutex};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use tokio::task;
use crate::{
error::AppError,
storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
},
utils::embedding::generate_embedding,
};
use futures::future::try_join_all;
use super::graph_mapper::GraphMapper; // For future parallelization
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMKnowledgeEntity {
pub key: String, // Temporary identifier
pub name: String,
pub description: String,
pub entity_type: String, // Should match KnowledgeEntityType variants
}
/// Represents a single relationship from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMRelationship {
#[serde(rename = "type")]
pub type_: String, // e.g., RelatedTo, RelevantTo
pub source: String, // Key of the source entity
pub target: String, // Key of the target entity
}
/// Represents the entire graph analysis result from the LLM.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LLMGraphAnalysisResult {
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
pub relationships: Vec<LLMRelationship>,
}
impl LLMGraphAnalysisResult {
/// Converts the LLM graph analysis result into database entities and relationships.
///
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
/// * `openai_client` - OpenAI client for LLM calls.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities(
&self,
source_id: &str,
user_id: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
// Create mapper and pre-assign IDs
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
// Process entities
let entities = self
.process_entities(source_id, user_id, Arc::clone(&mapper), openai_client)
.await?;
// Process relationships
let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?;
Ok((entities, relationships))
}
fn create_mapper(&self) -> Result<GraphMapper, AppError> {
let mut mapper = GraphMapper::new();
// Pre-assign all IDs
for entity in &self.knowledge_entities {
mapper.assign_id(&entity.key);
}
Ok(mapper)
}
async fn process_entities(
&self,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let futures: Vec<_> = self
.knowledge_entities
.iter()
.map(|entity| {
let mapper = Arc::clone(&mapper);
let openai_client = openai_client.clone();
let source_id = source_id.to_string();
let user_id = user_id.to_string();
let entity = entity.clone();
task::spawn(async move {
create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client)
.await
})
})
.collect();
let results = try_join_all(futures)
.await?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(results)
}
fn process_relationships(
&self,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
) -> Result<Vec<KnowledgeRelationship>, AppError> {
let mut mapper_guard = mapper
.lock()
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
self.relationships
.iter()
.map(|rel| {
let source_db_id = mapper_guard.get_or_parse_id(&rel.source);
let target_db_id = mapper_guard.get_or_parse_id(&rel.target);
Ok(KnowledgeRelationship::new(
source_db_id.to_string(),
target_db_id.to_string(),
user_id.to_string(),
source_id.to_string(),
rel.type_.clone(),
))
})
.collect()
}
}
async fn create_single_entity(
llm_entity: &LLMKnowledgeEntity,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<KnowledgeEntity, AppError> {
let assigned_id = {
let mapper = mapper
.lock()
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
AppError::GraphMapper(format!("ID not found for key: {}", llm_entity.key))
})?
.to_string()
};
let embedding_input = format!(
"name: {}, description: {}, type: {}",
llm_entity.name, llm_entity.description, llm_entity.entity_type
);
let embedding = generate_embedding(openai_client, &embedding_input).await?;
let now = Utc::now();
Ok(KnowledgeEntity {
id: assigned_id,
created_at: now,
updated_at: now,
name: llm_entity.name.to_string(),
description: llm_entity.description.to_string(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.to_string()),
source_id: source_id.to_string(),
metadata: None,
embedding,
user_id: user_id.into(),
})
}

View File

@@ -0,0 +1,2 @@
pub mod graph_mapper;
pub mod llm_analysis_result;

View File

@@ -0,0 +1,124 @@
use std::{sync::Arc, time::Instant};
use text_splitter::TextSplitter;
use tracing::{debug, info};
use crate::{
error::AppError,
storage::{
db::{store_item, SurrealDbClient},
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk, text_content::TextContent,
},
},
utils::embedding::generate_embedding,
};
use super::analysis::{
ingress_analyser::IngressAnalyzer, types::llm_analysis_result::LLMGraphAnalysisResult,
};
pub struct ContentProcessor {
db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
}
impl ContentProcessor {
pub async fn new(
surreal_db_client: Arc<SurrealDbClient>,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<Self, AppError> {
Ok(Self {
db_client: surreal_db_client,
openai_client,
})
}
pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
let now = Instant::now();
// Perform analyis, this step also includes retrieval
let analysis = self.perform_semantic_analysis(content).await?;
let end = now.elapsed();
info!(
"{:?} time elapsed during creation of entities and relationships",
end
);
// Convert analysis to objects
let (entities, relationships) = analysis
.to_database_entities(&content.id, &content.user_id, &self.openai_client)
.await?;
// Store everything
tokio::try_join!(
self.store_graph_entities(entities, relationships),
self.store_vector_chunks(content),
)?;
// Store original content
store_item(&self.db_client, content.to_owned()).await?;
self.db_client.rebuild_indexes().await?;
Ok(())
}
async fn perform_semantic_analysis(
&self,
content: &TextContent,
) -> Result<LLMGraphAnalysisResult, AppError> {
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
analyser
.analyze_content(
&content.category,
&content.instructions,
&content.text,
&content.user_id,
)
.await
}
async fn store_graph_entities(
&self,
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), AppError> {
for entity in &entities {
debug!("Storing entity: {:?}", entity);
store_item(&self.db_client, entity.clone()).await?;
}
for relationship in &relationships {
debug!("Storing relationship: {:?}", relationship);
relationship.store_relationship(&self.db_client).await?;
}
info!(
"Stored {} entities and {} relationships",
entities.len(),
relationships.len()
);
Ok(())
}
async fn store_vector_chunks(&self, content: &TextContent) -> Result<(), AppError> {
let splitter = TextSplitter::new(500..2000);
let chunks = splitter.chunks(&content.text);
// Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk).await?;
let text_chunk = TextChunk::new(
content.id.to_string(),
chunk.to_string(),
embedding,
content.user_id.to_string(),
);
store_item(&self.db_client, text_chunk).await?;
}
Ok(())
}
}

View File

@@ -0,0 +1,74 @@
use super::ingress_object::IngressObject;
use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use tracing::info;
use url::Url;
/// Struct defining the expected body when ingressing content.
#[derive(Serialize, Deserialize, Debug)]
pub struct IngressInput {
pub content: Option<String>,
pub instructions: String,
pub category: String,
pub files: Vec<FileInfo>,
}
/// Function to create ingress objects from input.
///
/// # Arguments
/// * `input` - IngressInput containing information needed to ingress content.
/// * `user_id` - User id of the ingressing user
///
/// # Returns
/// * `Vec<IngressObject>` - An array containing the ingressed objects, one file/contenttype per object.
pub fn create_ingress_objects(
input: IngressInput,
user_id: &str,
) -> Result<Vec<IngressObject>, AppError> {
// Initialize list
let mut object_list = Vec::new();
// Create a IngressObject from input.content if it exists, checking for URL or text
if let Some(input_content) = input.content {
match Url::parse(&input_content) {
Ok(url) => {
info!("Detected URL: {}", url);
object_list.push(IngressObject::Url {
url: url.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
}
Err(_) => {
if input_content.len() > 2 {
info!("Treating input as plain text");
object_list.push(IngressObject::Text {
text: input_content.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
}
}
}
}
for file in input.files {
object_list.push(IngressObject::File {
file_info: file,
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
})
}
// If no objects are constructed, we return Err
if object_list.is_empty() {
return Err(AppError::NotFound(
"No valid content or files provided".into(),
));
}
Ok(object_list)
}

View File

@@ -0,0 +1,277 @@
use std::{sync::Arc, time::Duration};
use crate::{
error::AppError,
storage::types::{file_info::FileInfo, text_content::TextContent},
};
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs,
};
use reqwest;
use scraper::{Html, Selector};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use tiktoken_rs::{o200k_base, CoreBPE};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngressObject {
Url {
url: String,
instructions: String,
category: String,
user_id: String,
},
Text {
text: String,
instructions: String,
category: String,
user_id: String,
},
File {
file_info: FileInfo,
instructions: String,
category: String,
user_id: String,
},
}
impl IngressObject {
/// Creates a new `TextContent` instance from a `IngressObject`.
///
/// # Arguments
/// `&self` - A reference to the `IngressObject`.
///
/// # Returns
/// `TextContent` - An object containing a text representation of the object, could be a scraped URL, parsed PDF, etc.
pub async fn to_text_content(
&self,
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<TextContent, AppError> {
match self {
IngressObject::Url {
url,
instructions,
category,
user_id,
} => {
let text = Self::fetch_text_from_url(url, openai_client).await?;
Ok(TextContent::new(
text,
instructions.into(),
category.into(),
None,
Some(url.into()),
user_id.into(),
))
}
IngressObject::Text {
text,
instructions,
category,
user_id,
} => Ok(TextContent::new(
text.into(),
instructions.into(),
category.into(),
None,
None,
user_id.into(),
)),
IngressObject::File {
file_info,
instructions,
category,
user_id,
} => {
let text = Self::extract_text_from_file(file_info).await?;
Ok(TextContent::new(
text,
instructions.into(),
category.into(),
Some(file_info.to_owned()),
None,
user_id.into(),
))
}
}
}
/// Get text from url, will return it as a markdown formatted string
async fn fetch_text_from_url(
url: &str,
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<String, AppError> {
// Use a client with timeouts and reuse
let client = reqwest::ClientBuilder::new()
.timeout(Duration::from_secs(30))
.build()?;
let response = client.get(url).send().await?.text().await?;
// Preallocate string with capacity
let mut structured_content = String::with_capacity(response.len() / 2);
let document = Html::parse_document(&response);
let main_selectors = Selector::parse(
"article, main, .article-content, .post-content, .entry-content, [role='main']",
)
.unwrap();
let content_element = document
.select(&main_selectors)
.next()
.or_else(|| document.select(&Selector::parse("body").unwrap()).next())
.ok_or(AppError::NotFound("No content found".into()))?;
// Compile selectors once
let heading_selector = Selector::parse("h1, h2, h3").unwrap();
let paragraph_selector = Selector::parse("p").unwrap();
// Process content in one pass
for element in content_element.select(&heading_selector) {
let _ = writeln!(
structured_content,
"<heading>{}</heading>",
element.text().collect::<String>().trim()
);
}
for element in content_element.select(&paragraph_selector) {
let _ = writeln!(
structured_content,
"<paragraph>{}</paragraph>",
element.text().collect::<String>().trim()
);
}
let content = structured_content
.replace(|c: char| c.is_control(), " ")
.replace(" ", " ");
Self::process_web_content(content, openai_client).await
}
pub async fn process_web_content(
content: String,
openai_client: &Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<String, AppError> {
const MAX_TOKENS: usize = 122000;
const SYSTEM_PROMPT: &str = r#"
You are a precise content extractor for web pages. Your task:
1. Extract ONLY the main article/content from the provided text
2. Maintain the original content - do not summarize or modify the core information
3. Ignore peripheral content such as:
- Navigation elements
- Error messages (e.g., "JavaScript required")
- Related articles sections
- Comments
- Social media links
- Advertisement text
FORMAT:
- Convert <heading> tags to markdown headings (#, ##, ###)
- Convert <paragraph> tags to markdown paragraphs
- Preserve quotes and important formatting
- Remove duplicate content
- Remove any metadata or technical artifacts
OUTPUT RULES:
- Output ONLY the cleaned content in markdown
- Do not add any explanations or meta-commentary
- Do not add summaries or conclusions
- Do not use any XML/HTML tags in the output
"#;
let bpe = o200k_base()?;
// Process content in chunks if needed
let truncated_content = if bpe.encode_with_special_tokens(&content).len() > MAX_TOKENS {
Self::truncate_content(&content, MAX_TOKENS, &bpe)?
} else {
content
};
let request = CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.0)
.max_tokens(16200u32)
.messages([
ChatCompletionRequestSystemMessage::from(SYSTEM_PROMPT).into(),
ChatCompletionRequestUserMessage::from(truncated_content).into(),
])
.build()?;
let response = openai_client.chat().create(request).await?;
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.map(|content| content.to_owned())
.ok_or(AppError::LLMParsing("No content in response".into()))
}
fn truncate_content(
content: &str,
max_tokens: usize,
tokenizer: &CoreBPE,
) -> Result<String, AppError> {
// Pre-allocate with estimated size
let mut result = String::with_capacity(content.len() / 2);
let mut current_tokens = 0;
// Process content by paragraph to maintain context
for paragraph in content.split("\n\n") {
let tokens = tokenizer.encode_with_special_tokens(paragraph).len();
// Check if adding paragraph exceeds limit
if current_tokens + tokens > max_tokens {
break;
}
result.push_str(paragraph);
result.push_str("\n\n");
current_tokens += tokens;
}
// Ensure we return valid content
if result.is_empty() {
return Err(AppError::Processing("Content exceeds token limit".into()));
}
Ok(result.trim_end().to_string())
}
/// Extracts text from a file based on its MIME type.
async fn extract_text_from_file(file_info: &FileInfo) -> Result<String, AppError> {
match file_info.mime_type.as_str() {
"text/plain" => {
// Read the file and return its content
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"text/markdown" => {
// Read the file and return its content
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"application/pdf" => {
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
Err(AppError::NotFound(file_info.mime_type.clone()))
}
"image/png" | "image/jpeg" => {
// TODO: Implement OCR on image using a crate like `tesseract`
Err(AppError::NotFound(file_info.mime_type.clone()))
}
"application/octet-stream" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"text/x-rust" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
// Handle other MIME types as needed
_ => Err(AppError::NotFound(file_info.mime_type.clone())),
}
}
}

View File

@@ -0,0 +1,182 @@
use chrono::Utc;
use futures::Stream;
use std::sync::Arc;
use surrealdb::{opt::PatchOp, Error, Notification};
use tracing::{debug, error, info};
use crate::{
error::AppError,
storage::{
db::{delete_item, get_item, store_item, SurrealDbClient},
types::{
job::{Job, JobStatus},
StoredObject,
},
},
};
use super::{content_processor::ContentProcessor, ingress_object::IngressObject};
pub struct JobQueue {
pub db: Arc<SurrealDbClient>,
}
pub const MAX_ATTEMPTS: u32 = 3;
impl JobQueue {
pub fn new(db: Arc<SurrealDbClient>) -> Self {
Self { db }
}
/// Creates a new job and stores it in the database
pub async fn enqueue(&self, content: IngressObject, user_id: String) -> Result<(), AppError> {
let job = Job::new(content, user_id).await;
info!("{:?}", job);
store_item(&self.db, job).await?;
Ok(())
}
/// Gets all jobs for a specific user
pub async fn get_user_jobs(&self, user_id: &str) -> Result<Vec<Job>, AppError> {
let jobs: Vec<Job> = self
.db
.query("SELECT * FROM job WHERE user_id = $user_id ORDER BY created_at DESC")
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)?;
debug!("{:?}", jobs);
Ok(jobs)
}
/// Gets all active jobs for a specific user
pub async fn get_unfinished_user_jobs(&self, user_id: &str) -> Result<Vec<Job>, AppError> {
let jobs: Vec<Job> = self
.db
.query(
"SELECT * FROM type::table($table)
WHERE user_id = $user_id
AND (
status = 'Created'
OR (
status.InProgress != NONE
AND status.InProgress.attempts < $max_attempts
)
)
ORDER BY created_at DESC",
)
.bind(("table", Job::table_name()))
.bind(("user_id", user_id.to_owned()))
.bind(("max_attempts", MAX_ATTEMPTS))
.await?
.take(0)?;
debug!("{:?}", jobs);
Ok(jobs)
}
pub async fn delete_job(&self, id: &str, user_id: &str) -> Result<(), AppError> {
get_item::<Job>(&self.db.client, id)
.await?
.filter(|job| job.user_id == user_id)
.ok_or_else(|| {
error!("Unauthorized attempt to delete job {id} by user {user_id}");
AppError::Auth("Not authorized to delete this job".into())
})?;
info!("Deleting job {id} for user {user_id}");
delete_item::<Job>(&self.db.client, id)
.await
.map_err(AppError::Database)?;
Ok(())
}
pub async fn update_status(&self, id: &str, status: JobStatus) -> Result<(), AppError> {
let _job: Option<Job> = self
.db
.update((Job::table_name(), id))
.patch(PatchOp::replace("/status", status))
.patch(PatchOp::replace(
"/updated_at",
surrealdb::sql::Datetime::default(),
))
.await?;
Ok(())
}
/// Listen for new jobs
pub async fn listen_for_jobs(
&self,
) -> Result<impl Stream<Item = Result<Notification<Job>, Error>>, Error> {
self.db.select("job").live().await
}
/// Get unfinished jobs, ie newly created and in progress up two times
pub async fn get_unfinished_jobs(&self) -> Result<Vec<Job>, AppError> {
let jobs: Vec<Job> = self
.db
.query(
"SELECT * FROM type::table($table)
WHERE
status = 'Created'
OR (
status.InProgress != NONE
AND status.InProgress.attempts < $max_attempts
)
ORDER BY created_at ASC",
)
.bind(("table", Job::table_name()))
.bind(("max_attempts", MAX_ATTEMPTS))
.await?
.take(0)?;
Ok(jobs)
}
// Method to process a single job
pub async fn process_job(
&self,
job: Job,
processor: &ContentProcessor,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Result<(), AppError> {
let current_attempts = match job.status {
JobStatus::InProgress { attempts, .. } => attempts + 1,
_ => 1,
};
// Update status to InProgress with attempt count
self.update_status(
&job.id,
JobStatus::InProgress {
attempts: current_attempts,
last_attempt: Utc::now(),
},
)
.await?;
let text_content = job.content.to_text_content(&openai_client).await?;
match processor.process(&text_content).await {
Ok(_) => {
self.update_status(&job.id, JobStatus::Completed).await?;
Ok(())
}
Err(e) => {
if current_attempts >= MAX_ATTEMPTS {
self.update_status(
&job.id,
JobStatus::Error(format!("Max attempts reached: {}", e)),
)
.await?;
}
Err(AppError::Processing(e.to_string()))
}
}
}
}

View File

@@ -0,0 +1,6 @@
pub mod analysis;
pub mod content_processor;
pub mod ingress_input;
pub mod ingress_object;
pub mod jobqueue;
pub mod queue_task;

View File

@@ -0,0 +1,13 @@
use crate::ingress::ingress_object::IngressObject;
use serde::Serialize;
#[derive(Serialize)]
pub struct QueueTask {
pub delivery_tag: u64,
pub content: IngressObject,
}
#[derive(Serialize)]
pub struct QueueTaskResponse {
pub tasks: Vec<QueueTask>,
}

6
crates/common/src/lib.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod error;
pub mod ingress;
pub mod retrieval;
pub mod server;
pub mod storage;
pub mod utils;

View File

@@ -0,0 +1,74 @@
use surrealdb::{engine::any::Any, Error, Surreal};
use tracing::debug;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, StoredObject};
/// Retrieves database entries that match a specific source identifier.
///
/// This function queries the database for all records in a specified table that have
/// a matching `source_id` field. It's commonly used to find related entities or
/// track the origin of database entries.
///
/// # Arguments
///
/// * `source_id` - The identifier to search for in the database
/// * `table_name` - The name of the table to search in
/// * `db_client` - The SurrealDB client instance for database operations
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
/// * `Err(Error)` - An error if the database query fails
///
/// # Errors
///
/// This function will return a `Error` if:
/// * The database query fails to execute
/// * The results cannot be deserialized into type `T`
pub async fn find_entities_by_source_ids<T>(
source_id: Vec<String>,
table_name: String,
db_client: &Surreal<Any>,
) -> Result<Vec<T>, Error>
where
T: for<'de> serde::Deserialize<'de>,
{
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
db_client
.query(query)
.bind(("table", table_name))
.bind(("source_ids", source_id))
.await?
.take(0)
}
/// Find entities by their relationship to the id
pub async fn find_entities_by_relationship_by_id(
db_client: &Surreal<Any>,
entity_id: String,
) -> Result<Vec<KnowledgeEntity>, Error> {
let query = format!(
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
entity_id
);
debug!("{}", query);
db_client.query(query).await?.take(0)
}
/// Get a specific KnowledgeEntity by its id
pub async fn get_entity_by_id(
db_client: &Surreal<Any>,
entity_id: &str,
) -> Result<Option<KnowledgeEntity>, Error> {
db_client
.select((KnowledgeEntity::table_name(), entity_id))
.await
}

View File

@@ -0,0 +1,92 @@
pub mod graph;
pub mod query_helper;
pub mod query_helper_prompt;
pub mod vector;
use crate::{
error::AppError,
retrieval::{
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
vector::find_items_by_vector_similarity,
},
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
};
use futures::future::{try_join, try_join_all};
use std::collections::HashMap;
use surrealdb::{engine::any::Any, Surreal};
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
/// to find the most relevant entities for a given query.
///
/// # Strategy
/// The function employs a three-pronged approach to knowledge retrieval:
/// 1. Direct vector similarity search on knowledge entities
/// 2. Text chunk similarity search with source entity lookup
/// 3. Graph relationship traversal from related entities
///
/// This combined approach ensures both semantic similarity matches and structurally
/// related content are included in the results.
///
/// # Arguments
/// * `db_client` - SurrealDB client for database operations
/// * `openai_client` - OpenAI client for vector embeddings generation
/// * `query` - The search query string to find relevant knowledge entities
/// * 'user_id' - The user id of the current user
///
/// # Returns
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
/// knowledge entities, or an error if the retrieval process fails
pub async fn combined_knowledge_entity_retrieval(
db_client: &Surreal<Any>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
// info!("Received input: {:?}", query);
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
find_items_by_vector_similarity(
10,
query,
db_client,
"knowledge_entity",
openai_client,
user_id,
),
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
)
.await?;
let source_ids = closest_chunks
.iter()
.map(|chunk: &TextChunk| chunk.source_id.clone())
.collect::<Vec<String>>();
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
.clone()
.into_iter()
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
.collect();
let items_from_relationships = try_join_all(items_from_relationships_futures)
.await?
.into_iter()
.flatten()
.collect::<Vec<KnowledgeEntity>>();
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
.into_iter()
.chain(items_from_text_chunk_similarity.into_iter())
.chain(items_from_relationships.into_iter())
.fold(HashMap::new(), |mut map, entity| {
map.insert(entity.id.clone(), entity);
map
})
.into_values()
.collect();
Ok(entities)
}

View File

@@ -0,0 +1,193 @@
use async_openai::{
error::OpenAIError,
types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
ResponseFormat, ResponseFormatJsonSchema,
},
};
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::debug;
use crate::{
error::AppError,
retrieval::combined_knowledge_entity_retrieval,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use super::query_helper_prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT};
#[derive(Debug, Deserialize)]
pub struct Reference {
#[allow(dead_code)]
pub reference: String,
}
#[derive(Debug, Deserialize)]
pub struct LLMResponseFormat {
pub answer: String,
#[allow(dead_code)]
pub references: Vec<Reference>,
}
// /// Orchestrator function that takes a query and clients and returns a answer with references
// ///
// /// # Arguments
// /// * `surreal_db_client` - Client for interacting with SurrealDn
// /// * `openai_client` - Client for interacting with openai
// /// * `query` - The query
// ///
// /// # Returns
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
// pub async fn get_answer_with_references(
// surreal_db_client: &SurrealDbClient,
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
// query: &str,
// ) -> Result<(String, Vec<String>), ApiError> {
// let entities =
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
// // Format entities and create message
// let entities_json = format_entities_json(&entities);
// let user_message = create_user_message(&entities_json, query);
// // Create and send request
// let request = create_chat_request(user_message)?;
// let response = openai_client
// .chat()
// .create(request)
// .await
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
// // Process response
// let answer = process_llm_response(response).await?;
// let references: Vec<String> = answer
// .references
// .into_iter()
// .map(|reference| reference.reference)
// .collect();
// Ok((answer.answer, references))
// }
/// Orchestrates query processing and returns an answer with references
///
/// Takes a query and uses the provided clients to generate an answer with supporting references.
///
/// # Arguments
///
/// * `surreal_db_client` - Client for SurrealDB interactions
/// * `openai_client` - Client for OpenAI API calls
/// * `query` - The user's query string
/// * `user_id` - The user's id
///
/// # Returns
///
/// Returns a tuple of the answer and its references, or an API error
#[derive(Debug)]
pub struct Answer {
pub content: String,
pub references: Vec<String>,
}
pub async fn get_answer_with_references(
surreal_db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Answer, AppError> {
let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
.await?;
let entities_json = format_entities_json(&entities);
debug!("{:?}", entities_json);
let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message)?;
let response = openai_client.chat().create(request).await?;
let llm_response = process_llm_response(response).await?;
Ok(Answer {
content: llm_response.answer,
references: llm_response
.references
.into_iter()
.map(|r| r.reference)
.collect(),
})
}
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
json!(entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>())
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r#"
Context Information:
==================
{}
User Question:
==================
{}
"#,
entities_json, query
)
}
pub fn create_chat_request(
user_message: String,
) -> Result<CreateChatCompletionRequest, OpenAIError> {
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: Some("Query answering AI".into()),
name: "query_answering_with_uuids".into(),
schema: Some(get_query_response_schema()),
strict: Some(true),
},
};
CreateChatCompletionRequestArgs::default()
.model("gpt-4o-mini")
.temperature(0.2)
.max_tokens(3048u32)
.messages([
ChatCompletionRequestSystemMessage::from(QUERY_SYSTEM_PROMPT).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
}
pub async fn process_llm_response(
response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, AppError> {
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(AppError::LLMParsing(
"No content found in LLM response".into(),
))
.and_then(|content| {
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {}", e))
})
})
}

View File

@@ -0,0 +1,48 @@
use serde_json::{json, Value};
pub static QUERY_SYSTEM_PROMPT: &str = r#"
You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
Your task is to:
1. Carefully analyze the provided knowledge entities in the context
2. Answer user questions based on this information
3. Provide clear, concise, and accurate responses
4. When referencing information, briefly mention which knowledge entity it came from
5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this
6. If only partial information is available, explain what you can answer and what information is missing
7. Avoid making assumptions or providing information not supported by the context
8. Output the references to the documents. Use the UUIDs and make sure they are correct!
Remember:
- Be direct and honest about the limitations of your knowledge
- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array
- If you need to combine information from multiple entities, explain how they connect
- Don't speculate beyond what's provided in the context
Example response formats:
"Based on [Entity Name], [answer...]"
"I found relevant information in multiple entries: [explanation...]"
"I apologize, but the provided context doesn't contain information about [topic]"
"#;
pub fn get_query_response_schema() -> Value {
json!({
"type": "object",
"properties": {
"answer": { "type": "string" },
"references": {
"type": "array",
"items": {
"type": "object",
"properties": {
"reference": { "type": "string" },
},
"required": ["reference"],
"additionalProperties": false,
}
}
},
"required": ["answer", "references"],
"additionalProperties": false
})
}

View File

@@ -0,0 +1,47 @@
use surrealdb::{engine::any::Any, Surreal};
use crate::{error::AppError, utils::embedding::generate_embedding};
/// Compares vectors and retrieves a number of items from the specified table.
///
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
/// and then deserializes the results into the specified type `T`.
///
/// # Arguments
///
/// * `take` - The number of items to retrieve from the database.
/// * `input_text` - The text to generate embeddings for.
/// * `db_client` - The SurrealDB client to use for querying the database.
/// * `table` - The table to query in the database.
/// * `openai_client` - The OpenAI client to use for generating embeddings.
/// * 'user_id`- The user id of the current user.
///
/// # Returns
///
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: u8,
input_text: &str,
db_client: &Surreal<Any>,
table: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<T>, AppError>
where
T: for<'de> serde::Deserialize<'de>,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
// Perform query and deserialize to struct
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
Ok(closest_entities)
}

View File

@@ -0,0 +1,180 @@
use crate::error::AppError;
use super::types::{analytics::Analytics, system_settings::SystemSettings, StoredObject};
use axum_session::{SessionConfig, SessionError, SessionStore};
use axum_session_surreal::SessionSurrealPool;
use std::ops::Deref;
use surrealdb::{
engine::any::{connect, Any},
opt::auth::Root,
Error, Surreal,
};
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Any>,
}
impl SurrealDbClient {
/// # Initialize a new datbase client
///
/// # Arguments
///
/// # Returns
/// * `SurrealDbClient` initialized
pub async fn new(
address: &str,
username: &str,
password: &str,
namespace: &str,
database: &str,
) -> Result<Self, Error> {
let db = connect(address).await?;
// Sign in to database
db.signin(Root { username, password }).await?;
// Set namespace
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
SessionStore::new(
Some(self.client.clone().into()),
SessionConfig::default()
.with_table_name("test_session_table")
.with_secure(true),
)
.await
}
pub async fn ensure_initialized(&self) -> Result<(), AppError> {
Self::build_indexes(&self).await?;
Self::setup_auth(&self).await?;
Analytics::ensure_initialized(self).await?;
SystemSettings::ensure_initialized(self).await?;
Ok(())
}
pub async fn setup_auth(&self) -> Result<(), Error> {
self.client.query(
"DEFINE TABLE user SCHEMALESS;
DEFINE INDEX unique_name ON TABLE user FIELDS email UNIQUE;
DEFINE ACCESS account ON DATABASE TYPE RECORD
SIGNUP ( CREATE user SET email = $email, password = crypto::argon2::generate($password), anonymous = false, user_id = $user_id)
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(password, $password) );",
)
.await?;
Ok(())
}
pub async fn build_indexes(&self) -> Result<(), Error> {
self.client.query("DEFINE INDEX idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536").await?;
self.client.query("DEFINE INDEX idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536").await?;
self.client
.query("DEFINE INDEX idx_job_status ON job FIELDS status")
.await?;
self.client
.query("DEFINE INDEX idx_job_user ON job FIELDS user_id")
.await?;
self.client
.query("DEFINE INDEX idx_job_created ON job FIELDS created_at")
.await?;
Ok(())
}
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS idx_embeddings_entities ON knowledge_entity")
.await?;
Ok(())
}
pub async fn drop_table<T>(&self) -> Result<Vec<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
{
self.client.delete(T::table_name()).await
}
}
impl Deref for SurrealDbClient {
type Target = Surreal<Any>;
fn deref(&self) -> &Self::Target {
&self.client
}
}
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - A initialized database client
/// * `item` - The item to be stored
///
/// # Returns
/// * `Result` - Item or Error
pub async fn store_item<T>(db_client: &Surreal<Any>, item: T) -> Result<Option<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
{
db_client
.create((T::table_name(), item.get_id()))
.content(item)
.await
}
/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - A initialized database client
///
/// # Returns
/// * `Result` - Vec<T> or Error
pub async fn get_all_stored_items<T>(db_client: &Surreal<Any>) -> Result<Vec<T>, Error>
where
T: for<'de> StoredObject,
{
db_client.select(T::table_name()).await
}
/// Operation to retrieve a single object by its ID, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - An initialized database client
/// * `id` - The ID of the item to retrieve
///
/// # Returns
/// * `Result<Option<T>, Error>` - The found item or Error
pub async fn get_item<T>(db_client: &Surreal<Any>, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
{
db_client.select((T::table_name(), id)).await
}
/// Operation to delete a single object by its ID, requires the struct to implement StoredObject
///
/// # Arguments
/// * `db_client` - An initialized database client
/// * `id` - The ID of the item to delete
///
/// # Returns
/// * `Result<Option<T>, Error>` - The deleted item or Error
pub async fn delete_item<T>(db_client: &Surreal<Any>, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
{
db_client.delete((T::table_name(), id)).await
}

View File

@@ -0,0 +1,2 @@
pub mod db;
pub mod types;

View File

@@ -0,0 +1,78 @@
use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject};
use serde::{Deserialize, Serialize};
use crate::{error::AppError, storage::db::SurrealDbClient};
#[derive(Debug, Serialize, Deserialize)]
pub struct Analytics {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub page_loads: i64,
pub visitors: i64,
}
impl Analytics {
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
let analytics = db.select(("analytics", "current")).await?;
if analytics.is_none() {
let created: Option<Analytics> = db
.create(("analytics", "current"))
.content(Analytics {
id: "current".to_string(),
visitors: 0,
page_loads: 0,
})
.await?;
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
};
analytics.ok_or(AppError::Validation("Failed to initialize settings".into()))
}
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
let analytics: Option<Self> = db
.client
.query("SELECT * FROM type::thing('analytics', 'current')")
.await?
.take(0)?;
analytics.ok_or(AppError::NotFound("Analytics not found".into()))
}
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('analytics', 'current') SET visitors += 1 RETURN AFTER")
.await?
.take(0)?;
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
}
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('analytics', 'current') SET page_loads += 1 RETURN AFTER")
.await?
.take(0)?;
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
}
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
#[derive(Debug, Deserialize)]
struct CountResult {
count: i64,
}
let result: Option<CountResult> = db
.client
.query("SELECT count() as count FROM type::table($table) GROUP ALL")
.bind(("table", User::table_name()))
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
}
}

View File

@@ -0,0 +1,52 @@
use uuid::Uuid;
use crate::{
error::AppError,
storage::db::{get_item, SurrealDbClient},
stored_object,
};
use super::message::Message;
stored_object!(Conversation, "conversation", {
user_id: String,
title: String
});
impl Conversation {
pub fn new(user_id: String, title: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
user_id,
title,
}
}
pub async fn get_complete_conversation(
conversation_id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<(Self, Vec<Message>), AppError> {
let conversation: Conversation = get_item(&db, conversation_id)
.await?
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
if conversation.user_id != user_id {
return Err(AppError::Auth(
"You don't have access to this conversation".to_string(),
));
}
let messages:Vec<Message> = db.client.
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
bind(("table_name", Message::table_name())).
bind(("conversation_id", conversation_id.to_string()))
.await?
.take(0)?;
Ok((conversation, messages))
}
}

View File

@@ -0,0 +1,248 @@
use axum_typed_multipart::FieldData;
use mime_guess::from_path;
use sha2::{Digest, Sha256};
use std::{
io::{BufReader, Read},
path::{Path, PathBuf},
};
use tempfile::NamedTempFile;
use thiserror::Error;
use tokio::fs::remove_dir_all;
use tracing::info;
use uuid::Uuid;
use crate::{
storage::db::{delete_item, get_item, store_item, SurrealDbClient},
stored_object,
};
#[derive(Error, Debug)]
pub enum FileError {
#[error("File not found for UUID: {0}")]
FileNotFound(String),
#[error("IO error occurred: {0}")]
Io(#[from] std::io::Error),
#[error("Duplicate file detected with SHA256: {0}")]
DuplicateFile(String),
#[error("SurrealDB error: {0}")]
SurrealError(#[from] surrealdb::Error),
#[error("Failed to persist file: {0}")]
PersistError(#[from] tempfile::PersistError),
#[error("File name missing in metadata")]
MissingFileName,
}
stored_object!(FileInfo, "file", {
sha256: String,
path: String,
file_name: String,
mime_type: String
});
impl FileInfo {
pub async fn new(
field_data: FieldData<NamedTempFile>,
db_client: &SurrealDbClient,
user_id: &str,
) -> Result<Self, FileError> {
let file = field_data.contents;
let file_name = field_data
.metadata
.file_name
.ok_or(FileError::MissingFileName)?;
// Calculate SHA256
let sha256 = Self::get_sha(&file).await?;
// Early return if file already exists
match Self::get_by_sha(&sha256, db_client).await {
Ok(existing_file) => {
info!("File already exists with SHA256: {}", sha256);
return Ok(existing_file);
}
Err(FileError::FileNotFound(_)) => (), // Expected case for new files
Err(e) => return Err(e), // Propagate unexpected errors
}
// Generate UUID and prepare paths
let uuid = Uuid::new_v4();
let sanitized_file_name = Self::sanitize_file_name(&file_name);
let now = Utc::now();
// Create new FileInfo instance
let file_info = Self {
id: uuid.to_string(),
created_at: now,
updated_at: now,
file_name,
sha256,
path: Self::persist_file(&uuid, file, &sanitized_file_name, user_id)
.await?
.to_string_lossy()
.into(),
mime_type: Self::guess_mime_type(Path::new(&sanitized_file_name)),
};
// Store in database
store_item(&db_client.client, file_info.clone()).await?;
Ok(file_info)
}
/// Guesses the MIME type based on the file extension.
///
/// # Arguments
/// * `path` - The path to the file.
///
/// # Returns
/// * `String` - The guessed MIME type as a string.
fn guess_mime_type(path: &Path) -> String {
from_path(path)
.first_or(mime::APPLICATION_OCTET_STREAM)
.to_string()
}
/// Calculates the SHA256 hash of the given file.
///
/// # Arguments
/// * `file` - The file to hash.
///
/// # Returns
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
let mut reader = BufReader::new(file.as_file());
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192]; // 8KB buffer
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
hasher.update(&buffer[..n]);
}
let digest = hasher.finalize();
Ok(format!("{:x}", digest))
}
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
fn sanitize_file_name(file_name: &str) -> String {
file_name
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
/// Persists the file to the filesystem under `./data/{user_id}/{uuid}/{file_name}`.
///
/// # Arguments
/// * `uuid` - The UUID of the file.
/// * `file` - The temporary file to persist.
/// * `file_name` - The sanitized file name.
/// * `user-id` - User id
///
/// # Returns
/// * `Result<PathBuf, FileError>` - The persisted file path or an error.
async fn persist_file(
uuid: &Uuid,
file: NamedTempFile,
file_name: &str,
user_id: &str,
) -> Result<PathBuf, FileError> {
let base_dir = Path::new("./data");
let user_dir = base_dir.join(user_id); // Create the user directory
let uuid_dir = user_dir.join(uuid.to_string()); // Create the UUID directory under the user directory
// Create the user and UUID directories if they don't exist
tokio::fs::create_dir_all(&uuid_dir)
.await
.map_err(FileError::Io)?;
// Define the final file path
let final_path = uuid_dir.join(file_name);
info!("Final path: {:?}", final_path);
// Persist the temporary file to the final path
file.persist(&final_path)?;
info!("Persisted file to {:?}", final_path);
Ok(final_path)
}
/// Retrieves a `FileInfo` by SHA256.
///
/// # Arguments
/// * `sha256` - The SHA256 hash string.
/// * `db_client` - Reference to the SurrealDbClient.
///
/// # Returns
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256);
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
response
.into_iter()
.next()
.ok_or(FileError::FileNotFound(sha256.to_string()))
}
/// Removes FileInfo from database and file from disk
///
/// # Arguments
/// * `id` - Id of the FileInfo
/// * `db_client` - Reference to SurrealDbClient
///
/// # Returns
/// `Result<(), FileError>`
pub async fn delete_by_id(id: &str, db_client: &SurrealDbClient) -> Result<(), FileError> {
// Get the FileInfo from the database
let file_info = match get_item::<FileInfo>(db_client, id).await? {
Some(info) => info,
None => {
return Err(FileError::FileNotFound(format!(
"File with id {} was not found",
id
)))
}
};
// Remove the file and its parent directory
let file_path = Path::new(&file_info.path);
if file_path.exists() {
// Get the parent directory of the file
if let Some(parent_dir) = file_path.parent() {
// Remove the entire directory containing the file
remove_dir_all(parent_dir).await?;
info!("Removed directory {:?} and its contents", parent_dir);
} else {
return Err(FileError::FileNotFound(
"File has no parent directory".to_string(),
));
}
} else {
return Err(FileError::FileNotFound(format!(
"File at path {:?} was not found",
file_path
)));
}
// Delete the FileInfo from the database
delete_item::<FileInfo>(db_client, id).await?;
Ok(())
}
}

View File

@@ -0,0 +1,36 @@
use uuid::Uuid;
use crate::{ingress::ingress_object::IngressObject, stored_object};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum JobStatus {
Created,
InProgress {
attempts: u32,
last_attempt: DateTime<Utc>,
},
Completed,
Error(String),
Cancelled,
}
stored_object!(Job, "job", {
content: IngressObject,
status: JobStatus,
user_id: String
});
impl Job {
pub async fn new(content: IngressObject, user_id: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
content,
status: JobStatus::Created,
created_at: now,
updated_at: now,
user_id,
}
}
}

View File

@@ -0,0 +1,121 @@
use crate::{
error::AppError, storage::db::SurrealDbClient, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum KnowledgeEntityType {
Idea,
Project,
Document,
Page,
TextSnippet,
// Add more types as needed
}
impl KnowledgeEntityType {
pub fn variants() -> &'static [&'static str] {
&["Idea", "Project", "Document", "Page", "TextSnippet"]
}
}
impl From<String> for KnowledgeEntityType {
fn from(s: String) -> Self {
match s.to_lowercase().as_str() {
"idea" => KnowledgeEntityType::Idea,
"project" => KnowledgeEntityType::Project,
"document" => KnowledgeEntityType::Document,
"page" => KnowledgeEntityType::Page,
"textsnippet" => KnowledgeEntityType::TextSnippet,
_ => KnowledgeEntityType::Document, // Default case
}
}
}
stored_object!(KnowledgeEntity, "knowledge_entity", {
source_id: String,
name: String,
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String
});
impl KnowledgeEntity {
pub fn new(
source_id: String,
name: String,
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id,
name,
description,
entity_type,
metadata,
embedding,
user_id,
}
}
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
Ok(())
}
pub async fn patch(
id: &str,
name: &str,
description: &str,
entity_type: &KnowledgeEntityType,
db_client: &SurrealDbClient,
ai_client: &Client<OpenAIConfig>,
) -> Result<(), AppError> {
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
name, description, entity_type
);
let embedding = generate_embedding(ai_client, &embedding_input).await?;
db_client
.client
.query(
"UPDATE type::thing($table, $id)
SET name = $name,
description = $description,
updated_at = $updated_at,
entity_type = $entity_type,
embedding = $embedding
RETURN AFTER",
)
.bind(("table", Self::table_name()))
.bind(("id", id.to_string()))
.bind(("name", name.to_string()))
.bind(("updated_at", Utc::now()))
.bind(("entity_type", entity_type.to_owned()))
.bind(("embedding", embedding))
.bind(("description", description.to_string()))
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,86 @@
use crate::storage::types::file_info::deserialize_flexible_id;
use crate::{error::AppError, storage::db::SurrealDbClient};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RelationshipMetadata {
pub user_id: String,
pub source_id: String,
pub relationship_type: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct KnowledgeRelationship {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
#[serde(rename = "in", deserialize_with = "deserialize_flexible_id")]
pub in_: String,
#[serde(deserialize_with = "deserialize_flexible_id")]
pub out: String,
pub metadata: RelationshipMetadata,
}
impl KnowledgeRelationship {
pub fn new(
in_: String,
out: String,
user_id: String,
source_id: String,
relationship_type: String,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
in_,
out,
metadata: RelationshipMetadata {
user_id,
source_id,
relationship_type,
},
}
}
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
let query = format!(
r#"RELATE knowledge_entity:`{}`->relates_to:`{}`->knowledge_entity:`{}`
SET
metadata.user_id = '{}',
metadata.source_id = '{}',
metadata.relationship_type = '{}'"#,
self.in_,
self.id,
self.out,
self.metadata.user_id,
self.metadata.source_id,
self.metadata.relationship_type
);
db_client.query(query).await?;
Ok(())
}
pub async fn delete_relationships_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'",
source_id
);
db_client.query(query).await?;
Ok(())
}
pub async fn delete_relationship_by_id(
id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!("DELETE relates_to:`{}`", id);
db_client.query(query).await?;
Ok(())
}
}

View File

@@ -0,0 +1,37 @@
use uuid::Uuid;
use crate::stored_object;
#[derive(Deserialize, Debug, Clone, Serialize)]
pub enum MessageRole {
User,
AI,
System,
}
stored_object!(Message, "message", {
conversation_id: String,
role: MessageRole,
content: String,
references: Option<Vec<String>>
});
impl Message {
pub fn new(
conversation_id: String,
role: MessageRole,
content: String,
references: Option<Vec<String>>,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
conversation_id,
role,
content,
references,
}
}
}

View File

@@ -0,0 +1,110 @@
use axum::async_trait;
use serde::{Deserialize, Serialize};
pub mod analytics;
pub mod conversation;
pub mod file_info;
pub mod job;
pub mod knowledge_entity;
pub mod knowledge_relationship;
pub mod message;
pub mod system_settings;
pub mod text_chunk;
pub mod text_content;
pub mod user;
#[async_trait]
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
fn table_name() -> &'static str;
fn get_id(&self) -> &str;
}
#[macro_export]
macro_rules! stored_object {
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
use axum::async_trait;
use serde::{Deserialize, Deserializer, Serialize};
use surrealdb::sql::Thing;
use $crate::storage::types::StoredObject;
use serde::de::{self, Visitor};
use std::fmt;
use chrono::{DateTime, Utc };
struct FlexibleIdVisitor;
impl<'de> Visitor<'de> for FlexibleIdVisitor {
type Value = String;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or a Thing")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_string())
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value)
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
// Try to deserialize as Thing
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(thing.id.to_raw())
}
}
pub fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(FlexibleIdVisitor)
}
fn serialize_datetime<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Into::<surrealdb::sql::Datetime>::into(*date).serialize(serializer)
}
fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
where
D: serde::Deserializer<'de>,
{
let dt = surrealdb::sql::Datetime::deserialize(deserializer)?;
Ok(DateTime::<Utc>::from(dt))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct $name {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub updated_at: DateTime<Utc>,
$(pub $field: $ty),*
}
#[async_trait]
impl StoredObject for $name {
fn table_name() -> &'static str {
$table
}
fn get_id(&self) -> &str {
&self.id
}
}
};
}

View File

@@ -0,0 +1,56 @@
use crate::storage::types::file_info::deserialize_flexible_id;
use serde::{Deserialize, Serialize};
use crate::{error::AppError, storage::db::SurrealDbClient};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SystemSettings {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub registrations_enabled: bool,
pub require_email_verification: bool,
}
impl SystemSettings {
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
let settings = db.select(("system_settings", "current")).await?;
if settings.is_none() {
let created: Option<SystemSettings> = db
.create(("system_settings", "current"))
.content(SystemSettings {
id: "current".to_string(),
registrations_enabled: true,
require_email_verification: false,
})
.await?;
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
};
settings.ok_or(AppError::Validation("Failed to initialize settings".into()))
}
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
let settings: Option<Self> = db
.client
.query("SELECT * FROM type::thing('system_settings', 'current')")
.await?
.take(0)?;
settings.ok_or(AppError::NotFound("System settings not found".into()))
}
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER")
.bind(("changes", changes))
.await?
.take(0)?;
updated.ok_or(AppError::Validation(
"Something went wrong updating the settings".into(),
))
}
}

View File

@@ -0,0 +1,38 @@
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
source_id: String,
chunk: String,
embedding: Vec<f32>,
user_id: String
});
impl TextChunk {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id,
chunk,
embedding,
user_id,
}
}
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
Ok(())
}
}

View File

@@ -0,0 +1,38 @@
use uuid::Uuid;
use crate::stored_object;
use super::file_info::FileInfo;
stored_object!(TextContent, "text_content", {
text: String,
file_info: Option<FileInfo>,
url: Option<String>,
instructions: String,
category: String,
user_id: String
});
impl TextContent {
pub fn new(
text: String,
instructions: String,
category: String,
file_info: Option<FileInfo>,
url: Option<String>,
user_id: String,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
text,
file_info,
url,
instructions,
category,
user_id,
}
}
}

View File

@@ -0,0 +1,352 @@
use crate::{
error::AppError,
storage::db::{get_item, SurrealDbClient},
stored_object,
};
use axum_session_auth::Authentication;
use surrealdb::{engine::any::Any, Surreal};
use uuid::Uuid;
use super::{
conversation::Conversation, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
text_content::TextContent,
};
#[derive(Deserialize)]
pub struct CategoryResponse {
category: String,
}
stored_object!(User, "user", {
email: String,
password: String,
anonymous: bool,
api_key: Option<String>,
admin: bool,
#[serde(default)]
timezone: String
});
#[async_trait]
impl Authentication<User, String, Surreal<Any>> for User {
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
let db = db.unwrap();
Ok(get_item::<Self>(db, userid.as_str()).await?.unwrap())
}
fn is_authenticated(&self) -> bool {
!self.anonymous
}
fn is_active(&self) -> bool {
!self.anonymous
}
fn is_anonymous(&self) -> bool {
self.anonymous
}
}
fn validate_timezone(input: &str) -> String {
use chrono_tz::Tz;
// Check if it's a valid IANA timezone identifier
match input.parse::<Tz>() {
Ok(_) => input.to_owned(),
Err(_) => {
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
}
}
}
impl User {
pub async fn create_new(
email: String,
password: String,
db: &SurrealDbClient,
timezone: String,
) -> Result<Self, AppError> {
// verify that the application allows new creations
let systemsettings = SystemSettings::get_current(db).await?;
if !systemsettings.registrations_enabled {
return Err(AppError::Auth("Registration is not allowed".into()));
}
let validated_tz = validate_timezone(&timezone);
let now = Utc::now();
let id = Uuid::new_v4().to_string();
let user: Option<User> = db
.client
.query(
"LET $count = (SELECT count() FROM type::table($table))[0].count;
CREATE type::thing('user', $id) SET
email = $email,
password = crypto::argon2::generate($password),
admin = $count < 1, // Changed from == 0 to < 1
anonymous = false,
created_at = $created_at,
updated_at = $updated_at,
timezone = $timezone",
)
.bind(("table", "user"))
.bind(("id", id))
.bind(("email", email))
.bind(("password", password))
.bind(("created_at", now))
.bind(("updated_at", now))
.bind(("timezone", validated_tz))
.await?
.take(1)?;
user.ok_or(AppError::Auth("User failed to create".into()))
}
pub async fn authenticate(
email: String,
password: String,
db: &SurrealDbClient,
) -> Result<Self, AppError> {
let user: Option<User> = db
.client
.query(
"SELECT * FROM user
WHERE email = $email
AND crypto::argon2::compare(password, $password)",
)
.bind(("email", email))
.bind(("password", password))
.await?
.take(0)?;
user.ok_or(AppError::Auth("User failed to authenticate".into()))
}
pub async fn find_by_email(
email: &str,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let user: Option<User> = db
.client
.query("SELECT * FROM user WHERE email = $email LIMIT 1")
.bind(("email", email.to_string()))
.await?
.take(0)?;
Ok(user)
}
pub async fn find_by_api_key(
api_key: &str,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let user: Option<User> = db
.client
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
.bind(("api_key", api_key.to_string()))
.await?
.take(0)?;
Ok(user)
}
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
// Generate a secure random API key
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
// Update the user record with the new API key
let user: Option<User> = db
.client
.query(
"UPDATE type::thing('user', $id)
SET api_key = $api_key
RETURN AFTER",
)
.bind(("id", id.to_owned()))
.bind(("api_key", api_key.clone()))
.await?
.take(0)?;
// If the user was found and updated, return the API key
if user.is_some() {
Ok(api_key)
} else {
Err(AppError::Auth("User not found".into()))
}
}
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
let user: Option<User> = db
.client
.query(
"UPDATE type::thing('user', $id)
SET api_key = NULL
RETURN AFTER",
)
.bind(("id", id.to_owned()))
.await?
.take(0)?;
if user.is_some() {
Ok(())
} else {
Err(AppError::Auth("User was not found".into()))
}
}
pub async fn get_knowledge_entities(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let entities: Vec<KnowledgeEntity> = db
.client
.query("SELECT * FROM type::table($table) WHERE user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name()))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)?;
Ok(entities)
}
pub async fn get_knowledge_relationships(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeRelationship>, AppError> {
let relationships: Vec<KnowledgeRelationship> = db
.client
.query("SELECT * FROM type::table($table) WHERE metadata.user_id = $user_id")
.bind(("table", "relates_to"))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)?;
Ok(relationships)
}
pub async fn get_latest_text_contents(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<TextContent>, AppError> {
let items: Vec<TextContent> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC LIMIT 5")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn get_text_contents(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<TextContent>, AppError> {
let items: Vec<TextContent> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn get_latest_knowledge_entities(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let items: Vec<KnowledgeEntity> = db
.client
.query(
"SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC LIMIT 5",
)
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", KnowledgeEntity::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn update_timezone(
user_id: &str,
timezone: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
db.query("UPDATE type::thing('user', $user_id) SET timezone = $timezone")
.bind(("table_name", User::table_name()))
.bind(("user_id", user_id.to_string()))
.bind(("timezone", timezone.to_string()))
.await?;
Ok(())
}
pub async fn get_user_categories(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<String>, AppError> {
// Query to select distinct categories for the user
let response: Vec<CategoryResponse> = db
.client
.query("SELECT category FROM type::table($table_name) WHERE user_id = $user_id GROUP BY category")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
// Extract the categories from the response
let categories: Vec<String> = response.into_iter().map(|item| item.category).collect();
Ok(categories)
}
pub async fn get_and_validate_knowledge_entity(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<KnowledgeEntity, AppError> {
let entity: KnowledgeEntity = get_item(db, &id)
.await?
.ok_or_else(|| AppError::NotFound("Entity not found".into()))?;
if entity.user_id != user_id {
return Err(AppError::Auth("Access denied".into()));
}
Ok(entity)
}
pub async fn get_and_validate_text_content(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<TextContent, AppError> {
let text_content: TextContent = get_item(db, &id)
.await?
.ok_or_else(|| AppError::NotFound("Content not found".into()))?;
if text_content.user_id != user_id {
return Err(AppError::Auth("Access denied".into()));
}
Ok(text_content)
}
pub async fn get_user_conversations(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<Conversation>, AppError> {
let conversations: Vec<Conversation> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id")
.bind(("table_name", Conversation::table_name()))
.bind(("user_id", user_id.to_string()))
.await?
.take(0)?;
Ok(conversations)
}
}

View File

@@ -0,0 +1,30 @@
use config::{Config, ConfigError, File};
#[derive(Clone, Debug)]
pub struct AppConfig {
pub smtp_username: String,
pub smtp_password: String,
pub smtp_relayer: String,
pub surrealdb_address: String,
pub surrealdb_username: String,
pub surrealdb_password: String,
pub surrealdb_namespace: String,
pub surrealdb_database: String,
}
pub fn get_config() -> Result<AppConfig, ConfigError> {
let config = Config::builder()
.add_source(File::with_name("config"))
.build()?;
Ok(AppConfig {
smtp_username: config.get_string("SMTP_USERNAME")?,
smtp_password: config.get_string("SMTP_PASSWORD")?,
smtp_relayer: config.get_string("SMTP_RELAYER")?,
surrealdb_address: config.get_string("SURREALDB_ADDRESS")?,
surrealdb_username: config.get_string("SURREALDB_USERNAME")?,
surrealdb_password: config.get_string("SURREALDB_PASSWORD")?,
surrealdb_namespace: config.get_string("SURREALDB_NAMESPACE")?,
surrealdb_database: config.get_string("SURREALDB_DATABASE")?,
})
}

View File

@@ -0,0 +1,48 @@
use async_openai::types::CreateEmbeddingRequestArgs;
use crate::error::AppError;
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity
/// comparisons, vector search, and other natural language processing tasks.
///
/// # Arguments
///
/// * `client`: The OpenAI client instance used to make API requests.
/// * `input`: The text string to generate embeddings for.
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
/// * `Err(ProcessingError)`: An error if the embedding generation fails
///
/// # Errors
///
/// This function can return a `AppError` in the following cases:
/// * If the OpenAI API request fails
/// * If the request building fails
/// * If no embedding data is received in the response
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
) -> Result<Vec<f32>, AppError> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input([input])
.build()?;
// Send the request to OpenAI
let response = client.embeddings().create(request).await?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
.embedding
.clone();
Ok(embedding)
}

View File

@@ -0,0 +1,81 @@
use lettre::address::AddressError;
use lettre::message::MultiPart;
use lettre::{transport::smtp::authentication::Credentials, SmtpTransport};
use lettre::{Message, Transport};
use minijinja::context;
use minijinja_autoreload::AutoReloader;
use thiserror::Error;
use tracing::info;
pub struct Mailer {
pub mailer: SmtpTransport,
}
#[derive(Error, Debug)]
pub enum EmailError {
#[error("Email construction error: {0}")]
EmailParsingError(#[from] AddressError),
#[error("Email sending error: {0}")]
SendingError(#[from] lettre::transport::smtp::Error),
#[error("Body constructing error: {0}")]
BodyError(#[from] lettre::error::Error),
#[error("Templating error: {0}")]
TemplatingError(#[from] minijinja::Error),
}
impl Mailer {
pub fn new(
username: &str,
relayer: &str,
password: &str,
) -> Result<Self, lettre::transport::smtp::Error> {
let creds = Credentials::new(username.to_owned(), password.to_owned());
let mailer = SmtpTransport::relay(&relayer)?.credentials(creds).build();
Ok(Mailer { mailer })
}
pub fn send_email_verification(
&self,
email_to: &str,
verification_code: &str,
templates: &AutoReloader,
) -> Result<(), EmailError> {
let name = email_to
.split('@')
.next()
.unwrap_or("User")
.chars()
.enumerate()
.map(|(i, c)| if i == 0 { c.to_ascii_uppercase() } else { c })
.collect::<String>();
let context = context! {
name => name,
verification_code => verification_code
};
let env = templates.acquire_env()?;
let html = env
.get_template("email/email_verification.html")?
.render(&context)?;
let plain = env
.get_template("email/email_verification.txt")?
.render(&context)?;
let email = Message::builder()
.from("Admin <minne@starks.cloud>".parse()?)
.reply_to("Admin <minne@starks.cloud>".parse()?)
.to(format!("{} <{}>", name, email_to).parse()?)
.subject("Verify Your Email Address")
.multipart(MultiPart::alternative_plain_html(plain, html))?;
info!("Sending email to: {}", email_to);
self.mailer.send(&email)?;
Ok(())
}
}

View File

@@ -0,0 +1,3 @@
pub mod config;
pub mod embedding;
pub mod mailer;