wip: query

This commit is contained in:
Per Stark
2024-11-23 22:27:51 +01:00
parent d731e69bf9
commit 4dbd517bf6
4 changed files with 72 additions and 18 deletions

View File

@@ -42,10 +42,10 @@ impl RabbitMQProducer {
/// * `Result<Confirmation, RabbitMQError>` - Confirmation of sent message or error
pub async fn publish(
&self,
ingress_object: &IngressObject,
ingress_object: IngressObject,
) -> Result<Confirmation, RabbitMQError> {
// Serialize IngressObject to JSON
let payload = serde_json::to_vec(ingress_object).map_err(|e| {
let payload = serde_json::to_vec(&ingress_object).map_err(|e| {
error!("Serialization Error: {}", e);
RabbitMQError::PublishError(format!("Serialization Error: {}", e))
})?;

View File

@@ -46,20 +46,22 @@ use crate::error::ProcessingError;
/// &db_client
/// ).await?;
/// ```
pub async fn find_entities_by_source_id<T>(
source_id: String,
pub async fn find_entities_by_source_ids<T>(
source_id: Vec<String>,
table_name: String,
db_client: &Surreal<Client>,
) -> Result<Vec<T>, ProcessingError>
where
T: for<'de> serde::Deserialize<'de>,
{
let query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
table_name, source_id
);
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
let matching_entities: Vec<T> = db_client.query(query).await?.take(0)?;
let matching_entities: Vec<T> = db_client
.query(query)
.bind(("table", table_name))
.bind(("source_ids", source_id))
.await?
.take(0)?;
Ok(matching_entities)
}

View File

@@ -5,6 +5,7 @@ use crate::{
storage::db::SurrealDbClient,
};
use axum::{http::StatusCode, response::IntoResponse, Extension, Json};
use futures::future::try_join_all;
use std::sync::Arc;
use tracing::info;
@@ -17,9 +18,12 @@ pub async fn ingress_handler(
let ingress_objects = create_ingress_objects(input, &db_client).await?;
for object in ingress_objects {
producer.publish(&object).await?;
}
let futures: Vec<_> = ingress_objects
.into_iter()
.map(|object| producer.publish(object))
.collect();
try_join_all(futures).await?;
Ok(StatusCode::OK)
}

View File

@@ -1,10 +1,15 @@
use crate::{
error::ApiError,
retrieval::vector::find_items_by_vector_similarity,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
retrieval::{graph::find_entities_by_source_ids, vector::find_items_by_vector_similarity},
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
},
};
use axum::{response::IntoResponse, Extension, Json};
use serde::Deserialize;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::info;
@@ -20,14 +25,57 @@ pub async fn query_handler(
info!("Received input: {:?}", query);
let openai_client = async_openai::Client::new();
let closest_items: Vec<KnowledgeEntity> = find_items_by_vector_similarity(
10,
let items_from_knowledge_entity_similarity: Vec<KnowledgeEntity> =
find_items_by_vector_similarity(
10,
query.query.to_string(),
&db_client,
"knowledge_entity".to_string(),
&openai_client,
)
.await?;
let closest_chunks: Vec<TextChunk> = find_items_by_vector_similarity(
5,
query.query,
&db_client,
"knowledge_entity".to_string(),
"text_chunk".to_string(),
&openai_client,
)
.await?;
Ok(format!("{:?}", closest_items))
let source_ids = closest_chunks
.iter()
.map(|chunk| chunk.source_id.clone())
.collect::<Vec<String>>();
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), &db_client).await?;
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
.into_iter()
.chain(items_from_text_chunk_similarity.into_iter())
.fold(HashMap::new(), |mut map, entity| {
map.insert(entity.id.clone(), entity);
map
})
.into_values()
.collect();
let entities_json = json!(entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>());
info!("{} Entities\n{:#?}", entities.len(), entities_json);
Ok("we got some stuff".to_string())
}