mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-17 23:14:08 +01:00
refactoring: continuing to break stuff out
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
pub mod error;
|
||||
pub mod models;
|
||||
pub mod rabbitmq;
|
||||
pub mod retrieval;
|
||||
pub mod routes;
|
||||
pub mod storage;
|
||||
pub mod surrealdb;
|
||||
|
||||
@@ -1,50 +1,17 @@
|
||||
use crate::retrieval::graph::find_entities_by_source_id;
|
||||
use crate::retrieval::vector::find_items_by_vector_similarity;
|
||||
use crate::storage::db::store_item;
|
||||
use crate::storage::types::knowledge_entity::KnowledgeEntity;
|
||||
use crate::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::storage::types::text_content::TextContent;
|
||||
use crate::{
|
||||
error::ProcessingError,
|
||||
surrealdb::SurrealDbClient,
|
||||
utils::llm::{create_json_ld, generate_embedding},
|
||||
};
|
||||
use crate::storage::types::StoredObject;
|
||||
use crate::utils::embedding::generate_embedding;
|
||||
use crate::{error::ProcessingError, surrealdb::SurrealDbClient, utils::llm::create_json_ld};
|
||||
use surrealdb::{engine::remote::ws::Client, Surreal};
|
||||
use text_splitter::TextSplitter;
|
||||
use tracing::{debug, info};
|
||||
|
||||
async fn vector_comparison<T>(
|
||||
take: u8,
|
||||
input_text: String,
|
||||
db_client: &Surreal<Client>,
|
||||
table: String,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<Vec<T>, ProcessingError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance",table, take, input_embedding);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
Ok(closest_entities)
|
||||
}
|
||||
|
||||
async fn get_related_nodes(
|
||||
id: String,
|
||||
db_client: &Surreal<Client>,
|
||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
||||
let query = format!("SELECT * FROM knowledge_entity WHERE source_id = '{}'", id);
|
||||
|
||||
// let query = format!("SELECT * FROM knowledge_entity WHERE in OR out {}", id);
|
||||
let related_nodes: Vec<KnowledgeEntity> = db_client.query(query).await?.take(0)?;
|
||||
|
||||
Ok(related_nodes)
|
||||
}
|
||||
|
||||
impl TextContent {
|
||||
/// Processes the `TextContent` by sending it to an LLM, storing in a graph DB, and vector DB.
|
||||
pub async fn process(&self) -> Result<(), ProcessingError> {
|
||||
@@ -56,7 +23,7 @@ impl TextContent {
|
||||
info!("{:?}", create_operation);
|
||||
|
||||
// Get related nodes
|
||||
let closest_text_content: Vec<TextChunk> = vector_comparison(
|
||||
let closest_text_content: Vec<TextChunk> = find_items_by_vector_similarity(
|
||||
3,
|
||||
self.text.clone(),
|
||||
&db_client,
|
||||
@@ -66,7 +33,12 @@ impl TextContent {
|
||||
.await?;
|
||||
|
||||
for node in closest_text_content {
|
||||
let related_nodes = get_related_nodes(node.source_id, &db_client).await?;
|
||||
let related_nodes: Vec<KnowledgeEntity> = find_entities_by_source_id(
|
||||
node.source_id.to_owned(),
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
&db_client,
|
||||
)
|
||||
.await?;
|
||||
for related_node in related_nodes {
|
||||
info!("{:?}", related_node.name);
|
||||
}
|
||||
@@ -86,12 +58,7 @@ impl TextContent {
|
||||
// db_client
|
||||
// .query("DEFINE INDEX idx_embedding ON text_chunk FIELDS embedding HNSW DIMENSION 1536")
|
||||
// .await?;
|
||||
db_client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
|
||||
.await?;
|
||||
db_client
|
||||
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
|
||||
.await?;
|
||||
db_client.rebuild_indexes().await?;
|
||||
|
||||
// Step 1: Send to LLM for analysis
|
||||
let analysis = create_json_ld(
|
||||
|
||||
65
src/retrieval/graph.rs
Normal file
65
src/retrieval/graph.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use surrealdb::{engine::remote::ws::Client, Surreal};
|
||||
|
||||
use crate::error::ProcessingError;
|
||||
|
||||
/// Retrieves database entries that match a specific source identifier.
|
||||
///
|
||||
/// This function queries the database for all records in a specified table that have
|
||||
/// a matching `source_id` field. It's commonly used to find related entities or
|
||||
/// track the origin of database entries.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - The identifier to search for in the database
|
||||
/// * `table_name` - The name of the table to search in
|
||||
/// * `db_client` - The SurrealDB client instance for database operations
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
|
||||
/// * `Err(ProcessingError)` - An error if the database query fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function will return a `ProcessingError` if:
|
||||
/// * The database query fails to execute
|
||||
/// * The results cannot be deserialized into type `T`
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// #[derive(serde::Deserialize)]
|
||||
/// struct KnowledgeEntity {
|
||||
/// id: String,
|
||||
/// source_id: String,
|
||||
/// // ... other fields
|
||||
/// }
|
||||
///
|
||||
/// let results = find_entities_by_source_id::<KnowledgeEntity>(
|
||||
/// "source123".to_string(),
|
||||
/// "knowledge_entity".to_string(),
|
||||
/// &db_client
|
||||
/// ).await?;
|
||||
/// ```
|
||||
pub async fn find_entities_by_source_id<T>(
|
||||
source_id: String,
|
||||
table_name: String,
|
||||
db_client: &Surreal<Client>,
|
||||
) -> Result<Vec<T>, ProcessingError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
table_name, source_id
|
||||
);
|
||||
|
||||
let matching_entities: Vec<T> = db_client.query(query).await?.take(0)?;
|
||||
|
||||
Ok(matching_entities)
|
||||
}
|
||||
2
src/retrieval/mod.rs
Normal file
2
src/retrieval/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod graph;
|
||||
pub mod vector;
|
||||
45
src/retrieval/vector.rs
Normal file
45
src/retrieval/vector.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use surrealdb::{engine::remote::ws::Client, Surreal};
|
||||
|
||||
use crate::{error::ProcessingError, utils::embedding::generate_embedding};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
|
||||
/// and then deserializes the results into the specified type `T`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `take`: The number of items to retrieve from the database.
|
||||
/// * `input_text`: The text to generate embeddings for.
|
||||
/// * `db_client`: The SurrealDB client to use for querying the database.
|
||||
/// * `table`: The table to query in the database.
|
||||
/// * `openai_client`: The OpenAI client to use for generating embeddings.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: u8,
|
||||
input_text: String,
|
||||
db_client: &Surreal<Client>,
|
||||
table: String,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
) -> Result<Vec<T>, ProcessingError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance", table, take, input_embedding);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
Ok(closest_entities)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ use crate::stored_object;
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(KnowledgeRelationship, "knowledge_relationship", {
|
||||
#[serde(rename = "in")]
|
||||
in_: String,
|
||||
out: String,
|
||||
relationship_type: String,
|
||||
|
||||
57
src/utils/embedding.rs
Normal file
57
src/utils/embedding.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use async_openai::types::CreateEmbeddingRequestArgs;
|
||||
|
||||
use crate::error::ProcessingError;
|
||||
|
||||
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
|
||||
///
|
||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
||||
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity
|
||||
/// comparisons, vector search, and other natural language processing tasks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `client`: The OpenAI client instance used to make API requests.
|
||||
/// * `input`: The text string to generate embeddings for.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
|
||||
/// * `Err(ProcessingError)`: An error if the embedding generation fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function can return a `ProcessingError` in the following cases:
|
||||
/// * If the OpenAI API request fails
|
||||
/// * If the request building fails
|
||||
/// * If no embedding data is received in the response
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// let client = async_openai::Client::new();
|
||||
/// let text = String::from("Hello, world!");
|
||||
/// let embedding = generate_embedding(&client, text).await?;
|
||||
/// ```
|
||||
pub async fn generate_embedding(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: String,
|
||||
) -> Result<Vec<f32>, ProcessingError> {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model("text-embedding-3-small")
|
||||
.input(&[input])
|
||||
.build()?;
|
||||
|
||||
// Send the request to OpenAI
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
// Extract the embedding vector
|
||||
let embedding: Vec<f32> = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
@@ -17,6 +17,8 @@ use surrealdb::Surreal;
|
||||
use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::embedding::generate_embedding;
|
||||
|
||||
/// Represents a single knowledge entity from the LLM.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct LLMKnowledgeEntity {
|
||||
@@ -42,29 +44,6 @@ pub struct LLMGraphAnalysisResult {
|
||||
pub relationships: Vec<LLMRelationship>,
|
||||
}
|
||||
|
||||
pub async fn generate_embedding(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: String,
|
||||
) -> Result<Vec<f32>, ProcessingError> {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model("text-embedding-3-small")
|
||||
.input(&[input])
|
||||
.build()?;
|
||||
|
||||
// Send the request to OpenAI
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
// Extract the embedding vector
|
||||
let embedding: Vec<f32> = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
impl LLMGraphAnalysisResult {
|
||||
/// Converts the LLM graph analysis result into database entities and relationships.
|
||||
/// Processes embeddings sequentially for simplicity.
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
pub mod embedding;
|
||||
pub mod llm;
|
||||
|
||||
Reference in New Issue
Block a user