refactor: better separation of dependencies to crates

node stuff to html crate only
This commit is contained in:
Per Stark
2025-04-04 12:50:38 +02:00
parent 20fc43638b
commit 5bc48fb30b
160 changed files with 231 additions and 337 deletions

36
common/src/error.rs Normal file
View File

@@ -0,0 +1,36 @@
use async_openai::error::OpenAIError;
use thiserror::Error;
use tokio::task::JoinError;
use crate::storage::types::file_info::FileError;
// Core internal errors
#[derive(Error, Debug)]
pub enum AppError {
#[error("Database error: {0}")]
Database(#[from] surrealdb::Error),
#[error("OpenAI error: {0}")]
OpenAI(#[from] OpenAIError),
#[error("File error: {0}")]
File(#[from] FileError),
#[error("Not found: {0}")]
NotFound(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("Authorization error: {0}")]
Auth(String),
#[error("LLM parsing error: {0}")]
LLMParsing(String),
#[error("Task join error: {0}")]
Join(#[from] JoinError),
#[error("Graph mapper error: {0}")]
GraphMapper(String),
#[error("IoError: {0}")]
Io(#[from] std::io::Error),
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Tiktoken error: {0}")]
Tiktoken(#[from] anyhow::Error),
#[error("Ingress Processing error: {0}")]
Processing(String),
}

3
common/src/lib.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod error;
pub mod storage;
pub mod utils;

191
common/src/storage/db.rs Normal file
View File

@@ -0,0 +1,191 @@
use crate::error::AppError;
use super::types::{analytics::Analytics, system_settings::SystemSettings, StoredObject};
use axum_session::{SessionConfig, SessionError, SessionStore};
use axum_session_surreal::SessionSurrealPool;
use futures::Stream;
use std::{ops::Deref, sync::Arc};
use surrealdb::{
engine::any::{connect, Any},
opt::auth::Root,
Error, Notification, Surreal,
};
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Any>,
}
pub trait ProvidesDb {
fn db(&self) -> &Arc<SurrealDbClient>;
}
impl SurrealDbClient {
/// # Initialize a new datbase client
///
/// # Arguments
///
/// # Returns
/// * `SurrealDbClient` initialized
pub async fn new(
address: &str,
username: &str,
password: &str,
namespace: &str,
database: &str,
) -> Result<Self, Error> {
let db = connect(address).await?;
// Sign in to database
db.signin(Root { username, password }).await?;
// Set namespace
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
SessionStore::new(
Some(self.client.clone().into()),
SessionConfig::default()
.with_table_name("test_session_table")
.with_secure(true),
)
.await
}
pub async fn ensure_initialized(&self) -> Result<(), AppError> {
Self::build_indexes(self).await?;
Self::setup_auth(self).await?;
Analytics::ensure_initialized(self).await?;
SystemSettings::ensure_initialized(self).await?;
Ok(())
}
pub async fn setup_auth(&self) -> Result<(), Error> {
self.client.query(
"DEFINE TABLE user SCHEMALESS;
DEFINE INDEX unique_name ON TABLE user FIELDS email UNIQUE;
DEFINE ACCESS account ON DATABASE TYPE RECORD
SIGNUP ( CREATE user SET email = $email, password = crypto::argon2::generate($password), anonymous = false, user_id = $user_id)
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(password, $password) );",
)
.await?;
Ok(())
}
pub async fn build_indexes(&self) -> Result<(), Error> {
self.client.query("DEFINE INDEX idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536").await?;
self.client.query("DEFINE INDEX idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536").await?;
self.client
.query("DEFINE INDEX idx_job_status ON job FIELDS status")
.await?;
self.client
.query("DEFINE INDEX idx_job_user ON job FIELDS user_id")
.await?;
self.client
.query("DEFINE INDEX idx_job_created ON job FIELDS created_at")
.await?;
Ok(())
}
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS idx_embeddings_entities ON knowledge_entity")
.await?;
Ok(())
}
pub async fn drop_table<T>(&self) -> Result<Vec<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
{
self.client.delete(T::table_name()).await
}
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
///
/// # Arguments
/// * `item` - The item to be stored
///
/// # Returns
/// * `Result` - Item or Error
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
{
self.client
.create((T::table_name(), item.get_id()))
.content(item)
.await
}
/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject
///
/// # Returns
/// * `Result` - Vec<T> or Error
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
where
T: for<'de> StoredObject,
{
self.client.select(T::table_name()).await
}
/// Operation to retrieve a single object by its ID, requires the struct to implement StoredObject
///
/// # Arguments
/// * `id` - The ID of the item to retrieve
///
/// # Returns
/// * `Result<Option<T>, Error>` - The found item or Error
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
{
self.client.select((T::table_name(), id)).await
}
/// Operation to delete a single object by its ID, requires the struct to implement StoredObject
///
/// # Arguments
/// * `id` - The ID of the item to delete
///
/// # Returns
/// * `Result<Option<T>, Error>` - The deleted item or Error
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
{
self.client.delete((T::table_name(), id)).await
}
/// Operation to listen to a table for updates, requires the struct to implement StoredObject
///
/// # Returns
/// * `Result<Option<T>, Error>` - The deleted item or Error
pub async fn listen<T>(
&self,
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
where
T: for<'de> StoredObject + std::marker::Unpin,
{
self.client.select(T::table_name()).live().await
}
}
impl Deref for SurrealDbClient {
type Target = Surreal<Any>;
fn deref(&self) -> &Self::Target {
&self.client
}
}

View File

@@ -0,0 +1,2 @@
pub mod db;
pub mod types;

View File

@@ -0,0 +1,78 @@
use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject};
use serde::{Deserialize, Serialize};
use crate::{error::AppError, storage::db::SurrealDbClient};
#[derive(Debug, Serialize, Deserialize)]
pub struct Analytics {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub page_loads: i64,
pub visitors: i64,
}
impl Analytics {
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
let analytics = db.select(("analytics", "current")).await?;
if analytics.is_none() {
let created: Option<Analytics> = db
.create(("analytics", "current"))
.content(Analytics {
id: "current".to_string(),
visitors: 0,
page_loads: 0,
})
.await?;
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
};
analytics.ok_or(AppError::Validation("Failed to initialize settings".into()))
}
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
let analytics: Option<Self> = db
.client
.query("SELECT * FROM type::thing('analytics', 'current')")
.await?
.take(0)?;
analytics.ok_or(AppError::NotFound("Analytics not found".into()))
}
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('analytics', 'current') SET visitors += 1 RETURN AFTER")
.await?
.take(0)?;
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
}
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('analytics', 'current') SET page_loads += 1 RETURN AFTER")
.await?
.take(0)?;
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
}
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
#[derive(Debug, Deserialize)]
struct CountResult {
count: i64,
}
let result: Option<CountResult> = db
.client
.query("SELECT count() as count FROM type::table($table) GROUP ALL")
.bind(("table", User::table_name()))
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
}
}

