mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-24 09:48:32 +02:00
refactor: add openai client and improve reference handling
This commit is contained in:
@@ -42,6 +42,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config, false).await?),
|
rabbitmq_consumer: Arc::new(RabbitMQConsumer::new(&config, false).await?),
|
||||||
surreal_db_client: Arc::new(SurrealDbClient::new().await?),
|
surreal_db_client: Arc::new(SurrealDbClient::new().await?),
|
||||||
tera: Arc::new(Tera::new("src/server/templates/**/*.html").unwrap()),
|
tera: Arc::new(Tera::new("src/server/templates/**/*.html").unwrap()),
|
||||||
|
openai_client: Arc::new(async_openai::Client::new()),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create Axum router
|
// Create Axum router
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ impl<'a> IngressAnalyzer<'a> {
|
|||||||
text, category, instructions
|
text, category, instructions
|
||||||
);
|
);
|
||||||
|
|
||||||
combined_knowledge_entity_retrieval(self.db_client, self.openai_client, input_text).await
|
combined_knowledge_entity_retrieval(self.db_client, self.openai_client, &input_text).await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_llm_request(
|
fn prepare_llm_request(
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ async fn create_single_entity(
|
|||||||
llm_entity.name, llm_entity.description, llm_entity.entity_type
|
llm_entity.name, llm_entity.description, llm_entity.entity_type
|
||||||
);
|
);
|
||||||
|
|
||||||
let embedding = generate_embedding(openai_client, embedding_input).await?;
|
let embedding = generate_embedding(openai_client, &embedding_input).await?;
|
||||||
|
|
||||||
Ok(KnowledgeEntity {
|
Ok(KnowledgeEntity {
|
||||||
id: assigned_id,
|
id: assigned_id,
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ impl ContentProcessor {
|
|||||||
|
|
||||||
// Could potentially process chunks in parallel with a bounded concurrent limit
|
// Could potentially process chunks in parallel with a bounded concurrent limit
|
||||||
for chunk in chunks {
|
for chunk in chunks {
|
||||||
let embedding = generate_embedding(&self.openai_client, chunk.to_string()).await?;
|
let embedding = generate_embedding(&self.openai_client, chunk).await?;
|
||||||
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
|
let text_chunk = TextChunk::new(content.id.to_string(), chunk.to_string(), embedding);
|
||||||
store_item(&self.db_client, text_chunk).await?;
|
store_item(&self.db_client, text_chunk).await?;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ use crate::{
|
|||||||
use futures::future::{try_join, try_join_all};
|
use futures::future::{try_join, try_join_all};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use surrealdb::{engine::remote::ws::Client, Surreal};
|
use surrealdb::{engine::remote::ws::Client, Surreal};
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
|
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
|
||||||
/// to find the most relevant entities for a given query.
|
/// to find the most relevant entities for a given query.
|
||||||
@@ -37,14 +36,14 @@ use tracing::info;
|
|||||||
pub async fn combined_knowledge_entity_retrieval(
|
pub async fn combined_knowledge_entity_retrieval(
|
||||||
db_client: &Surreal<Client>,
|
db_client: &Surreal<Client>,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
query: String,
|
query: &str,
|
||||||
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
) -> Result<Vec<KnowledgeEntity>, ProcessingError> {
|
||||||
info!("Received input: {:?}", query);
|
// info!("Received input: {:?}", query);
|
||||||
|
|
||||||
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
||||||
find_items_by_vector_similarity(
|
find_items_by_vector_similarity(
|
||||||
10,
|
10,
|
||||||
query.clone(),
|
query,
|
||||||
db_client,
|
db_client,
|
||||||
"knowledge_entity".to_string(),
|
"knowledge_entity".to_string(),
|
||||||
openai_client,
|
openai_client,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ use crate::{error::ProcessingError, utils::embedding::generate_embedding};
|
|||||||
/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
/// * `T`: The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||||
pub async fn find_items_by_vector_similarity<T>(
|
pub async fn find_items_by_vector_similarity<T>(
|
||||||
take: u8,
|
take: u8,
|
||||||
input_text: String,
|
input_text: &str,
|
||||||
db_client: &Surreal<Client>,
|
db_client: &Surreal<Client>,
|
||||||
table: String,
|
table: String,
|
||||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
|
|||||||
@@ -12,4 +12,5 @@ pub struct AppState {
|
|||||||
pub rabbitmq_consumer: Arc<RabbitMQConsumer>,
|
pub rabbitmq_consumer: Arc<RabbitMQConsumer>,
|
||||||
pub surreal_db_client: Arc<SurrealDbClient>,
|
pub surreal_db_client: Arc<SurrealDbClient>,
|
||||||
pub tera: Arc<Tera>,
|
pub tera: Arc<Tera>,
|
||||||
|
pub openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
pub mod helper;
|
pub mod helper;
|
||||||
pub mod prompt;
|
pub mod prompt;
|
||||||
|
|
||||||
use crate::{error::ApiError, retrieval::combined_knowledge_entity_retrieval, server::AppState};
|
use crate::{error::ApiError, server::AppState};
|
||||||
use axum::{extract::State, response::IntoResponse, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use helper::{
|
use helper::get_answer_with_references;
|
||||||
create_chat_request, create_user_message, format_entities_json, process_llm_response,
|
|
||||||
};
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use tracing::{debug, info};
|
use tracing::info;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct QueryInput {
|
pub struct QueryInput {
|
||||||
@@ -32,42 +30,13 @@ pub async fn query_handler(
|
|||||||
Json(query): Json<QueryInput>,
|
Json(query): Json<QueryInput>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
info!("Received input: {:?}", query);
|
info!("Received input: {:?}", query);
|
||||||
let openai_client = async_openai::Client::new();
|
|
||||||
|
|
||||||
// Retrieve entities
|
let answer =
|
||||||
let entities = combined_knowledge_entity_retrieval(
|
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
|
||||||
&state.surreal_db_client,
|
.await?;
|
||||||
&openai_client,
|
|
||||||
query.query.clone(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Format entities and create message
|
|
||||||
let entities_json = format_entities_json(&entities);
|
|
||||||
let user_message = create_user_message(&entities_json, &query.query);
|
|
||||||
debug!("{:?}", user_message);
|
|
||||||
|
|
||||||
// Create and send request
|
|
||||||
let request = create_chat_request(user_message)?;
|
|
||||||
let response = openai_client
|
|
||||||
.chat()
|
|
||||||
.create(request)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ApiError::QueryError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Process response
|
|
||||||
let answer = process_llm_response(response).await?;
|
|
||||||
debug!("{:?}", answer);
|
|
||||||
|
|
||||||
let references: Vec<String> = answer
|
|
||||||
.references
|
|
||||||
.into_iter()
|
|
||||||
.map(|reference| reference.reference)
|
|
||||||
.collect();
|
|
||||||
info!("{:?}", references);
|
|
||||||
|
|
||||||
Ok(
|
Ok(
|
||||||
Json(serde_json::json!({"answer": answer.answer, "references": references}))
|
Json(serde_json::json!({"answer": answer.content, "references": answer.references}))
|
||||||
.into_response(),
|
.into_response(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,107 @@ use async_openai::types::{
|
|||||||
};
|
};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
use crate::{error::ApiError, storage::types::knowledge_entity::KnowledgeEntity};
|
use crate::{
|
||||||
|
error::ApiError,
|
||||||
|
retrieval::combined_knowledge_entity_retrieval,
|
||||||
|
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
|
||||||
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT},
|
prompt::{get_query_response_schema, QUERY_SYSTEM_PROMPT},
|
||||||
LLMResponseFormat,
|
LLMResponseFormat,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// /// Orchestrator function that takes a query and clients and returns a answer with references
|
||||||
|
// ///
|
||||||
|
// /// # Arguments
|
||||||
|
// /// * `surreal_db_client` - Client for interacting with SurrealDn
|
||||||
|
// /// * `openai_client` - Client for interacting with openai
|
||||||
|
// /// * `query` - The query
|
||||||
|
// ///
|
||||||
|
// /// # Returns
|
||||||
|
// /// * `Result<(String, Vec<String>, ApiError)` - Will return the answer, and the list of references or Error
|
||||||
|
// pub async fn get_answer_with_references(
|
||||||
|
// surreal_db_client: &SurrealDbClient,
|
||||||
|
// openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
|
// query: &str,
|
||||||
|
// ) -> Result<(String, Vec<String>), ApiError> {
|
||||||
|
// let entities =
|
||||||
|
// combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query.into()).await?;
|
||||||
|
|
||||||
|
// // Format entities and create message
|
||||||
|
// let entities_json = format_entities_json(&entities);
|
||||||
|
// let user_message = create_user_message(&entities_json, query);
|
||||||
|
|
||||||
|
// // Create and send request
|
||||||
|
// let request = create_chat_request(user_message)?;
|
||||||
|
// let response = openai_client
|
||||||
|
// .chat()
|
||||||
|
// .create(request)
|
||||||
|
// .await
|
||||||
|
// .map_err(|e| ApiError::QueryError(e.to_string()))?;
|
||||||
|
|
||||||
|
// // Process response
|
||||||
|
// let answer = process_llm_response(response).await?;
|
||||||
|
|
||||||
|
// let references: Vec<String> = answer
|
||||||
|
// .references
|
||||||
|
// .into_iter()
|
||||||
|
// .map(|reference| reference.reference)
|
||||||
|
// .collect();
|
||||||
|
|
||||||
|
// Ok((answer.answer, references))
|
||||||
|
// }
|
||||||
|
|
||||||
|
/// Orchestrates query processing and returns an answer with references
|
||||||
|
///
|
||||||
|
/// Takes a query and uses the provided clients to generate an answer with supporting references.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `surreal_db_client` - Client for SurrealDB interactions
|
||||||
|
/// * `openai_client` - Client for OpenAI API calls
|
||||||
|
/// * `query` - The user's query string
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// Returns a tuple of the answer and its references, or an API error
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Answer {
|
||||||
|
pub content: String,
|
||||||
|
pub references: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_answer_with_references(
|
||||||
|
surreal_db_client: &SurrealDbClient,
|
||||||
|
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
|
query: &str,
|
||||||
|
) -> Result<Answer, ApiError> {
|
||||||
|
let entities =
|
||||||
|
combined_knowledge_entity_retrieval(surreal_db_client, openai_client, query).await?;
|
||||||
|
|
||||||
|
let entities_json = format_entities_json(&entities);
|
||||||
|
let user_message = create_user_message(&entities_json, query);
|
||||||
|
|
||||||
|
let request = create_chat_request(user_message)?;
|
||||||
|
let response = openai_client
|
||||||
|
.chat()
|
||||||
|
.create(request)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError::QueryError(e.to_string()))?;
|
||||||
|
|
||||||
|
let llm_response = process_llm_response(response).await?;
|
||||||
|
|
||||||
|
Ok(Answer {
|
||||||
|
content: llm_response.answer,
|
||||||
|
references: llm_response
|
||||||
|
.references
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| r.reference)
|
||||||
|
.collect(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
|
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
|
||||||
json!(entities
|
json!(entities
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@@ -9,13 +9,7 @@ use tracing::info;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::ApiError,
|
||||||
retrieval::combined_knowledge_entity_retrieval,
|
server::{routes::query::helper::get_answer_with_references, AppState},
|
||||||
server::{
|
|
||||||
routes::query::helper::{
|
|
||||||
create_chat_request, create_user_message, format_entities_json, process_llm_response,
|
|
||||||
},
|
|
||||||
AppState,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct SearchParams {
|
pub struct SearchParams {
|
||||||
@@ -28,43 +22,18 @@ pub async fn search_result_handler(
|
|||||||
) -> Result<Html<String>, ApiError> {
|
) -> Result<Html<String>, ApiError> {
|
||||||
info!("Displaying search results");
|
info!("Displaying search results");
|
||||||
|
|
||||||
let openai_client = async_openai::Client::new();
|
let answer =
|
||||||
|
get_answer_with_references(&state.surreal_db_client, &state.openai_client, &query.query)
|
||||||
// Retrieve entities
|
.await?;
|
||||||
let entities = combined_knowledge_entity_retrieval(
|
|
||||||
&state.surreal_db_client,
|
|
||||||
&openai_client,
|
|
||||||
query.query.clone(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Format entities and create message
|
|
||||||
let entities_json = format_entities_json(&entities);
|
|
||||||
let user_message = create_user_message(&entities_json, &query.query);
|
|
||||||
|
|
||||||
// Create and send request
|
|
||||||
let request = create_chat_request(user_message)?;
|
|
||||||
let response = openai_client
|
|
||||||
.chat()
|
|
||||||
.create(request)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ApiError::QueryError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Process response
|
|
||||||
let answer = process_llm_response(response).await?;
|
|
||||||
|
|
||||||
let references: Vec<String> = answer
|
|
||||||
.references
|
|
||||||
.into_iter()
|
|
||||||
.map(|reference| reference.reference)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let output = state
|
let output = state
|
||||||
.tera
|
.tera
|
||||||
.render(
|
.render(
|
||||||
"search_result.html",
|
"search_result.html",
|
||||||
&Context::from_value(json!({"result": answer.answer, "references": references}))
|
&Context::from_value(
|
||||||
.unwrap(),
|
json!({"result": answer.content, "references": answer.references}),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -27,11 +27,11 @@ use crate::error::ProcessingError;
|
|||||||
/// * If no embedding data is received in the response
|
/// * If no embedding data is received in the response
|
||||||
pub async fn generate_embedding(
|
pub async fn generate_embedding(
|
||||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||||
input: String,
|
input: &str,
|
||||||
) -> Result<Vec<f32>, ProcessingError> {
|
) -> Result<Vec<f32>, ProcessingError> {
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
let request = CreateEmbeddingRequestArgs::default()
|
||||||
.model("text-embedding-3-small")
|
.model("text-embedding-3-small")
|
||||||
.input(&[input])
|
.input([input])
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
// Send the request to OpenAI
|
// Send the request to OpenAI
|
||||||
|
|||||||
Reference in New Issue
Block a user