diff --git a/ingestion-pipeline/src/types/llm_enrichment_result.rs b/ingestion-pipeline/src/types/llm_enrichment_result.rs index 8b845d7..5898b8e 100644 --- a/ingestion-pipeline/src/types/llm_enrichment_result.rs +++ b/ingestion-pipeline/src/types/llm_enrichment_result.rs @@ -6,9 +6,12 @@ use tokio::task; use common::{ error::AppError, - storage::types::{ - knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, - knowledge_relationship::KnowledgeRelationship, + storage::{ + db::SurrealDbClient, + types::{ + knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, + knowledge_relationship::KnowledgeRelationship, + }, }, utils::embedding::generate_embedding, }; @@ -56,13 +59,20 @@ impl LLMEnrichmentResult { source_id: &str, user_id: &str, openai_client: &async_openai::Client, + db_client: &SurrealDbClient, ) -> Result<(Vec, Vec), AppError> { // Create mapper and pre-assign IDs let mapper = Arc::new(Mutex::new(self.create_mapper()?)); // Process entities let entities = self - .process_entities(source_id, user_id, Arc::clone(&mapper), openai_client) + .process_entities( + source_id, + user_id, + Arc::clone(&mapper), + openai_client, + db_client, + ) .await?; // Process relationships @@ -88,6 +98,7 @@ impl LLMEnrichmentResult { user_id: &str, mapper: Arc>, openai_client: &async_openai::Client, + db_client: &SurrealDbClient, ) -> Result, AppError> { let futures: Vec<_> = self .knowledge_entities @@ -98,10 +109,18 @@ impl LLMEnrichmentResult { let source_id = source_id.to_string(); let user_id = user_id.to_string(); let entity = entity.clone(); + let db_client = db_client.clone(); task::spawn(async move { - create_single_entity(&entity, &source_id, &user_id, mapper, &openai_client) - .await + create_single_entity( + &entity, + &source_id, + &user_id, + mapper, + &openai_client, + &db_client.clone(), + ) + .await }) }) .collect(); @@ -120,14 +139,14 @@ impl LLMEnrichmentResult { user_id: &str, mapper: Arc>, ) -> Result, AppError> { - let mut mapper_guard = mapper + let mapper_guard = mapper .lock() .map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?; self.relationships .iter() .map(|rel| { - let source_db_id = mapper_guard.get_or_parse_id(&rel.source); - let target_db_id = mapper_guard.get_or_parse_id(&rel.target); + let source_db_id = mapper_guard.get_or_parse_id(&rel.source)?; + let target_db_id = mapper_guard.get_or_parse_id(&rel.target)?; Ok(KnowledgeRelationship::new( source_db_id.to_string(), @@ -146,17 +165,13 @@ async fn create_single_entity( user_id: &str, mapper: Arc>, openai_client: &async_openai::Client, + db_client: &SurrealDbClient, ) -> Result { let assigned_id = { let mapper = mapper .lock() .map_err(|_| AppError::GraphMapper("Failed to lock mapper".into()))?; - mapper - .get_id(&llm_entity.key) - .ok_or_else(|| { - AppError::GraphMapper(format!("ID not found for key: {}", llm_entity.key)) - })? - .to_string() + mapper.get_id(&llm_entity.key)?.to_string() }; let embedding_input = format!( @@ -164,7 +179,7 @@ async fn create_single_entity( llm_entity.name, llm_entity.description, llm_entity.entity_type ); - let embedding = generate_embedding(openai_client, &embedding_input).await?; + let embedding = generate_embedding(openai_client, &embedding_input, db_client).await?; let now = Utc::now(); Ok(KnowledgeEntity { diff --git a/ingestion-pipeline/src/utils/mod.rs b/ingestion-pipeline/src/utils/mod.rs index 210dfba..7e7964c 100644 --- a/ingestion-pipeline/src/utils/mod.rs +++ b/ingestion-pipeline/src/utils/mod.rs @@ -1,5 +1,6 @@ pub mod llm_instructions; +use common::error::AppError; use std::collections::HashMap; use uuid::Uuid; @@ -21,24 +22,39 @@ impl GraphMapper { key_to_id: HashMap::new(), } } - /// Get ID, tries to parse UUID - pub fn get_or_parse_id(&mut self, key: &str) -> Uuid { + /// Tries to get an ID by first parsing the key as a UUID, + /// and if that fails, looking it up in the internal map. + pub fn get_or_parse_id(&self, key: &str) -> Result { + // First, try to parse the key as a UUID. if let Ok(parsed_uuid) = Uuid::parse_str(key) { - parsed_uuid - } else { - *self.key_to_id.get(key).unwrap() + return Ok(parsed_uuid); } + + // If parsing fails, look it up in the map. + self.key_to_id + .get(key) + .map(|id| *id) // Dereference the &Uuid to get Uuid + // If `get` returned None, create and return an error. + .ok_or_else(|| { + AppError::GraphMapper(format!( + "Key '{}' is not a valid UUID and was not found in the map.", + key + )) + }) } - /// Assigns a new UUID for a given key. + /// Assigns a new UUID for a given key. (No changes needed here) pub fn assign_id(&mut self, key: &str) -> Uuid { let id = Uuid::new_v4(); self.key_to_id.insert(key.to_string(), id); id } - /// Retrieves the UUID for a given key. - pub fn get_id(&self, key: &str) -> Option<&Uuid> { - self.key_to_id.get(key) + /// Retrieves the UUID for a given key, returning a Result for consistency. + pub fn get_id(&self, key: &str) -> Result { + self.key_to_id + .get(key) + .map(|id| *id) + .ok_or_else(|| AppError::GraphMapper(format!("Key '{}' not found in map.", key))) } }