mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-25 02:41:27 +01:00
user restricted to own objects
This commit is contained in:
@@ -69,6 +69,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
app_state.surreal_db_client.build_indexes().await?;
|
||||
|
||||
// Create Axum router
|
||||
let app = Router::new()
|
||||
.nest("/api/v1", api_routes_v1(&app_state))
|
||||
|
||||
@@ -66,6 +66,8 @@ pub enum ApiError {
|
||||
UserAlreadyExists,
|
||||
#[error("User was not found")]
|
||||
UserNotFound,
|
||||
#[error("You must provide valid credentials")]
|
||||
AuthRequired,
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
@@ -78,6 +80,7 @@ impl IntoResponse for ApiError {
|
||||
ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
||||
ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
ApiError::AuthRequired => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
ApiError::IngressContentError(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
|
||||
|
||||
@@ -37,9 +37,10 @@ impl<'a> IngressAnalyzer<'a> {
|
||||
category: &str,
|
||||
instructions: &str,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
||||
let similar_entities = self
|
||||
.find_similar_entities(category, instructions, text)
|
||||
.find_similar_entities(category, instructions, text, user_id)
|
||||
.await?;
|
||||
let llm_request =
|
||||
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
|
||||
@@ -51,13 +52,20 @@ impl<'a> IngressAnalyzer<'a> {
|
||||
category: &str,
|
||||
instructions: &str,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
||||
let input_text = format!(
|
||||
"content: {}, category: {}, user_instructions: {}",
|
||||
text, category, instructions
|
||||
);
|
||||
|
||||
combined_knowledge_entity_retrieval(self.db_client, self.openai_client, &input_text).await
|
||||
combined_knowledge_entity_retrieval(
|
||||
self.db_client,
|
||||
self.openai_client,
|
||||
&input_text,
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn prepare_llm_request(
|
||||
|
||||
@@ -53,6 +53,7 @@ impl LLMGraphAnalysisResult {
|
||||
pub async fn to_database_entities(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
|
||||
// Create mapper and pre-assign IDs
|
||||
@@ -60,7 +61,7 @@ impl LLMGraphAnalysisResult {
|
||||
|
||||
// Process entities
|
||||
let entities = self
|
||||
.process_entities(source_id, Arc::clone(&mapper), openai_client)
|
||||
.process_entities(source_id, user_id, Arc::clone(&mapper), openai_client)
|
||||
.await?;
|
||||
|
||||
// Process relationships
|
||||
@@ -83,6 +84,7 @@ impl LLMGraphAnalysisResult {
|
||||
async fn process_entities(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
||||
@@ -93,10 +95,12 @@ impl LLMGraphAnalysisResult {
|
||||
let mapper = Arc::clone(&mapper);
|
||||
let openai_client = openai_client.clone();
|
||||
let source_id = source_id.to_string();
|
||||
let user_id = user_id.to_string();
|
||||
let entity = entity.clone();
|
||||
|
||||
task::spawn(async move {
|
||||
create_single_entity(&entity, &source_id, mapper, &openai_client).await
|
||||
create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client)
|
||||
.await
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -135,6 +139,7 @@ impl LLMGraphAnalysisResult {
|
||||
async fn create_single_entity(
|
||||
llm_entity: &LLMKnowledgeEntity,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<KnowledgeEntity, ProcessingError> {
|
||||
@@ -168,5 +173,6 @@ async fn create_single_entity(
|
||||
source_id: source_id.to_string(),
|
||||
metadata: None,
|
||||
embedding,
|
||||
user_id: user_id.into(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ impl ContentProcessor {
|
||||
|
||||
// Convert analysis to objects
|
||||
let (entities, relationships) = analysis
|
||||
.to_database_entities(&content.id, &self.openai_client)
|
||||
.to_database_entities(&content.id, &content.user_id, &self.openai_client)
|
||||
.await?;
|
||||
|
||||
// Store everything
|
||||
@@ -68,7 +68,12 @@ impl ContentProcessor {
|
||||
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
||||
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
|
||||
analyser
|
||||
.analyze_content(&content.category, &content.instructions, &content.text)
|
||||
.analyze_content(
|
||||
&content.category,
|
||||
&content.instructions,
|
||||
&content.text,
|
||||
&content.user_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -102,7 +107,12 @@ impl ContentProcessor {
|
||||
// Could potentially process chunks in parallel with a bounded concurrent limit
|
||||
for chunk in chunks {
|
||||
let embedding = generate_embedding(&self.openai_client, chunk).await?;
|
||||
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
|
||||
let text_chunk = TextChunk::new(
|
||||
content.id.to_string(),
|
||||
chunk.to_string(),
|
||||
embedding,
|
||||
content.user_id.to_string(),
|
||||
);
|
||||
store_item(&self.db_client, text_chunk).await?;
|
||||
}
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ pub enum IngressContentError {
|
||||
pub async fn create_ingress_objects(
|
||||
input: IngressInput,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<IngressObject>, IngressContentError> {
|
||||
// Initialize list
|
||||
let mut object_list = Vec::new();
|
||||
@@ -69,6 +70,7 @@ pub async fn create_ingress_objects(
|
||||
url: url.to_string(),
|
||||
instructions: input.instructions.clone(),
|
||||
category: input.category.clone(),
|
||||
user_id: user_id.into(),
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
@@ -77,6 +79,7 @@ pub async fn create_ingress_objects(
|
||||
text: input_content.to_string(),
|
||||
instructions: input.instructions.clone(),
|
||||
category: input.category.clone(),
|
||||
user_id: user_id.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -90,6 +93,7 @@ pub async fn create_ingress_objects(
|
||||
file_info,
|
||||
instructions: input.instructions.clone(),
|
||||
category: input.category.clone(),
|
||||
user_id: user_id.into(),
|
||||
});
|
||||
} else {
|
||||
info!("No file with id: {}", id);
|
||||
|
||||
@@ -10,16 +10,19 @@ pub enum IngressObject {
|
||||
url: String,
|
||||
instructions: String,
|
||||
category: String,
|
||||
user_id: String,
|
||||
},
|
||||
Text {
|
||||
text: String,
|
||||
instructions: String,
|
||||
category: String,
|
||||
user_id: String,
|
||||
},
|
||||
File {
|
||||
file_info: FileInfo,
|
||||
instructions: String,
|
||||
category: String,
|
||||
user_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -37,6 +40,7 @@ impl IngressObject {
|
||||
url,
|
||||
instructions,
|
||||
category,
|
||||
user_id,
|
||||
} => {
|
||||
let text = Self::fetch_text_from_url(url).await?;
|
||||
Ok(TextContent::new(
|
||||
@@ -44,22 +48,26 @@ impl IngressObject {
|
||||
instructions.into(),
|
||||
category.into(),
|
||||
None,
|
||||
user_id.into(),
|
||||
))
|
||||
}
|
||||
IngressObject::Text {
|
||||
text,
|
||||
instructions,
|
||||
category,
|
||||
user_id,
|
||||
} => Ok(TextContent::new(
|
||||
text.into(),
|
||||
instructions.into(),
|
||||
category.into(),
|
||||
None,
|
||||
user_id.into(),
|
||||
)),
|
||||
IngressObject::File {
|
||||
file_info,
|
||||
instructions,
|
||||
category,
|
||||
user_id,
|
||||
} => {
|
||||
let text = Self::extract_text_from_file(file_info).await?;
|
||||
Ok(TextContent::new(
|
||||
@@ -67,6 +75,7 @@ impl IngressObject {
|
||||
instructions.into(),
|
||||
category.into(),
|
||||
Some(file_info.to_owned()),
|
||||
user_id.into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ use surrealdb::{engine::any::Any, Surreal};
|
||||
/// * `db_client` - SurrealDB client for database operations
|
||||
/// * `openai_client` - OpenAI client for vector embeddings generation
|
||||
/// * `query` - The search query string to find relevant knowledge entities
|
||||
/// * 'user_id' - The user id of the current user
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Vec<KnowledgeEntity>, ProcessingError>` - A deduplicated vector of relevant
|
||||
@@ -37,6 +38,7 @@ pub async fn combined_knowledge_entity_retrieval(
|
||||
db_client: &Surreal<Any>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
||||
// info!("Received input: {:?}", query);
|
||||
|
||||
@@ -47,6 +49,7 @@ pub async fn combined_knowledge_entity_retrieval(
|
||||
db_client,
|
||||
"knowledge_entity".to_string(),
|
||||
openai_client,
|
||||
user_id,
|
||||
),
|
||||
find_items_by_vector_similarity(
|
||||
5,
|
||||
@@ -54,6 +57,7 @@ pub async fn combined_knowledge_entity_retrieval(
|
||||
db_client,
|
||||
"text_chunk".to_string(),
|
||||
openai_client,
|
||||
user_id,
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -9,11 +9,12 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `take`: The number of items to retrieve from the database.
|
||||
/// * `input_text`: The text to generate embeddings for.
|
||||
/// * `db_client`: The SurrealDB client to use for querying the database.
|
||||
/// * `table`: The table to query in the database.
|
||||
/// * `openai_client`: The OpenAI client to use for generating embeddings.
|
||||
/// * `take` - The number of items to retrieve from the database.
|
||||
/// * `input_text` - The text to generate embeddings for.
|
||||
/// * `db_client` - The SurrealDB client to use for querying the database.
|
||||
/// * `table` - The table to query in the database.
|
||||
/// * `openai_client` - The OpenAI client to use for generating embeddings.
|
||||
/// * 'user_id`- The user id of the current user.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
@@ -21,13 +22,14 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: u8,
|
||||
input_text: &str,
|
||||
db_client: &Surreal<Any>,
|
||||
table: String,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<T>, ProcessingError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
@@ -36,7 +38,7 @@ where
|
||||
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance", table, take, input_embedding);
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
@@ -13,7 +13,7 @@ pub async fn api_auth(
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ApiError> {
|
||||
let api_key = extract_api_key(&request).ok_or(ApiError::UserNotFound)?;
|
||||
let api_key = extract_api_key(&request).ok_or(ApiError::AuthRequired)?;
|
||||
|
||||
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
|
||||
let user = user.ok_or(ApiError::UserNotFound)?;
|
||||
|
||||
@@ -2,18 +2,20 @@ use crate::{
|
||||
error::ApiError,
|
||||
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
|
||||
server::AppState,
|
||||
storage::types::user::User,
|
||||
};
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
|
||||
use futures::future::try_join_all;
|
||||
use tracing::info;
|
||||
|
||||
pub async fn ingress_handler(
|
||||
State(state): State<AppState>,
|
||||
Extension(user): Extension<User>,
|
||||
Json(input): Json<IngressInput>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
info!("Received input: {:?}", input);
|
||||
|
||||
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client).await?;
|
||||
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client, &user.id).await?;
|
||||
|
||||
let futures: Vec<_> = ingress_objects
|
||||
.into_iter()
|
||||
|
||||
@@ -33,9 +33,13 @@ pub async fn query_handler(
|
||||
info!("Received input: {:?}", query);
|
||||
info!("{:?}", user);
|
||||
|
||||
let answer =
|
||||
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
|
||||
.await?;
|
||||
let answer = get_answer_with_references(
|
||||
&state.surreal_db_client,
|
||||
&state.openai_client,
|
||||
&query.query,
|
||||
&user.id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(
|
||||
Json(serde_json::json!({"answer": answer.content, "references": answer.references}))
|
||||
|
||||
@@ -4,6 +4,7 @@ use async_openai::types::{
|
||||
ResponseFormat, ResponseFormatJsonSchema,
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::ApiError,
|
||||
@@ -66,6 +67,7 @@ use super::{
|
||||
/// * `surreal_db_client` - Client for SurrealDB interactions
|
||||
/// * `openai_client` - Client for OpenAI API calls
|
||||
/// * `query` - The user's query string
|
||||
/// * `user_id` - The user's id
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
@@ -80,11 +82,14 @@ pub async fn get_answer_with_references(
|
||||
surreal_db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Answer, ApiError> {
|
||||
let entities =
|
||||
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query).await?;
|
||||
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
|
||||
.await?;
|
||||
|
||||
let entities_json = format_entities_json(&entities);
|
||||
debug!("{:?}", entities_json);
|
||||
let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
let request = create_chat_request(user_message)?;
|
||||
|
||||
@@ -2,14 +2,18 @@ use axum::{
|
||||
extract::{Query, State},
|
||||
response::Html,
|
||||
};
|
||||
use axum_session_auth::AuthSession;
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use tera::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
error::ApiError,
|
||||
server::{routes::query::helper::get_answer_with_references, AppState},
|
||||
storage::types::user::User,
|
||||
};
|
||||
#[derive(Deserialize)]
|
||||
pub struct SearchParams {
|
||||
@@ -19,12 +23,19 @@ pub struct SearchParams {
|
||||
pub async fn search_result_handler(
|
||||
State(state): State<AppState>,
|
||||
Query(query): Query<SearchParams>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
) -> Result<Html<String>, ApiError> {
|
||||
info!("Displaying search results");
|
||||
|
||||
let answer =
|
||||
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
|
||||
.await?;
|
||||
let user_id = auth.current_user.ok_or_else(|| ApiError::AuthRequired)?.id;
|
||||
|
||||
let answer = get_answer_with_references(
|
||||
&state.surreal_db_client,
|
||||
&state.openai_client,
|
||||
&query.query,
|
||||
&user_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let output = state
|
||||
.tera
|
||||
|
||||
@@ -30,7 +30,8 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
|
||||
description: String,
|
||||
entity_type: KnowledgeEntityType,
|
||||
metadata: Option<serde_json::Value>,
|
||||
embedding: Vec<f32>
|
||||
embedding: Vec<f32>,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl KnowledgeEntity {
|
||||
@@ -41,6 +42,7 @@ impl KnowledgeEntity {
|
||||
entity_type: KnowledgeEntityType,
|
||||
metadata: Option<serde_json::Value>,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
@@ -50,6 +52,7 @@ impl KnowledgeEntity {
|
||||
entity_type,
|
||||
metadata,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,18 @@ use uuid::Uuid;
|
||||
stored_object!(TextChunk, "text_chunk", {
|
||||
source_id: String,
|
||||
chunk: String,
|
||||
embedding: Vec<f32>
|
||||
embedding: Vec<f32>,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextChunk {
|
||||
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>) -> Self {
|
||||
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
source_id,
|
||||
chunk,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@ stored_object!(TextContent, "text_content", {
|
||||
text: String,
|
||||
file_info: Option<FileInfo>,
|
||||
instructions: String,
|
||||
category: String
|
||||
category: String,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextContent {
|
||||
@@ -17,6 +18,7 @@ impl TextContent {
|
||||
instructions: String,
|
||||
category: String,
|
||||
file_info: Option<FileInfo>,
|
||||
user_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
@@ -24,6 +26,7 @@ impl TextContent {
|
||||
file_info,
|
||||
instructions,
|
||||
category,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user