mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-30 22:32:07 +02:00
wip: heavy refactoring html routers
This commit is contained in:
@@ -1,15 +1,4 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_openai::error::OpenAIError;
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{Html, IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use minijinja::context;
|
||||
use minijinja_autoreload::AutoReloader;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinError;
|
||||
|
||||
@@ -49,206 +38,3 @@ pub enum AppError {
|
||||
#[error("Ingress Processing error: {0}")]
|
||||
Processing(String),
|
||||
}
|
||||
|
||||
// API-specific errors
|
||||
#[derive(Debug, Serialize)]
|
||||
pub enum ApiError {
|
||||
InternalError(String),
|
||||
ValidationError(String),
|
||||
NotFound(String),
|
||||
Unauthorized(String),
|
||||
}
|
||||
|
||||
impl From<AppError> for ApiError {
|
||||
fn from(err: AppError) -> Self {
|
||||
match err {
|
||||
AppError::Database(_) | AppError::OpenAI(_) | AppError::Email(_) => {
|
||||
tracing::error!("Internal error: {:?}", err);
|
||||
ApiError::InternalError("Internal server error".to_string())
|
||||
}
|
||||
AppError::NotFound(msg) => ApiError::NotFound(msg),
|
||||
AppError::Validation(msg) => ApiError::ValidationError(msg),
|
||||
AppError::Auth(msg) => ApiError::Unauthorized(msg),
|
||||
_ => ApiError::InternalError("Internal server error".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, body) = match self {
|
||||
ApiError::InternalError(message) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
json!({
|
||||
"error": message,
|
||||
"status": "error"
|
||||
}),
|
||||
),
|
||||
ApiError::ValidationError(message) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
json!({
|
||||
"error": message,
|
||||
"status": "error"
|
||||
}),
|
||||
),
|
||||
ApiError::NotFound(message) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
json!({
|
||||
"error": message,
|
||||
"status": "error"
|
||||
}),
|
||||
),
|
||||
ApiError::Unauthorized(message) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({
|
||||
"error": message,
|
||||
"status": "error"
|
||||
}),
|
||||
), // ... other matches
|
||||
};
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub type TemplateResult<T> = Result<T, HtmlError>;
|
||||
|
||||
// Helper trait for converting to HtmlError with templates
|
||||
pub trait IntoHtmlError {
|
||||
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError;
|
||||
}
|
||||
// // Implement for AppError
|
||||
impl IntoHtmlError for AppError {
|
||||
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
|
||||
HtmlError::new(self, templates)
|
||||
}
|
||||
}
|
||||
// // Implement for minijinja::Error directly
|
||||
impl IntoHtmlError for minijinja::Error {
|
||||
fn with_template(self, templates: Arc<AutoReloader>) -> HtmlError {
|
||||
HtmlError::from_template_error(self, templates)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ErrorContext {
|
||||
#[allow(dead_code)]
|
||||
templates: Arc<AutoReloader>,
|
||||
}
|
||||
|
||||
impl ErrorContext {
|
||||
pub fn new(templates: Arc<AutoReloader>) -> Self {
|
||||
Self { templates }
|
||||
}
|
||||
}
|
||||
|
||||
pub enum HtmlError {
|
||||
ServerError(Arc<AutoReloader>),
|
||||
NotFound(Arc<AutoReloader>),
|
||||
Unauthorized(Arc<AutoReloader>),
|
||||
BadRequest(String, Arc<AutoReloader>),
|
||||
Template(String, Arc<AutoReloader>),
|
||||
}
|
||||
|
||||
impl HtmlError {
|
||||
pub fn new(error: AppError, templates: Arc<AutoReloader>) -> Self {
|
||||
match error {
|
||||
AppError::NotFound(_msg) => HtmlError::NotFound(templates),
|
||||
AppError::Auth(_msg) => HtmlError::Unauthorized(templates),
|
||||
AppError::Validation(msg) => HtmlError::BadRequest(msg, templates),
|
||||
_ => {
|
||||
tracing::error!("Internal error: {:?}", error);
|
||||
HtmlError::ServerError(templates)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_template_error(error: minijinja::Error, templates: Arc<AutoReloader>) -> Self {
|
||||
tracing::error!("Template error: {:?}", error);
|
||||
HtmlError::Template(error.to_string(), templates)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for HtmlError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, context, templates) = match self {
|
||||
HtmlError::ServerError(templates) | HtmlError::Template(_, templates) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
context! {
|
||||
status_code => 500,
|
||||
title => "Internal Server Error",
|
||||
error => "Internal Server Error",
|
||||
description => "Something went wrong on our end."
|
||||
},
|
||||
templates,
|
||||
),
|
||||
HtmlError::NotFound(templates) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
context! {
|
||||
status_code => 404,
|
||||
title => "Page Not Found",
|
||||
error => "Not Found",
|
||||
description => "The page you're looking for doesn't exist or was removed."
|
||||
},
|
||||
templates,
|
||||
),
|
||||
HtmlError::Unauthorized(templates) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
context! {
|
||||
status_code => 401,
|
||||
title => "Unauthorized",
|
||||
error => "Access Denied",
|
||||
description => "You need to be logged in to access this page."
|
||||
},
|
||||
templates,
|
||||
),
|
||||
HtmlError::BadRequest(msg, templates) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
context! {
|
||||
status_code => 400,
|
||||
title => "Bad Request",
|
||||
error => "Bad Request",
|
||||
description => msg
|
||||
},
|
||||
templates,
|
||||
),
|
||||
};
|
||||
|
||||
let html = match templates.acquire_env() {
|
||||
Ok(env) => match env.get_template("errors/error.html") {
|
||||
Ok(tmpl) => match tmpl.render(context) {
|
||||
Ok(output) => output,
|
||||
Err(e) => {
|
||||
tracing::error!("Template render error: {:?}", e);
|
||||
Self::fallback_html()
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Template get error: {:?}", e);
|
||||
Self::fallback_html()
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Environment acquire error: {:?}", e);
|
||||
Self::fallback_html()
|
||||
}
|
||||
};
|
||||
|
||||
(status, Html(html)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl HtmlError {
|
||||
fn fallback_html() -> String {
|
||||
r#"
|
||||
<html>
|
||||
<body>
|
||||
<div class="container mx-auto p-4">
|
||||
<h1 class="text-4xl text-error">Error</h1>
|
||||
<p class="mt-4">Sorry, something went wrong displaying this page.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"#
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
use crate::{
|
||||
error::AppError,
|
||||
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
|
||||
retrieval::combined_knowledge_entity_retrieval,
|
||||
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
|
||||
};
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
|
||||
ResponseFormatJsonSchema,
|
||||
},
|
||||
};
|
||||
use serde_json::json;
|
||||
use tracing::debug;
|
||||
|
||||
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
|
||||
|
||||
pub struct IngressAnalyzer<'a> {
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
}
|
||||
|
||||
impl<'a> IngressAnalyzer<'a> {
|
||||
pub fn new(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn analyze_content(
|
||||
&self,
|
||||
category: &str,
|
||||
instructions: &str,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<LLMGraphAnalysisResult, AppError> {
|
||||
let similar_entities = self
|
||||
.find_similar_entities(category, instructions, text, user_id)
|
||||
.await?;
|
||||
let llm_request =
|
||||
self.prepare_llm_request(category, instructions, text, &similar_entities)?;
|
||||
self.perform_analysis(llm_request).await
|
||||
}
|
||||
|
||||
async fn find_similar_entities(
|
||||
&self,
|
||||
category: &str,
|
||||
instructions: &str,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let input_text = format!(
|
||||
"content: {}, category: {}, user_instructions: {}",
|
||||
text, category, instructions
|
||||
);
|
||||
|
||||
combined_knowledge_entity_retrieval(
|
||||
self.db_client,
|
||||
self.openai_client,
|
||||
&input_text,
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn prepare_llm_request(
|
||||
&self,
|
||||
category: &str,
|
||||
instructions: &str,
|
||||
text: &str,
|
||||
similar_entities: &[KnowledgeEntity],
|
||||
) -> Result<CreateChatCompletionRequest, OpenAIError> {
|
||||
let entities_json = json!(similar_entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entity.id,
|
||||
"name": entity.name,
|
||||
"description": entity.description
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>());
|
||||
|
||||
let user_message = format!(
|
||||
"Category:\n{}\nInstructions:\n{}\nContent:\n{}\nExisting KnowledgeEntities in database:\n{}",
|
||||
category, instructions, text, entities_json
|
||||
);
|
||||
|
||||
debug!("Prepared LLM request message: {}", user_message);
|
||||
|
||||
let response_format = ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: Some("Structured analysis of the submitted content".into()),
|
||||
name: "content_analysis".into(),
|
||||
schema: Some(get_ingress_analysis_schema()),
|
||||
strict: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.model("gpt-4o-mini")
|
||||
.temperature(0.2)
|
||||
.max_tokens(3048u32)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
])
|
||||
.response_format(response_format)
|
||||
.build()
|
||||
}
|
||||
|
||||
async fn perform_analysis(
|
||||
&self,
|
||||
request: CreateChatCompletionRequest,
|
||||
) -> Result<LLMGraphAnalysisResult, AppError> {
|
||||
let response = self.openai_client.chat().create(request).await?;
|
||||
debug!("Received LLM response: {:?}", response);
|
||||
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".to_string(),
|
||||
))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!(
|
||||
"Failed to parse LLM response into analysis: {}",
|
||||
e
|
||||
))
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
pub mod ingress_analyser;
|
||||
pub mod prompt;
|
||||
pub mod types;
|
||||
@@ -1,81 +0,0 @@
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub static INGRESS_ANALYSIS_SYSTEM_MESSAGE: &str = r#"
|
||||
You are an AI assistant. You will receive a text content, along with user instructions and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users instructions and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
|
||||
|
||||
The JSON should have the following structure:
|
||||
|
||||
{
|
||||
"knowledge_entities": [
|
||||
{
|
||||
"key": "unique-key-1",
|
||||
"name": "Entity Name",
|
||||
"description": "A detailed description of the entity.",
|
||||
"entity_type": "TypeOfEntity"
|
||||
},
|
||||
// More entities...
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"type": "RelationshipType",
|
||||
"source": "unique-key-1 or UUID from existing database",
|
||||
"target": "unique-key-1 or UUID from existing database"
|
||||
},
|
||||
// More relationships...
|
||||
]
|
||||
}
|
||||
|
||||
Guidelines:
|
||||
1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.
|
||||
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
|
||||
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
|
||||
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
|
||||
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
|
||||
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
|
||||
7. Only create relationships between existing KnowledgeEntities.
|
||||
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
|
||||
9. A new relationship MUST include a newly created KnowledgeEntity.
|
||||
"#;
|
||||
|
||||
pub fn get_ingress_analysis_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"knowledge_entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": { "type": "string" },
|
||||
"name": { "type": "string" },
|
||||
"description": { "type": "string" },
|
||||
"entity_type": {
|
||||
"type": "string",
|
||||
"enum": ["idea", "project", "document", "page", "textsnippet"]
|
||||
}
|
||||
},
|
||||
"required": ["key", "name", "description", "entity_type"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"relationships": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["RelatedTo", "RelevantTo", "SimilarTo"]
|
||||
},
|
||||
"source": { "type": "string" },
|
||||
"target": { "type": "string" }
|
||||
},
|
||||
"required": ["type", "source", "target"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["knowledge_entities", "relationships"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Intermediate struct to hold mapping between LLM keys and generated IDs.
|
||||
#[derive(Clone)]
|
||||
pub struct GraphMapper {
|
||||
pub key_to_id: HashMap<String, Uuid>,
|
||||
}
|
||||
|
||||
impl Default for GraphMapper {
|
||||
fn default() -> Self {
|
||||
GraphMapper::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphMapper {
|
||||
pub fn new() -> Self {
|
||||
GraphMapper {
|
||||
key_to_id: HashMap::new(),
|
||||
}
|
||||
}
|
||||
/// Get ID, tries to parse UUID
|
||||
pub fn get_or_parse_id(&mut self, key: &str) -> Uuid {
|
||||
if let Ok(parsed_uuid) = Uuid::parse_str(key) {
|
||||
parsed_uuid
|
||||
} else {
|
||||
*self.key_to_id.get(key).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a new UUID for a given key.
|
||||
pub fn assign_id(&mut self, key: &str) -> Uuid {
|
||||
let id = Uuid::new_v4();
|
||||
self.key_to_id.insert(key.to_string(), id);
|
||||
id
|
||||
}
|
||||
|
||||
/// Retrieves the UUID for a given key.
|
||||
pub fn get_id(&self, key: &str) -> Option<&Uuid> {
|
||||
self.key_to_id.get(key)
|
||||
}
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::task;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use futures::future::try_join_all;
|
||||
|
||||
use super::graph_mapper::GraphMapper; // For future parallelization
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMKnowledgeEntity {
|
||||
pub key: String, // Temporary identifier
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub entity_type: String, // Should match KnowledgeEntityType variants
|
||||
}
|
||||
|
||||
/// Represents a single relationship from the LLM.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMRelationship {
|
||||
#[serde(rename = "type")]
|
||||
pub type_: String, // e.g., RelatedTo, RelevantTo
|
||||
pub source: String, // Key of the source entity
|
||||
pub target: String, // Key of the target entity
|
||||
}
|
||||
|
||||
/// Represents the entire graph analysis result from the LLM.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMGraphAnalysisResult {
|
||||
pub knowledge_entities: Vec<LLMKnowledgeEntity>,
|
||||
pub relationships: Vec<LLMRelationship>,
|
||||
}
|
||||
|
||||
impl LLMGraphAnalysisResult {
|
||||
/// Converts the LLM graph analysis result into database entities and relationships.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - A UUID representing the source identifier.
|
||||
/// * `openai_client` - OpenAI client for LLM calls.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
|
||||
pub async fn to_database_entities(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
|
||||
// Create mapper and pre-assign IDs
|
||||
let mapper = Arc::new(Mutex::new(self.create_mapper()?));
|
||||
|
||||
// Process entities
|
||||
let entities = self
|
||||
.process_entities(source_id, user_id, Arc::clone(&mapper), openai_client)
|
||||
.await?;
|
||||
|
||||
// Process relationships
|
||||
let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?;
|
||||
|
||||
Ok((entities, relationships))
|
||||
}
|
||||
|
||||
fn create_mapper(&self) -> Result<GraphMapper, AppError> {
|
||||
let mut mapper = GraphMapper::new();
|
||||
|
||||
// Pre-assign all IDs
|
||||
for entity in &self.knowledge_entities {
|
||||
mapper.assign_id(&entity.key);
|
||||
}
|
||||
|
||||
Ok(mapper)
|
||||
}
|
||||
|
||||
async fn process_entities(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let futures: Vec<_> = self
|
||||
.knowledge_entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
let mapper = Arc::clone(&mapper);
|
||||
let openai_client = openai_client.clone();
|
||||
let source_id = source_id.to_string();
|
||||
let user_id = user_id.to_string();
|
||||
let entity = entity.clone();
|
||||
|
||||
task::spawn(async move {
|
||||
create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client)
|
||||
.await
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = try_join_all(futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn process_relationships(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
) -> Result<Vec<KnowledgeRelationship>, AppError> {
|
||||
let mut mapper_guard = mapper
|
||||
.lock()
|
||||
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
|
||||
self.relationships
|
||||
.iter()
|
||||
.map(|rel| {
|
||||
let source_db_id = mapper_guard.get_or_parse_id(&rel.source);
|
||||
let target_db_id = mapper_guard.get_or_parse_id(&rel.target);
|
||||
|
||||
Ok(KnowledgeRelationship::new(
|
||||
source_db_id.to_string(),
|
||||
target_db_id.to_string(),
|
||||
user_id.to_string(),
|
||||
source_id.to_string(),
|
||||
rel.type_.clone(),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
async fn create_single_entity(
|
||||
llm_entity: &LLMKnowledgeEntity,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
mapper: Arc<Mutex<GraphMapper>>,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<KnowledgeEntity, AppError> {
|
||||
let assigned_id = {
|
||||
let mapper = mapper
|
||||
.lock()
|
||||
.map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?;
|
||||
mapper
|
||||
.get_id(&llm_entity.key)
|
||||
.ok_or_else(|| {
|
||||
AppError::GraphMapper(format!("ID not found for key: {}", llm_entity.key))
|
||||
})?
|
||||
.to_string()
|
||||
};
|
||||
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {}",
|
||||
llm_entity.name, llm_entity.description, llm_entity.entity_type
|
||||
);
|
||||
|
||||
let embedding = generate_embedding(openai_client, &embedding_input).await?;
|
||||
|
||||
let now = Utc::now();
|
||||
Ok(KnowledgeEntity {
|
||||
id: assigned_id,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
name: llm_entity.name.to_string(),
|
||||
description: llm_entity.description.to_string(),
|
||||
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.to_string()),
|
||||
source_id: source_id.to_string(),
|
||||
metadata: None,
|
||||
embedding,
|
||||
user_id: user_id.into(),
|
||||
})
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
pub mod graph_mapper;
|
||||
pub mod llm_analysis_result;
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
pub mod analysis;
|
||||
pub mod content_processor;
|
||||
@@ -1,5 +1,3 @@
|
||||
pub mod error;
|
||||
pub mod ingress;
|
||||
pub mod retrieval;
|
||||
pub mod storage;
|
||||
pub mod utils;
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
use surrealdb::Error;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
|
||||
|
||||
/// Retrieves database entries that match a specific source identifier.
|
||||
///
|
||||
/// This function queries the database for all records in a specified table that have
|
||||
/// a matching `source_id` field. It's commonly used to find related entities or
|
||||
/// track the origin of database entries.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - The identifier to search for in the database
|
||||
/// * `table_name` - The name of the table to search in
|
||||
/// * `db_client` - The SurrealDB client instance for database operations
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
|
||||
/// * `Err(Error)` - An error if the database query fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function will return a `Error` if:
|
||||
/// * The database query fails to execute
|
||||
/// * The results cannot be deserialized into type `T`
|
||||
pub async fn find_entities_by_source_ids<T>(
|
||||
source_id: Vec<String>,
|
||||
table_name: String,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
|
||||
|
||||
db.query(query)
|
||||
.bind(("table", table_name))
|
||||
.bind(("source_ids", source_id))
|
||||
.await?
|
||||
.take(0)
|
||||
}
|
||||
|
||||
/// Find entities by their relationship to the id
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: String,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let query = format!(
|
||||
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
|
||||
entity_id
|
||||
);
|
||||
|
||||
debug!("{}", query);
|
||||
|
||||
db.query(query).await?.take(0)
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
pub mod graph;
|
||||
pub mod query_helper;
|
||||
pub mod query_helper_prompt;
|
||||
pub mod vector;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
retrieval::{
|
||||
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
||||
vector::find_items_by_vector_similarity,
|
||||
},
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
},
|
||||
};
|
||||
use futures::future::{try_join, try_join_all};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
|
||||
/// to find the most relevant entities for a given query.
|
||||
///
|
||||
/// # Strategy
|
||||
/// The function employs a three-pronged approach to knowledge retrieval:
|
||||
/// 1. Direct vector similarity search on knowledge entities
|
||||
/// 2. Text chunk similarity search with source entity lookup
|
||||
/// 3. Graph relationship traversal from related entities
|
||||
///
|
||||
/// This combined approach ensures both semantic similarity matches and structurally
|
||||
/// related content are included in the results.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - SurrealDB client for database operations
|
||||
/// * `openai_client` - OpenAI client for vector embeddings generation
|
||||
/// * `query` - The search query string to find relevant knowledge entities
|
||||
/// * 'user_id' - The user id of the current user
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
|
||||
/// knowledge entities, or an error if the retrieval process fails
|
||||
pub async fn combined_knowledge_entity_retrieval(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
// info!("Received input: {:?}", query);
|
||||
|
||||
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
||||
find_items_by_vector_similarity(
|
||||
10,
|
||||
query,
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
openai_client,
|
||||
user_id,
|
||||
),
|
||||
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let source_ids = closest_chunks
|
||||
.iter()
|
||||
.map(|chunk: &TextChunk| chunk.source_id.clone())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
|
||||
|
||||
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
|
||||
.collect();
|
||||
|
||||
let items_from_relationships = try_join_all(items_from_relationships_futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<KnowledgeEntity>>();
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
|
||||
.into_iter()
|
||||
.chain(items_from_text_chunk_similarity.into_iter())
|
||||
.chain(items_from_relationships.into_iter())
|
||||
.fold(HashMap::new(), |mut map, entity| {
|
||||
map.insert(entity.id.clone(), entity);
|
||||
map
|
||||
})
|
||||
.into_values()
|
||||
.collect();
|
||||
|
||||
Ok(entities)
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
||||
ResponseFormat, ResponseFormatJsonSchema,
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
retrieval::combined_knowledge_entity_retrieval,
|
||||
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
|
||||
};
|
||||
|
||||
use super::query_helper_prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Reference {
|
||||
#[allow(dead_code)]
|
||||
pub reference: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LLMResponseFormat {
|
||||
pub answer: String,
|
||||
#[allow(dead_code)]
|
||||
pub references: Vec<Reference>,
|
||||
}
|
||||
|
||||
// /// Orchestrator function that takes a query and clients and returns a answer with references
|
||||
// ///
|
||||
// /// # Arguments
|
||||
// /// * `surreal_db_client` - Client for interacting with SurrealDn
|
||||
// /// * `openai_client` - Client for interacting with openai
|
||||
// /// * `query` - The query
|
||||
// ///
|
||||
// /// # Returns
|
||||
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
|
||||
// pub async fn get_answer_with_references(
|
||||
// surreal_db_client: &SurrealDbClient,
|
||||
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
// query: &str,
|
||||
// ) -> Result<(String, Vec<String>), ApiError> {
|
||||
// let entities =
|
||||
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
|
||||
|
||||
// // Format entities and create message
|
||||
// let entities_json = format_entities_json(&entities);
|
||||
// let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
// // Create and send request
|
||||
// let request = create_chat_request(user_message)?;
|
||||
// let response = openai_client
|
||||
// .chat()
|
||||
// .create(request)
|
||||
// .await
|
||||
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
|
||||
|
||||
// // Process response
|
||||
// let answer = process_llm_response(response).await?;
|
||||
|
||||
// let references: Vec<String> = answer
|
||||
// .references
|
||||
// .into_iter()
|
||||
// .map(|reference| reference.reference)
|
||||
// .collect();
|
||||
|
||||
// Ok((answer.answer, references))
|
||||
// }
|
||||
|
||||
/// Orchestrates query processing and returns an answer with references
|
||||
///
|
||||
/// Takes a query and uses the provided clients to generate an answer with supporting references.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `surreal_db_client` - Client for SurrealDB interactions
|
||||
/// * `openai_client` - Client for OpenAI API calls
|
||||
/// * `query` - The user's query string
|
||||
/// * `user_id` - The user's id
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a tuple of the answer and its references, or an API error
|
||||
#[derive(Debug)]
|
||||
pub struct Answer {
|
||||
pub content: String,
|
||||
pub references: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn get_answer_with_references(
|
||||
surreal_db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Answer, AppError> {
|
||||
let entities =
|
||||
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query, user_id)
|
||||
.await?;
|
||||
|
||||
let entities_json = format_entities_json(&entities);
|
||||
debug!("{:?}", entities_json);
|
||||
let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
let request = create_chat_request(user_message)?;
|
||||
let response = openai_client.chat().create(request).await?;
|
||||
|
||||
let llm_response = process_llm_response(response).await?;
|
||||
|
||||
Ok(Answer {
|
||||
content: llm_response.answer,
|
||||
references: llm_response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
|
||||
json!(entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entity.id,
|
||||
"name": entity.name,
|
||||
"description": entity.description
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r#"
|
||||
Context Information:
|
||||
==================
|
||||
{}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{}
|
||||
"#,
|
||||
entities_json, query
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_chat_request(
|
||||
user_message: String,
|
||||
) -> Result<CreateChatCompletionRequest, OpenAIError> {
|
||||
let response_format = ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: Some("Query answering AI".into()),
|
||||
name: "query_answering_with_uuids".into(),
|
||||
schema: Some(get_query_response_schema()),
|
||||
strict: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.model("gpt-4o-mini")
|
||||
.temperature(0.2)
|
||||
.max_tokens(3048u32)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(QUERY_SYSTEM_PROMPT).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
])
|
||||
.response_format(response_format)
|
||||
.build()
|
||||
}
|
||||
|
||||
pub async fn process_llm_response(
|
||||
response: CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, AppError> {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".into(),
|
||||
))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {}", e))
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub static QUERY_SYSTEM_PROMPT: &str = r#"
|
||||
You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
|
||||
|
||||
Your task is to:
|
||||
1. Carefully analyze the provided knowledge entities in the context
|
||||
2. Answer user questions based on this information
|
||||
3. Provide clear, concise, and accurate responses
|
||||
4. When referencing information, briefly mention which knowledge entity it came from
|
||||
5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this
|
||||
6. If only partial information is available, explain what you can answer and what information is missing
|
||||
7. Avoid making assumptions or providing information not supported by the context
|
||||
8. Output the references to the documents. Use the UUIDs and make sure they are correct!
|
||||
|
||||
Remember:
|
||||
- Be direct and honest about the limitations of your knowledge
|
||||
- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array
|
||||
- If you need to combine information from multiple entities, explain how they connect
|
||||
- Don't speculate beyond what's provided in the context
|
||||
|
||||
Example response formats:
|
||||
"Based on [Entity Name], [answer...]"
|
||||
"I found relevant information in multiple entries: [explanation...]"
|
||||
"I apologize, but the provided context doesn't contain information about [topic]"
|
||||
"#;
|
||||
|
||||
pub fn get_query_response_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": { "type": "string" },
|
||||
"references": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reference": { "type": "string" },
|
||||
},
|
||||
"required": ["reference"],
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["answer", "references"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use crate::{error::AppError, utils::embedding::generate_embedding};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
|
||||
/// and then deserializes the results into the specified type `T`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `take` - The number of items to retrieve from the database.
|
||||
/// * `input_text` - The text to generate embeddings for.
|
||||
/// * `db_client` - The SurrealDB client to use for querying the database.
|
||||
/// * `table` - The table to query in the database.
|
||||
/// * `openai_client` - The OpenAI client to use for generating embeddings.
|
||||
/// * 'user_id`- The user id of the current user.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: u8,
|
||||
input_text: &str,
|
||||
db_client: &Surreal<Any>,
|
||||
table: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<T>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
Ok(closest_entities)
|
||||
}
|
||||
Reference in New Issue
Block a user