View File

@@ -0,0 +1,49 @@
use uuid::Uuid;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use super::message::Message;
stored_object!(Conversation, "conversation", {
user_id: String,
title: String
});
impl Conversation {
pub fn new(user_id: String, title: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
user_id,
title,
}
}
pub async fn get_complete_conversation(
conversation_id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<(Self, Vec<Message>), AppError> {
let conversation: Conversation = db
.get_item(conversation_id)
.await?
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
if conversation.user_id != user_id {
return Err(AppError::Auth(
"You don't have access to this conversation".to_string(),
));
}
let messages:Vec<Message> = db.client.
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
bind(("table_name", Message::table_name())).
bind(("conversation_id", conversation_id.to_string()))
.await?
.take(0)?;
Ok((conversation, messages))
}
}

View File

@@ -0,0 +1,245 @@
use axum_typed_multipart::FieldData;
use mime_guess::from_path;
use sha2::{Digest, Sha256};
use std::{
io::{BufReader, Read},
path::{Path, PathBuf},
};
use tempfile::NamedTempFile;
use thiserror::Error;
use tokio::fs::remove_dir_all;
use tracing::info;
use uuid::Uuid;
use crate::{storage::db::SurrealDbClient, stored_object};
#[derive(Error, Debug)]
pub enum FileError {
#[error("File not found for UUID: {0}")]
FileNotFound(String),
#[error("IO error occurred: {0}")]
Io(#[from] std::io::Error),
#[error("Duplicate file detected with SHA256: {0}")]
DuplicateFile(String),
#[error("SurrealDB error: {0}")]
SurrealError(#[from] surrealdb::Error),
#[error("Failed to persist file: {0}")]
PersistError(#[from] tempfile::PersistError),
#[error("File name missing in metadata")]
MissingFileName,
}
stored_object!(FileInfo, "file", {
sha256: String,
path: String,
file_name: String,
mime_type: String
});
impl FileInfo {
pub async fn new(
field_data: FieldData<NamedTempFile>,
db_client: &SurrealDbClient,
user_id: &str,
) -> Result<Self, FileError> {
let file = field_data.contents;
let file_name = field_data
.metadata
.file_name
.ok_or(FileError::MissingFileName)?;
// Calculate SHA256
let sha256 = Self::get_sha(&file).await?;
// Early return if file already exists
match Self::get_by_sha(&sha256, db_client).await {
Ok(existing_file) => {
info!("File already exists with SHA256: {}", sha256);
return Ok(existing_file);
}
Err(FileError::FileNotFound(_)) => (), // Expected case for new files
Err(e) => return Err(e), // Propagate unexpected errors
}
// Generate UUID and prepare paths
let uuid = Uuid::new_v4();
let sanitized_file_name = Self::sanitize_file_name(&file_name);
let now = Utc::now();
// Create new FileInfo instance
let file_info = Self {
id: uuid.to_string(),
created_at: now,
updated_at: now,
file_name,
sha256,
path: Self::persist_file(&uuid, file, &sanitized_file_name, user_id)
.await?
.to_string_lossy()
.into(),
mime_type: Self::guess_mime_type(Path::new(&sanitized_file_name)),
};
// Store in database
db_client.store_item(file_info.clone()).await?;
Ok(file_info)
}
/// Guesses the MIME type based on the file extension.
///
/// # Arguments
/// * `path` - The path to the file.
///
/// # Returns
/// * `String` - The guessed MIME type as a string.
fn guess_mime_type(path: &Path) -> String {
from_path(path)
.first_or(mime::APPLICATION_OCTET_STREAM)
.to_string()
}
/// Calculates the SHA256 hash of the given file.
///
/// # Arguments
/// * `file` - The file to hash.
///
/// # Returns
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
let mut reader = BufReader::new(file.as_file());
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192]; // 8KB buffer
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
hasher.update(&buffer[..n]);
}
let digest = hasher.finalize();
Ok(format!("{:x}", digest))
}
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
fn sanitize_file_name(file_name: &str) -> String {
file_name
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '.' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
/// Persists the file to the filesystem under `./data/{user_id}/{uuid}/{file_name}`.
///
/// # Arguments
/// * `uuid` - The UUID of the file.
/// * `file` - The temporary file to persist.
/// * `file_name` - The sanitized file name.
/// * `user-id` - User id
///
/// # Returns
/// * `Result<PathBuf, FileError>` - The persisted file path or an error.
async fn persist_file(
uuid: &Uuid,
file: NamedTempFile,
file_name: &str,
user_id: &str,
) -> Result<PathBuf, FileError> {
let base_dir = Path::new("./data");
let user_dir = base_dir.join(user_id); // Create the user directory
let uuid_dir = user_dir.join(uuid.to_string()); // Create the UUID directory under the user directory
// Create the user and UUID directories if they don't exist
tokio::fs::create_dir_all(&uuid_dir)
.await
.map_err(FileError::Io)?;
// Define the final file path
let final_path = uuid_dir.join(file_name);
info!("Final path: {:?}", final_path);
// Persist the temporary file to the final path
file.persist(&final_path)?;
info!("Persisted file to {:?}", final_path);
Ok(final_path)
}
/// Retrieves a `FileInfo` by SHA256.
///
/// # Arguments
/// * `sha256` - The SHA256 hash string.
/// * `db_client` - Reference to the SurrealDbClient.
///
/// # Returns
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256);
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
response
.into_iter()
.next()
.ok_or(FileError::FileNotFound(sha256.to_string()))
}
/// Removes FileInfo from database and file from disk
///
/// # Arguments
/// * `id` - Id of the FileInfo
/// * `db_client` - Reference to SurrealDbClient
///
/// # Returns
/// `Result<(), FileError>`
pub async fn delete_by_id(id: &str, db_client: &SurrealDbClient) -> Result<(), FileError> {
// Get the FileInfo from the database
let file_info = match db_client.get_item::<FileInfo>(id).await? {
Some(info) => info,
None => {
return Err(FileError::FileNotFound(format!(
"File with id {} was not found",
id
)))
}
};
// Remove the file and its parent directory
let file_path = Path::new(&file_info.path);
if file_path.exists() {
// Get the parent directory of the file
if let Some(parent_dir) = file_path.parent() {
// Remove the entire directory containing the file
remove_dir_all(parent_dir).await?;
info!("Removed directory {:?} and its contents", parent_dir);
} else {
return Err(FileError::FileNotFound(
"File has no parent directory".to_string(),
));
}
} else {
return Err(FileError::FileNotFound(format!(
"File at path {:?} was not found",
file_path
)));
}
// Delete the FileInfo from the database
db_client.delete_item::<FileInfo>(id).await?;
Ok(())
}
}

