mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-29 22:01:59 +02:00
refactor: better separation of dependencies to crates
node stuff to html crate only
This commit is contained in:
182
composite-retrieval/src/answer_retrieval.rs
Normal file
182
composite-retrieval/src/answer_retrieval.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
|
||||
ResponseFormat, ResponseFormatJsonSchema,
|
||||
},
|
||||
};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::retrieve_entities;
|
||||
|
||||
use super::answer_retrieval_helper::get_query_response_schema;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Reference {
|
||||
#[allow(dead_code)]
|
||||
pub reference: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LLMResponseFormat {
|
||||
pub answer: String,
|
||||
#[allow(dead_code)]
|
||||
pub references: Vec<Reference>,
|
||||
}
|
||||
|
||||
/// Orchestrates query processing and returns an answer with references
|
||||
///
|
||||
/// Takes a query and uses the provided clients to generate an answer with supporting references.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `surreal_db_client` - Client for SurrealDB interactions
|
||||
/// * `openai_client` - Client for OpenAI API calls
|
||||
/// * `query` - The user's query string
|
||||
/// * `user_id` - The user's id
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a tuple of the answer and its references, or an API error
|
||||
#[derive(Debug)]
|
||||
pub struct Answer {
|
||||
pub content: String,
|
||||
pub references: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn get_answer_with_references(
|
||||
surreal_db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Answer, AppError> {
|
||||
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
|
||||
let settings = SystemSettings::get_current(surreal_db_client).await?;
|
||||
|
||||
let entities_json = format_entities_json(&entities);
|
||||
let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
let request = create_chat_request(user_message, &settings)?;
|
||||
let response = openai_client.chat().create(request).await?;
|
||||
|
||||
let llm_response = process_llm_response(response).await?;
|
||||
|
||||
Ok(Answer {
|
||||
content: llm_response.answer,
|
||||
references: llm_response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
|
||||
json!(entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entity.id,
|
||||
"name": entity.name,
|
||||
"description": entity.description
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r#"
|
||||
Context Information:
|
||||
==================
|
||||
{}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{}
|
||||
"#,
|
||||
entities_json, query
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_user_message_with_history(
|
||||
entities_json: &Value,
|
||||
history: &[Message],
|
||||
query: &str,
|
||||
) -> String {
|
||||
format!(
|
||||
r#"
|
||||
Chat history:
|
||||
==================
|
||||
{}
|
||||
|
||||
Context Information:
|
||||
==================
|
||||
{}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{}
|
||||
"#,
|
||||
format_history(history),
|
||||
entities_json,
|
||||
query
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_chat_request(
|
||||
user_message: String,
|
||||
settings: &SystemSettings,
|
||||
) -> Result<CreateChatCompletionRequest, OpenAIError> {
|
||||
let response_format = ResponseFormat::JsonSchema {
|
||||
json_schema: ResponseFormatJsonSchema {
|
||||
description: Some("Query answering AI".into()),
|
||||
name: "query_answering_with_uuids".into(),
|
||||
schema: Some(get_query_response_schema()),
|
||||
strict: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.model(&settings.query_model)
|
||||
.temperature(0.2)
|
||||
.max_tokens(3048u32)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(settings.query_system_prompt.clone()).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
])
|
||||
.response_format(response_format)
|
||||
.build()
|
||||
}
|
||||
|
||||
pub async fn process_llm_response(
|
||||
response: CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, AppError> {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.ok_or(AppError::LLMParsing(
|
||||
"No content found in LLM response".into(),
|
||||
))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {}", e))
|
||||
})
|
||||
})
|
||||
}
|
||||
26
composite-retrieval/src/answer_retrieval_helper.rs
Normal file
26
composite-retrieval/src/answer_retrieval_helper.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use common::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub static QUERY_SYSTEM_PROMPT: &str = DEFAULT_QUERY_SYSTEM_PROMPT;
|
||||
|
||||
pub fn get_query_response_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": { "type": "string" },
|
||||
"references": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reference": { "type": "string" },
|
||||
},
|
||||
"required": ["reference"],
|
||||
"additionalProperties": false,
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["answer", "references"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
63
composite-retrieval/src/graph.rs
Normal file
63
composite-retrieval/src/graph.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use surrealdb::Error;
|
||||
use tracing::debug;
|
||||
|
||||
use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
|
||||
|
||||
/// Retrieves database entries that match a specific source identifier.
|
||||
///
|
||||
/// This function queries the database for all records in a specified table that have
|
||||
/// a matching `source_id` field. It's commonly used to find related entities or
|
||||
/// track the origin of database entries.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - The identifier to search for in the database
|
||||
/// * `table_name` - The name of the table to search in
|
||||
/// * `db_client` - The SurrealDB client instance for database operations
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
|
||||
/// * `Err(Error)` - An error if the database query fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function will return a `Error` if:
|
||||
/// * The database query fails to execute
|
||||
/// * The results cannot be deserialized into type `T`
|
||||
pub async fn find_entities_by_source_ids<T>(
|
||||
source_id: Vec<String>,
|
||||
table_name: String,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
|
||||
|
||||
db.query(query)
|
||||
.bind(("table", table_name))
|
||||
.bind(("source_ids", source_id))
|
||||
.await?
|
||||
.take(0)
|
||||
}
|
||||
|
||||
/// Find entities by their relationship to the id
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: String,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let query = format!(
|
||||
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
|
||||
entity_id
|
||||
);
|
||||
|
||||
debug!("{}", query);
|
||||
|
||||
db.query(query).await?.take(0)
|
||||
}
|
||||
90
composite-retrieval/src/lib.rs
Normal file
90
composite-retrieval/src/lib.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
pub mod answer_retrieval;
|
||||
pub mod answer_retrieval_helper;
|
||||
pub mod graph;
|
||||
pub mod vector;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
},
|
||||
};
|
||||
use futures::future::{try_join, try_join_all};
|
||||
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
|
||||
use std::collections::HashMap;
|
||||
use vector::find_items_by_vector_similarity;
|
||||
|
||||
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
|
||||
/// to find the most relevant entities for a given query.
|
||||
///
|
||||
/// # Strategy
|
||||
/// The function employs a three-pronged approach to knowledge retrieval:
|
||||
/// 1. Direct vector similarity search on knowledge entities
|
||||
/// 2. Text chunk similarity search with source entity lookup
|
||||
/// 3. Graph relationship traversal from related entities
|
||||
///
|
||||
/// This combined approach ensures both semantic similarity matches and structurally
|
||||
/// related content are included in the results.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - SurrealDB client for database operations
|
||||
/// * `openai_client` - OpenAI client for vector embeddings generation
|
||||
/// * `query` - The search query string to find relevant knowledge entities
|
||||
/// * 'user_id' - The user id of the current user
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
|
||||
/// knowledge entities, or an error if the retrieval process fails
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
||||
find_items_by_vector_similarity(
|
||||
10,
|
||||
query,
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
openai_client,
|
||||
user_id,
|
||||
),
|
||||
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let source_ids = closest_chunks
|
||||
.iter()
|
||||
.map(|chunk: &TextChunk| chunk.source_id.clone())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
|
||||
|
||||
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
|
||||
.collect();
|
||||
|
||||
let items_from_relationships = try_join_all(items_from_relationships_futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<KnowledgeEntity>>();
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
|
||||
.into_iter()
|
||||
.chain(items_from_text_chunk_similarity.into_iter())
|
||||
.chain(items_from_relationships.into_iter())
|
||||
.fold(HashMap::new(), |mut map, entity| {
|
||||
map.insert(entity.id.clone(), entity);
|
||||
map
|
||||
})
|
||||
.into_values()
|
||||
.collect();
|
||||
|
||||
Ok(entities)
|
||||
}
|
||||
47
composite-retrieval/src/vector.rs
Normal file
47
composite-retrieval/src/vector.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{error::AppError, utils::embedding::generate_embedding};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
|
||||
/// and then deserializes the results into the specified type `T`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `take` - The number of items to retrieve from the database.
|
||||
/// * `input_text` - The text to generate embeddings for.
|
||||
/// * `db_client` - The SurrealDB client to use for querying the database.
|
||||
/// * `table` - The table to query in the database.
|
||||
/// * `openai_client` - The OpenAI client to use for generating embeddings.
|
||||
/// * 'user_id`- The user id of the current user.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: u8,
|
||||
input_text: &str,
|
||||
db_client: &Surreal<Any>,
|
||||
table: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<T>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
Ok(closest_entities)
|
||||
}
|
||||
Reference in New Issue
Block a user