user restricted to own objects

This commit is contained in:
Per Stark
2024-12-15 22:52:34 +01:00
parent ae4781363f
commit 291c473d00
18 changed files with 109 additions and 28 deletions

View File

@@ -69,6 +69,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.await?;
app_state.surreal_db_client.build_indexes().await?;
// Create Axum router
let app = Router::new()
.nest("/api/v1", api_routes_v1(&app_state))

View File

@@ -66,6 +66,8 @@ pub enum ApiError {
UserAlreadyExists,
#[error("User was not found")]
UserNotFound,
#[error("You must provide valid credentials")]
AuthRequired,
}
impl IntoResponse for ApiError {
@@ -78,6 +80,7 @@ impl IntoResponse for ApiError {
ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::AuthRequired => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::IngressContentError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())

View File

@@ -37,9 +37,10 @@ impl<'a> IngressAnalyzer<'a> {
category: &str,
instructions: &str,
text: &str,
user_id: &str,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let similar_entities = self
.find_similar_entities(category, instructions, text)
.find_similar_entities(category, instructions, text, user_id)
.await?;
let llm_request =
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
@@ -51,13 +52,20 @@ impl<'a> IngressAnalyzer<'a> {
category: &str,
instructions: &str,
text: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
let input_text = format!(
"content: {}, category: {}, user_instructions: {}",
text, category, instructions
);
combined_knowledge_entity_retrieval(self.db_client, self.openai_client, &input_text).await
combined_knowledge_entity_retrieval(
self.db_client,
self.openai_client,
&input_text,
user_id,
)
.await
}
fn prepare_llm_request(

View File

@@ -53,6 +53,7 @@ impl LLMGraphAnalysisResult {
pub async fn to_database_entities(
&self,
source_id: &str,
user_id: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
// Create mapper and pre-assign IDs
@@ -60,7 +61,7 @@ impl LLMGraphAnalysisResult {
// Process entities
let entities = self
.process_entities(source_id, Arc::clone(&mapper), openai_client)
.process_entities(source_id, user_id, Arc::clone(&mapper), openai_client)
.await?;
// Process relationships
@@ -83,6 +84,7 @@ impl LLMGraphAnalysisResult {
async fn process_entities(
&self,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
@@ -93,10 +95,12 @@ impl LLMGraphAnalysisResult {
let mapper = Arc::clone(&mapper);
let openai_client = openai_client.clone();
let source_id = source_id.to_string();
let user_id = user_id.to_string();
let entity = entity.clone();
task::spawn(async move {
create_single_entity(&entity, &source_id, mapper, &openai_client).await
create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client)
.await
})
})
.collect();
@@ -135,6 +139,7 @@ impl LLMGraphAnalysisResult {
async fn create_single_entity(
llm_entity: &LLMKnowledgeEntity,
source_id: &str,
user_id: &str,
mapper: Arc<Mutex<GraphMapper>>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<KnowledgeEntity, ProcessingError> {
@@ -168,5 +173,6 @@ async fn create_single_entity(
source_id: source_id.to_string(),
metadata: None,
embedding,
user_id: user_id.into(),
})
}

View File

@@ -49,7 +49,7 @@ impl ContentProcessor {
// Convert analysis to objects
let (entities, relationships) = analysis
.to_database_entities(&content.id, &self.openai_client)
.to_database_entities(&content.id, &content.user_id, &self.openai_client)
.await?;
// Store everything
@@ -68,7 +68,12 @@ impl ContentProcessor {
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
analyser
.analyze_content(&content.category, &content.instructions, &content.text)
.analyze_content(
&content.category,
&content.instructions,
&content.text,
&content.user_id,
)
.await
}
@@ -102,7 +107,12 @@ impl ContentProcessor {
// Could potentially process chunks in parallel with a bounded concurrent limit
for chunk in chunks {
let embedding = generate_embedding(&self.openai_client, chunk).await?;
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
let text_chunk = TextChunk::new(
content.id.to_string(),
chunk.to_string(),
embedding,
content.user_id.to_string(),
);
store_item(&self.db_client, text_chunk).await?;
}

View File

@@ -56,6 +56,7 @@ pub enum IngressContentError {
pub async fn create_ingress_objects(
input: IngressInput,
db_client: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<IngressObject>, IngressContentError> {
// Initialize list
let mut object_list = Vec::new();
@@ -69,6 +70,7 @@ pub async fn create_ingress_objects(
url: url.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
}
Err(_) => {
@@ -77,6 +79,7 @@ pub async fn create_ingress_objects(
text: input_content.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
}
}
@@ -90,6 +93,7 @@ pub async fn create_ingress_objects(
file_info,
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
} else {
info!("No file with id: {}", id);

View File

@@ -10,16 +10,19 @@ pub enum IngressObject {
url: String,
instructions: String,
category: String,
user_id: String,
},
Text {
text: String,
instructions: String,
category: String,
user_id: String,
},
File {
file_info: FileInfo,
instructions: String,
category: String,
user_id: String,
},
}
@@ -37,6 +40,7 @@ impl IngressObject {
url,
instructions,
category,
user_id,
} => {
let text = Self::fetch_text_from_url(url).await?;
Ok(TextContent::new(
@@ -44,22 +48,26 @@ impl IngressObject {
instructions.into(),
category.into(),
None,
user_id.into(),
))
}
IngressObject::Text {
text,
instructions,
category,
user_id,
} => Ok(TextContent::new(
text.into(),
instructions.into(),
category.into(),
None,
user_id.into(),
)),
IngressObject::File {
file_info,
instructions,
category,
user_id,
} => {
let text = Self::extract_text_from_file(file_info).await?;
Ok(TextContent::new(
@@ -67,6 +75,7 @@ impl IngressObject {
instructions.into(),
category.into(),
Some(file_info.to_owned()),
user_id.into(),
))
}
}

View File

@@ -29,6 +29,7 @@ use surrealdb::{engine::any::Any, Surreal};
/// * `db_client` - SurrealDB client for database operations
/// * `openai_client` - OpenAI client for vector embeddings generation
/// * `query` - The search query string to find relevant knowledge entities
/// * 'user_id' - The user id of the current user
///
/// # Returns
/// * `Result<Vec<KnowledgeEntity>, ProcessingError>` - A deduplicated vector of relevant
@@ -37,6 +38,7 @@ pub async fn combined_knowledge_entity_retrieval(
db_client: &Surreal<Any>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
// info!("Received input: {:?}", query);
@@ -47,6 +49,7 @@ pub async fn combined_knowledge_entity_retrieval(
db_client,
"knowledge_entity".to_string(),
openai_client,
user_id,
),
find_items_by_vector_similarity(
5,
@@ -54,6 +57,7 @@ pub async fn combined_knowledge_entity_retrieval(
db_client,
"text_chunk".to_string(),
openai_client,
user_id,
),
)
.await?;

View File

@@ -9,11 +9,12 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
///
/// # Arguments
///
/// * `take`: The number of items to retrieve from the database.
/// * `input_text`: The text to generate embeddings for.
/// * `db_client`: The SurrealDB client to use for querying the database.
/// * `table`: The table to query in the database.
/// * `openai_client`: The OpenAI client to use for generating embeddings.
/// * `take` - The number of items to retrieve from the database.
/// * `input_text` - The text to generate embeddings for.
/// * `db_client` - The SurrealDB client to use for querying the database.
/// * `table` - The table to query in the database.
/// * `openai_client` - The OpenAI client to use for generating embeddings.
/// * 'user_id`- The user id of the current user.
///
/// # Returns
///
@@ -21,13 +22,14 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
///
/// # Type Parameters
///
/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`.
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: u8,
input_text: &str,
db_client: &Surreal<Any>,
table: String,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<T>, ProcessingError>
where
T: for<'de> serde::Deserialize<'de>,
@@ -36,7 +38,7 @@ where
let input_embedding = generate_embedding(openai_client, input_text).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance", table, take, input_embedding);
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
// Perform query and deserialize to struct
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;

View File

@@ -13,7 +13,7 @@ pub async fn api_auth(
mut request: Request,
next: Next,
) -> Result<Response, ApiError> {
let api_key = extract_api_key(&request).ok_or(ApiError::UserNotFound)?;
let api_key = extract_api_key(&request).ok_or(ApiError::AuthRequired)?;
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
let user = user.ok_or(ApiError::UserNotFound)?;

View File

@@ -2,18 +2,20 @@ use crate::{
error::ApiError,
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
server::AppState,
storage::types::user::User,
};
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
use futures::future::try_join_all;
use tracing::info;
pub async fn ingress_handler(
State(state): State<AppState>,
Extension(user): Extension<User>,
Json(input): Json<IngressInput>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", input);
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client).await?;
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client, &user.id).await?;
let futures: Vec<_> = ingress_objects
.into_iter()

View File

@@ -33,9 +33,13 @@ pub async fn query_handler(
info!("Received input: {:?}", query);
info!("{:?}", user);
let answer =
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
.await?;
let answer = get_answer_with_references(
&state.surreal_db_client,
&state.openai_client,
&query.query,
&user.id,
)
.await?;
Ok(
Json(serde_json::json!({"answer": answer.content, "references": answer.references}))

View File

@@ -4,6 +4,7 @@ use async_openai::types::{
ResponseFormat, ResponseFormatJsonSchema,
};
use serde_json::{json, Value};
use tracing::debug;
use crate::{
error::ApiError,
@@ -66,6 +67,7 @@ use super::{
/// * `surreal_db_client` - Client for SurrealDB interactions
/// * `openai_client` - Client for OpenAI API calls
/// * `query` - The user's query string
/// * `user_id` - The user's id
///
/// # Returns
///
@@ -80,11 +82,14 @@ pub async fn get_answer_with_references(
surreal_db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Answer, ApiError> {
let entities =
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query).await?;
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
.await?;
let entities_json = format_entities_json(&entities);
debug!("{:?}", entities_json);
let user_message = create_user_message(&entities_json, query);
let request = create_chat_request(user_message)?;

View File

@@ -2,14 +2,18 @@ use axum::{
extract::{Query, State},
response::Html,
};
use axum_session_auth::AuthSession;
use axum_session_surreal::SessionSurrealPool;
use serde::Deserialize;
use serde_json::json;
use surrealdb::{engine::any::Any, Surreal};
use tera::Context;
use tracing::info;
use crate::{
error::ApiError,
server::{routes::query::helper::get_answer_with_references, AppState},
storage::types::user::User,
};
#[derive(Deserialize)]
pub struct SearchParams {
@@ -19,12 +23,19 @@ pub struct SearchParams {
pub async fn search_result_handler(
State(state): State<AppState>,
Query(query): Query<SearchParams>,
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
) -> Result<Html<String>, ApiError> {
info!("Displaying search results");
let answer =
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
.await?;
let user_id = auth.current_user.ok_or_else(|| ApiError::AuthRequired)?.id;
let answer = get_answer_with_references(
&state.surreal_db_client,
&state.openai_client,
&query.query,
&user_id,
)
.await?;
let output = state
.tera

View File

@@ -30,7 +30,8 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>
embedding: Vec<f32>,
user_id: String
});
impl KnowledgeEntity {
@@ -41,6 +42,7 @@ impl KnowledgeEntity {
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
@@ -50,6 +52,7 @@ impl KnowledgeEntity {
entity_type,
metadata,
embedding,
user_id,
}
}
}

View File

@@ -4,16 +4,18 @@ use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
source_id: String,
chunk: String,
embedding: Vec<f32>
embedding: Vec<f32>,
user_id: String
});
impl TextChunk {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>) -> Self {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
Self {
id: Uuid::new_v4().to_string(),
source_id,
chunk,
embedding,
user_id,
}
}
}

View File

@@ -8,7 +8,8 @@ stored_object!(TextContent, "text_content", {
text: String,
file_info: Option<FileInfo>,
instructions: String,
category: String
category: String,
user_id: String
});
impl TextContent {
@@ -17,6 +18,7 @@ impl TextContent {
instructions: String,
category: String,
file_info: Option<FileInfo>,
user_id: String,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
@@ -24,6 +26,7 @@ impl TextContent {
file_info,
instructions,
category,
user_id,
}
}
}

3
todo.md Normal file
View File

@@ -0,0 +1,3 @@
\[x\] add user_id to ingress objects
\[x\] restrict retrieval to users own objects
\[\] web frontend stuff