View File

@@ -0,0 +1,95 @@
use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use tracing::info;
use url::Url;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngestionPayload {
Url {
url: String,
instructions: String,
category: String,
user_id: String,
},
Text {
text: String,
instructions: String,
category: String,
user_id: String,
},
File {
file_info: FileInfo,
instructions: String,
category: String,
user_id: String,
},
}
impl IngestionPayload {
/// Creates ingestion payloads from the provided content, instructions, and files.
///
/// # Arguments
/// * `content` - Optional textual content to be ingressed
/// * `instructions` - Instructions for processing the ingress content
/// * `category` - Category to classify the ingressed content
/// * `files` - Vector of `FileInfo` objects containing information about uploaded files
/// * `user_id` - Identifier of the user performing the ingress operation
///
/// # Returns
/// * `Result<Vec<IngestionPayload>, AppError>` - On success, returns a vector of ingress objects
/// (one per file/content type). On failure, returns an `AppError`.
pub fn create_ingestion_payload(
content: Option<String>,
instructions: String,
category: String,
files: Vec<FileInfo>,
user_id: &str,
) -> Result<Vec<IngestionPayload>, AppError> {
// Initialize list
let mut object_list = Vec::new();
// Create a IngestionPayload from content if it exists, checking for URL or text
if let Some(input_content) = content {
match Url::parse(&input_content) {
Ok(url) => {
info!("Detected URL: {}", url);
object_list.push(IngestionPayload::Url {
url: url.to_string(),
instructions: instructions.clone(),
category: category.clone(),
user_id: user_id.into(),
});
}
Err(_) => {
if input_content.len() > 2 {
info!("Treating input as plain text");
object_list.push(IngestionPayload::Text {
text: input_content.to_string(),
instructions: instructions.clone(),
category: category.clone(),
user_id: user_id.into(),
});
}
}
}
}
for file in files {
object_list.push(IngestionPayload::File {
file_info: file,
instructions: instructions.clone(),
category: category.clone(),
user_id: user_id.into(),
})
}
// If no objects are constructed, we return Err
if object_list.is_empty() {
return Err(AppError::NotFound(
"No valid content or files provided".into(),
));
}
Ok(object_list)
}
}

