refactoring: continuing to break stuff out

This commit is contained in:
Per Stark
2024-11-21 12:03:31 +01:00
parent 53b94c3569
commit d9707f21a5
9 changed files with 187 additions and 69 deletions

65
src/retrieval/graph.rs Normal file
View File

@@ -0,0 +1,65 @@
use surrealdb::{engine::remote::ws::Client, Surreal};
use crate::error::ProcessingError;
/// Retrieves database entries that match a specific source identifier.
///
/// This function queries the database for all records in a specified table that have
/// a matching `source_id` field. It's commonly used to find related entities or
/// track the origin of database entries.
///
/// # Arguments
///
/// * `source_id` - The identifier to search for in the database
/// * `table_name` - The name of the table to search in
/// * `db_client` - The SurrealDB client instance for database operations
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
/// * `Err(ProcessingError)` - An error if the database query fails
///
/// # Errors
///
/// This function will return a `ProcessingError` if:
/// * The database query fails to execute
/// * The results cannot be deserialized into type `T`
///
/// # Example
///
/// ```rust
/// #[derive(serde::Deserialize)]
/// struct KnowledgeEntity {
/// id: String,
/// source_id: String,
/// // ... other fields
/// }
///
/// let results = find_entities_by_source_id::<KnowledgeEntity>(
/// "source123".to_string(),
/// "knowledge_entity".to_string(),
/// &db_client
/// ).await?;
/// ```
pub async fn find_entities_by_source_id<T>(
source_id: String,
table_name: String,
db_client: &Surreal<Client>,
) -> Result<Vec<T>, ProcessingError>
where
T: for<'de> serde::Deserialize<'de>,
{
let query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
table_name, source_id
);
let matching_entities: Vec<T> = db_client.query(query).await?.take(0)?;
Ok(matching_entities)
}

2
src/retrieval/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod graph;
pub mod vector;

45
src/retrieval/vector.rs Normal file
View File

@@ -0,0 +1,45 @@
use surrealdb::{engine::remote::ws::Client, Surreal};
use crate::{error::ProcessingError, utils::embedding::generate_embedding};
/// Compares vectors and retrieves a number of items from the specified table.
///
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
/// and then deserializes the results into the specified type `T`.
///
/// # Arguments
///
/// * `take`: The number of items to retrieve from the database.
/// * `input_text`: The text to generate embeddings for.
/// * `db_client`: The SurrealDB client to use for querying the database.
/// * `table`: The table to query in the database.
/// * `openai_client`: The OpenAI client to use for generating embeddings.
///
/// # Returns
///
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
///
/// # Type Parameters
///
/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: u8,
input_text: String,
db_client: &Surreal<Client>,
table: String,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<T>, ProcessingError>
where
T: for<'de> serde::Deserialize<'de>,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance", table, take, input_embedding);
// Perform query and deserialize to struct
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
Ok(closest_entities)
}