removed ingressinput struct and consistent naming

This commit is contained in:
Per Stark
2025-03-06 10:59:34 +01:00
parent 4ab5d3b551
commit ef1478547e
9 changed files with 100 additions and 124 deletions

View File

@@ -2,7 +2,7 @@ use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use common::{
error::{ApiError, AppError},
ingress::ingress_input::{create_ingress_objects, IngressInput},
ingress::ingress_object::IngressObject,
storage::types::{file_info::FileInfo, job::Job, user::User},
};
use futures::{future::try_join_all, TryFutureExt};
@@ -38,13 +38,11 @@ pub async fn ingress_data(
debug!("Got file infos");
let ingress_objects = create_ingress_objects(
IngressInput {
content: input.content,
instructions: input.instructions,
category: input.category,
files: file_infos,
},
let ingress_objects = IngressObject::create_ingress_objects(
input.content,
input.instructions,
input.category,
file_infos,
user.id.as_str(),
)?;
debug!("Got ingress objects");

View File

@@ -2,7 +2,7 @@ use crate::{
error::AppError,
ingress::analysis::prompt::{get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE},
retrieval::combined_knowledge_entity_retrieval,
storage::types::knowledge_entity::KnowledgeEntity,
storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity},
};
use async_openai::{
error::OpenAIError,
@@ -13,20 +13,18 @@ use async_openai::{
},
};
use serde_json::json;
use surrealdb::engine::any::Any;
use surrealdb::Surreal;
use tracing::debug;
use super::types::llm_analysis_result::LLMGraphAnalysisResult;
pub struct IngressAnalyzer<'a> {
db_client: &'a Surreal<Any>,
db_client: &'a SurrealDbClient,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl<'a> IngressAnalyzer<'a> {
pub fn new(
db_client: &'a Surreal<Any>,
db_client: &'a SurrealDbClient,
openai_client: &'a async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Self {
Self {

View File

@@ -1,74 +0,0 @@
use super::ingress_object::IngressObject;
use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use tracing::info;
use url::Url;
/// Struct defining the expected body when ingressing content.
#[derive(Serialize, Deserialize, Debug)]
pub struct IngressInput {
pub content: Option<String>,
pub instructions: String,
pub category: String,
pub files: Vec<FileInfo>,
}
/// Function to create ingress objects from input.
///
/// # Arguments
/// * `input` - IngressInput containing information needed to ingress content.
/// * `user_id` - User id of the ingressing user
///
/// # Returns
/// * `Vec<IngressObject>` - An array containing the ingressed objects, one file/contenttype per object.
pub fn create_ingress_objects(
input: IngressInput,
user_id: &str,
) -> Result<Vec<IngressObject>, AppError> {
// Initialize list
let mut object_list = Vec::new();
// Create a IngressObject from input.content if it exists, checking for URL or text
if let Some(input_content) = input.content {
match Url::parse(&input_content) {
Ok(url) => {
info!("Detected URL: {}", url);
object_list.push(IngressObject::Url {
url: url.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
}
Err(_) => {
if input_content.len() > 2 {
info!("Treating input as plain text");
object_list.push(IngressObject::Text {
text: input_content.to_string(),
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
});
}
}
}
}
for file in input.files {
object_list.push(IngressObject::File {
file_info: file,
instructions: input.instructions.clone(),
category: input.category.clone(),
user_id: user_id.into(),
})
}
// If no objects are constructed, we return Err
if object_list.is_empty() {
return Err(AppError::NotFound(
"No valid content or files provided".into(),
));
}
Ok(object_list)
}

View File

@@ -13,6 +13,8 @@ use scraper::{Html, Selector};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use tiktoken_rs::{o200k_base, CoreBPE};
use tracing::info;
use url::Url;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum IngressObject {
@@ -37,6 +39,72 @@ pub enum IngressObject {
}
impl IngressObject {
/// Creates ingress objects from the provided content, instructions, and files.
///
/// # Arguments
/// * `content` - Optional textual content to be ingressed
/// * `instructions` - Instructions for processing the ingress content
/// * `category` - Category to classify the ingressed content
/// * `files` - Vector of `FileInfo` objects containing information about uploaded files
/// * `user_id` - Identifier of the user performing the ingress operation
///
/// # Returns
/// * `Result<Vec<IngressObject>, AppError>` - On success, returns a vector of ingress objects
/// (one per file/content type). On failure, returns an `AppError`.
pub fn create_ingress_objects(
content: Option<String>,
instructions: String,
category: String,
files: Vec<FileInfo>,
user_id: &str,
) -> Result<Vec<IngressObject>, AppError> {
// Initialize list
let mut object_list = Vec::new();
// Create a IngressObject from content if it exists, checking for URL or text
if let Some(input_content) = content {
match Url::parse(&input_content) {
Ok(url) => {
info!("Detected URL: {}", url);
object_list.push(IngressObject::Url {
url: url.to_string(),
instructions: instructions.clone(),
category: category.clone(),
user_id: user_id.into(),
});
}
Err(_) => {
if input_content.len() > 2 {
info!("Treating input as plain text");
object_list.push(IngressObject::Text {
text: input_content.to_string(),
instructions: instructions.clone(),
category: category.clone(),
user_id: user_id.into(),
});
}
}
}
}
for file in files {
object_list.push(IngressObject::File {
file_info: file,
instructions: instructions.clone(),
category: category.clone(),
user_id: user_id.into(),
})
}
// If no objects are constructed, we return Err
if object_list.is_empty() {
return Err(AppError::NotFound(
"No valid content or files provided".into(),
));
}
Ok(object_list)
}
/// Creates a new `TextContent` instance from a `IngressObject`.
///
/// # Arguments

View File

@@ -1,4 +1,3 @@
pub mod analysis;
pub mod content_processor;
pub mod ingress_input;
pub mod ingress_object;

View File

@@ -1,7 +1,7 @@
use surrealdb::{engine::any::Any, Error, Surreal};
use surrealdb::Error;
use tracing::debug;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, StoredObject};
use crate::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
/// Retrieves database entries that match a specific source identifier.
///
@@ -33,15 +33,14 @@ use crate::storage::types::{knowledge_entity::KnowledgeEntity, StoredObject};
pub async fn find_entities_by_source_ids<T>(
source_id: Vec<String>,
table_name: String,
db_client: &Surreal<Any>,
db: &SurrealDbClient,
) -> Result<Vec<T>, Error>
where
T: for<'de> serde::Deserialize<'de>,
{
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
db_client
.query(query)
db.query(query)
.bind(("table", table_name))
.bind(("source_ids", source_id))
.await?
@@ -50,7 +49,7 @@ where
/// Find entities by their relationship to the id
pub async fn find_entities_by_relationship_by_id(
db_client: &Surreal<Any>,
db: &SurrealDbClient,
entity_id: String,
) -> Result<Vec<KnowledgeEntity>, Error> {
let query = format!(
@@ -60,15 +59,5 @@ pub async fn find_entities_by_relationship_by_id(
debug!("{}", query);
db_client.query(query).await?.take(0)
}
/// Get a specific KnowledgeEntity by its id
pub async fn get_entity_by_id(
db_client: &Surreal<Any>,
entity_id: &str,
) -> Result<Option<KnowledgeEntity>, Error> {
db_client
.select((KnowledgeEntity::table_name(), entity_id))
.await
db.query(query).await?.take(0)
}

View File

@@ -9,11 +9,13 @@ use crate::{
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
vector::find_items_by_vector_similarity,
},
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
},
};
use futures::future::{try_join, try_join_all};
use std::collections::HashMap;
use surrealdb::{engine::any::Any, Surreal};
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
/// to find the most relevant entities for a given query.
@@ -37,7 +39,7 @@ use surrealdb::{engine::any::Any, Surreal};
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
/// knowledge entities, or an error if the retrieval process fails
pub async fn combined_knowledge_entity_retrieval(
db_client: &Surreal<Any>,
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,

View File

@@ -12,7 +12,7 @@ use tracing::info;
use common::{
error::{AppError, HtmlError, IntoHtmlError},
ingress::ingress_input::{create_ingress_objects, IngressInput},
ingress::ingress_object::IngressObject,
storage::types::{file_info::FileInfo, job::Job, user::User},
};
@@ -112,13 +112,11 @@ pub async fn process_ingress_form(
}))
.await?;
let ingress_objects = create_ingress_objects(
IngressInput {
content: input.content,
instructions: input.instructions,
category: input.category,
files: file_infos,
},
let ingress_objects = IngressObject::create_ingress_objects(
input.content,
input.instructions,
input.category,
file_infos,
user.id.as_str(),
)
.map_err(|e| HtmlError::new(e, state.templates.clone()))?;

View File

@@ -24,7 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = get_config()?;
let surreal_db_client = Arc::new(
let db = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
@@ -37,12 +37,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let openai_client = Arc::new(async_openai::Client::new());
let content_processor =
ContentProcessor::new(surreal_db_client.clone(), openai_client.clone()).await?;
let content_processor = ContentProcessor::new(db.clone(), openai_client.clone()).await?;
loop {
// First, check for any unfinished jobs
let unfinished_jobs = Job::get_unfinished_jobs(&surreal_db_client).await?;
let unfinished_jobs = Job::get_unfinished_jobs(&db).await?;
if !unfinished_jobs.is_empty() {
info!("Found {} unfinished jobs", unfinished_jobs.len());
@@ -54,7 +53,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// If no unfinished jobs, start listening for new ones
info!("Listening for new jobs...");
let mut job_stream = Job::listen_for_jobs(&surreal_db_client).await?;
let mut job_stream = Job::listen_for_jobs(&db).await?;
while let Some(notification) = job_stream.next().await {
match notification {
@@ -80,9 +79,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
JobStatus::InProgress { attempts, .. } => {
// Only process if this is a retry after an error, not our own update
if let Ok(Some(current_job)) = surreal_db_client
.get_item::<Job>(&notification.data.id)
.await
if let Ok(Some(current_job)) =
db.get_item::<Job>(&notification.data.id).await
{
match current_job.status {
JobStatus::Error(_)