View File

@@ -0,0 +1,102 @@
use futures::Stream;
use surrealdb::{opt::PatchOp, Notification};
use uuid::Uuid;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use super::ingestion_payload::IngestionPayload;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum IngestionTaskStatus {
Created,
InProgress {
attempts: u32,
last_attempt: DateTime<Utc>,
},
Completed,
Error(String),
Cancelled,
}
stored_object!(IngestionTask, "job", {
content: IngestionPayload,
status: IngestionTaskStatus,
user_id: String
});
pub const MAX_ATTEMPTS: u32 = 3;
impl IngestionTask {
pub async fn new(content: IngestionPayload, user_id: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
content,
status: IngestionTaskStatus::Created,
created_at: now,
updated_at: now,
user_id,
}
}
/// Creates a new job and stores it in the database
pub async fn create_and_add_to_db(
content: IngestionPayload,
user_id: String,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let job = Self::new(content, user_id).await;
db.store_item(job).await?;
Ok(())
}
// Update job status
pub async fn update_status(
id: &str,
status: IngestionTaskStatus,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let _job: Option<Self> = db
.update((Self::table_name(), id))
.patch(PatchOp::replace("/status", status))
.patch(PatchOp::replace(
"/updated_at",
surrealdb::sql::Datetime::default(),
))
.await?;
Ok(())
}
/// Listen for new jobs
pub async fn listen_for_tasks(
db: &SurrealDbClient,
) -> Result<impl Stream<Item = Result<Notification<Self>, surrealdb::Error>>, surrealdb::Error>
{
db.listen::<Self>().await
}
/// Get all unfinished tasks, ie newly created and in progress up two times
pub async fn get_unfinished_tasks(db: &SurrealDbClient) -> Result<Vec<Self>, AppError> {
let jobs: Vec<Self> = db
.query(
"SELECT * FROM type::table($table)
WHERE
status = 'Created'
OR (
status.InProgress != NONE
AND status.InProgress.attempts < $max_attempts
)
ORDER BY created_at ASC",
)
.bind(("table", Self::table_name()))
.bind(("max_attempts", MAX_ATTEMPTS))
.await?
.take(0)?;
Ok(jobs)
}
}

View File

@@ -0,0 +1,121 @@
use crate::{
error::AppError, storage::db::SurrealDbClient, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum KnowledgeEntityType {
Idea,
Project,
Document,
Page,
TextSnippet,
// Add more types as needed
}
impl KnowledgeEntityType {
pub fn variants() -> &'static [&'static str] {
&["Idea", "Project", "Document", "Page", "TextSnippet"]
}
}
impl From<String> for KnowledgeEntityType {
fn from(s: String) -> Self {
match s.to_lowercase().as_str() {
"idea" => KnowledgeEntityType::Idea,
"project" => KnowledgeEntityType::Project,
"document" => KnowledgeEntityType::Document,
"page" => KnowledgeEntityType::Page,
"textsnippet" => KnowledgeEntityType::TextSnippet,
_ => KnowledgeEntityType::Document, // Default case
}
}
}
stored_object!(KnowledgeEntity, "knowledge_entity", {
source_id: String,
name: String,
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String
});
impl KnowledgeEntity {
pub fn new(
source_id: String,
name: String,
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id,
name,
description,
entity_type,
metadata,
embedding,
user_id,
}
}
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
Ok(())
}
pub async fn patch(
id: &str,
name: &str,
description: &str,
entity_type: &KnowledgeEntityType,
db_client: &SurrealDbClient,
ai_client: &Client<OpenAIConfig>,
) -> Result<(), AppError> {
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
name, description, entity_type
);
let embedding = generate_embedding(ai_client, &embedding_input).await?;
db_client
.client
.query(
"UPDATE type::thing($table, $id)
SET name = $name,
description = $description,
updated_at = $updated_at,
entity_type = $entity_type,
embedding = $embedding
RETURN AFTER",
)
.bind(("table", Self::table_name()))
.bind(("id", id.to_string()))
.bind(("name", name.to_string()))
.bind(("updated_at", Utc::now()))
.bind(("entity_type", entity_type.to_owned()))
.bind(("embedding", embedding))
.bind(("description", description.to_string()))
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,86 @@
use crate::storage::types::file_info::deserialize_flexible_id;
use crate::{error::AppError, storage::db::SurrealDbClient};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RelationshipMetadata {
pub user_id: String,
pub source_id: String,
pub relationship_type: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct KnowledgeRelationship {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
#[serde(rename = "in", deserialize_with = "deserialize_flexible_id")]
pub in_: String,
#[serde(deserialize_with = "deserialize_flexible_id")]
pub out: String,
pub metadata: RelationshipMetadata,
}
impl KnowledgeRelationship {
pub fn new(
in_: String,
out: String,
user_id: String,
source_id: String,
relationship_type: String,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
in_,
out,
metadata: RelationshipMetadata {
user_id,
source_id,
relationship_type,
},
}
}
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
let query = format!(
r#"RELATE knowledge_entity:`{}`->relates_to:`{}`->knowledge_entity:`{}`
SET
metadata.user_id = '{}',
metadata.source_id = '{}',
metadata.relationship_type = '{}'"#,
self.in_,
self.id,
self.out,
self.metadata.user_id,
self.metadata.source_id,
self.metadata.relationship_type
);
db_client.query(query).await?;
Ok(())
}
pub async fn delete_relationships_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'",
source_id
);
db_client.query(query).await?;
Ok(())
}
pub async fn delete_relationship_by_id(
id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!("DELETE relates_to:`{}`", id);
db_client.query(query).await?;
Ok(())
}
}

