feat: text splitting and storage

This commit is contained in:
Per Stark
2024-11-20 12:10:23 +01:00
parent 8ba853a329
commit c3ccb8c034
6 changed files with 257 additions and 89 deletions

View File

@@ -18,7 +18,7 @@ pub struct KnowledgeEntity {
pub embedding: Option<Vec<f32>>,
}
fn thing_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
pub fn thing_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{

View File

@@ -38,7 +38,7 @@ impl IngressObject {
let text = Self::fetch_text_from_url(url).await?;
let id = Uuid::new_v4();
Ok(TextContent {
id,
id: id.to_string(),
text,
instructions: instructions.clone(),
category: category.clone(),
@@ -48,7 +48,7 @@ impl IngressObject {
IngressObject::Text { text, instructions, category } => {
let id = Uuid::new_v4();
Ok(TextContent {
id,
id: id.to_string(),
text: text.clone(),
instructions: instructions.clone(),
category: category.clone(),
@@ -59,7 +59,7 @@ impl IngressObject {
let id = Uuid::new_v4();
let text = Self::extract_text_from_file(file_info).await?;
Ok(TextContent {
id,
id: id.to_string(),
text,
instructions: instructions.clone(),
category: category.clone(),

View File

@@ -1,17 +1,32 @@
use crate::{
models::file_info::FileInfo,
surrealdb::{SurrealDbClient, SurrealError},
utils::llm::{create_json_ld, generate_embedding},
};
use async_openai::error::OpenAIError;
use serde::{Deserialize, Serialize};
use surrealdb::{engine::remote::ws::Client, Surreal};
use text_splitter::TextSplitter;
use thiserror::Error;
use tracing::{debug, info};
use uuid::Uuid;
use crate::{models::file_info::FileInfo, surrealdb::{SurrealDbClient, SurrealError}, utils::llm::create_json_ld};
use thiserror::Error;
use super::graph_entities::{KnowledgeEntity, KnowledgeRelationship};
use super::graph_entities::{thing_to_string, KnowledgeEntity, KnowledgeRelationship};
#[derive(Serialize, Deserialize, Debug)]
struct TextChunk {
#[serde(deserialize_with = "thing_to_string")]
id: String,
source_id: String,
chunk: String,
embedding: Vec<f32>,
}
/// Represents a single piece of text content extracted from various sources.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TextContent {
pub id: Uuid,
#[serde(deserialize_with = "thing_to_string")]
pub id: String,
pub text: String,
pub file_info: Option<FileInfo>,
pub instructions: String,
@@ -26,13 +41,13 @@ pub enum ProcessingError {
#[error("SurrealDB error: {0}")]
SurrealError(#[from] SurrealError),
#[error("SurrealDb error: {0}")]
SurrealDbError(#[from] surrealdb::Error),
#[error("Graph DB storage error: {0}")]
GraphDBError(String),
#[error("Vector DB storage error: {0}")]
VectorDBError(String),
@@ -43,39 +58,106 @@ pub enum ProcessingError {
OpenAIerror(#[from] OpenAIError),
}
async fn vector_comparison<T>(
take: u8,
input_text: String,
db_client: &Surreal<Client>,
table: String,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<Vec<T>, ProcessingError>
where
T: for<'de> serde::Deserialize<'de>, // Add this trait bound for deserialization
{
let input_embedding = generate_embedding(&openai_client, input_text).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} ORDER BY distance",table, take, input_embedding);
// Perform query and deserialize to struct
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
Ok(closest_entities)
}
async fn get_related_nodes(
id: String,
db_client: &Surreal<Client>,
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
let query = format!("SELECT -> knowledge_relationship -> knowledge_entity as related_nodes FROM knowledge_entity WHERE source_id = `{}`", id);
// let query = format!("SELECT * FROM knowledge_entity WHERE in OR out {}", id);
let related_nodes: Vec<KnowledgeEntity> = db_client.query(query).await?.take(0)?;
Ok(related_nodes)
}
impl TextContent {
/// Processes the `TextContent` by sending it to an LLM, storing in a graph DB, and vector DB.
pub async fn process(&self) -> Result<(), ProcessingError> {
// Store TextContent
let db_client = SurrealDbClient::new().await?;
let openai_client = async_openai::Client::new();
// let deleted: Vec<KnowledgeEntity> = db_client.delete("knowledge_entity").await?;
self.store_text_content(&db_client).await?;
let closest_text_content: Vec<TextChunk> = vector_comparison(
4,
self.text.clone(),
&db_client,
"text_chunk".to_string(),
&openai_client,
)
.await?;
for node in closest_text_content {
info!("{}-{}", node.id, node.source_id);
let related_nodes = get_related_nodes(node.source_id, &db_client).await?;
info!("{:?}", related_nodes);
}
panic!("STOPPING");
// let deleted: Vec<TextChunk> = db_client.delete("text_chunk").await?;
// info! {"{:?} KnowledgeEntities deleted", deleted.len()};
// let relationships_deleted: Vec<KnowledgeRelationship> =
// db_client.delete("knowledge_relationship").await?;
// info!("{:?} Relationships deleted", relationships_deleted.len());
// panic!("STOP");
// db_client.query("REMOVE INDEX embeddings ON knowledge_entity").await?;
// db_client.query("DEFINE INDEX embeddings ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536").await?;
db_client.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity").await?;
// db_client
// .query("DEFINE INDEX idx_embedding ON text_chunk FIELDS embedding HNSW DIMENSION 1536")
// .await?;
db_client
.query("REBUILD INDEX IF EXISTS idx_embedding ON text_chunk")
.await?;
db_client
.query("REBUILD INDEX IF EXISTS embeddings ON knowledge_entity")
.await?;
// Step 1: Send to LLM for analysis
let analysis = create_json_ld(&self.category, &self.instructions, &self.text, &db_client).await?;
let analysis = create_json_ld(
&self.category,
&self.instructions,
&self.text,
&db_client,
&openai_client,
)
.await?;
// info!("{:#?}", &analysis);
// Step 2: Convert LLM analysis to database entities
let (entities, relationships) = analysis.to_database_entities(&self.id).await?;
let (entities, relationships) = analysis
.to_database_entities(&self.id, &openai_client)
.await?;
// Step 3: Store in database
self.store_in_graph_db(entities, relationships, &db_client).await?;
self.store_in_graph_db(entities, relationships, &db_client)
.await?;
// Step 4: Split text and store in Vector DB
// self.store_in_vector_db().await?;
self.store_in_vector_db(&db_client, &openai_client).await?;
Ok(())
}
@@ -87,14 +169,17 @@ impl TextContent {
db_client: &Surreal<Client>,
) -> Result<(), ProcessingError> {
for entity in &entities {
info!("{:?}, {:?}, {:?}", &entity.id, &entity.name, &entity.description);
info!(
"{:?}, {:?}, {:?}",
&entity.id, &entity.name, &entity.description
);
let _created: Option<KnowledgeEntity> = db_client
.create(("knowledge_entity", &entity.id.to_string()))
.content(entity.clone())
.await?;
debug!("{:?}",_created);
debug!("{:?}", _created);
}
for relationship in &relationships {
@@ -105,13 +190,13 @@ impl TextContent {
.content(relationship.clone())
.await?;
debug!("{:?}",_created);
debug!("{:?}", _created);
}
// for relationship in &relationships {
// let in_entity: Option<KnowledgeEntity> = db_client.select(("knowledge_entity",relationship.in_.to_string())).await?;
// let out_entity: Option<KnowledgeEntity> = db_client.select(("knowledge_entity", relationship.out.to_string())).await?;
// if let (Some(in_), Some(out)) = (in_entity, out_entity) {
// info!("{} - {} is {} to {} - {}", in_.id, in_.name, relationship.relationship_type, out.id, out.name);
// }
@@ -120,24 +205,59 @@ impl TextContent {
// }
// }
info!("Inserted to database: {:?} entities, {:?} relationships", entities.len(), relationships.len());
info!(
"Inserted to database: {:?} entities, {:?} relationships",
entities.len(),
relationships.len()
);
Ok(())
}
/// Splits text and stores it in a vector database.
#[allow(dead_code)]
async fn store_in_vector_db(&self) -> Result<(), ProcessingError> {
// TODO: Implement text splitting and vector storage logic.
// Example:
/*
let chunks = text_splitter::split(&self.text);
let vector_db = VectorDB::new("http://vector-db:5000");
async fn store_in_vector_db(
&self,
db_client: &Surreal<Client>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(), ProcessingError> {
let max_characters = 500..2000;
let splitter = TextSplitter::new(max_characters);
let chunks = splitter.chunks(self.text.as_str());
for chunk in chunks {
vector_db.insert(chunk).await.map_err(|e| ProcessingError::VectorDBError(e.to_string()))?;
info!("Chunk: {}", chunk);
let embedding = generate_embedding(&openai_client, chunk.to_string()).await?;
let text_chunk = TextChunk {
id: Uuid::new_v4().to_string(),
source_id: self.id.clone(),
chunk: chunk.to_string(),
embedding,
};
info!("{:?}", text_chunk);
let _created: Option<TextChunk> = db_client
.create(("text_chunk", text_chunk.id.clone()))
.content(text_chunk)
.await?;
debug!("{:?}", _created);
}
*/
unimplemented!()
Ok(())
}
/// Stores text content in database
async fn store_text_content(&self, db_client: &Surreal<Client>) -> Result<(), ProcessingError> {
let _created: Option<TextContent> = db_client
.create(("text_content", self.id.clone()))
.content(self.clone())
.await?;
debug!("{:?}", _created);
Ok(())
}
}

View File

@@ -38,7 +38,7 @@ pub struct LLMGraphAnalysisResult {
pub relationships: Vec<LLMRelationship>,
}
async fn generate_embedding(
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: String,
) -> Result<Vec<f32>, ProcessingError> {
@@ -73,13 +73,15 @@ impl LLMGraphAnalysisResult {
/// # Arguments
///
/// * `source_id` - A UUID representing the source identifier.
/// * `openai_client` - OpenAI client for LLM calls.
///
/// # Returns
///
/// * `Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError>` - A tuple containing vectors of `KnowledgeEntity` and `KnowledgeRelationship`.
pub async fn to_database_entities(
&self,
source_id: &Uuid,
source_id: &String,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<(Vec<KnowledgeEntity>, Vec<KnowledgeRelationship>), ProcessingError> {
let mut mapper = GraphMapper::new();
@@ -88,7 +90,6 @@ impl LLMGraphAnalysisResult {
mapper.assign_id(&llm_entity.key);
}
let openai_client = async_openai::Client::new();
let mut entities = vec![];
@@ -154,15 +155,13 @@ pub async fn create_json_ld(
instructions: &str,
text: &str,
db_client: &Surreal<Client>,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
// Initialize llm client
let client = async_openai::Client::new();
// Format the input for more cohesive comparison
let input_text = format!("content: {:?}, category: {:?}, user_instructions: {:?}", text, category, instructions);
// Generate embedding of the input
let input_embedding = generate_embedding(&client, input_text).await?;
let input_embedding = generate_embedding(&openai_client, input_text).await?;
let number_of_entities_to_get = 10;
@@ -276,6 +275,7 @@ pub async fn create_json_ld(
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
7. Only create relationships between existing KnowledgeEntities.
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
9. A new relationship MUST include a newly created KnowledgeEntity.
"#;
let user_message = format!(
@@ -297,7 +297,7 @@ pub async fn create_json_ld(
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
// Send the request to OpenAI
let response = client
let response = openai_client
.chat()
.create(request)
.await