refactoring: working macro and generics

This commit is contained in:
Per Stark
2024-11-20 22:44:30 +01:00
parent 7222223c31
commit 41134cfa49
11 changed files with 198 additions and 167 deletions

View File

@@ -1,10 +1,12 @@
use crate::models::graph_entities::{
GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship,
use crate::{
error::ProcessingError,
models::graph_entities::{
GraphMapper, KnowledgeEntity, KnowledgeEntityType, KnowledgeRelationship,
},
};
use crate::models::text_content::ProcessingError;
use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs
CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -45,21 +47,16 @@ pub async fn generate_embedding(
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-small")
.input(&[input])
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
.build()?;
// Send the request to OpenAI
let response = client
.embeddings()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
let response = client.embeddings().create(request).await?;
// Extract the embedding vector
let embedding: Vec<f32> = response
.data
.first()
.ok_or_else(|| ProcessingError::LLMError("No embedding data received".into()))?
.ok_or_else(|| ProcessingError::EmbeddingError("No embedding data received".into()))?
.embedding
.clone();
@@ -90,7 +87,6 @@ impl LLMGraphAnalysisResult {
mapper.assign_id(&llm_entity.key);
}
let mut entities = vec![];
// Step 2: Process each knowledge entity sequentially
@@ -99,7 +95,10 @@ impl LLMGraphAnalysisResult {
let assigned_id = mapper
.get_id(&llm_entity.key)
.ok_or_else(|| {
ProcessingError::LLMError(format!("ID not found for key: {}", llm_entity.key))
ProcessingError::GraphProcessingError(format!(
"ID not found for key: {}",
llm_entity.key
))
})?
.clone();
@@ -158,37 +157,46 @@ pub async fn create_json_ld(
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<LLMGraphAnalysisResult, ProcessingError> {
// Format the input for more cohesive comparison
let input_text = format!("content: {:?}, category: {:?}, user_instructions: {:?}", text, category, instructions);
let input_text = format!(
"content: {:?}, category: {:?}, user_instructions: {:?}",
text, category, instructions
);
// Generate embedding of the input
let input_embedding = generate_embedding(&openai_client, input_text).await?;
let number_of_entities_to_get = 10;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM knowledge_entity WHERE embedding <|{},40|> {:?} ORDER BY distance",number_of_entities_to_get, input_embedding);
// Perform query and deserialize to struct
// Perform query and deserialize to struct
let closest_entities: Vec<KnowledgeEntity> = db_client.query(closest_query).await?.take(0)?;
#[allow(dead_code)]
#[derive(Debug)]
struct KnowledgeEntityToLLM {
id: String,
name: String,
description: String
description: String,
}
info!("Number of KnowledgeEntities sent as context: {}", closest_entities.len());
info!(
"Number of KnowledgeEntities sent as context: {}",
closest_entities.len()
);
// Only keep most relevant information
let closest_entities_to_llm: Vec<KnowledgeEntityToLLM> = closest_entities.clone().into_iter().map(|entity| KnowledgeEntityToLLM {
id: entity.id,
name: entity.name,
description: entity.description
}).collect();
let closest_entities_to_llm: Vec<KnowledgeEntityToLLM> = closest_entities
.clone()
.into_iter()
.map(|entity| KnowledgeEntityToLLM {
id: entity.id,
name: entity.name,
description: entity.description,
})
.collect();
debug!("{:?}", closest_entities_to_llm);
let schema = json!({
"type": "object",
@@ -293,32 +301,26 @@ pub async fn create_json_ld(
ChatCompletionRequestUserMessage::from(user_message).into(),
])
.response_format(response_format)
.build()
.map_err(|e| ProcessingError::LLMError(e.to_string()))?;
.build()?;
// Send the request to OpenAI
let response = openai_client
.chat()
.create(request)
.await
.map_err(|e| ProcessingError::LLMError(format!("OpenAI API request failed: {}", e)))?;
let response = openai_client.chat().create(request).await?;
debug!("{:?}", response);
// Extract and parse the response
for choice in response.choices {
if let Some(content) = choice.message.content {
let analysis: LLMGraphAnalysisResult = serde_json::from_str(&content).map_err(|e| {
ProcessingError::LLMError(format!(
response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.ok_or(ProcessingError::LLMParsingError(
"No content found in LLM response".into(),
))
.and_then(|content| {
serde_json::from_str(content).map_err(|e| {
ProcessingError::LLMParsingError(format!(
"Failed to parse LLM response into analysis: {}",
e
))
})?;
return Ok(analysis);
}
}
Err(ProcessingError::LLMError(
"No content found in LLM response".into(),
))
})
})
}