View File

@@ -0,0 +1,62 @@
use uuid::Uuid;
use crate::stored_object;
#[derive(Deserialize, Debug, Clone, Serialize)]
pub enum MessageRole {
User,
AI,
System,
}
stored_object!(Message, "message", {
conversation_id: String,
role: MessageRole,
content: String,
references: Option<Vec<String>>
});
impl Message {
pub fn new(
conversation_id: String,
role: MessageRole,
content: String,
references: Option<Vec<String>>,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
conversation_id,
role,
content,
references,
}
}
}
impl fmt::Display for MessageRole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MessageRole::User => write!(f, "User"),
MessageRole::AI => write!(f, "AI"),
MessageRole::System => write!(f, "System"),
}
}
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.role, self.content)
}
}
// helper function to format a vector of messages
pub fn format_history(history: &[Message]) -> String {
history
.iter()
.map(|msg| format!("{}", msg))
.collect::<Vec<String>>()
.join("\n")
}

View File

@@ -0,0 +1,112 @@
use axum::async_trait;
use serde::{Deserialize, Serialize};
pub mod analytics;
pub mod conversation;
pub mod file_info;
pub mod ingestion_payload;
pub mod ingestion_task;
pub mod knowledge_entity;
pub mod knowledge_relationship;
pub mod message;
pub mod system_prompts;
pub mod system_settings;
pub mod text_chunk;
pub mod text_content;
pub mod user;
#[async_trait]
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
fn table_name() -> &'static str;
fn get_id(&self) -> &str;
}
#[macro_export]
macro_rules! stored_object {
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
use axum::async_trait;
use serde::{Deserialize, Deserializer, Serialize};
use surrealdb::sql::Thing;
use $crate::storage::types::StoredObject;
use serde::de::{self, Visitor};
use std::fmt;
use chrono::{DateTime, Utc };
struct FlexibleIdVisitor;
impl<'de> Visitor<'de> for FlexibleIdVisitor {
type Value = String;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or a Thing")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value.to_string())
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value)
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
// Try to deserialize as Thing
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(thing.id.to_raw())
}
}
pub fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(FlexibleIdVisitor)
}
fn serialize_datetime<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Into::<surrealdb::sql::Datetime>::into(*date).serialize(serializer)
}
fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
where
D: serde::Deserializer<'de>,
{
let dt = surrealdb::sql::Datetime::deserialize(deserializer)?;
Ok(DateTime::<Utc>::from(dt))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct $name {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub updated_at: DateTime<Utc>,
$(pub $field: $ty),*
}
#[async_trait]
impl StoredObject for $name {
fn table_name() -> &'static str {
$table
}
fn get_id(&self) -> &str {
&self.id
}
}
};
}

View File

@@ -0,0 +1,56 @@
pub static DEFAULT_QUERY_SYSTEM_PROMPT: &str = r#"You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
Your task is to:
1. Carefully analyze the provided knowledge entities in the context
2. Answer user questions based on this information
3. Provide clear, concise, and accurate responses
4. When referencing information, briefly mention which knowledge entity it came from
5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this
6. If only partial information is available, explain what you can answer and what information is missing
7. Avoid making assumptions or providing information not supported by the context
8. Output the references to the documents. Use the UUIDs and make sure they are correct!
Remember:
- Be direct and honest about the limitations of your knowledge
- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array
- If you need to combine information from multiple entities, explain how they connect
- Don't speculate beyond what's provided in the context
Example response formats:
"Based on [Entity Name], [answer...]"
"I found relevant information in multiple entries: [explanation...]"
"I apologize, but the provided context doesn't contain information about [topic]""#;
pub static DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT: &str = r#"You are an AI assistant. You will receive a text content, along with user instructions and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users instructions and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
The JSON should have the following structure:
{
"knowledge_entities": [
{
"key": "unique-key-1",
"name": "Entity Name",
"description": "A detailed description of the entity.",
"entity_type": "TypeOfEntity"
},
// More entities...
],
"relationships": [
{
"type": "RelationshipType",
"source": "unique-key-1 or UUID from existing database",
"target": "unique-key-1 or UUID from existing database"
},
// More relationships...
]
}
Guidelines:
1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
7. Only create relationships between existing KnowledgeEntities.
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
9. A new relationship MUST include a newly created KnowledgeEntity."#;

View File

