mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-12 17:24:26 +02:00
user restricted to own objects
This commit is contained in:
@@ -69,6 +69,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
app_state.surreal_db_client.build_indexes().await?;
|
||||||
|
|
||||||
// Create Axum router
|
// Create Axum router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.nest("/api/v1", api_routes_v1(&app_state))
|
.nest("/api/v1", api_routes_v1(&app_state))
|
||||||
|
|||||||
@@ -66,6 +66,8 @@ pub enum ApiError {
|
|||||||
UserAlreadyExists,
|
UserAlreadyExists,
|
||||||
#[error("User was not found")]
|
#[error("User was not found")]
|
||||||
UserNotFound,
|
UserNotFound,
|
||||||
|
#[error("You must provide valid credentials")]
|
||||||
|
AuthRequired,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for ApiError {
|
impl IntoResponse for ApiError {
|
||||||
@@ -78,6 +80,7 @@ impl IntoResponse for ApiError {
|
|||||||
ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
ApiError::OpenAIerror(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
||||||
ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
ApiError::QueryError(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||||
ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()),
|
ApiError::UserAlreadyExists => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||||
|
ApiError::AuthRequired => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||||
ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()),
|
ApiError::UserNotFound => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||||
ApiError::IngressContentError(_) => {
|
ApiError::IngressContentError(_) => {
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
|
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
|
||||||
|
|||||||
@@ -37,9 +37,10 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
category: &str,
|
category: &str,
|
||||||
instructions: &str,
|
instructions: &str,
|
||||||
text: &str,
|
text: &str,
|
||||||
|
user_id: &str,
|
||||||
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
||||||
let similar_entities = self
|
let similar_entities = self
|
||||||
.find_similar_entities(category, instructions, text)
|
.find_similar_entities(category, instructions, text, user_id)
|
||||||
.await?;
|
.await?;
|
||||||
let llm_request =
|
let llm_request =
|
||||||
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
|
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
|
||||||
@@ -51,13 +52,20 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
category: &str,
|
category: &str,
|
||||||
instructions: &str,
|
instructions: &str,
|
||||||
text: &str,
|
text: &str,
|
||||||
|
user_id: &str,
|
||||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
||||||
let input_text = format!(
|
let input_text = format!(
|
||||||
"content: {}, category: {}, user_instructions: {}",
|
"content: {}, category: {}, user_instructions: {}",
|
||||||
text, category, instructions
|
text, category, instructions
|
||||||
);
|
);
|
||||||
|
|
||||||
combined_knowledge_entity_retrieval(self.db_client, self.openai_client, &input_text).await
|
combined_knowledge_entity_retrieval(
|
||||||
|
self.db_client,
|
||||||
|
self.openai_client,
|
||||||
|
&input_text,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_llm_request(
|
fn prepare_llm_request(
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ impl LLMGraphAnalysisResult {
|
|||||||
pub async fn to_database_entities(
|
pub async fn to_database_entities(
|
||||||
&self,
|
&self,
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
|
user_id: &str,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
|
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
|
||||||
// Create mapper and pre-assign IDs
|
// Create mapper and pre-assign IDs
|
||||||
@@ -60,7 +61,7 @@ impl LLMGraphAnalysisResult {
|
|||||||
|
|
||||||
// Process entities
|
// Process entities
|
||||||
let entities = self
|
let entities = self
|
||||||
.process_entities(source_id, Arc::clone(&mapper), openai_client)
|
.process_entities(source_id, user_id, Arc::clone(&mapper), openai_client)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Process relationships
|
// Process relationships
|
||||||
@@ -83,6 +84,7 @@ impl LLMGraphAnalysisResult {
|
|||||||
async fn process_entities(
|
async fn process_entities(
|
||||||
&self,
|
&self,
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
|
user_id: &str,
|
||||||
mapper: Arc<Mutex<GraphMapper>>,
|
mapper: Arc<Mutex<GraphMapper>>,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
||||||
@@ -93,10 +95,12 @@ impl LLMGraphAnalysisResult {
|
|||||||
let mapper = Arc::clone(&mapper);
|
let mapper = Arc::clone(&mapper);
|
||||||
let openai_client = openai_client.clone();
|
let openai_client = openai_client.clone();
|
||||||
let source_id = source_id.to_string();
|
let source_id = source_id.to_string();
|
||||||
|
let user_id = user_id.to_string();
|
||||||
let entity = entity.clone();
|
let entity = entity.clone();
|
||||||
|
|
||||||
task::spawn(async move {
|
task::spawn(async move {
|
||||||
create_single_entity(&entity, &source_id, mapper, &openai_client).await
|
create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client)
|
||||||
|
.await
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@@ -135,6 +139,7 @@ impl LLMGraphAnalysisResult {
|
|||||||
async fn create_single_entity(
|
async fn create_single_entity(
|
||||||
llm_entity: &LLMKnowledgeEntity,
|
llm_entity: &LLMKnowledgeEntity,
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
|
user_id: &str,
|
||||||
mapper: Arc<Mutex<GraphMapper>>,
|
mapper: Arc<Mutex<GraphMapper>>,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
) -> Result<KnowledgeEntity, ProcessingError> {
|
) -> Result<KnowledgeEntity, ProcessingError> {
|
||||||
@@ -168,5 +173,6 @@ async fn create_single_entity(
|
|||||||
source_id: source_id.to_string(),
|
source_id: source_id.to_string(),
|
||||||
metadata: None,
|
metadata: None,
|
||||||
embedding,
|
embedding,
|
||||||
|
user_id: user_id.into(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ impl ContentProcessor {
|
|||||||
|
|
||||||
// Convert analysis to objects
|
// Convert analysis to objects
|
||||||
let (entities, relationships) = analysis
|
let (entities, relationships) = analysis
|
||||||
.to_database_entities(&content.id, &self.openai_client)
|
.to_database_entities(&content.id, &content.user_id, &self.openai_client)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Store everything
|
// Store everything
|
||||||
@@ -68,7 +68,12 @@ impl ContentProcessor {
|
|||||||
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
|
||||||
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
|
let analyser = IngressAnalyzer::new(&self.db_client, &self.openai_client);
|
||||||
analyser
|
analyser
|
||||||
.analyze_content(&content.category, &content.instructions, &content.text)
|
.analyze_content(
|
||||||
|
&content.category,
|
||||||
|
&content.instructions,
|
||||||
|
&content.text,
|
||||||
|
&content.user_id,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,7 +107,12 @@ impl ContentProcessor {
|
|||||||
// Could potentially process chunks in parallel with a bounded concurrent limit
|
// Could potentially process chunks in parallel with a bounded concurrent limit
|
||||||
for chunk in chunks {
|
for chunk in chunks {
|
||||||
let embedding = generate_embedding(&self.openai_client, chunk).await?;
|
let embedding = generate_embedding(&self.openai_client, chunk).await?;
|
||||||
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
|
let text_chunk = TextChunk::new(
|
||||||
|
content.id.to_string(),
|
||||||
|
chunk.to_string(),
|
||||||
|
embedding,
|
||||||
|
content.user_id.to_string(),
|
||||||
|
);
|
||||||
store_item(&self.db_client, text_chunk).await?;
|
store_item(&self.db_client, text_chunk).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ pub enum IngressContentError {
|
|||||||
pub async fn create_ingress_objects(
|
pub async fn create_ingress_objects(
|
||||||
input: IngressInput,
|
input: IngressInput,
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
|
user_id: &str,
|
||||||
) -> Result<Vec<IngressObject>, IngressContentError> {
|
) -> Result<Vec<IngressObject>, IngressContentError> {
|
||||||
// Initialize list
|
// Initialize list
|
||||||
let mut object_list = Vec::new();
|
let mut object_list = Vec::new();
|
||||||
@@ -69,6 +70,7 @@ pub async fn create_ingress_objects(
|
|||||||
url: url.to_string(),
|
url: url.to_string(),
|
||||||
instructions: input.instructions.clone(),
|
instructions: input.instructions.clone(),
|
||||||
category: input.category.clone(),
|
category: input.category.clone(),
|
||||||
|
user_id: user_id.into(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -77,6 +79,7 @@ pub async fn create_ingress_objects(
|
|||||||
text: input_content.to_string(),
|
text: input_content.to_string(),
|
||||||
instructions: input.instructions.clone(),
|
instructions: input.instructions.clone(),
|
||||||
category: input.category.clone(),
|
category: input.category.clone(),
|
||||||
|
user_id: user_id.into(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -90,6 +93,7 @@ pub async fn create_ingress_objects(
|
|||||||
file_info,
|
file_info,
|
||||||
instructions: input.instructions.clone(),
|
instructions: input.instructions.clone(),
|
||||||
category: input.category.clone(),
|
category: input.category.clone(),
|
||||||
|
user_id: user_id.into(),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
info!("No file with id: {}", id);
|
info!("No file with id: {}", id);
|
||||||
|
|||||||
@@ -10,16 +10,19 @@ pub enum IngressObject {
|
|||||||
url: String,
|
url: String,
|
||||||
instructions: String,
|
instructions: String,
|
||||||
category: String,
|
category: String,
|
||||||
|
user_id: String,
|
||||||
},
|
},
|
||||||
Text {
|
Text {
|
||||||
text: String,
|
text: String,
|
||||||
instructions: String,
|
instructions: String,
|
||||||
category: String,
|
category: String,
|
||||||
|
user_id: String,
|
||||||
},
|
},
|
||||||
File {
|
File {
|
||||||
file_info: FileInfo,
|
file_info: FileInfo,
|
||||||
instructions: String,
|
instructions: String,
|
||||||
category: String,
|
category: String,
|
||||||
|
user_id: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,6 +40,7 @@ impl IngressObject {
|
|||||||
url,
|
url,
|
||||||
instructions,
|
instructions,
|
||||||
category,
|
category,
|
||||||
|
user_id,
|
||||||
} => {
|
} => {
|
||||||
let text = Self::fetch_text_from_url(url).await?;
|
let text = Self::fetch_text_from_url(url).await?;
|
||||||
Ok(TextContent::new(
|
Ok(TextContent::new(
|
||||||
@@ -44,22 +48,26 @@ impl IngressObject {
|
|||||||
instructions.into(),
|
instructions.into(),
|
||||||
category.into(),
|
category.into(),
|
||||||
None,
|
None,
|
||||||
|
user_id.into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
IngressObject::Text {
|
IngressObject::Text {
|
||||||
text,
|
text,
|
||||||
instructions,
|
instructions,
|
||||||
category,
|
category,
|
||||||
|
user_id,
|
||||||
} => Ok(TextContent::new(
|
} => Ok(TextContent::new(
|
||||||
text.into(),
|
text.into(),
|
||||||
instructions.into(),
|
instructions.into(),
|
||||||
category.into(),
|
category.into(),
|
||||||
None,
|
None,
|
||||||
|
user_id.into(),
|
||||||
)),
|
)),
|
||||||
IngressObject::File {
|
IngressObject::File {
|
||||||
file_info,
|
file_info,
|
||||||
instructions,
|
instructions,
|
||||||
category,
|
category,
|
||||||
|
user_id,
|
||||||
} => {
|
} => {
|
||||||
let text = Self::extract_text_from_file(file_info).await?;
|
let text = Self::extract_text_from_file(file_info).await?;
|
||||||
Ok(TextContent::new(
|
Ok(TextContent::new(
|
||||||
@@ -67,6 +75,7 @@ impl IngressObject {
|
|||||||
instructions.into(),
|
instructions.into(),
|
||||||
category.into(),
|
category.into(),
|
||||||
Some(file_info.to_owned()),
|
Some(file_info.to_owned()),
|
||||||
|
user_id.into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ use surrealdb::{engine::any::Any, Surreal};
|
|||||||
/// * `db_client` - SurrealDB client for database operations
|
/// * `db_client` - SurrealDB client for database operations
|
||||||
/// * `openai_client` - OpenAI client for vector embeddings generation
|
/// * `openai_client` - OpenAI client for vector embeddings generation
|
||||||
/// * `query` - The search query string to find relevant knowledge entities
|
/// * `query` - The search query string to find relevant knowledge entities
|
||||||
|
/// * 'user_id' - The user id of the current user
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<Vec<KnowledgeEntity>, ProcessingError>` - A deduplicated vector of relevant
|
/// * `Result<Vec<KnowledgeEntity>, ProcessingError>` - A deduplicated vector of relevant
|
||||||
@@ -37,6 +38,7 @@ pub async fn combined_knowledge_entity_retrieval(
|
|||||||
db_client: &Surreal<Any>,
|
db_client: &Surreal<Any>,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
query: &str,
|
query: &str,
|
||||||
|
user_id: &str,
|
||||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
||||||
// info!("Received input: {:?}", query);
|
// info!("Received input: {:?}", query);
|
||||||
|
|
||||||
@@ -47,6 +49,7 @@ pub async fn combined_knowledge_entity_retrieval(
|
|||||||
db_client,
|
db_client,
|
||||||
"knowledge_entity".to_string(),
|
"knowledge_entity".to_string(),
|
||||||
openai_client,
|
openai_client,
|
||||||
|
user_id,
|
||||||
),
|
),
|
||||||
find_items_by_vector_similarity(
|
find_items_by_vector_similarity(
|
||||||
5,
|
5,
|
||||||
@@ -54,6 +57,7 @@ pub async fn combined_knowledge_entity_retrieval(
|
|||||||
db_client,
|
db_client,
|
||||||
"text_chunk".to_string(),
|
"text_chunk".to_string(),
|
||||||
openai_client,
|
openai_client,
|
||||||
|
user_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@@ -9,11 +9,12 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
|
|||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `take`: The number of items to retrieve from the database.
|
/// * `take` - The number of items to retrieve from the database.
|
||||||
/// * `input_text`: The text to generate embeddings for.
|
/// * `input_text` - The text to generate embeddings for.
|
||||||
/// * `db_client`: The SurrealDB client to use for querying the database.
|
/// * `db_client` - The SurrealDB client to use for querying the database.
|
||||||
/// * `table`: The table to query in the database.
|
/// * `table` - The table to query in the database.
|
||||||
/// * `openai_client`: The OpenAI client to use for generating embeddings.
|
/// * `openai_client` - The OpenAI client to use for generating embeddings.
|
||||||
|
/// * 'user_id`- The user id of the current user.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
@@ -21,13 +22,14 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
|
|||||||
///
|
///
|
||||||
/// # Type Parameters
|
/// # Type Parameters
|
||||||
///
|
///
|
||||||
/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||||
pub async fn find_items_by_vector_similarity<T>(
|
pub async fn find_items_by_vector_similarity<T>(
|
||||||
take: u8,
|
take: u8,
|
||||||
input_text: &str,
|
input_text: &str,
|
||||||
db_client: &Surreal<Any>,
|
db_client: &Surreal<Any>,
|
||||||
table: String,
|
table: String,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
|
user_id: &str,
|
||||||
) -> Result<Vec<T>, ProcessingError>
|
) -> Result<Vec<T>, ProcessingError>
|
||||||
where
|
where
|
||||||
T: for<'de> serde::Deserialize<'de>,
|
T: for<'de> serde::Deserialize<'de>,
|
||||||
@@ -36,7 +38,7 @@ where
|
|||||||
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
||||||
|
|
||||||
// Construct the query
|
// Construct the query
|
||||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance", table, take, input_embedding);
|
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
|
||||||
|
|
||||||
// Perform query and deserialize to struct
|
// Perform query and deserialize to struct
|
||||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ pub async fn api_auth(
|
|||||||
mut request: Request,
|
mut request: Request,
|
||||||
next: Next,
|
next: Next,
|
||||||
) -> Result<Response, ApiError> {
|
) -> Result<Response, ApiError> {
|
||||||
let api_key = extract_api_key(&request).ok_or(ApiError::UserNotFound)?;
|
let api_key = extract_api_key(&request).ok_or(ApiError::AuthRequired)?;
|
||||||
|
|
||||||
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
|
let user = User::find_by_api_key(&api_key, &state.surreal_db_client).await?;
|
||||||
let user = user.ok_or(ApiError::UserNotFound)?;
|
let user = user.ok_or(ApiError::UserNotFound)?;
|
||||||
|
|||||||
@@ -2,18 +2,20 @@ use crate::{
|
|||||||
error::ApiError,
|
error::ApiError,
|
||||||
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
|
ingress::types::ingress_input::{create_ingress_objects, IngressInput},
|
||||||
server::AppState,
|
server::AppState,
|
||||||
|
storage::types::user::User,
|
||||||
};
|
};
|
||||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
pub async fn ingress_handler(
|
pub async fn ingress_handler(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
|
Extension(user): Extension<User>,
|
||||||
Json(input): Json<IngressInput>,
|
Json(input): Json<IngressInput>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
info!("Received input: {:?}", input);
|
info!("Received input: {:?}", input);
|
||||||
|
|
||||||
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client).await?;
|
let ingress_objects = create_ingress_objects(input, &state.surreal_db_client, &user.id).await?;
|
||||||
|
|
||||||
let futures: Vec<_> = ingress_objects
|
let futures: Vec<_> = ingress_objects
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|||||||
@@ -33,9 +33,13 @@ pub async fn query_handler(
|
|||||||
info!("Received input: {:?}", query);
|
info!("Received input: {:?}", query);
|
||||||
info!("{:?}", user);
|
info!("{:?}", user);
|
||||||
|
|
||||||
let answer =
|
let answer = get_answer_with_references(
|
||||||
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
|
&state.surreal_db_client,
|
||||||
.await?;
|
&state.openai_client,
|
||||||
|
&query.query,
|
||||||
|
&user.id,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(
|
Ok(
|
||||||
Json(serde_json::json!({"answer": answer.content, "references": answer.references}))
|
Json(serde_json::json!({"answer": answer.content, "references": answer.references}))
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use async_openai::types::{
|
|||||||
ResponseFormat, ResponseFormatJsonSchema,
|
ResponseFormat, ResponseFormatJsonSchema,
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::ApiError,
|
||||||
@@ -66,6 +67,7 @@ use super::{
|
|||||||
/// * `surreal_db_client` - Client for SurrealDB interactions
|
/// * `surreal_db_client` - Client for SurrealDB interactions
|
||||||
/// * `openai_client` - Client for OpenAI API calls
|
/// * `openai_client` - Client for OpenAI API calls
|
||||||
/// * `query` - The user's query string
|
/// * `query` - The user's query string
|
||||||
|
/// * `user_id` - The user's id
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
@@ -80,11 +82,14 @@ pub async fn get_answer_with_references(
|
|||||||
surreal_db_client: &SurrealDbClient,
|
surreal_db_client: &SurrealDbClient,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
query: &str,
|
query: &str,
|
||||||
|
user_id: &str,
|
||||||
) -> Result<Answer, ApiError> {
|
) -> Result<Answer, ApiError> {
|
||||||
let entities =
|
let entities =
|
||||||
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query).await?;
|
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let entities_json = format_entities_json(&entities);
|
let entities_json = format_entities_json(&entities);
|
||||||
|
debug!("{:?}", entities_json);
|
||||||
let user_message = create_user_message(&entities_json, query);
|
let user_message = create_user_message(&entities_json, query);
|
||||||
|
|
||||||
let request = create_chat_request(user_message)?;
|
let request = create_chat_request(user_message)?;
|
||||||
|
|||||||
@@ -2,14 +2,18 @@ use axum::{
|
|||||||
extract::{Query, State},
|
extract::{Query, State},
|
||||||
response::Html,
|
response::Html,
|
||||||
};
|
};
|
||||||
|
use axum_session_auth::AuthSession;
|
||||||
|
use axum_session_surreal::SessionSurrealPool;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use surrealdb::{engine::any::Any, Surreal};
|
||||||
use tera::Context;
|
use tera::Context;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::ApiError,
|
||||||
server::{routes::query::helper::get_answer_with_references, AppState},
|
server::{routes::query::helper::get_answer_with_references, AppState},
|
||||||
|
storage::types::user::User,
|
||||||
};
|
};
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct SearchParams {
|
pub struct SearchParams {
|
||||||
@@ -19,12 +23,19 @@ pub struct SearchParams {
|
|||||||
pub async fn search_result_handler(
|
pub async fn search_result_handler(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Query(query): Query<SearchParams>,
|
Query(query): Query<SearchParams>,
|
||||||
|
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||||
) -> Result<Html<String>, ApiError> {
|
) -> Result<Html<String>, ApiError> {
|
||||||
info!("Displaying search results");
|
info!("Displaying search results");
|
||||||
|
|
||||||
let answer =
|
let user_id = auth.current_user.ok_or_else(|| ApiError::AuthRequired)?.id;
|
||||||
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
|
|
||||||
.await?;
|
let answer = get_answer_with_references(
|
||||||
|
&state.surreal_db_client,
|
||||||
|
&state.openai_client,
|
||||||
|
&query.query,
|
||||||
|
&user_id,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let output = state
|
let output = state
|
||||||
.tera
|
.tera
|
||||||
|
|||||||
@@ -30,7 +30,8 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
|
|||||||
description: String,
|
description: String,
|
||||||
entity_type: KnowledgeEntityType,
|
entity_type: KnowledgeEntityType,
|
||||||
metadata: Option<serde_json::Value>,
|
metadata: Option<serde_json::Value>,
|
||||||
embedding: Vec<f32>
|
embedding: Vec<f32>,
|
||||||
|
user_id: String
|
||||||
});
|
});
|
||||||
|
|
||||||
impl KnowledgeEntity {
|
impl KnowledgeEntity {
|
||||||
@@ -41,6 +42,7 @@ impl KnowledgeEntity {
|
|||||||
entity_type: KnowledgeEntityType,
|
entity_type: KnowledgeEntityType,
|
||||||
metadata: Option<serde_json::Value>,
|
metadata: Option<serde_json::Value>,
|
||||||
embedding: Vec<f32>,
|
embedding: Vec<f32>,
|
||||||
|
user_id: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
@@ -50,6 +52,7 @@ impl KnowledgeEntity {
|
|||||||
entity_type,
|
entity_type,
|
||||||
metadata,
|
metadata,
|
||||||
embedding,
|
embedding,
|
||||||
|
user_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,16 +4,18 @@ use uuid::Uuid;
|
|||||||
stored_object!(TextChunk, "text_chunk", {
|
stored_object!(TextChunk, "text_chunk", {
|
||||||
source_id: String,
|
source_id: String,
|
||||||
chunk: String,
|
chunk: String,
|
||||||
embedding: Vec<f32>
|
embedding: Vec<f32>,
|
||||||
|
user_id: String
|
||||||
});
|
});
|
||||||
|
|
||||||
impl TextChunk {
|
impl TextChunk {
|
||||||
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>) -> Self {
|
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
source_id,
|
source_id,
|
||||||
chunk,
|
chunk,
|
||||||
embedding,
|
embedding,
|
||||||
|
user_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ stored_object!(TextContent, "text_content", {
|
|||||||
text: String,
|
text: String,
|
||||||
file_info: Option<FileInfo>,
|
file_info: Option<FileInfo>,
|
||||||
instructions: String,
|
instructions: String,
|
||||||
category: String
|
category: String,
|
||||||
|
user_id: String
|
||||||
});
|
});
|
||||||
|
|
||||||
impl TextContent {
|
impl TextContent {
|
||||||
@@ -17,6 +18,7 @@ impl TextContent {
|
|||||||
instructions: String,
|
instructions: String,
|
||||||
category: String,
|
category: String,
|
||||||
file_info: Option<FileInfo>,
|
file_info: Option<FileInfo>,
|
||||||
|
user_id: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
@@ -24,6 +26,7 @@ impl TextContent {
|
|||||||
file_info,
|
file_info,
|
||||||
instructions,
|
instructions,
|
||||||
category,
|
category,
|
||||||
|
user_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user