diff --git a/src/lib.rs b/src/lib.rs index eceb2a4..046b5f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod error; pub mod models; pub mod rabbitmq; +pub mod retrieval; pub mod routes; pub mod storage; pub mod surrealdb; diff --git a/src/models/text_content.rs b/src/models/text_content.rs index 89df5f2..21a74f3 100644 --- a/src/models/text_content.rs +++ b/src/models/text_content.rs @@ -1,50 +1,17 @@ +use crate::retrieval::graph::find_entities_by_source_id; +use crate::retrieval::vector::find_items_by_vector_similarity; use crate::storage::db::store_item; use crate::storage::types::knowledge_entity::KnowledgeEntity; use crate::storage::types::knowledge_relationship::KnowledgeRelationship; use crate::storage::types::text_chunk::TextChunk; use crate::storage::types::text_content::TextContent; -use crate::{ - error::ProcessingError, - surrealdb::SurrealDbClient, - utils::llm::{create_json_ld, generate_embedding}, -}; +use crate::storage::types::StoredObject; +use crate::utils::embedding::generate_embedding; +use crate::{error::ProcessingError, surrealdb::SurrealDbClient, utils::llm::create_json_ld}; use surrealdb::{engine::remote::ws::Client, Surreal}; use text_splitter::TextSplitter; use tracing::{debug, info}; -async fn vector_comparison( - take: u8, - input_text: String, - db_client: &Surreal, - table: String, - openai_client: &async_openai::Client, -) -> Result, ProcessingError> -where - T: for<'de> serde::Deserialize<'de>, -{ - let input_embedding = generate_embedding(openai_client, input_text).await?; - - // Construct the query - let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance",table, take, input_embedding); - - // Perform query and deserialize to struct - let closest_entities: Vec = db_client.query(closest_query).await?.take(0)?; - - Ok(closest_entities) -} - -async fn get_related_nodes( - id: String, - db_client: &Surreal, -) -> Result, ProcessingError> { - let query = format!("SELECT * FROM knowledge_entity WHERE source_id = '{}'", id); - - // let query = format!("SELECT * FROM knowledge_entity WHERE in OR out {}", id); - let related_nodes: Vec = db_client.query(query).await?.take(0)?; - - Ok(related_nodes) -} - impl TextContent { /// Processes the `TextContent` by sending it to an LLM, storing in a graph DB, and vector DB. pub async fn process(&self) -> Result<(), ProcessingError> { @@ -56,7 +23,7 @@ impl TextContent { info!("{:?}", create_operation); // Get related nodes - let closest_text_content: Vec = vector_comparison( + let closest_text_content: Vec = find_items_by_vector_similarity( 3, self.text.clone(), &db_client, @@ -66,7 +33,12 @@ impl TextContent { .await?; for node in closest_text_content { - let related_nodes = get_related_nodes(node.source_id, &db_client).await?; + let related_nodes: Vec = find_entities_by_source_id( + node.source_id.to_owned(), + KnowledgeEntity::table_name().to_string(), + &db_client, + ) + .await?; for related_node in related_nodes { info!("{:?}", related_node.name); } @@ -86,12 +58,7 @@ impl TextContent { // db_client // .query("DEFINE INDEX idx_embedding ON text_chunk FIELDS embedding HNSW DIMENSION 1536") // .await?; - db_client - .query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk") - .await?; - db_client - .query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity") - .await?; + db_client.rebuild_indexes().await?; // Step 1: Send to LLM for analysis let analysis = create_json_ld( diff --git a/src/retrieval/graph.rs b/src/retrieval/graph.rs new file mode 100644 index 0000000..e2b3338 --- /dev/null +++ b/src/retrieval/graph.rs @@ -0,0 +1,65 @@ +use surrealdb::{engine::remote::ws::Client, Surreal}; + +use crate::error::ProcessingError; + +/// Retrieves database entries that match a specific source identifier. +/// +/// This function queries the database for all records in a specified table that have +/// a matching `source_id` field. It's commonly used to find related entities or +/// track the origin of database entries. +/// +/// # Arguments +/// +/// * `source_id` - The identifier to search for in the database +/// * `table_name` - The name of the table to search in +/// * `db_client` - The SurrealDB client instance for database operations +/// +/// # Type Parameters +/// +/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize` +/// +/// # Returns +/// +/// Returns a `Result` containing either: +/// * `Ok(Vec)` - A vector of matching records deserialized into type `T` +/// * `Err(ProcessingError)` - An error if the database query fails +/// +/// # Errors +/// +/// This function will return a `ProcessingError` if: +/// * The database query fails to execute +/// * The results cannot be deserialized into type `T` +/// +/// # Example +/// +/// ```rust +/// #[derive(serde::Deserialize)] +/// struct KnowledgeEntity { +/// id: String, +/// source_id: String, +/// // ... other fields +/// } +/// +/// let results = find_entities_by_source_id::( +/// "source123".to_string(), +/// "knowledge_entity".to_string(), +/// &db_client +/// ).await?; +/// ``` +pub async fn find_entities_by_source_id( + source_id: String, + table_name: String, + db_client: &Surreal, +) -> Result, ProcessingError> +where + T: for<'de> serde::Deserialize<'de>, +{ + let query = format!( + "SELECT * FROM {} WHERE source_id = '{}'", + table_name, source_id + ); + + let matching_entities: Vec = db_client.query(query).await?.take(0)?; + + Ok(matching_entities) +} diff --git a/src/retrieval/mod.rs b/src/retrieval/mod.rs new file mode 100644 index 0000000..a9e0919 --- /dev/null +++ b/src/retrieval/mod.rs @@ -0,0 +1,2 @@ +pub mod graph; +pub mod vector; diff --git a/src/retrieval/vector.rs b/src/retrieval/vector.rs new file mode 100644 index 0000000..0adb63a --- /dev/null +++ b/src/retrieval/vector.rs @@ -0,0 +1,45 @@ +use surrealdb::{engine::remote::ws::Client, Surreal}; + +use crate::{error::ProcessingError, utils::embedding::generate_embedding}; + +/// Compares vectors and retrieves a number of items from the specified table. +/// +/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database, +/// and then deserializes the results into the specified type `T`. +/// +/// # Arguments +/// +/// * `take`: The number of items to retrieve from the database. +/// * `input_text`: The text to generate embeddings for. +/// * `db_client`: The SurrealDB client to use for querying the database. +/// * `table`: The table to query in the database. +/// * `openai_client`: The OpenAI client to use for generating embeddings. +/// +/// # Returns +/// +/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs. +/// +/// # Type Parameters +/// +/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`. +pub async fn find_items_by_vector_similarity( + take: u8, + input_text: String, + db_client: &Surreal, + table: String, + openai_client: &async_openai::Client, +) -> Result, ProcessingError> +where + T: for<'de> serde::Deserialize<'de>, +{ + // Generate embeddings + let input_embedding = generate_embedding(openai_client, input_text).await?; + + // Construct the query + let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance", table, take, input_embedding); + + // Perform query and deserialize to struct + let closest_entities: Vec = db_client.query(closest_query).await?.take(0)?; + + Ok(closest_entities) +} diff --git a/src/storage/types/knowledge_relationship.rs b/src/storage/types/knowledge_relationship.rs index a588721..1311a88 100644 --- a/src/storage/types/knowledge_relationship.rs +++ b/src/storage/types/knowledge_relationship.rs @@ -2,6 +2,7 @@ use crate::stored_object; use uuid::Uuid; stored_object!(KnowledgeRelationship, "knowledge_relationship", { + #[serde(rename = "in")] in_: String, out: String, relationship_type: String, diff --git a/src/utils/embedding.rs b/src/utils/embedding.rs new file mode 100644 index 0000000..617751f --- /dev/null +++ b/src/utils/embedding.rs @@ -0,0 +1,57 @@ +use async_openai::types::CreateEmbeddingRequestArgs; + +use crate::error::ProcessingError; + +/// Generates an embedding vector for the given input text using OpenAI's embedding model. +/// +/// This function takes a text input and converts it into a numerical vector representation (embedding) +/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity +/// comparisons, vector search, and other natural language processing tasks. +/// +/// # Arguments +/// +/// * `client`: The OpenAI client instance used to make API requests. +/// * `input`: The text string to generate embeddings for. +/// +/// # Returns +/// +/// Returns a `Result` containing either: +/// * `Ok(Vec)`: A vector of 32-bit floating point numbers representing the text embedding +/// * `Err(ProcessingError)`: An error if the embedding generation fails +/// +/// # Errors +/// +/// This function can return a `ProcessingError` in the following cases: +/// * If the OpenAI API request fails +/// * If the request building fails +/// * If no embedding data is received in the response +/// +/// # Example +/// +/// ```rust +/// let client = async_openai::Client::new(); +/// let text = String::from("Hello, world!"); +/// let embedding = generate_embedding(&client, text).await?; +/// ``` +pub async fn generate_embedding( + client: &async_openai::Client, + input: String, +) -> Result, ProcessingError> { + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-3-small") + .input(&[input]) + .build()?; + + // Send the request to OpenAI + let response = client.embeddings().create(request).await?; + + // Extract the embedding vector + let embedding: Vec = response + .data + .first() + .ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))? + .embedding + .clone(); + + Ok(embedding) +} diff --git a/src/utils/llm.rs b/src/utils/llm.rs index dbedc1f..905918c 100644 --- a/src/utils/llm.rs +++ b/src/utils/llm.rs @@ -17,6 +17,8 @@ use surrealdb::Surreal; use tracing::{debug, info}; use uuid::Uuid; +use super::embedding::generate_embedding; + /// Represents a single knowledge entity from the LLM. #[derive(Debug, Serialize, Deserialize, Clone)] pub struct LLMKnowledgeEntity { @@ -42,29 +44,6 @@ pub struct LLMGraphAnalysisResult { pub relationships: Vec, } -pub async fn generate_embedding( - client: &async_openai::Client, - input: String, -) -> Result, ProcessingError> { - let request = CreateEmbeddingRequestArgs::default() - .model("text-embedding-3-small") - .input(&[input]) - .build()?; - - // Send the request to OpenAI - let response = client.embeddings().create(request).await?; - - // Extract the embedding vector - let embedding: Vec = response - .data - .first() - .ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))? - .embedding - .clone(); - - Ok(embedding) -} - impl LLMGraphAnalysisResult { /// Converts the LLM graph analysis result into database entities and relationships. /// Processes embeddings sequentially for simplicity. diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 214bbef..4bc235f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1 +1,2 @@ +pub mod embedding; pub mod llm;