diff --git a/src/bin/server.rs b/src/bin/server.rs index 1465b59..f9d9113 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -69,6 +69,8 @@ async fn main() -> Result<(), Box> { ) .await?; + app_state.surreal_db_client.build_indexes().await?; + // Create Axum router let app = Router::new() .nest("/api/v1", api_routes_v1(&app_state)) diff --git a/src/error.rs b/src/error.rs index 5c5b9b2..07cf92d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -66,6 +66,8 @@ pub enum ApiError { UserAlreadyExists, #[error("User was not found")] UserNotFound, + #[error("You must provide valid credentials")] + AuthRequired, } impl IntoResponse for ApiError { @@ -78,6 +80,7 @@ impl IntoResponse for ApiError { ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()), ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()), + ApiError::AuthRequired => (StatusCode::BAD_REQUEST, self.to_string()), ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()), ApiError::IngressContentError(_) => { (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()) diff --git a/src/ingress/analysis/ingress_analyser.rs b/src/ingress/analysis/ingress_analyser.rs index 6d47f69..f10752c 100644 --- a/src/ingress/analysis/ingress_analyser.rs +++ b/src/ingress/analysis/ingress_analyser.rs @@ -37,9 +37,10 @@ impl<'a> IngressAnalyzer<'a> { category: &str, instructions: &str, text: &str, + user_id: &str, ) -> Result { let similar_entities = self - .find_similar_entities(category, instructions, text) + .find_similar_entities(category, instructions, text, user_id) .await?; let llm_request = self.prepare_llm_request(category, instructions, text, &similar_entities)?; @@ -51,13 +52,20 @@ impl<'a> IngressAnalyzer<'a> { category: &str, instructions: &str, text: &str, + user_id: &str, ) -> Result, ProcessingError> { let input_text = format!( "content: {}, category: {}, user_instructions: {}", text, category, instructions ); - combined_knowledge_entity_retrieval(self.db_client, self.openai_client, &input_text).await + combined_knowledge_entity_retrieval( + self.db_client, + self.openai_client, + &input_text, + user_id, + ) + .await } fn prepare_llm_request( diff --git a/src/ingress/analysis/types/llm_analysis_result.rs b/src/ingress/analysis/types/llm_analysis_result.rs index 69c7923..b8e3401 100644 --- a/src/ingress/analysis/types/llm_analysis_result.rs +++ b/src/ingress/analysis/types/llm_analysis_result.rs @@ -53,6 +53,7 @@ impl LLMGraphAnalysisResult { pub async fn to_database_entities( &self, source_id: &str, + user_id: &str, openai_client: &async_openai::Client, ) -> Result<(Vec, Vec), ProcessingError> { // Create mapper and pre-assign IDs @@ -60,7 +61,7 @@ impl LLMGraphAnalysisResult { // Process entities let entities = self - .process_entities(source_id, Arc::clone(&mapper), openai_client) + .process_entities(source_id, user_id, Arc::clone(&mapper), openai_client) .await?; // Process relationships @@ -83,6 +84,7 @@ impl LLMGraphAnalysisResult { async fn process_entities( &self, source_id: &str, + user_id: &str, mapper: Arc>, openai_client: &async_openai::Client, ) -> Result, ProcessingError> { @@ -93,10 +95,12 @@ impl LLMGraphAnalysisResult { let mapper = Arc::clone(&mapper); let openai_client = openai_client.clone(); let source_id = source_id.to_string(); + let user_id = user_id.to_string(); let entity = entity.clone(); task::spawn(async move { - create_single_entity(&entity, &source_id, mapper, &openai_client).await + create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client) + .await }) }) .collect(); @@ -135,6 +139,7 @@ impl LLMGraphAnalysisResult { async fn create_single_entity( llm_entity: &LLMKnowledgeEntity, source_id: &str, + user_id: &str, mapper: Arc>, openai_client: &async_openai::Client, ) -> Result { @@ -168,5 +173,6 @@ async fn create_single_entity( source_id: source_id.to_string(), metadata: None, embedding, + user_id: user_id.into(), }) } diff --git a/src/ingress/content_processor.rs b/src/ingress/content_processor.rs index 03dec71..045a961 100644 --- a/src/ingress/content_processor.rs +++ b/src/ingress/content_processor.rs @@ -49,7 +49,7 @@ impl ContentProcessor { // Convert analysis to objects let (entities, relationships) = analysis - .to_database_entities(&content.id, &self.openai_client) + .to_database_entities(&content.id, &content.user_id, &self.openai_client) .await?; // Store everything @@ -68,7 +68,12 @@ impl ContentProcessor { ) -> Result { let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client); analyser - .analyze_content(&content.category, &content.instructions, &content.text) + .analyze_content( + &content.category, + &content.instructions, + &content.text, + &content.user_id, + ) .await } @@ -102,7 +107,12 @@ impl ContentProcessor { // Could potentially process chunks in parallel with a bounded concurrent limit for chunk in chunks { let embedding = generate_embedding(&self.openai_client, chunk).await?; - let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding); + let text_chunk = TextChunk::new( + content.id.to_string(), + chunk.to_string(), + embedding, + content.user_id.to_string(), + ); store_item(&self.db_client, text_chunk).await?; } diff --git a/src/ingress/types/ingress_input.rs b/src/ingress/types/ingress_input.rs index 49ba4ff..ed09cfa 100644 --- a/src/ingress/types/ingress_input.rs +++ b/src/ingress/types/ingress_input.rs @@ -56,6 +56,7 @@ pub enum IngressContentError { pub async fn create_ingress_objects( input: IngressInput, db_client: &SurrealDbClient, + user_id: &str, ) -> Result, IngressContentError> { // Initialize list let mut object_list = Vec::new(); @@ -69,6 +70,7 @@ pub async fn create_ingress_objects( url: url.to_string(), instructions: input.instructions.clone(), category: input.category.clone(), + user_id: user_id.into(), }); } Err(_) => { @@ -77,6 +79,7 @@ pub async fn create_ingress_objects( text: input_content.to_string(), instructions: input.instructions.clone(), category: input.category.clone(), + user_id: user_id.into(), }); } } @@ -90,6 +93,7 @@ pub async fn create_ingress_objects( file_info, instructions: input.instructions.clone(), category: input.category.clone(), + user_id: user_id.into(), }); } else { info!("No file with id: {}", id); diff --git a/src/ingress/types/ingress_object.rs b/src/ingress/types/ingress_object.rs index bd45279..0c35f45 100644 --- a/src/ingress/types/ingress_object.rs +++ b/src/ingress/types/ingress_object.rs @@ -10,16 +10,19 @@ pub enum IngressObject { url: String, instructions: String, category: String, + user_id: String, }, Text { text: String, instructions: String, category: String, + user_id: String, }, File { file_info: FileInfo, instructions: String, category: String, + user_id: String, }, } @@ -37,6 +40,7 @@ impl IngressObject { url, instructions, category, + user_id, } => { let text = Self::fetch_text_from_url(url).await?; Ok(TextContent::new( @@ -44,22 +48,26 @@ impl IngressObject { instructions.into(), category.into(), None, + user_id.into(), )) } IngressObject::Text { text, instructions, category, + user_id, } => Ok(TextContent::new( text.into(), instructions.into(), category.into(), None, + user_id.into(), )), IngressObject::File { file_info, instructions, category, + user_id, } => { let text = Self::extract_text_from_file(file_info).await?; Ok(TextContent::new( @@ -67,6 +75,7 @@ impl IngressObject { instructions.into(), category.into(), Some(file_info.to_owned()), + user_id.into(), )) } } diff --git a/src/retrieval/mod.rs b/src/retrieval/mod.rs index 8f9ba91..d9cde99 100644 --- a/src/retrieval/mod.rs +++ b/src/retrieval/mod.rs @@ -29,6 +29,7 @@ use surrealdb::{engine::any::Any, Surreal}; /// * `db_client` - SurrealDB client for database operations /// * `openai_client` - OpenAI client for vector embeddings generation /// * `query` - The search query string to find relevant knowledge entities +/// * 'user_id' - The user id of the current user /// /// # Returns /// * `Result, ProcessingError>` - A deduplicated vector of relevant @@ -37,6 +38,7 @@ pub async fn combined_knowledge_entity_retrieval( db_client: &Surreal, openai_client: &async_openai::Client, query: &str, + user_id: &str, ) -> Result, ProcessingError> { // info!("Received input: {:?}", query); @@ -47,6 +49,7 @@ pub async fn combined_knowledge_entity_retrieval( db_client, "knowledge_entity".to_string(), openai_client, + user_id, ), find_items_by_vector_similarity( 5, @@ -54,6 +57,7 @@ pub async fn combined_knowledge_entity_retrieval( db_client, "text_chunk".to_string(), openai_client, + user_id, ), ) .await?; diff --git a/src/retrieval/vector.rs b/src/retrieval/vector.rs index 7a131f9..38ee87d 100644 --- a/src/retrieval/vector.rs +++ b/src/retrieval/vector.rs @@ -9,11 +9,12 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding}; /// /// # Arguments /// -/// * `take`: The number of items to retrieve from the database. -/// * `input_text`: The text to generate embeddings for. -/// * `db_client`: The SurrealDB client to use for querying the database. -/// * `table`: The table to query in the database. -/// * `openai_client`: The OpenAI client to use for generating embeddings. +/// * `take` - The number of items to retrieve from the database. +/// * `input_text` - The text to generate embeddings for. +/// * `db_client` - The SurrealDB client to use for querying the database. +/// * `table` - The table to query in the database. +/// * `openai_client` - The OpenAI client to use for generating embeddings. +/// * 'user_id`- The user id of the current user. /// /// # Returns /// @@ -21,13 +22,14 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding}; /// /// # Type Parameters /// -/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`. +/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`. pub async fn find_items_by_vector_similarity( take: u8, input_text: &str, db_client: &Surreal, table: String, openai_client: &async_openai::Client, + user_id: &str, ) -> Result, ProcessingError> where T: for<'de> serde::Deserialize<'de>, @@ -36,7 +38,7 @@ where let input_embedding = generate_embedding(openai_client, input_text).await?; // Construct the query - let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance", table, take, input_embedding); + let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id); // Perform query and deserialize to struct let closest_entities: Vec = db_client.query(closest_query).await?.take(0)?; diff --git a/src/server/middleware_api_auth.rs b/src/server/middleware_api_auth.rs index 995a904..324ba07 100644 --- a/src/server/middleware_api_auth.rs +++ b/src/server/middleware_api_auth.rs @@ -13,7 +13,7 @@ pub async fn api_auth( mut request: Request, next: Next, ) -> Result { - let api_key = extract_api_key(&request).ok_or(ApiError::UserNotFound)?; + let api_key = extract_api_key(&request).ok_or(ApiError::AuthRequired)?; let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?; let user = user.ok_or(ApiError::UserNotFound)?; diff --git a/src/server/routes/ingress.rs b/src/server/routes/ingress.rs index cb80e1f..eb33f25 100644 --- a/src/server/routes/ingress.rs +++ b/src/server/routes/ingress.rs @@ -2,18 +2,20 @@ use crate::{ error::ApiError, ingress::types::ingress_input::{create_ingress_objects, IngressInput}, server::AppState, + storage::types::user::User, }; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; use futures::future::try_join_all; use tracing::info; pub async fn ingress_handler( State(state): State, + Extension(user): Extension, Json(input): Json, ) -> Result { info!("Received input: {:?}", input); - let ingress_objects = create_ingress_objects(input, &state.surreal_db_client).await?; + let ingress_objects = create_ingress_objects(input, &state.surreal_db_client, &user.id).await?; let futures: Vec<_> = ingress_objects .into_iter() diff --git a/src/server/routes/query.rs b/src/server/routes/query.rs index 574941c..c4ef0f3 100644 --- a/src/server/routes/query.rs +++ b/src/server/routes/query.rs @@ -33,9 +33,13 @@ pub async fn query_handler( info!("Received input: {:?}", query); info!("{:?}", user); - let answer = - get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query) - .await?; + let answer = get_answer_with_references( + &state.surreal_db_client, + &state.openai_client, + &query.query, + &user.id, + ) + .await?; Ok( Json(serde_json::json!({"answer": answer.content, "references": answer.references})) diff --git a/src/server/routes/query/helper.rs b/src/server/routes/query/helper.rs index f281b5a..525706a 100644 --- a/src/server/routes/query/helper.rs +++ b/src/server/routes/query/helper.rs @@ -4,6 +4,7 @@ use async_openai::types::{ ResponseFormat, ResponseFormatJsonSchema, }; use serde_json::{json, Value}; +use tracing::debug; use crate::{ error::ApiError, @@ -66,6 +67,7 @@ use super::{ /// * `surreal_db_client` - Client for SurrealDB interactions /// * `openai_client` - Client for OpenAI API calls /// * `query` - The user's query string +/// * `user_id` - The user's id /// /// # Returns /// @@ -80,11 +82,14 @@ pub async fn get_answer_with_references( surreal_db_client: &SurrealDbClient, openai_client: &async_openai::Client, query: &str, + user_id: &str, ) -> Result { let entities = - combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query).await?; + combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id) + .await?; let entities_json = format_entities_json(&entities); + debug!("{:?}", entities_json); let user_message = create_user_message(&entities_json, query); let request = create_chat_request(user_message)?; diff --git a/src/server/routes/search_result.rs b/src/server/routes/search_result.rs index 7809076..b348ce3 100644 --- a/src/server/routes/search_result.rs +++ b/src/server/routes/search_result.rs @@ -2,14 +2,18 @@ use axum::{ extract::{Query, State}, response::Html, }; +use axum_session_auth::AuthSession; +use axum_session_surreal::SessionSurrealPool; use serde::Deserialize; use serde_json::json; +use surrealdb::{engine::any::Any, Surreal}; use tera::Context; use tracing::info; use crate::{ error::ApiError, server::{routes::query::helper::get_answer_with_references, AppState}, + storage::types::user::User, }; #[derive(Deserialize)] pub struct SearchParams { @@ -19,12 +23,19 @@ pub struct SearchParams { pub async fn search_result_handler( State(state): State, Query(query): Query, + auth: AuthSession, Surreal>, ) -> Result, ApiError> { info!("Displaying search results"); - let answer = - get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query) - .await?; + let user_id = auth.current_user.ok_or_else(|| ApiError::AuthRequired)?.id; + + let answer = get_answer_with_references( + &state.surreal_db_client, + &state.openai_client, + &query.query, + &user_id, + ) + .await?; let output = state .tera diff --git a/src/storage/types/knowledge_entity.rs b/src/storage/types/knowledge_entity.rs index 8419c6a..a6536f1 100644 --- a/src/storage/types/knowledge_entity.rs +++ b/src/storage/types/knowledge_entity.rs @@ -30,7 +30,8 @@ stored_object!(KnowledgeEntity, "knowledge_entity", { description: String, entity_type: KnowledgeEntityType, metadata: Option, - embedding: Vec + embedding: Vec, + user_id: String }); impl KnowledgeEntity { @@ -41,6 +42,7 @@ impl KnowledgeEntity { entity_type: KnowledgeEntityType, metadata: Option, embedding: Vec, + user_id: String, ) -> Self { Self { id: Uuid::new_v4().to_string(), @@ -50,6 +52,7 @@ impl KnowledgeEntity { entity_type, metadata, embedding, + user_id, } } } diff --git a/src/storage/types/text_chunk.rs b/src/storage/types/text_chunk.rs index 691b15b..2983650 100644 --- a/src/storage/types/text_chunk.rs +++ b/src/storage/types/text_chunk.rs @@ -4,16 +4,18 @@ use uuid::Uuid; stored_object!(TextChunk, "text_chunk", { source_id: String, chunk: String, - embedding: Vec + embedding: Vec, + user_id: String }); impl TextChunk { - pub fn new(source_id: String, chunk: String, embedding: Vec) -> Self { + pub fn new(source_id: String, chunk: String, embedding: Vec, user_id: String) -> Self { Self { id: Uuid::new_v4().to_string(), source_id, chunk, embedding, + user_id, } } } diff --git a/src/storage/types/text_content.rs b/src/storage/types/text_content.rs index b1a86f4..abb7a96 100644 --- a/src/storage/types/text_content.rs +++ b/src/storage/types/text_content.rs @@ -8,7 +8,8 @@ stored_object!(TextContent, "text_content", { text: String, file_info: Option, instructions: String, - category: String + category: String, + user_id: String }); impl TextContent { @@ -17,6 +18,7 @@ impl TextContent { instructions: String, category: String, file_info: Option, + user_id: String, ) -> Self { Self { id: Uuid::new_v4().to_string(), @@ -24,6 +26,7 @@ impl TextContent { file_info, instructions, category, + user_id, } } } diff --git a/todo.md b/todo.md new file mode 100644 index 0000000..9124ad1 --- /dev/null +++ b/todo.md @@ -0,0 +1,3 @@ +\[x\] add user_id to ingress objects +\[x\] restrict retrieval to users own objects +\[\] web frontend stuff