user restricted to own objects

This commit is contained in:
Per Stark
2024-12-15 22:52:34 +01:00
parent 646792291c
commit cf6078eceb
18 changed files with 109 additions and 28 deletions
+2
View File
@@ -69,6 +69,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
) )
.await?; .await?;
app_state.surreal_db_client.build_indexes().await?;
// Create Axum router // Create Axum router
let app = Router::new() let app = Router::new()
.nest("/api/v1", api_routes_v1(&app_state)) .nest("/api/v1", api_routes_v1(&app_state))
+3
View File
@@ -66,6 +66,8 @@ pub enum ApiError {
UserAlreadyExists, UserAlreadyExists,
#[error("User was not found")] #[error("User was not found")]
UserNotFound, UserNotFound,
#[error("You must provide valid credentials")]
AuthRequired,
} }
impl IntoResponse for ApiError { impl IntoResponse for ApiError {
@@ -78,6 +80,7 @@ impl IntoResponse for ApiError {
ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()), ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()), ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::AuthRequired => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()), ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::IngressContentError(_) => { ApiError::IngressContentError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()) (StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
+10 -2
View File
@@ -37,9 +37,10 @@ impl<'a> IngressAnalyzer<'a> {
category: &str, category: &str,
instructions: &str, instructions: &str,
text: &str, text: &str,
user_id: &str,
) -> Result<LLMGraphAnalysisResult, ProcessingError> { ) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let similar_entities = self let similar_entities = self
.find_similar_entities(category, instructions, text) .find_similar_entities(category, instructions, text, user_id)
.await?; .await?;
let llm_request = let llm_request =
self.prepare_llm_request(category, instructions, text, &similar_entities)?; self.prepare_llm_request(category, instructions, text, &similar_entities)?;
@@ -51,13 +52,20 @@ impl<'a> IngressAnalyzer<'a> {
category: &str, category: &str,
instructions: &str, instructions: &str,
text: &str, text: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> { ) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
let input_text = format!( let input_text = format!(
"content: {}, category: {}, user_instructions: {}", "content: {}, category: {}, user_instructions: {}",
text, category, instructions text, category, instructions
); );
combined_knowledge_entity_retrieval(self.db_client, self.openai_client, &input_text).await combined_knowledge_entity_retrieval(
self.db_client,
self.openai_client,
&input_text,
user_id,
)
.await
} }
fn prepare_llm_request( fn prepare_llm_request(
@@ -53,6 +53,7 @@ impl LLMGraphAnalysisResult {
pub async fn to_database_entities( pub async fn to_database_entities(
&self, &self,
source_id: &str, source_id: &str,
user_id: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> { ) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
// Create mapper and pre-assign IDs // Create mapper and pre-assign IDs
@@ -60,7 +61,7 @@ impl LLMGraphAnalysisResult {
// Process entities // Process entities
let entities = self let entities = self
.process_entities(source_id, Arc::clone(&mapper), openai_client) .process_entities(source_id, user_id, Arc::clone(&mapper), openai_client)
.await?; .await?;
// Process relationships // Process relationships
@@ -83,6 +84,7 @@ impl LLMGraphAnalysisResult {
async fn process_entities( async fn process_entities(
&self, &self,
source_id: &str, source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>, mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> { ) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
@@ -93,10 +95,12 @@ impl LLMGraphAnalysisResult {
let mapper = Arc::clone(&mapper); let mapper = Arc::clone(&mapper);
let openai_client = openai_client.clone(); let openai_client = openai_client.clone();
let source_id = source_id.to_string(); let source_id = source_id.to_string();
let user_id = user_id.to_string();
let entity = entity.clone(); let entity = entity.clone();
task::spawn(async move { task::spawn(async move {
create_single_entity(&entity, &source_id, mapper, &openai_client).await create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client)
.await
}) })
}) })
.collect(); .collect();
@@ -135,6 +139,7 @@ impl LLMGraphAnalysisResult {
async fn create_single_entity( async fn create_single_entity(
llm_entity: &LLMKnowledgeEntity, llm_entity: &LLMKnowledgeEntity,
source_id: &str, source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>, mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<KnowledgeEntity, ProcessingError> { ) -> Result<KnowledgeEntity, ProcessingError> {
@@ -168,5 +173,6 @@ async fn create_single_entity(
source_id: source_id.to_string(), source_id: source_id.to_string(),
metadata: None, metadata: None,
embedding, embedding,
user_id: user_id.into(),
}) })
} }
+13 -3
View File
@@ -49,7 +49,7 @@ impl ContentProcessor {
// Convert analysis to objects // Convert analysis to objects
let (entities, relationships) = analysis let (entities, relationships) = analysis
.to_database_entities(&content.id, &self.openai_client) .to_database_entities(&content.id, &content.user_id, &self.openai_client)
.await?; .await?;
// Store everything // Store everything
@@ -68,7 +68,12 @@ impl ContentProcessor {
) -> Result<LLMGraphAnalysisResult, ProcessingError> { ) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client); let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
analyser analyser
.analyze_content(&content.category, &content.instructions, &content.text) .analyze_content(
&content.category,
&content.instructions,
&content.text,
&content.user_id,
)
.await .await
} }
@@ -102,7 +107,12 @@ impl ContentProcessor {
// Could potentially process chunks in parallel with a bounded concurrent limit // Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks { for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk).await?; let embedding = generate_embedding(&self.openai_client, chunk).await?;
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding); let text_chunk = TextChunk::new(
content.id.to_string(),
chunk.to_string(),
embedding,
content.user_id.to_string(),
);
store_item(&self.db_client, text_chunk).await?; store_item(&self.db_client, text_chunk).await?;
} }
+4
View File
@@ -56,6 +56,7 @@ pub enum IngressContentError {
pub async fn create_ingress_objects( pub async fn create_ingress_objects(
input: IngressInput, input: IngressInput,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<IngressObject>, IngressContentError> { ) -> Result<Vec<IngressObject>, IngressContentError> {
// Initialize list // Initialize list
let mut object_list = Vec::new(); let mut object_list = Vec::new();
@@ -69,6 +70,7 @@ pub async fn create_ingress_objects(
url: url.to_string(), url: url.to_string(),
instructions: input.instructions.clone(), instructions: input.instructions.clone(),
category: input.category.clone(), category: input.category.clone(),
user_id: user_id.into(),
}); });
} }
Err(_) => { Err(_) => {
@@ -77,6 +79,7 @@ pub async fn create_ingress_objects(
text: input_content.to_string(), text: input_content.to_string(),
instructions: input.instructions.clone(), instructions: input.instructions.clone(),
category: input.category.clone(), category: input.category.clone(),
user_id: user_id.into(),
}); });
} }
} }
@@ -90,6 +93,7 @@ pub async fn create_ingress_objects(
file_info, file_info,
instructions: input.instructions.clone(), instructions: input.instructions.clone(),
category: input.category.clone(), category: input.category.clone(),
user_id: user_id.into(),
}); });
} else { } else {
info!("No file with id: {}", id); info!("No file with id: {}", id);
+9
View File
@@ -10,16 +10,19 @@ pub enum IngressObject {
url: String, url: String,
instructions: String, instructions: String,
category: String, category: String,
user_id: String,
}, },
Text { Text {
text: String, text: String,
instructions: String, instructions: String,
category: String, category: String,
user_id: String,
}, },
File { File {
file_info: FileInfo, file_info: FileInfo,
instructions: String, instructions: String,
category: String, category: String,
user_id: String,
}, },
} }
@@ -37,6 +40,7 @@ impl IngressObject {
url, url,
instructions, instructions,
category, category,
user_id,
} => { } => {
let text = Self::fetch_text_from_url(url).await?; let text = Self::fetch_text_from_url(url).await?;
Ok(TextContent::new( Ok(TextContent::new(
@@ -44,22 +48,26 @@ impl IngressObject {
instructions.into(), instructions.into(),
category.into(), category.into(),
None, None,
user_id.into(),
)) ))
} }
IngressObject::Text { IngressObject::Text {
text, text,
instructions, instructions,
category, category,
user_id,
} => Ok(TextContent::new( } => Ok(TextContent::new(
text.into(), text.into(),
instructions.into(), instructions.into(),
category.into(), category.into(),
None, None,
user_id.into(),
)), )),
IngressObject::File { IngressObject::File {
file_info, file_info,
instructions, instructions,
category, category,
user_id,
} => { } => {
let text = Self::extract_text_from_file(file_info).await?; let text = Self::extract_text_from_file(file_info).await?;
Ok(TextContent::new( Ok(TextContent::new(
@@ -67,6 +75,7 @@ impl IngressObject {
instructions.into(), instructions.into(),
category.into(), category.into(),
Some(file_info.to_owned()), Some(file_info.to_owned()),
user_id.into(),
)) ))
} }
} }
+4
View File
@@ -29,6 +29,7 @@ use surrealdb::{engine::any::Any, Surreal};
/// * `db_client` - SurrealDB client for database operations /// * `db_client` - SurrealDB client for database operations
/// * `openai_client` - OpenAI client for vector embeddings generation /// * `openai_client` - OpenAI client for vector embeddings generation
/// * `query` - The search query string to find relevant knowledge entities /// * `query` - The search query string to find relevant knowledge entities
/// * 'user_id' - The user id of the current user
/// ///
/// # Returns /// # Returns
/// * `Result<Vec<KnowledgeEntity>, ProcessingError>` - A deduplicated vector of relevant /// * `Result<Vec<KnowledgeEntity>, ProcessingError>` - A deduplicated vector of relevant
@@ -37,6 +38,7 @@ pub async fn combined_knowledge_entity_retrieval(
db_client: &Surreal<Any>, db_client: &Surreal<Any>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str, query: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> { ) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
// info!("Received input: {:?}", query); // info!("Received input: {:?}", query);
@@ -47,6 +49,7 @@ pub async fn combined_knowledge_entity_retrieval(
db_client, db_client,
"knowledge_entity".to_string(), "knowledge_entity".to_string(),
openai_client, openai_client,
user_id,
), ),
find_items_by_vector_similarity( find_items_by_vector_similarity(
5, 5,
@@ -54,6 +57,7 @@ pub async fn combined_knowledge_entity_retrieval(
db_client, db_client,
"text_chunk".to_string(), "text_chunk".to_string(),
openai_client, openai_client,
user_id,
), ),
) )
.await?; .await?;
+9 -7
View File
@@ -9,11 +9,12 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `take`: The number of items to retrieve from the database. /// * `take` - The number of items to retrieve from the database.
/// * `input_text`: The text to generate embeddings for. /// * `input_text` - The text to generate embeddings for.
/// * `db_client`: The SurrealDB client to use for querying the database. /// * `db_client` - The SurrealDB client to use for querying the database.
/// * `table`: The table to query in the database. /// * `table` - The table to query in the database.
/// * `openai_client`: The OpenAI client to use for generating embeddings. /// * `openai_client` - The OpenAI client to use for generating embeddings.
/// * 'user_id`- The user id of the current user.
/// ///
/// # Returns /// # Returns
/// ///
@@ -21,13 +22,14 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
/// ///
/// # Type Parameters /// # Type Parameters
/// ///
/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`. /// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>( pub async fn find_items_by_vector_similarity<T>(
take: u8, take: u8,
input_text: &str, input_text: &str,
db_client: &Surreal<Any>, db_client: &Surreal<Any>,
table: String, table: String,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<T>, ProcessingError> ) -> Result<Vec<T>, ProcessingError>
where where
T: for<'de> serde::Deserialize<'de>, T: for<'de> serde::Deserialize<'de>,
@@ -36,7 +38,7 @@ where
let input_embedding = generate_embedding(openai_client, input_text).await?; let input_embedding = generate_embedding(openai_client, input_text).await?;
// Construct the query // Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance", table, take, input_embedding); let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
// Perform query and deserialize to struct // Perform query and deserialize to struct
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?; let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
+1 -1
View File
@@ -13,7 +13,7 @@ pub async fn api_auth(
mut request: Request, mut request: Request,
next: Next, next: Next,
) -> Result<Response, ApiError> { ) -> Result<Response, ApiError> {
let api_key = extract_api_key(&request).ok_or(ApiError::UserNotFound)?; let api_key = extract_api_key(&request).ok_or(ApiError::AuthRequired)?;
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?; let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
let user = user.ok_or(ApiError::UserNotFound)?; let user = user.ok_or(ApiError::UserNotFound)?;
+4 -2
View File
@@ -2,18 +2,20 @@ use crate::{
error::ApiError, error::ApiError,
ingress::types::ingress_input::{create_ingress_objects, IngressInput}, ingress::types::ingress_input::{create_ingress_objects, IngressInput},
server::AppState, server::AppState,
storage::types::user::User,
}; };
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
use futures::future::try_join_all; use futures::future::try_join_all;
use tracing::info; use tracing::info;
pub async fn ingress_handler( pub async fn ingress_handler(
State(state): State<AppState>, State(state): State<AppState>,
Extension(user): Extension<User>,
Json(input): Json<IngressInput>, Json(input): Json<IngressInput>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", input); info!("Received input: {:?}", input);
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client).await?; let ingress_objects = create_ingress_objects(input, &state.surreal_db_client, &user.id).await?;
let futures: Vec<_> = ingress_objects let futures: Vec<_> = ingress_objects
.into_iter() .into_iter()
+7 -3
View File
@@ -33,9 +33,13 @@ pub async fn query_handler(
info!("Received input: {:?}", query); info!("Received input: {:?}", query);
info!("{:?}", user); info!("{:?}", user);
let answer = let answer = get_answer_with_references(
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query) &state.surreal_db_client,
.await?; &state.openai_client,
&query.query,
&user.id,
)
.await?;
Ok( Ok(
Json(serde_json::json!({"answer": answer.content, "references": answer.references})) Json(serde_json::json!({"answer": answer.content, "references": answer.references}))
+6 -1
View File
@@ -4,6 +4,7 @@ use async_openai::types::{
ResponseFormat, ResponseFormatJsonSchema, ResponseFormat, ResponseFormatJsonSchema,
}; };
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::debug;
use crate::{ use crate::{
error::ApiError, error::ApiError,
@@ -66,6 +67,7 @@ use super::{
/// * `surreal_db_client` - Client for SurrealDB interactions /// * `surreal_db_client` - Client for SurrealDB interactions
/// * `openai_client` - Client for OpenAI API calls /// * `openai_client` - Client for OpenAI API calls
/// * `query` - The user's query string /// * `query` - The user's query string
/// * `user_id` - The user's id
/// ///
/// # Returns /// # Returns
/// ///
@@ -80,11 +82,14 @@ pub async fn get_answer_with_references(
surreal_db_client: &SurrealDbClient, surreal_db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>, openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str, query: &str,
user_id: &str,
) -> Result<Answer, ApiError> { ) -> Result<Answer, ApiError> {
let entities = let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query).await?; combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
.await?;
let entities_json = format_entities_json(&entities); let entities_json = format_entities_json(&entities);
debug!("{:?}", entities_json);
let user_message = create_user_message(&entities_json, query); let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message)?; let request = create_chat_request(user_message)?;
+14 -3
View File
@@ -2,14 +2,18 @@ use axum::{
extract::{Query, State}, extract::{Query, State},
response::Html, response::Html,
}; };
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::Deserialize; use serde::Deserialize;
use serde_json::json; use serde_json::json;
use surrealdb::{engine::any::Any, Surreal};
use tera::Context; use tera::Context;
use tracing::info; use tracing::info;
use crate::{ use crate::{
error::ApiError, error::ApiError,
server::{routes::query::helper::get_answer_with_references, AppState}, server::{routes::query::helper::get_answer_with_references, AppState},
storage::types::user::User,
}; };
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct SearchParams { pub struct SearchParams {
@@ -19,12 +23,19 @@ pub struct SearchParams {
pub async fn search_result_handler( pub async fn search_result_handler(
State(state): State<AppState>, State(state): State<AppState>,
Query(query): Query<SearchParams>, Query(query): Query<SearchParams>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<Html<String>, ApiError> { ) -> Result<Html<String>, ApiError> {
info!("Displaying search results"); info!("Displaying search results");
let answer = let user_id = auth.current_user.ok_or_else(|| ApiError::AuthRequired)?.id;
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
.await?; let answer = get_answer_with_references(
&state.surreal_db_client,
&state.openai_client,
&query.query,
&user_id,
)
.await?;
let output = state let output = state
.tera .tera
+4 -1
View File
@@ -30,7 +30,8 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
description: String, description: String,
entity_type: KnowledgeEntityType, entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>, metadata: Option<serde_json::Value>,
embedding: Vec<f32> embedding: Vec<f32>,
user_id: String
}); });
impl KnowledgeEntity { impl KnowledgeEntity {
@@ -41,6 +42,7 @@ impl KnowledgeEntity {
entity_type: KnowledgeEntityType, entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>, metadata: Option<serde_json::Value>,
embedding: Vec<f32>, embedding: Vec<f32>,
user_id: String,
) -> Self { ) -> Self {
Self { Self {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
@@ -50,6 +52,7 @@ impl KnowledgeEntity {
entity_type, entity_type,
metadata, metadata,
embedding, embedding,
user_id,
} }
} }
} }
+4 -2
View File
@@ -4,16 +4,18 @@ use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", { stored_object!(TextChunk, "text_chunk", {
source_id: String, source_id: String,
chunk: String, chunk: String,
embedding: Vec<f32> embedding: Vec<f32>,
user_id: String
}); });
impl TextChunk { impl TextChunk {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>) -> Self { pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
Self { Self {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
source_id, source_id,
chunk, chunk,
embedding, embedding,
user_id,
} }
} }
} }
+4 -1
View File
@@ -8,7 +8,8 @@ stored_object!(TextContent, "text_content", {
text: String, text: String,
file_info: Option<FileInfo>, file_info: Option<FileInfo>,
instructions: String, instructions: String,
category: String category: String,
user_id: String
}); });
impl TextContent { impl TextContent {
@@ -17,6 +18,7 @@ impl TextContent {
instructions: String, instructions: String,
category: String, category: String,
file_info: Option<FileInfo>, file_info: Option<FileInfo>,
user_id: String,
) -> Self { ) -> Self {
Self { Self {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
@@ -24,6 +26,7 @@ impl TextContent {
file_info, file_info,
instructions, instructions,
category, category,
user_id,
} }
} }
} }
+3
View File
@@ -0,0 +1,3 @@
\[x\] add user_id to ingress objects
\[x\] restrict retrieval to users own objects
\[\] web frontend stuff