@@ -0,0 +1,77 @@
use crate::storage::types::file_info::deserialize_flexible_id;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{error::AppError, storage::db::SurrealDbClient};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SystemSettings {
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub registrations_enabled: bool,
pub require_email_verification: bool,
pub query_model: String,
pub processing_model: String,
pub query_system_prompt: String,
pub ingestion_system_prompt: String,
}
impl SystemSettings {
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
let settings = db.select(("system_settings", "current")).await?;
if settings.is_none() {
let created: Option<SystemSettings> = db
.create(("system_settings", "current"))
.content(SystemSettings {
id: "current".to_string(),
registrations_enabled: true,
require_email_verification: false,
query_model: "gpt-4o-mini".to_string(),
processing_model: "gpt-4o-mini".to_string(),
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
ingestion_system_prompt: crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(),
})
.await?;
return created.ok_or(AppError::Validation("Failed to initialize settings".into()));
};
settings.ok_or(AppError::Validation("Failed to initialize settings".into()))
}
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
let settings: Option<Self> = db
.client
.query("SELECT * FROM type::thing('system_settings', 'current')")
.await?
.take(0)?;
settings.ok_or(AppError::NotFound("System settings not found".into()))
}
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
let updated: Option<Self> = db
.client
.query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER")
.bind(("changes", changes))
.await?
.take(0)?;
updated.ok_or(AppError::Validation(
"Something went wrong updating the settings".into(),
))
}
pub fn new() -> Self {
Self {
id: Uuid::new_v4().to_string(),
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
ingestion_system_prompt: crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT.to_string(),
query_model: "gpt-4o-mini".to_string(),
processing_model: "gpt-4o-mini".to_string(),
registrations_enabled: true,
require_email_verification: false,
}
}
}

View File

@@ -0,0 +1,38 @@
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
source_id: String,
chunk: String,
embedding: Vec<f32>,
user_id: String
});
impl TextChunk {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id,
chunk,
embedding,
user_id,
}
}
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
Ok(())
}
}

View File

@@ -0,0 +1,59 @@
use surrealdb::opt::PatchOp;
use uuid::Uuid;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use super::file_info::FileInfo;
stored_object!(TextContent, "text_content", {
text: String,
file_info: Option<FileInfo>,
url: Option<String>,
instructions: String,
category: String,
user_id: String
});
impl TextContent {
pub fn new(
text: String,
instructions: String,
category: String,
file_info: Option<FileInfo>,
url: Option<String>,
user_id: String,
) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
text,
file_info,
url,
instructions,
category,
user_id,
}
}
pub async fn patch(
id: &str,
instructions: &str,
category: &str,
text: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let now = Utc::now();
let _res: Option<Self> = db
.update((Self::table_name(), id))
.patch(PatchOp::replace("/instructions", instructions))
.patch(PatchOp::replace("/category", category))
.patch(PatchOp::replace("/text", text))
.patch(PatchOp::replace("/updated_at", now))
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,416 @@
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use axum_session_auth::Authentication;
use surrealdb::{engine::any::Any, Surreal};
use uuid::Uuid;
use super::{
conversation::Conversation, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
text_content::TextContent,
};
#[derive(Deserialize)]
pub struct CategoryResponse {
category: String,
}
stored_object!(User, "user", {
email: String,
password: String,
anonymous: bool,
api_key: Option<String>,
admin: bool,
#[serde(default)]
timezone: String
});
#[async_trait]
impl Authentication<User, String, Surreal<Any>> for User {
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
let db = db.unwrap();
Ok(db
.select((Self::table_name(), userid.as_str()))
.await?
.unwrap())
}
fn is_authenticated(&self) -> bool {
!self.anonymous
}
fn is_active(&self) -> bool {
!self.anonymous
}
fn is_anonymous(&self) -> bool {
self.anonymous
}
}
fn validate_timezone(input: &str) -> String {
use chrono_tz::Tz;
// Check if it's a valid IANA timezone identifier
match input.parse::<Tz>() {
Ok(_) => input.to_owned(),
Err(_) => {
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
}
}
}
impl User {
pub async fn create_new(
email: String,
password: String,
db: &SurrealDbClient,
timezone: String,
) -> Result<Self, AppError> {
// verify that the application allows new creations
let systemsettings = SystemSettings::get_current(db).await?;
if !systemsettings.registrations_enabled {
return Err(AppError::Auth("Registration is not allowed".into()));
}
let validated_tz = validate_timezone(&timezone);
let now = Utc::now();
let id = Uuid::new_v4().to_string();
let user: Option<User> = db
.client
.query(
"LET $count = (SELECT count() FROM type::table($table))[0].count;
CREATE type::thing('user', $id) SET
email = $email,
password = crypto::argon2::generate($password),
admin = $count < 1,
anonymous = false,
created_at = $created_at,
updated_at = $updated_at,
timezone = $timezone",
)
.bind(("table", "user"))
.bind(("id", id))
.bind(("email", email))
.bind(("password", password))
.bind(("created_at", now))
.bind(("updated_at", now))
.bind(("timezone", validated_tz))
.await?
.take(1)?;
user.ok_or(AppError::Auth("User failed to create".into()))
}
pub async fn patch_password(
email: &str,
password: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
db.client
.query(
"UPDATE user
SET password = crypto::argon2::generate($password)
WHERE email = $email",
)
.bind(("email", email.to_owned()))
.bind(("password", password.to_owned()))
.await?;
Ok(())
}
pub async fn authenticate(
email: &str,
password: &str,
db: &SurrealDbClient,
) -> Result<Self, AppError> {
let user: Option<User> = db
.client
.query(
"SELECT * FROM user
WHERE email = $email
AND crypto::argon2::compare(password, $password)",
)
.bind(("email", email.to_owned()))
.bind(("password", password.to_owned()))
.await?
.take(0)?;
user.ok_or(AppError::Auth("User failed to authenticate".into()))
}
pub async fn find_by_email(
email: &str,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let user: Option<User> = db
.client
.query("SELECT * FROM user WHERE email = $email LIMIT 1")
.bind(("email", email.to_string()))
.await?
.take(0)?;
Ok(user)
}
pub async fn find_by_api_key(
api_key: &str,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let user: Option<User> = db
.client
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
.bind(("api_key", api_key.to_string()))
.await?
.take(0)?;
Ok(user)
}
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
// Generate a secure random API key
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
// Update the user record with the new API key
let user: Option<User> = db
.client
.query(
"UPDATE type::thing('user', $id)
SET api_key = $api_key
RETURN AFTER",
)
.bind(("id", id.to_owned()))
.bind(("api_key", api_key.clone()))
.await?
.take(0)?;
// If the user was found and updated, return the API key
if user.is_some() {
Ok(api_key)
} else {
Err(AppError::Auth("User not found".into()))
}
}
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
let user: Option<User> = db
.client
.query(
"UPDATE type::thing('user', $id)
SET api_key = NULL
RETURN AFTER",
)
.bind(("id", id.to_owned()))
.await?
.take(0)?;
if user.is_some() {
Ok(())
} else {
Err(AppError::Auth("User was not found".into()))
}
}
pub async fn get_knowledge_entities(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let entities: Vec<KnowledgeEntity> = db
.client
.query("SELECT * FROM type::table($table) WHERE user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name()))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)?;
Ok(entities)
}
pub async fn get_knowledge_relationships(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeRelationship>, AppError> {
let relationships: Vec<KnowledgeRelationship> = db
.client
.query("SELECT * FROM type::table($table) WHERE metadata.user_id = $user_id")
.bind(("table", "relates_to"))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)?;
Ok(relationships)
}
pub async fn get_latest_text_contents(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<TextContent>, AppError> {
let items: Vec<TextContent> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC LIMIT 5")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn get_text_contents(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<TextContent>, AppError> {
let items: Vec<TextContent> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn get_latest_knowledge_entities(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let items: Vec<KnowledgeEntity> = db
.client
.query(
"SELECT * FROM type::table($table_name) WHERE user_id = $user_id ORDER BY created_at DESC LIMIT 5",
)
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", KnowledgeEntity::table_name()))
.await?
.take(0)?;
Ok(items)
}
pub async fn update_timezone(
user_id: &str,
timezone: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
db.query("UPDATE type::thing('user', $user_id) SET timezone = $timezone")
.bind(("table_name", User::table_name()))
.bind(("user_id", user_id.to_string()))
.bind(("timezone", timezone.to_string()))
.await?;
Ok(())
}
pub async fn get_user_categories(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<String>, AppError> {
// Query to select distinct categories for the user
let response: Vec<CategoryResponse> = db
.client
.query("SELECT category FROM type::table($table_name) WHERE user_id = $user_id GROUP BY category")
.bind(("user_id", user_id.to_owned()))
.bind(("table_name", TextContent::table_name()))
.await?
.take(0)?;
// Extract the categories from the response
let categories: Vec<String> = response.into_iter().map(|item| item.category).collect();
Ok(categories)
}
pub async fn get_and_validate_knowledge_entity(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<KnowledgeEntity, AppError> {
let entity: KnowledgeEntity = db
.get_item(id)
.await?
.ok_or_else(|| AppError::NotFound("Entity not found".into()))?;
if entity.user_id != user_id {
return Err(AppError::Auth("Access denied".into()));
}
Ok(entity)
}
pub async fn get_and_validate_text_content(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<TextContent, AppError> {
let text_content: TextContent = db
.get_item(id)
.await?
.ok_or_else(|| AppError::NotFound("Content not found".into()))?;
if text_content.user_id != user_id {
return Err(AppError::Auth("Access denied".into()));
}
Ok(text_content)
}
pub async fn get_user_conversations(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<Conversation>, AppError> {
let conversations: Vec<Conversation> = db
.client
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id")
.bind(("table_name", Conversation::table_name()))
.bind(("user_id", user_id.to_string()))
.await?
.take(0)?;
Ok(conversations)
}
/// Gets all active ingestion tasks for the specified user
pub async fn get_unfinished_ingestion_tasks(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<IngestionTask>, AppError> {
let jobs: Vec<IngestionTask> = db
.query(
"SELECT * FROM type::table($table)
WHERE user_id = $user_id
AND (
status = 'Created'
OR (
status.InProgress != NONE
AND status.InProgress.attempts < $max_attempts
)
)
ORDER BY created_at DESC",
)
.bind(("table", IngestionTask::table_name()))
.bind(("user_id", user_id.to_owned()))
.bind(("max_attempts", 3))
.await?
.take(0)?;
Ok(jobs)
}
/// Validate and delete job
pub async fn validate_and_delete_job(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
db.get_item::<IngestionTask>(id)
.await?
.filter(|job| job.user_id == user_id)
.ok_or_else(|| AppError::Auth("Not authorized to delete this job".into()))?;
db.delete_item::<IngestionTask>(id)
.await
.map_err(AppError::Database)?;
Ok(())
}
}

View File

@@ -0,0 +1,24 @@
use config::{Config, ConfigError, File};
#[derive(Clone, Debug)]
pub struct AppConfig {
pub surrealdb_address: String,
pub surrealdb_username: String,
pub surrealdb_password: String,
pub surrealdb_namespace: String,
pub surrealdb_database: String,
}
pub fn get_config() -> Result<AppConfig, ConfigError> {
let config = Config::builder()
.add_source(File::with_name("config"))
.build()?;
Ok(AppConfig {
surrealdb_address: config.get_string("SURREALDB_ADDRESS")?,
surrealdb_username: config.get_string("SURREALDB_USERNAME")?,
surrealdb_password: config.get_string("SURREALDB_PASSWORD")?,
surrealdb_namespace: config.get_string("SURREALDB_NAMESPACE")?,
surrealdb_database: config.get_string("SURREALDB_DATABASE")?,
})
}

View File

@@ -0,0 +1,48 @@
use async_openai::types::CreateEmbeddingRequestArgs;
use crate::error::AppError;
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity
/// comparisons, vector search, and other natural language processing tasks.
///
/// # Arguments
///
/// * `client`: The OpenAI client instance used to make API requests.
/// * `input`: The text string to generate embeddings for.
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
/// * `Err(ProcessingError)`: An error if the embedding generation fails
///
/// # Errors
///
/// This function can return a `AppError` in the following cases:
/// * If the OpenAI API request fails
/// * If the request building fails
/// * If no embedding data is received in the response
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,
) -> Result<Vec<f32>, AppError> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input([input])
.build()?;
// Send the request to OpenAI
let response = client.embeddings().create(request).await?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
.embedding
.clone();
Ok(embedding)
}

3
common/src/utils/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod config;
pub mod embedding;
pub mod template_engine;

View File

@@ -0,0 +1,96 @@
pub use minijinja::{path_loader, Environment, Value};
pub use minijinja_autoreload::AutoReloader;
pub use minijinja_contrib;
pub use minijinja_embed;
use std::sync::Arc;
pub trait ProvidesTemplateEngine {
fn template_engine(&self) -> &Arc<TemplateEngine>;
}
#[derive(Clone)]
pub enum TemplateEngine {
// Use AutoReload for debug builds (debug_assertions is true)
#[cfg(debug_assertions)]
AutoReload(Arc<AutoReloader>),
// Use Embedded for release builds (debug_assertions is false)
#[cfg(not(debug_assertions))]
Embedded(Arc<Environment<'static>>),
}
#[macro_export]
macro_rules! create_template_engine {
// Macro takes the relative path to the templates dir as input
($relative_path:expr) => {{
// Code for debug builds (AutoReload)
#[cfg(debug_assertions)]
{
// These lines execute in the CALLING crate's context
let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let template_path = crate_dir.join($relative_path);
let reloader = $crate::utils::template_engine::AutoReloader::new(move |notifier| {
let mut env = $crate::utils::template_engine::Environment::new();
env.set_loader($crate::utils::template_engine::path_loader(&template_path));
notifier.set_fast_reload(true);
notifier.watch_path(&template_path, true);
// Add contrib filters/functions
$crate::utils::template_engine::minijinja_contrib::add_to_environment(&mut env);
Ok(env)
});
$crate::utils::template_engine::TemplateEngine::AutoReload(std::sync::Arc::new(
reloader,
))
}
// Code for release builds (Embedded)
#[cfg(not(debug_assertions))]
{
// These lines also execute in the CALLING crate's context
let mut env = $crate::utils::template_engine::Environment::new();
$crate::utils::template_engine::minijinja_embed::load_templates!(&mut env);
// Add contrib filters/functions
$crate::utils::template_engine::minijinja_contrib::add_to_environment(&mut env);
$crate::utils::template_engine::TemplateEngine::Embedded(std::sync::Arc::new(env))
}
}};
}
impl TemplateEngine {
pub fn render(&self, name: &str, ctx: &Value) -> Result<String, minijinja::Error> {
match self {
// Only compile this arm for debug builds
#[cfg(debug_assertions)]
TemplateEngine::AutoReload(reloader) => {
let env = reloader.acquire_env()?;
env.get_template(name)?.render(ctx)
}
// Only compile this arm for release builds
#[cfg(not(debug_assertions))]
TemplateEngine::Embedded(env) => env.get_template(name)?.render(ctx),
}
}
pub fn render_block(
&self,
template_name: &str,
block_name: &str,
context: &Value,
) -> Result<String, minijinja::Error> {
match self {
// Only compile this arm for debug builds
#[cfg(debug_assertions)]
TemplateEngine::AutoReload(reloader) => {
let env = reloader.acquire_env()?;
let template = env.get_template(template_name)?;
let mut state = template.eval_to_state(context)?;
state.render_block(block_name)
}
// Only compile this arm for release builds
#[cfg(not(debug_assertions))]
TemplateEngine::Embedded(env) => {
let template = env.get_template(template_name)?;
let mut state = template.eval_to_state(context)?;
state.render_block(block_name)
}
}
}
}