chore: ingestion-pipeline refactor, sort technical debt, rustfmt

This commit is contained in:
Per Stark
2026-05-31 19:37:34 +02:00
parent 5c2d2e24d3
commit 3897345ab3
47 changed files with 1729 additions and 1343 deletions
+2
View File
@@ -59,6 +59,8 @@ base64 = "0.22.1"
object_store = { version = "0.11.2", features = ["aws"] } object_store = { version = "0.11.2", features = ["aws"] }
bytes = "1.7.1" bytes = "1.7.1"
state-machines = "0.2.0" state-machines = "0.2.0"
pdf-extract = "0.9"
lopdf = "0.32"
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] } fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
[profile.dist] [profile.dist]
+8 -6
View File
@@ -28,7 +28,11 @@ pub fn hnsw_index_overwrite_sql(index_name: &str, table: &str, dimension: usize)
/// Recreates an HNSW index inside a transaction (for tests and dimension migrations). /// Recreates an HNSW index inside a transaction (for tests and dimension migrations).
#[must_use] #[must_use]
pub fn hnsw_index_redefine_transaction_sql(index_name: &str, table: &str, dimension: usize) -> String { pub fn hnsw_index_redefine_transaction_sql(
index_name: &str,
table: &str,
dimension: usize,
) -> String {
format!( format!(
"BEGIN TRANSACTION; "BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS {index_name} ON TABLE {table}; REMOVE INDEX IF EXISTS {index_name} ON TABLE {table};
@@ -204,9 +208,7 @@ pub async fn ensure_runtime(
/// ///
/// Returns `AppError::InternalError` if any index rebuild operation fails. /// Returns `AppError::InternalError` if any index rebuild operation fails.
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> { pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_inner(db) rebuild_inner(db).await.map_err(AppError::internal)
.await
.map_err(AppError::internal)
} }
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> { async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
@@ -297,8 +299,8 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
return Ok("unknown".to_string()); return Ok("unknown".to_string());
}; };
let parsed: IndexInfoForIndex = serde_json::from_value(info) let parsed: IndexInfoForIndex =
.context("deserializing INFO FOR INDEX response")?; serde_json::from_value(info).context("deserializing INFO FOR INDEX response")?;
Ok(parsed.building_status()) Ok(parsed.building_status())
} }
+5 -6
View File
@@ -518,8 +518,9 @@ mod tests {
#[test] #[test]
fn test_sidebar_conversation_deserializes_plain_string_id() { fn test_sidebar_conversation_deserializes_plain_string_id() {
let item: SidebarConversation = serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#) let item: SidebarConversation =
.expect("valid sidebar conversation json"); serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#)
.expect("valid sidebar conversation json");
assert_eq!(item.id, "conv-plain"); assert_eq!(item.id, "conv-plain");
assert_eq!(item.title, "My chat"); assert_eq!(item.title, "My chat");
} }
@@ -568,8 +569,7 @@ mod tests {
)) ))
.await?; .await?;
let owner_messages = let owner_messages = fetch_messages_for_owner(&db, &conversation_id, owner).await?;
fetch_messages_for_owner(&db, &conversation_id, owner).await?;
assert_eq!(owner_messages.len(), 1); assert_eq!(owner_messages.len(), 1);
assert_eq!( assert_eq!(
owner_messages owner_messages
@@ -579,8 +579,7 @@ mod tests {
"secret message" "secret message"
); );
let intruder_messages = let intruder_messages = fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
assert!( assert!(
intruder_messages.is_empty(), intruder_messages.is_empty(),
"SQL owner filter must not return messages for a non-owner user_id" "SQL owner filter must not return messages for a non-owner user_id"
+6 -13
View File
@@ -205,13 +205,9 @@ impl FileInfo {
let now = Utc::now(); let now = Utc::now();
let storage_prefix = format!("{user_id}/{uuid}"); let storage_prefix = format!("{user_id}/{uuid}");
let path = Self::persist_bytes_with_storage( let path =
&storage_prefix, Self::persist_bytes_with_storage(&storage_prefix, &sanitized_file_name, bytes, storage)
&sanitized_file_name, .await?;
bytes,
storage,
)
.await?;
let file_info = FileInfo { let file_info = FileInfo {
id: uuid.to_string(), id: uuid.to_string(),
@@ -262,8 +258,8 @@ impl FileInfo {
}; };
// Remove the object's parent prefix in the object store // Remove the object's parent prefix in the object store
let (parent_prefix, _file_name) = store::split_object_path(&file_info.path) let (parent_prefix, _file_name) =
.map_err(AppError::internal)?; store::split_object_path(&file_info.path).map_err(AppError::internal)?;
storage storage
.delete_prefix(&parent_prefix) .delete_prefix(&parent_prefix)
.await .await
@@ -290,10 +286,7 @@ impl FileInfo {
&self, &self,
storage: &StorageManager, storage: &StorageManager,
) -> Result<bytes::Bytes, AppError> { ) -> Result<bytes::Bytes, AppError> {
storage storage.get(&self.path).await.map_err(AppError::Storage)
.get(&self.path)
.await
.map_err(AppError::Storage)
} }
/// Persist bytes to storage using StorageManager. /// Persist bytes to storage using StorageManager.
+19 -2
View File
@@ -26,6 +26,19 @@ pub enum IngestionPayload {
}, },
} }
impl Default for IngestionPayload {
/// An empty text payload, used as a cheap placeholder when the real content
/// has been moved out of a task (see [`crate::storage::types::ingestion_task::IngestionTask::take_content`]).
fn default() -> Self {
Self::Text {
text: String::new(),
context: String::new(),
category: String::new(),
user_id: String::new(),
}
}
}
/// Shared ingest metadata moved or cloned into each payload variant. /// Shared ingest metadata moved or cloned into each payload variant.
struct IngestFields { struct IngestFields {
context: String, context: String,
@@ -440,7 +453,8 @@ mod tests {
} }
#[test] #[test]
fn test_create_ingestion_payload_short_content_with_file_only_yields_file() -> anyhow::Result<()> { fn test_create_ingestion_payload_short_content_with_file_only_yields_file() -> anyhow::Result<()>
{
let context = "ctx"; let context = "ctx";
let category = "cat"; let category = "cat";
let user_id = "user123"; let user_id = "user123";
@@ -501,7 +515,10 @@ mod tests {
)?; )?;
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
assert!(matches!(result.first(), Some(IngestionPayload::File { .. }))); assert!(matches!(
result.first(),
Some(IngestionPayload::File { .. })
));
assert!(matches!(result.get(1), Some(IngestionPayload::File { .. }))); assert!(matches!(result.get(1), Some(IngestionPayload::File { .. })));
Ok(()) Ok(())
} }
+23 -2
View File
@@ -211,6 +211,16 @@ impl IngestionTask {
self.attempts < self.max_attempts self.attempts < self.max_attempts
} }
/// Moves the payload out of the task, leaving an empty placeholder behind.
///
/// The task's `content` is only needed while driving the pipeline; the
/// terminal `user_id`, `state`, and bookkeeping fields are stored separately,
/// so replacing it with the default placeholder avoids cloning large payloads.
#[must_use]
pub fn take_content(&mut self) -> IngestionPayload {
std::mem::take(&mut self.content)
}
#[must_use] #[must_use]
pub fn lease_duration(&self) -> Duration { pub fn lease_duration(&self) -> Duration {
Duration::from_secs(u64::try_from(self.lease_duration_secs.max(0)).unwrap_or(0)) Duration::from_secs(u64::try_from(self.lease_duration_secs.max(0)).unwrap_or(0))
@@ -654,6 +664,18 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_take_content_moves_payload_and_leaves_default() {
let user_id = "user123";
let payload = create_payload(user_id);
let mut task = IngestionTask::new(payload.clone(), user_id.to_string());
let taken = task.take_content();
assert_eq!(taken, payload);
assert_eq!(task.content, IngestionPayload::default());
}
#[tokio::test] #[tokio::test]
async fn test_create_all_and_add_to_db_empty() -> anyhow::Result<()> { async fn test_create_all_and_add_to_db_empty() -> anyhow::Result<()> {
let db = memory_db().await?; let db = memory_db().await?;
@@ -676,8 +698,7 @@ mod tests {
}, },
]; ];
let created = let created = IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?;
IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?;
assert_eq!(created.len(), 2); assert_eq!(created.len(), 2);
let first = created.first().expect("expected first task"); let first = created.first().expect("expected first task");
+24 -40
View File
@@ -1,16 +1,12 @@
#![allow( #![allow(clippy::missing_docs_in_private_items, clippy::module_name_repetitions)]
clippy::missing_docs_in_private_items,
clippy::module_name_repetitions,
)]
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Write; use std::fmt::Write;
use crate::{ use crate::{
error::AppError, storage::db::SurrealDbClient, error::AppError, storage::db::SurrealDbClient, storage::indexes::hnsw_index_overwrite_sql,
storage::indexes::hnsw_index_overwrite_sql,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
storage::types::system_settings::SystemSettings, stored_object, storage::types::system_settings::SystemSettings, stored_object,
utils::embedding::generate_embedding, utils::embedding::{generate_embedding_with_params, generate_embedding_with_provider, EmbeddingProvider},
}; };
use async_openai::{config::OpenAIConfig, Client}; use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{ use tokio_retry::{
@@ -315,12 +311,12 @@ impl KnowledgeEntity {
description: &str, description: &str,
entity_type: &KnowledgeEntityType, entity_type: &KnowledgeEntityType,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
ai_client: &Client<OpenAIConfig>, embedding_provider: &EmbeddingProvider,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let embedding_input = format!( let embedding_input =
"name: {name}, description: {description}, type: {entity_type:?}", format!("name: {name}, description: {description}, type: {entity_type:?}",);
); let embedding =
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?; generate_embedding_with_provider(embedding_provider, &embedding_input).await?;
let entity: KnowledgeEntity = db_client let entity: KnowledgeEntity = db_client
.get_item(id) .get_item(id)
@@ -334,12 +330,7 @@ impl KnowledgeEntity {
settings.embedding_dimensions as usize, settings.embedding_dimensions as usize,
)?; )?;
let emb = KnowledgeEntityEmbedding::new( let emb = KnowledgeEntityEmbedding::new(id, entity.source_id, embedding, entity.user_id);
id,
entity.source_id,
embedding,
entity.user_id,
);
let now = Utc::now(); let now = Utc::now();
@@ -411,7 +402,7 @@ impl KnowledgeEntity {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3); let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || { let embedding = Retry::spawn(retry_strategy, || {
crate::utils::embedding::generate_embedding_with_params( generate_embedding_with_params(
openai_client, openai_client,
&embedding_input, &embedding_input,
new_model, new_model,
@@ -431,11 +422,7 @@ impl KnowledgeEntity {
} }
new_embeddings.insert( new_embeddings.insert(
entity.id.clone(), entity.id.clone(),
( (embedding, entity.user_id.clone(), entity.source_id.clone()),
embedding,
entity.user_id.clone(),
entity.source_id.clone(),
),
); );
} }
info!("Successfully generated all new embeddings."); info!("Successfully generated all new embeddings.");
@@ -446,9 +433,8 @@ impl KnowledgeEntity {
// Add all update statements to the embedding table // Add all update statements to the embedding table
for (id, (embedding, user_id, source_id)) in new_embeddings { for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding).map_err(|e| { let embedding = serde_json::to_string(&embedding)
AppError::internal(format!("embedding serialization failed: {e}")) .map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
})?;
write!( write!(
transaction_query, transaction_query,
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \ "UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
@@ -526,9 +512,7 @@ impl KnowledgeEntity {
entity.name, entity.description, entity.entity_type entity.name, entity.description, entity.entity_type
); );
let embedding = provider let embedding = provider.embed(&embedding_input).await?;
.embed(&embedding_input)
.await?;
// Safety check: ensure the generated embedding has the correct dimension. // Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions { if embedding.len() != new_dimensions {
@@ -541,11 +525,7 @@ impl KnowledgeEntity {
} }
new_embeddings.insert( new_embeddings.insert(
entity.id.clone(), entity.id.clone(),
( (embedding, entity.user_id.clone(), entity.source_id.clone()),
embedding,
entity.user_id.clone(),
entity.source_id.clone(),
),
); );
} }
info!("Successfully generated all new embeddings."); info!("Successfully generated all new embeddings.");
@@ -580,9 +560,8 @@ impl KnowledgeEntity {
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings { for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding).map_err(|e| { let embedding = serde_json::to_string(&embedding)
AppError::internal(format!("embedding serialization failed: {e}")) .map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
})?;
write!( write!(
transaction_query, transaction_query,
"CREATE type::thing('knowledge_entity_embedding', '{id}') SET \ "CREATE type::thing('knowledge_entity_embedding', '{id}') SET \
@@ -833,7 +812,9 @@ mod tests {
.await .await
.expect("Failed to apply migrations"); .expect("Failed to apply migrations");
configure_embedding_dimension(&db, 3).await.expect("configure dim"); configure_embedding_dimension(&db, 3)
.await
.expect("configure dim");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await .await
.expect("Failed to redefine index length"); .expect("Failed to redefine index length");
@@ -1105,7 +1086,10 @@ mod tests {
.await .await
.with_context(|| "store entity with embedding".to_string())?; .with_context(|| "store entity with embedding".to_string())?;
let query = format!("DELETE type::thing('knowledge_entity', '{id}')", id = entity.id); let query = format!(
"DELETE type::thing('knowledge_entity', '{id}')",
id = entity.id
);
db.client db.client
.query(query) .query(query)
.await .await
@@ -51,12 +51,7 @@ impl KnowledgeEntityEmbedding {
/// ///
/// The embedding record id equals `entity_id` so each entity has at most one embedding row. /// The embedding record id equals `entity_id` so each entity has at most one embedding row.
#[must_use] #[must_use]
pub fn new( pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
entity_id: &str,
source_id: String,
embedding: Vec<f32>,
user_id: String,
) -> Self {
let now = Utc::now(); let now = Utc::now();
Self { Self {
id: entity_id.to_owned(), id: entity_id.to_owned(),
@@ -300,8 +295,7 @@ mod tests {
let db = prepare_knowledge_entity_test_db(3).await?; let db = prepare_knowledge_entity_test_db(3).await?;
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim"); let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
let result = let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
assert!(matches!(result, Err(AppError::Validation(_)))); assert!(matches!(result, Err(AppError::Validation(_))));
@@ -1,5 +1,5 @@
use crate::utils::serde_helpers::deserialize_flexible_id;
use crate::storage::types::user::User; use crate::storage::types::user::User;
use crate::utils::serde_helpers::deserialize_flexible_id;
use crate::{error::AppError, storage::db::SurrealDbClient}; use crate::{error::AppError, storage::db::SurrealDbClient};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
@@ -43,18 +43,10 @@ impl KnowledgeRelationship {
} }
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> { pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
User::get_and_validate_knowledge_entity( User::get_and_validate_knowledge_entity(&self.in_, &self.metadata.user_id, db_client)
&self.in_, .await?;
&self.metadata.user_id, User::get_and_validate_knowledge_entity(&self.out, &self.metadata.user_id, db_client)
db_client, .await?;
)
.await?;
User::get_and_validate_knowledge_entity(
&self.out,
&self.metadata.user_id,
db_client,
)
.await?;
db_client db_client
.client .client
+4 -1
View File
@@ -520,7 +520,10 @@ mod tests {
.with_context(|| "Failed to patch query prompt".to_string())?; .with_context(|| "Failed to patch query prompt".to_string())?;
assert_eq!(patched.query_system_prompt, sentinel); assert_eq!(patched.query_system_prompt, sentinel);
assert_eq!(patched.ingestion_system_prompt, original.ingestion_system_prompt); assert_eq!(
patched.ingestion_system_prompt,
original.ingestion_system_prompt
);
assert_eq!(patched.query_model, original.query_model); assert_eq!(patched.query_model, original.query_model);
assert_eq!( assert_eq!(
patched.registrations_enabled, patched.registrations_enabled,
+10 -19
View File
@@ -74,10 +74,7 @@ impl TextChunk {
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let settings = SystemSettings::get_current(db).await?; let settings = SystemSettings::get_current(db).await?;
TextChunkEmbedding::validate_dimension( TextChunkEmbedding::validate_dimension(&embedding, settings.embedding_dimensions as usize)?;
&embedding,
settings.embedding_dimensions as usize,
)?;
let chunk_id = chunk.id.clone(); let chunk_id = chunk.id.clone();
let emb = TextChunkEmbedding::new( let emb = TextChunkEmbedding::new(
@@ -308,9 +305,8 @@ impl TextChunk {
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings { for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding).map_err(|e| { let embedding = serde_json::to_string(&embedding)
AppError::internal(format!("embedding serialization failed: {e}")) .map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
})?;
let id = surql_json_string(&id)?; let id = surql_json_string(&id)?;
let user_id = surql_json_string(&user_id)?; let user_id = surql_json_string(&user_id)?;
let source_id = surql_json_string(&source_id)?; let source_id = surql_json_string(&source_id)?;
@@ -382,9 +378,7 @@ impl TextChunk {
info!(progress = i, total = total_chunks, "Re-embedding progress"); info!(progress = i, total = total_chunks, "Re-embedding progress");
} }
let embedding = provider let embedding = provider.embed(&chunk.chunk).await?;
.embed(&chunk.chunk)
.await?;
// Safety check: ensure the generated embedding has the correct dimension. // Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions { if embedding.len() != new_dimensions {
@@ -429,9 +423,8 @@ impl TextChunk {
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings { for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding = serde_json::to_string(&embedding).map_err(|e| { let embedding = serde_json::to_string(&embedding)
AppError::internal(format!("embedding serialization failed: {e}")) .map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
})?;
let id = surql_json_string(&id)?; let id = surql_json_string(&id)?;
let user_id = surql_json_string(&user_id)?; let user_id = surql_json_string(&user_id)?;
let source_id = surql_json_string(&source_id)?; let source_id = surql_json_string(&source_id)?;
@@ -662,7 +655,9 @@ mod tests {
.await .await
.expect("Failed to start in-memory surrealdb"); .expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.expect("migrations");
configure_embedding_dimension(&db, 5).await.expect("configure dim"); configure_embedding_dimension(&db, 5)
.await
.expect("configure dim");
TextChunkEmbedding::redefine_hnsw_index(&db, 5) TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await .await
.expect("redefine index"); .expect("redefine index");
@@ -1040,11 +1035,7 @@ mod tests {
.with_context(|| "migrations".to_string())?; .with_context(|| "migrations".to_string())?;
configure_embedding_dimension(&db, 3).await?; configure_embedding_dimension(&db, 3).await?;
let chunk = TextChunk::new( let chunk = TextChunk::new("src".to_string(), "body".to_string(), "user".to_string());
"src".to_string(),
"body".to_string(),
"user".to_string(),
);
let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], &db) let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], &db)
.await .await
@@ -299,12 +299,7 @@ mod tests {
("chunk-s2", source_id, vec![0.2]), ("chunk-s2", source_id, vec![0.2]),
("chunk-other", other_source, vec![0.3]), ("chunk-other", other_source, vec![0.3]),
] { ] {
let emb = TextChunkEmbedding::new( let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
key,
src.to_string(),
vec,
user_id.to_string(),
);
db.upsert_item(emb) db.upsert_item(emb)
.await .await
.with_context(|| format!("store embedding for {key}"))?; .with_context(|| format!("store embedding for {key}"))?;
+4 -8
View File
@@ -118,9 +118,7 @@ impl TextContent {
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
if updated.is_none() { if updated.is_none() {
return Err(AppError::NotFound(format!( return Err(AppError::NotFound(format!("text content {id} not found")));
"text content {id} not found"
)));
} }
Ok(()) Ok(())
@@ -142,7 +140,8 @@ impl TextContent {
.await .await
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
let existing: Option<surrealdb::sql::Thing> = response.take(0).map_err(AppError::Database)?; let existing: Option<surrealdb::sql::Thing> =
response.take(0).map_err(AppError::Database)?;
Ok(existing.is_some()) Ok(existing.is_some())
} }
@@ -254,10 +253,7 @@ impl TextContent {
for content in contents { for content in contents {
let label = build_source_label(&content); let label = build_source_label(&content);
labels.insert(content.id.clone(), label.clone()); labels.insert(content.id.clone(), label.clone());
labels.insert( labels.insert(format!("{}:{}", Self::table_name(), content.id), label);
format!("{}:{}", Self::table_name(), content.id),
label,
);
} }
Ok(labels) Ok(labels)
+2 -5
View File
@@ -8,8 +8,7 @@ use crate::storage::{
db::SurrealDbClient, db::SurrealDbClient,
indexes::{ensure_runtime, rebuild}, indexes::{ensure_runtime, rebuild},
types::{ types::{
knowledge_entity_embedding::KnowledgeEntityEmbedding, knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
system_settings::SystemSettings,
text_chunk_embedding::TextChunkEmbedding, text_chunk_embedding::TextChunkEmbedding,
}, },
}; };
@@ -27,9 +26,7 @@ pub async fn setup_test_db() -> Result<SurrealDbClient> {
.await .await
.context("start in-memory surrealdb")?; .context("start in-memory surrealdb")?;
db.apply_migrations() db.apply_migrations().await.context("apply migrations")?;
.await
.context("apply migrations")?;
Ok(db) Ok(db)
} }
+21 -23
View File
@@ -59,9 +59,7 @@ async fn run_fastembed(
texts: Vec<String>, texts: Vec<String>,
) -> Result<Vec<Vec<f32>>, EmbeddingError> { ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> { match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut guard = model let mut guard = model.lock().map_err(EmbeddingError::mutex_poisoned)?;
.lock()
.map_err(EmbeddingError::mutex_poisoned)?;
guard.embed(texts, None).map_err(EmbeddingError::fastembed) guard.embed(texts, None).map_err(EmbeddingError::fastembed)
}) })
.await .await
@@ -215,21 +213,22 @@ impl EmbeddingProvider {
let model_name_for_task = model_name.clone(); let model_name_for_task = model_name.clone();
let model_name_code = model_name.to_string(); let model_name_code = model_name.to_string();
let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> { let (model, dimension) =
let model = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?; let model = TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| { let info =
EmbeddingError::Config(format!( EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
"fastembed model metadata missing for {model_name_code}" EmbeddingError::Config(format!(
)) "fastembed model metadata missing for {model_name_code}"
})?; ))
Ok((model, info.dim)) })?;
}) Ok((model, info.dim))
.await })
{ .await
Ok(result) => result?, {
Err(join_error) => return Err(EmbeddingError::from(join_error)), Ok(result) => result?,
}; Err(join_error) => return Err(EmbeddingError::from(join_error)),
};
Ok(EmbeddingProvider { Ok(EmbeddingProvider {
inner: EmbeddingInner::FastEmbed { inner: EmbeddingInner::FastEmbed {
@@ -440,10 +439,7 @@ mod tests {
#[test] #[test]
fn embedding_backend_defaults_to_fastembed() { fn embedding_backend_defaults_to_fastembed() {
assert_eq!( assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
EmbeddingBackend::default(),
EmbeddingBackend::FastEmbed
);
} }
#[test] #[test]
@@ -472,7 +468,9 @@ mod tests {
#[test] #[test]
fn embedding_backend_from_str_accepts_aliases() { fn embedding_backend_from_str_accepts_aliases() {
assert_eq!( assert_eq!(
"fast-embed".parse::<EmbeddingBackend>().expect("fast-embed"), "fast-embed"
.parse::<EmbeddingBackend>()
.expect("fast-embed"),
EmbeddingBackend::FastEmbed EmbeddingBackend::FastEmbed
); );
assert_eq!( assert_eq!(
+1 -2
View File
@@ -9,8 +9,7 @@ pub use orchestrator::{
}; };
pub use store::{ pub use store::{
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata, seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
CorpusQuestion, EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
ParagraphShardStore, MANIFEST_VERSION,
}; };
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig { pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
+17 -24
View File
@@ -33,8 +33,7 @@ use crate::{
use crate::corpus::{ use crate::corpus::{
CorpusCacheConfig, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion, CorpusCacheConfig, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, ParagraphShardStore, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
MANIFEST_VERSION,
}; };
const INGESTION_SPEC_VERSION: u32 = 2; const INGESTION_SPEC_VERSION: u32 = 2;
@@ -273,10 +272,19 @@ pub async fn ensure_corpus(
.context("shard record missing after ingestion run")?; .context("shard record missing after ingestion run")?;
if cache.refresh_embeddings_only || shard_record.needs_reembed { if cache.refresh_embeddings_only || shard_record.needs_reembed {
// Embeddings are now generated by the pipeline using FastEmbed - no need to re-embed // Embeddings are now generated by the pipeline using FastEmbed - no need to re-embed
shard_record.shard.ingestion_fingerprint.clone_from(&ingestion_fingerprint); shard_record
.shard
.ingestion_fingerprint
.clone_from(&ingestion_fingerprint);
shard_record.shard.ingested_at = Utc::now(); shard_record.shard.ingested_at = Utc::now();
shard_record.shard.embedding_backend.clone_from(&embedding_backend_label); shard_record
shard_record.shard.embedding_model.clone_from(&embedding_model_code); .shard
.embedding_backend
.clone_from(&embedding_backend_label);
shard_record
.shard
.embedding_model
.clone_from(&embedding_model_code);
shard_record.shard.embedding_dimension = embedding_dimension; shard_record.shard.embedding_dimension = embedding_dimension;
shard_record.dirty = true; shard_record.dirty = true;
shard_record.needs_reembed = false; shard_record.needs_reembed = false;
@@ -543,31 +551,16 @@ async fn ingest_single_paragraph(
let task = IngestionTask::new(payload, user_id.to_string()); let task = IngestionTask::new(payload, user_id.to_string());
match pipeline.produce_artifacts(&task).await { match pipeline.produce_artifacts(&task).await {
Ok(artifacts) => { Ok(artifacts) => {
let entities: Vec<EmbeddedKnowledgeEntity> = artifacts // Artifacts already carry the shared `Embedded*` types and FastEmbed
.entities // embeddings, so they can be persisted to the shard without re-mapping.
.into_iter()
.map(|e| EmbeddedKnowledgeEntity {
entity: e.entity,
embedding: e.embedding,
})
.collect();
let chunks: Vec<EmbeddedTextChunk> = artifacts
.chunks
.into_iter()
.map(|c| EmbeddedTextChunk {
chunk: c.chunk,
embedding: c.embedding,
})
.collect();
// No need to reembed - pipeline now uses FastEmbed internally
let mut shard = ParagraphShard::new( let mut shard = ParagraphShard::new(
paragraph, paragraph,
request.shard_path, request.shard_path,
ingestion_fingerprint, ingestion_fingerprint,
artifacts.text_content, artifacts.text_content,
entities, artifacts.entities,
artifacts.relationships, artifacts.relationships,
chunks, artifacts.chunks,
&embedding_backend, &embedding_backend,
embedding_model.clone(), embedding_model.clone(),
embedding_dimension, embedding_dimension,
+3 -11
View File
@@ -54,17 +54,9 @@ fn default_chunk_only() -> bool {
false false
} }
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] // Reuse the pipeline's canonical embedded-artifact types so the on-disk corpus
pub struct EmbeddedKnowledgeEntity { // format and the ingestion output never drift apart.
pub entity: KnowledgeEntity, pub use ingestion_pipeline::{EmbeddedKnowledgeEntity, EmbeddedTextChunk};
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EmbeddedTextChunk {
pub chunk: TextChunk,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, serde::Deserialize)] #[derive(Debug, Clone, serde::Deserialize)]
struct LegacyKnowledgeEntity { struct LegacyKnowledgeEntity {
+5 -1
View File
@@ -11,7 +11,11 @@ use tracing::warn;
use super::{ConvertedParagraph, ConvertedQuestion}; use super::{ConvertedParagraph, ConvertedQuestion};
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects, clippy::cast_sign_loss)] #[allow(
clippy::too_many_lines,
clippy::arithmetic_side_effects,
clippy::cast_sign_loss
)]
pub fn convert_nq( pub fn convert_nq(
raw_path: &Path, raw_path: &Path,
include_unanswerable: bool, include_unanswerable: bool,
+1 -2
View File
@@ -166,8 +166,7 @@ async fn async_main() -> anyhow::Result<()> {
); );
if parsed.config.slice_grow.is_some() { if parsed.config.slice_grow.is_some() {
eval::grow_slice(&dataset, &parsed.config) eval::grow_slice(&dataset, &parsed.config).context("growing slice ledger")?;
.context("growing slice ledger")?;
return Ok(()); return Ok(());
} }
+20 -7
View File
@@ -14,7 +14,7 @@ use common::{
utils::embedding::EmbeddingProvider, utils::embedding::EmbeddingProvider,
}; };
use retrieval_pipeline::{ use retrieval_pipeline::{
pipeline::{StageTimings, RetrievalConfig}, pipeline::{RetrievalConfig, StageTimings},
reranking::RerankerPool, reranking::RerankerPool,
}; };
@@ -122,11 +122,15 @@ impl<'a> EvaluationContext<'a> {
} }
pub fn slice(&self) -> Result<&slice::ResolvedSlice<'a>> { pub fn slice(&self) -> Result<&slice::ResolvedSlice<'a>> {
self.slice.as_ref().ok_or_else(|| anyhow!("slice has not been prepared")) self.slice
.as_ref()
.ok_or_else(|| anyhow!("slice has not been prepared"))
} }
pub fn db(&self) -> Result<&SurrealDbClient> { pub fn db(&self) -> Result<&SurrealDbClient> {
self.db.as_ref().ok_or_else(|| anyhow!("database connection missing")) self.db
.as_ref()
.ok_or_else(|| anyhow!("database connection missing"))
} }
pub fn descriptor(&self) -> Result<&snapshot::Descriptor> { pub fn descriptor(&self) -> Result<&snapshot::Descriptor> {
@@ -142,15 +146,23 @@ impl<'a> EvaluationContext<'a> {
} }
pub fn openai_client(&self) -> Result<Arc<Client<async_openai::config::OpenAIConfig>>> { pub fn openai_client(&self) -> Result<Arc<Client<async_openai::config::OpenAIConfig>>> {
Ok(Arc::clone(self.openai_client.as_ref().ok_or_else(|| anyhow!("openai client missing"))?)) Ok(Arc::clone(
self.openai_client
.as_ref()
.ok_or_else(|| anyhow!("openai client missing"))?,
))
} }
pub fn corpus_handle(&self) -> Result<&corpus::CorpusHandle> { pub fn corpus_handle(&self) -> Result<&corpus::CorpusHandle> {
self.corpus_handle.as_ref().ok_or_else(|| anyhow!("corpus handle missing")) self.corpus_handle
.as_ref()
.ok_or_else(|| anyhow!("corpus handle missing"))
} }
pub fn evaluation_user(&self) -> Result<&User> { pub fn evaluation_user(&self) -> Result<&User> {
self.eval_user.as_ref().ok_or_else(|| anyhow!("evaluation user missing")) self.eval_user
.as_ref()
.ok_or_else(|| anyhow!("evaluation user missing"))
} }
#[allow(clippy::arithmetic_side_effects)] #[allow(clippy::arithmetic_side_effects)]
@@ -168,7 +180,8 @@ impl<'a> EvaluationContext<'a> {
} }
pub fn into_summary(self) -> Result<EvaluationSummary> { pub fn into_summary(self) -> Result<EvaluationSummary> {
self.summary.ok_or_else(|| anyhow!("evaluation summary missing")) self.summary
.ok_or_else(|| anyhow!("evaluation summary missing"))
} }
} }
+15 -16
View File
@@ -10,7 +10,7 @@ use crate::eval::{
CaseSummary, RetrievedSummary, CaseSummary, RetrievedSummary,
}; };
use retrieval_pipeline::{ use retrieval_pipeline::{
pipeline::{self, StageTimings, RetrievalConfig}, pipeline::{self, RetrievalConfig, StageTimings},
reranking::RerankerPool, reranking::RerankerPool,
}; };
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
@@ -169,10 +169,10 @@ pub(crate) async fn run_queries(
let query_start = Instant::now(); let query_start = Instant::now();
debug!(question_id = %question_id, "Evaluating query"); debug!(question_id = %question_id, "Evaluating query");
let query_embedding = let query_embedding = embedding_provider
embedding_provider.embed(&question).await.with_context(|| { .embed(&question)
format!("generating embedding for question {question_id}") .await
})?; .with_context(|| format!("generating embedding for question {question_id}"))?;
let reranker = match rerank_pool.as_ref() { let reranker = match rerank_pool.as_ref() {
Some(pool) => pool.checkout().await, Some(pool) => pool.checkout().await,
None => None, None => None,
@@ -204,8 +204,10 @@ pub(crate) async fn run_queries(
let mut match_rank = None; let mut match_rank = None;
let answers_lower: Vec<String> = let answers_lower: Vec<String> =
answers.iter().map(|ans| ans.to_ascii_lowercase()).collect(); answers.iter().map(|ans| ans.to_ascii_lowercase()).collect();
let expected_chunk_ids_set: HashSet<&str> = let expected_chunk_ids_set: HashSet<&str> = expected_chunk_ids
expected_chunk_ids.iter().map(std::string::String::as_str).collect(); .iter()
.map(std::string::String::as_str)
.collect();
let chunk_id_required = has_verified_chunks; let chunk_id_required = has_verified_chunks;
let mut entity_hit = false; let mut entity_hit = false;
let mut chunk_text_hit = false; let mut chunk_text_hit = false;
@@ -304,15 +306,12 @@ pub(crate) async fn run_queries(
None None
}; };
Ok::< Ok::<(usize, CaseSummary, Option<CaseDiagnostics>, StageTimings), anyhow::Error>((
( idx,
usize, summary,
CaseSummary, diagnostics,
Option<CaseDiagnostics>, stage_timings,
StageTimings, ))
),
anyhow::Error,
>((idx, summary, diagnostics, stage_timings))
} }
}) })
.buffer_unordered(concurrency) .buffer_unordered(concurrency)
+5 -1
View File
@@ -13,7 +13,11 @@ use super::super::{
}; };
use super::{map_guard_error, StageResult}; use super::{map_guard_error, StageResult};
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects, clippy::cast_precision_loss)] #[allow(
clippy::too_many_lines,
clippy::arithmetic_side_effects,
clippy::cast_precision_loss
)]
pub(crate) async fn summarize( pub(crate) async fn summarize(
machine: EvaluationMachine<(), QueriesFinished>, machine: EvaluationMachine<(), QueriesFinished>,
ctx: &mut EvaluationContext<'_>, ctx: &mut EvaluationContext<'_>,
+147 -26
View File
@@ -403,11 +403,20 @@ pub fn write_reports(
}) })
} }
#[allow(clippy::too_many_lines, clippy::write_with_newline, clippy::unwrap_used)] #[allow(
clippy::too_many_lines,
clippy::write_with_newline,
clippy::unwrap_used
)]
fn render_markdown(report: &EvaluationReport) -> String { fn render_markdown(report: &EvaluationReport) -> String {
let mut md = String::new(); let mut md = String::new();
write!(md, "# Retrieval Evaluation (k={})\\n\\n", report.retrieval.k).unwrap(); write!(
md,
"# Retrieval Evaluation (k={})\\n\\n",
report.retrieval.k
)
.unwrap();
md.push_str("## Overview\\n\\n"); md.push_str("## Overview\\n\\n");
md.push_str("| Metric | Value |\\n| --- | --- |\\n"); md.push_str("| Metric | Value |\\n| --- | --- |\\n");
@@ -424,34 +433,94 @@ fn render_markdown(report: &EvaluationReport) -> String {
) )
.unwrap(); .unwrap();
write!(md, "| Total Cases | {} |\\n", report.overview.total_cases).unwrap(); write!(md, "| Total Cases | {} |\\n", report.overview.total_cases).unwrap();
write!(md, "| Filtered Questions | {} |\\n", report.overview.filtered_questions).unwrap(); write!(
md,
"| Filtered Questions | {} |\\n",
report.overview.filtered_questions
)
.unwrap();
md.push_str("\\n## Dataset & Slice\\n\\n"); md.push_str("\\n## Dataset & Slice\\n\\n");
md.push_str("| Metric | Value |\\n| --- | --- |\\n"); md.push_str("| Metric | Value |\\n| --- | --- |\\n");
write!(md, "| Dataset | {} (`{}`) |\\n", report.dataset.label, report.dataset.id).unwrap(); write!(
md,
"| Dataset | {} (`{}`) |\\n",
report.dataset.label, report.dataset.id
)
.unwrap();
write!(md, "| Dataset Source | {} |\\n", report.dataset.source).unwrap(); write!(md, "| Dataset Source | {} |\\n", report.dataset.source).unwrap();
write!(md, "| Includes Unanswerable | {} |\\n", bool_badge(report.dataset.includes_unanswerable)).unwrap(); write!(
write!(md, "| Require Verified Chunks | {} |\\n", bool_badge(report.dataset.require_verified_chunks)).unwrap(); md,
"| Includes Unanswerable | {} |\\n",
bool_badge(report.dataset.includes_unanswerable)
)
.unwrap();
write!(
md,
"| Require Verified Chunks | {} |\\n",
bool_badge(report.dataset.require_verified_chunks)
)
.unwrap();
let embedding_label = if let Some(model) = report.dataset.embedding_model.as_ref() { let embedding_label = if let Some(model) = report.dataset.embedding_model.as_ref() {
format!("{} ({model})", report.dataset.embedding_backend) format!("{} ({model})", report.dataset.embedding_backend)
} else { } else {
report.dataset.embedding_backend.clone() report.dataset.embedding_backend.clone()
}; };
write!(md, "| Embedding | {embedding_label} |\\n").unwrap(); write!(md, "| Embedding | {embedding_label} |\\n").unwrap();
write!(md, "| Embedding Dim | {} |\\n", report.dataset.embedding_dimension).unwrap(); write!(
md,
"| Embedding Dim | {} |\\n",
report.dataset.embedding_dimension
)
.unwrap();
write!(md, "| Slice ID | `{}` |\\n", report.slice.id).unwrap(); write!(md, "| Slice ID | `{}` |\\n", report.slice.id).unwrap();
write!(md, "| Slice Seed | {} |\\n", report.slice.seed).unwrap(); write!(md, "| Slice Seed | {} |\\n", report.slice.seed).unwrap();
write!(md, "| Slice Window (offset/length) | {}/{} |\\n", report.slice.window_offset, report.slice.window_length).unwrap(); write!(
write!(md, "| Slice Questions (window/ledger) | {}/{} |\\n", report.slice.slice_cases, report.slice.ledger_total_cases).unwrap(); md,
write!(md, "| Slice Positives / Negatives | {}/{} |\\n", report.slice.positives, report.slice.negatives).unwrap(); "| Slice Window (offset/length) | {}/{} |\\n",
write!(md, "| Slice Paragraphs | {} |\\n", report.slice.total_paragraphs).unwrap(); report.slice.window_offset, report.slice.window_length
write!(md, "| Negative Multiplier | {:.2} |\\n", report.slice.negative_multiplier).unwrap(); )
.unwrap();
write!(
md,
"| Slice Questions (window/ledger) | {}/{} |\\n",
report.slice.slice_cases, report.slice.ledger_total_cases
)
.unwrap();
write!(
md,
"| Slice Positives / Negatives | {}/{} |\\n",
report.slice.positives, report.slice.negatives
)
.unwrap();
write!(
md,
"| Slice Paragraphs | {} |\\n",
report.slice.total_paragraphs
)
.unwrap();
write!(
md,
"| Negative Multiplier | {:.2} |\\n",
report.slice.negative_multiplier
)
.unwrap();
md.push_str("\\n## Retrieval Metrics\\n\\n"); md.push_str("\\n## Retrieval Metrics\\n\\n");
md.push_str("| Metric | Value |\\n| --- | --- |\\n"); md.push_str("| Metric | Value |\\n| --- | --- |\\n");
write!(md, "| Cases | {} |\\n", report.retrieval.cases).unwrap(); write!(md, "| Cases | {} |\\n", report.retrieval.cases).unwrap();
write!(md, "| Correct@{} | {}/{} |\\n", report.retrieval.k, report.retrieval.correct, report.retrieval.cases).unwrap(); write!(
write!(md, "| Precision@{} | {:.3} |\\n", report.retrieval.k, report.retrieval.precision).unwrap(); md,
"| Correct@{} | {}/{} |\\n",
report.retrieval.k, report.retrieval.correct, report.retrieval.cases
)
.unwrap();
write!(
md,
"| Precision@{} | {:.3} |\\n",
report.retrieval.k, report.retrieval.precision
)
.unwrap();
write!( write!(
md, md,
"| Precision@1/2/3 | {:.3} / {:.3} / {:.3} |\\n", "| Precision@1/2/3 | {:.3} / {:.3} / {:.3} |\\n",
@@ -462,7 +531,12 @@ fn render_markdown(report: &EvaluationReport) -> String {
.unwrap(); .unwrap();
write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap(); write!(md, "| MRR | {:.3} |\\n", report.retrieval.mrr).unwrap();
write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap(); write!(md, "| NDCG | {:.3} |\\n", report.retrieval.average_ndcg).unwrap();
write!(md, "| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n", report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95).unwrap(); write!(
md,
"| Latency Avg / P50 / P95 (ms) | {:.1} / {} / {} |\\n",
report.retrieval.latency.avg, report.retrieval.latency.p50, report.retrieval.latency.p95
)
.unwrap();
write!( write!(
md, md,
"| Resolve entities | {} |\\n", "| Resolve entities | {} |\\n",
@@ -473,8 +547,14 @@ fn render_markdown(report: &EvaluationReport) -> String {
if report.retrieval.rerank_enabled { if report.retrieval.rerank_enabled {
let pool = report let pool = report
.retrieval .retrieval
.rerank_pool_size.map_or_else(|| "?".into(), |size| size.to_string()); .rerank_pool_size
write!(md, "| Rerank | enabled (pool {pool}, keep top {}) |\\n", report.retrieval.rerank_keep_top).unwrap(); .map_or_else(|| "?".into(), |size| size.to_string());
write!(
md,
"| Rerank | enabled (pool {pool}, keep top {}) |\\n",
report.retrieval.rerank_keep_top
)
.unwrap();
} else { } else {
md.push_str("| Rerank | disabled |\\n"); md.push_str("| Rerank | disabled |\\n");
} }
@@ -489,8 +569,18 @@ fn render_markdown(report: &EvaluationReport) -> String {
md.push_str("\\n## Performance\\n\\n"); md.push_str("\\n## Performance\\n\\n");
md.push_str("| Metric | Value |\\n| --- | --- |\\n"); md.push_str("| Metric | Value |\\n| --- | --- |\\n");
write!(md, "| OpenAI Base URL | {} |\\n", report.performance.openai_base_url).unwrap(); write!(
write!(md, "| Ingestion Duration | {} ms |\\n", report.performance.ingestion_ms).unwrap(); md,
"| OpenAI Base URL | {} |\\n",
report.performance.openai_base_url
)
.unwrap();
write!(
md,
"| Ingestion Duration | {} ms |\\n",
report.performance.ingestion_ms
)
.unwrap();
if let Some(seed) = report.performance.namespace_seed_ms { if let Some(seed) = report.performance.namespace_seed_ms {
write!(md, "| Namespace Seed | {seed} ms |\\n").unwrap(); write!(md, "| Namespace Seed | {seed} ms |\\n").unwrap();
} }
@@ -504,14 +594,44 @@ fn render_markdown(report: &EvaluationReport) -> String {
} }
) )
.unwrap(); .unwrap();
write!(md, "| Corpus Paragraphs | {} |\\n", report.performance.corpus_paragraphs).unwrap(); write!(
md,
"| Corpus Paragraphs | {} |\\n",
report.performance.corpus_paragraphs
)
.unwrap();
if report.detailed_report { if report.detailed_report {
write!(md, "| Ingestion Cache | `{}` |\\n", report.performance.ingestion_cache_path).unwrap(); write!(
write!(md, "| Ingestion Reused | {} |\\n", bool_badge(report.performance.ingestion_reused)).unwrap(); md,
write!(md, "| Embeddings Reused | {} |\\n", bool_badge(report.performance.embeddings_reused)).unwrap(); "| Ingestion Cache | `{}` |\\n",
report.performance.ingestion_cache_path
)
.unwrap();
write!(
md,
"| Ingestion Reused | {} |\\n",
bool_badge(report.performance.ingestion_reused)
)
.unwrap();
write!(
md,
"| Embeddings Reused | {} |\\n",
bool_badge(report.performance.embeddings_reused)
)
.unwrap();
} }
write!(md, "| Positives Cached | {} |\\n", report.performance.positive_paragraphs_reused).unwrap(); write!(
write!(md, "| Negatives Cached | {} |\\n", report.performance.negative_paragraphs_reused).unwrap(); md,
"| Positives Cached | {} |\\n",
report.performance.positive_paragraphs_reused
)
.unwrap();
write!(
md,
"| Negatives Cached | {} |\\n",
report.performance.negative_paragraphs_reused
)
.unwrap();
md.push_str("\\n## Retrieval Stage Timings\\n\\n"); md.push_str("\\n## Retrieval Stage Timings\\n\\n");
md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n"); md.push_str("| Stage | Avg (ms) | P50 (ms) | P95 (ms) |\\n| --- | --- | --- | --- |\\n");
@@ -583,7 +703,8 @@ fn render_markdown(report: &EvaluationReport) -> String {
for case in &report.llm_cases { for case in &report.llm_cases {
let retrieved = render_retrieved(&case.retrieved); let retrieved = render_retrieved(&case.retrieved);
let rank = case let rank = case
.match_rank.map_or_else(|| "-".into(), |rank| rank.to_string()); .match_rank
.map_or_else(|| "-".into(), |rank| rank.to_string());
write!( write!(
md, md,
"| `{}` | {} | {} | {} |\\n", "| `{}` | {} | {} | {} |\\n",
+31 -19
View File
@@ -99,10 +99,13 @@ fn sanitize_identifier(input: &str) -> String {
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
hasher.update(input.as_bytes()); hasher.update(input.as_bytes());
let digest = hasher.finalize(); let digest = hasher.finalize();
digest.iter().take(6).fold(String::with_capacity(12), |mut s, b| { digest
let _ = write!(s, "{b:02x}"); .iter()
s .take(6)
}) .fold(String::with_capacity(12), |mut s, b| {
let _ = write!(s, "{b:02x}");
s
})
} else { } else {
trimmed trimmed
} }
@@ -127,7 +130,9 @@ pub struct SliceWindow<'a> {
impl SliceWindow<'_> { impl SliceWindow<'_> {
pub fn positive_ids(&self) -> impl Iterator<Item = &str> { pub fn positive_ids(&self) -> impl Iterator<Item = &str> {
self.positive_paragraph_ids.iter().map(std::string::String::as_str) self.positive_paragraph_ids
.iter()
.map(std::string::String::as_str)
} }
} }
@@ -169,7 +174,10 @@ impl DatasetIndex {
.paragraph_by_id .paragraph_by_id
.get(id) .get(id)
.ok_or_else(|| anyhow!("slice references unknown paragraph '{id}'"))?; .ok_or_else(|| anyhow!("slice references unknown paragraph '{id}'"))?;
dataset.paragraphs.get(*idx).ok_or_else(|| anyhow!("paragraph index out of bounds")) dataset
.paragraphs
.get(*idx)
.ok_or_else(|| anyhow!("paragraph index out of bounds"))
} }
fn question<'a>( fn question<'a>(
@@ -181,7 +189,9 @@ impl DatasetIndex {
.question_by_id .question_by_id
.get(question_id) .get(question_id)
.ok_or_else(|| anyhow!("slice references unknown question '{question_id}'"))?; .ok_or_else(|| anyhow!("slice references unknown question '{question_id}'"))?;
let paragraph = dataset.paragraphs.get(*p_idx) let paragraph = dataset
.paragraphs
.get(*p_idx)
.ok_or_else(|| anyhow!("paragraph index out of bounds for question '{question_id}'"))?; .ok_or_else(|| anyhow!("paragraph index out of bounds for question '{question_id}'"))?;
let question = paragraph let question = paragraph
.questions .questions
@@ -318,9 +328,7 @@ pub fn resolve_slice<'a>(
.is_some_and(|manifest| manifest.version != SLICE_VERSION) .is_some_and(|manifest| manifest.version != SLICE_VERSION)
{ {
warn!( warn!(
slice = manifest slice = manifest.as_ref().map_or("unknown", |m| m.slice_id.as_str()),
.as_ref()
.map_or("unknown", |m| m.slice_id.as_str()),
found = manifest.as_ref().map_or(0, |m| m.version), found = manifest.as_ref().map_or(0, |m| m.version),
expected = SLICE_VERSION, expected = SLICE_VERSION,
"Slice manifest version mismatch; regenerating" "Slice manifest version mismatch; regenerating"
@@ -919,7 +927,11 @@ fn ensure_shard_paths(manifest: &mut SliceManifest) -> bool {
changed changed
} }
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation, clippy::cast_sign_loss)] #[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn desired_negative_target( fn desired_negative_target(
positive_count: usize, positive_count: usize,
requested_corpus: usize, requested_corpus: usize,
@@ -1007,10 +1019,13 @@ fn compute_slice_id(key: &SliceKey<'_>) -> Result<String> {
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
hasher.update(payload); hasher.update(payload);
let digest = hasher.finalize(); let digest = hasher.finalize();
Ok(digest.iter().take(16).fold(String::with_capacity(32), |mut s, b| { Ok(digest
let _ = write!(s, "{b:02x}"); .iter()
s .take(16)
})) .fold(String::with_capacity(32), |mut s, b| {
let _ = write!(s, "{b:02x}");
s
}))
} }
#[allow(clippy::indexing_slicing)] #[allow(clippy::indexing_slicing)]
@@ -1050,10 +1065,7 @@ impl<'a> From<&'a Config> for SliceConfig<'a> {
} }
} }
pub fn slice_config_with_limit( pub fn slice_config_with_limit(config: &Config, limit_override: Option<usize>) -> SliceConfig<'_> {
config: &Config,
limit_override: Option<usize>,
) -> SliceConfig<'_> {
SliceConfig { SliceConfig {
cache_dir: config.cache_dir.as_path(), cache_dir: config.cache_dir.as_path(),
force_convert: config.force_convert, force_convert: config.force_convert,
+4 -1
View File
@@ -409,7 +409,10 @@ pub fn build_case_diagnostics(
candidates: &[EvaluationCandidate], candidates: &[EvaluationCandidate],
pipeline_stats: Option<Diagnostics>, pipeline_stats: Option<Diagnostics>,
) -> CaseDiagnostics { ) -> CaseDiagnostics {
let expected_set: HashSet<&str> = expected_chunk_ids.iter().map(std::string::String::as_str).collect(); let expected_set: HashSet<&str> = expected_chunk_ids
.iter()
.map(std::string::String::as_str)
.collect();
let mut seen_chunks: HashSet<String> = HashSet::new(); let mut seen_chunks: HashSet<String> = HashSet::new();
let mut attached_chunk_ids = Vec::new(); let mut attached_chunk_ids = Vec::new();
let mut entity_diagnostics = Vec::new(); let mut entity_diagnostics = Vec::new();
+4 -3
View File
@@ -21,7 +21,7 @@ use common::{
knowledge_relationship::KnowledgeRelationship, knowledge_relationship::KnowledgeRelationship,
user::User, user::User,
}, },
utils::embedding::generate_embedding, utils::embedding::generate_embedding_with_provider,
}; };
use retrieval_pipeline; use retrieval_pipeline;
use tracing::debug; use tracing::debug;
@@ -183,7 +183,8 @@ pub async fn create_knowledge_entity(
let embedding_input = let embedding_input =
format!("name: {name}, description: {description}, type: {entity_type:?}"); format!("name: {name}, description: {description}, type: {entity_type:?}");
let embedding = generate_embedding(&state.openai_client, &embedding_input, &state.db).await?; let embedding =
generate_embedding_with_provider(&state.embedding_provider, &embedding_input).await?;
let source_id = format!("manual::{}", Uuid::new_v4()); let source_id = format!("manual::{}", Uuid::new_v4());
let new_entity = KnowledgeEntity::new( let new_entity = KnowledgeEntity::new(
@@ -914,7 +915,7 @@ pub async fn patch_knowledge_entity(
&form.description, &form.description,
&entity_type, &entity_type,
&state.db, &state.db,
&state.openai_client, &state.embedding_provider,
) )
.await?; .await?;
+2 -2
View File
@@ -27,8 +27,8 @@ url = { workspace = true }
uuid = { workspace = true } uuid = { workspace = true }
headless_chrome = { workspace = true } headless_chrome = { workspace = true }
base64 = { workspace = true } base64 = { workspace = true }
pdf-extract = "0.9" pdf-extract = { workspace = true }
lopdf = "0.32" lopdf = { workspace = true }
bytes = { workspace = true } bytes = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
state-machines = { workspace = true } state-machines = { workspace = true }
+16 -4
View File
@@ -8,19 +8,28 @@ use common::storage::{
db::SurrealDbClient, db::SurrealDbClient,
types::ingestion_task::{IngestionTask, DEFAULT_LEASE_SECS}, types::ingestion_task::{IngestionTask, DEFAULT_LEASE_SECS},
}; };
pub use pipeline::{IngestionConfig, IngestionPipeline, IngestionTuning}; pub use pipeline::{
EmbeddedKnowledgeEntity, EmbeddedTextChunk, IngestionConfig, IngestionPipeline,
IngestionTuning, PipelineArtifacts,
};
use std::sync::Arc; use std::sync::Arc;
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use uuid::Uuid; use uuid::Uuid;
/// How long the worker sleeps when no task is ready to claim.
const WORKER_IDLE_BACKOFF_MS: u64 = 500;
/// How long the worker sleeps after a transient claim error before retrying.
const WORKER_CLAIM_ERROR_BACKOFF_MS: u64 = 1_000;
pub async fn run_worker_loop( pub async fn run_worker_loop(
db: Arc<SurrealDbClient>, db: Arc<SurrealDbClient>,
ingestion_pipeline: Arc<IngestionPipeline>, ingestion_pipeline: Arc<IngestionPipeline>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let worker_id = format!("ingestion-worker-{}", Uuid::new_v4()); let worker_id = format!("ingestion-worker-{}", Uuid::new_v4());
let lease_duration = Duration::from_secs(DEFAULT_LEASE_SECS as u64); let lease_duration = Duration::from_secs(DEFAULT_LEASE_SECS as u64);
let idle_backoff = Duration::from_millis(500); let idle_backoff = Duration::from_millis(WORKER_IDLE_BACKOFF_MS);
let claim_error_backoff = Duration::from_millis(WORKER_CLAIM_ERROR_BACKOFF_MS);
loop { loop {
match IngestionTask::claim_next_ready(&db, &worker_id, Utc::now(), lease_duration).await { match IngestionTask::claim_next_ready(&db, &worker_id, Utc::now(), lease_duration).await {
@@ -41,8 +50,11 @@ pub async fn run_worker_loop(
} }
Err(err) => { Err(err) => {
error!(%worker_id, error = %err, "failed to claim ingestion task"); error!(%worker_id, error = %err, "failed to claim ingestion task");
warn!("Backing off for 1s after claim error"); warn!(
sleep(Duration::from_secs(1)).await; backoff_ms = WORKER_CLAIM_ERROR_BACKOFF_MS,
"Backing off after claim error"
);
sleep(claim_error_backoff).await;
} }
} }
} }
+4 -2
View File
@@ -9,8 +9,10 @@ pub struct IngestionTuning {
pub chunk_min_tokens: usize, pub chunk_min_tokens: usize,
pub chunk_max_tokens: usize, pub chunk_max_tokens: usize,
pub chunk_overlap_tokens: usize, pub chunk_overlap_tokens: usize,
pub chunk_insert_concurrency: usize,
pub entity_embedding_concurrency: usize, pub entity_embedding_concurrency: usize,
/// Maximum characters of content body used to build the similarity-search query
/// during retrieval. Longer bodies are truncated to keep embedding inputs bounded.
pub embedding_query_char_limit: usize,
} }
impl Default for IngestionTuning { impl Default for IngestionTuning {
@@ -25,8 +27,8 @@ impl Default for IngestionTuning {
chunk_min_tokens: 256, chunk_min_tokens: 256,
chunk_max_tokens: 512, chunk_max_tokens: 512,
chunk_overlap_tokens: 50, chunk_overlap_tokens: 50,
chunk_insert_concurrency: 8,
entity_embedding_concurrency: 4, entity_embedding_concurrency: 4,
embedding_query_char_limit: 12_000,
} }
} }
} }
+3 -2
View File
@@ -12,19 +12,20 @@ use common::{
}, },
}; };
use retrieval_pipeline::RetrievedEntity; use retrieval_pipeline::RetrievedEntity;
use serde::{Deserialize, Serialize};
use tracing::error; use tracing::error;
use super::enrichment_result::LLMEnrichmentResult; use super::enrichment_result::LLMEnrichmentResult;
use super::{config::IngestionConfig, services::PipelineServices}; use super::{config::IngestionConfig, services::PipelineServices};
#[derive(Debug, Clone)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedKnowledgeEntity { pub struct EmbeddedKnowledgeEntity {
pub entity: KnowledgeEntity, pub entity: KnowledgeEntity,
pub embedding: Vec<f32>, pub embedding: Vec<f32>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedTextChunk { pub struct EmbeddedTextChunk {
pub chunk: TextChunk, pub chunk: TextChunk,
pub embedding: Vec<f32>, pub embedding: Vec<f32>,
@@ -4,7 +4,6 @@ use chrono::Utc;
use futures::stream::{self, StreamExt, TryStreamExt}; use futures::stream::{self, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use common::{ use common::{
error::AppError, error::AppError,
storage::{ storage::{
@@ -178,3 +177,98 @@ async fn create_single_entity(
Ok(EmbeddedKnowledgeEntity { entity, embedding }) Ok(EmbeddedKnowledgeEntity { entity, embedding })
} }
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
use super::*;
use uuid::Uuid;
fn entity(key: &str) -> LLMKnowledgeEntity {
LLMKnowledgeEntity {
key: key.to_string(),
name: format!("name-{key}"),
description: format!("desc-{key}"),
entity_type: "Idea".to_string(),
}
}
fn relationship(type_: &str, source: &str, target: &str) -> LLMRelationship {
LLMRelationship {
type_: type_.to_string(),
source: source.to_string(),
target: target.to_string(),
}
}
#[test]
fn create_mapper_assigns_id_per_entity_key() {
let result = LLMEnrichmentResult {
knowledge_entities: vec![entity("k1"), entity("k2")],
relationships: Vec::new(),
};
let mapper = result.create_mapper();
assert!(mapper.get_id("k1").is_ok());
assert!(mapper.get_id("k2").is_ok());
assert_ne!(
mapper.get_id("k1").expect("k1"),
mapper.get_id("k2").expect("k2")
);
}
#[test]
fn process_relationships_resolves_keys_to_assigned_ids() {
let result = LLMEnrichmentResult {
knowledge_entities: vec![entity("k1"), entity("k2")],
relationships: vec![relationship("relates_to", "k1", "k2")],
};
let mapper = result.create_mapper();
let relationships = result
.process_relationships("source-1", "user-1", &mapper)
.expect("relationships resolve");
assert_eq!(relationships.len(), 1);
let rel = relationships.first().expect("one relationship");
assert_eq!(rel.in_, mapper.get_id("k1").expect("k1").to_string());
assert_eq!(rel.out, mapper.get_id("k2").expect("k2").to_string());
assert_eq!(rel.metadata.relationship_type, "relates_to");
assert_eq!(rel.metadata.source_id, "source-1");
assert_eq!(rel.metadata.user_id, "user-1");
}
#[test]
fn process_relationships_accepts_raw_uuid_endpoints() {
let raw = Uuid::new_v4();
let result = LLMEnrichmentResult {
knowledge_entities: vec![entity("k1")],
relationships: vec![relationship("relates_to", "k1", &raw.to_string())],
};
let mapper = result.create_mapper();
let relationships = result
.process_relationships("source-1", "user-1", &mapper)
.expect("raw uuid target resolves");
assert_eq!(
relationships.first().expect("one relationship").out,
raw.to_string()
);
}
#[test]
fn process_relationships_errors_on_unknown_endpoint() {
let result = LLMEnrichmentResult {
knowledge_entities: vec![entity("k1")],
relationships: vec![relationship("relates_to", "k1", "missing-key")],
};
let mapper = result.create_mapper();
assert!(matches!(
result.process_relationships("source-1", "user-1", &mapper),
Err(AppError::GraphMapper(_))
));
}
}
+58 -49
View File
@@ -1,12 +1,15 @@
mod config; mod config;
mod context; mod context;
mod enrichment_result; mod enrichment_result;
mod persistence;
mod preparation; mod preparation;
mod services; mod services;
mod stages; mod stages;
mod state; mod state;
pub use config::{IngestionConfig, IngestionTuning}; pub use config::{IngestionConfig, IngestionTuning};
#[allow(clippy::module_name_repetitions)]
pub use context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk, PipelineArtifacts};
pub use enrichment_result::{LLMEnrichmentResult, LLMKnowledgeEntity, LLMRelationship}; pub use enrichment_result::{LLMEnrichmentResult, LLMKnowledgeEntity, LLMRelationship};
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
pub use services::{DefaultPipelineServices, PipelineServices}; pub use services::{DefaultPipelineServices, PipelineServices};
@@ -33,11 +36,18 @@ use retrieval_pipeline::reranking::RerankerPool;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use self::{ use self::{
context::{PipelineArtifacts, PipelineContext}, context::PipelineContext,
stages::{enrich, persist, prepare_content, retrieve_related}, stages::{enrich, persist, prepare_content, retrieve_related},
state::ready, state::{ready, Enriched, IngestionMachine},
}; };
/// Wall-clock duration of each pre-persistence pipeline stage.
struct StageTimings {
prepare: Duration,
retrieve: Duration,
enrich: Duration,
}
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
pub struct IngestionPipeline { pub struct IngestionPipeline {
db: Arc<SurrealDbClient>, db: Arc<SurrealDbClient>,
@@ -81,6 +91,7 @@ impl IngestionPipeline {
reranker_pool, reranker_pool,
storage, storage,
embedding_provider, embedding_provider,
pipeline_config.tuning.embedding_query_char_limit,
); );
Self::with_services(db, pipeline_config, Arc::new(services)) Self::with_services(db, pipeline_config, Arc::new(services))
@@ -109,15 +120,7 @@ impl IngestionPipeline {
)] )]
pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> { pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> {
let mut processing_task = task.mark_processing(&self.db).await?; let mut processing_task = task.mark_processing(&self.db).await?;
let payload = std::mem::replace( let payload = processing_task.take_content();
&mut processing_task.content,
IngestionPayload::Text {
text: String::new(),
context: String::new(),
category: String::new(),
user_id: processing_task.user_id.clone(),
},
);
match self match self
.drive_pipeline(&processing_task, payload) .drive_pipeline(&processing_task, payload)
@@ -191,6 +194,44 @@ impl IngestionPipeline {
u64::try_from(duration.as_millis()).unwrap_or(u64::MAX) u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
} }
/// Runs the shared `prepare → retrieve → enrich` stages, recording per-stage timings.
///
/// Both the full task path ([`Self::drive_pipeline`]) and the artifact-only path
/// ([`Self::produce_artifacts`]) share this prefix; only the terminal step differs
/// (persist vs. return artifacts).
async fn run_through_enrichment(
&self,
ctx: &mut PipelineContext<'_>,
payload: IngestionPayload,
) -> Result<(IngestionMachine<(), Enriched>, StageTimings), AppError> {
let machine = ready();
let stage_start = Instant::now();
let machine = prepare_content(machine, ctx, payload)
.await
.map_err(|err| ctx.abort(err))?;
let prepare = stage_start.elapsed();
let stage_start = Instant::now();
let machine = retrieve_related(machine, ctx)
.await
.map_err(|err| ctx.abort(err))?;
let retrieve = stage_start.elapsed();
let stage_start = Instant::now();
let machine = enrich(machine, ctx).await.map_err(|err| ctx.abort(err))?;
let enrich = stage_start.elapsed();
Ok((
machine,
StageTimings {
prepare,
retrieve,
enrich,
},
))
}
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(task_id = %task.id, attempt = task.attempts, user_id = %task.user_id) fields(task_id = %task.id, attempt = task.attempts, user_id = %task.user_id)
@@ -207,27 +248,8 @@ impl IngestionPipeline {
self.services.as_ref(), self.services.as_ref(),
); );
let machine = ready();
let pipeline_started = Instant::now(); let pipeline_started = Instant::now();
let (machine, timings) = self.run_through_enrichment(&mut ctx, payload).await?;
let stage_start = Instant::now();
let machine = prepare_content(machine, &mut ctx, payload)
.await
.map_err(|err| ctx.abort(err))?;
let prepare_duration = stage_start.elapsed();
let stage_start = Instant::now();
let machine = retrieve_related(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
let retrieve_duration = stage_start.elapsed();
let stage_start = Instant::now();
let machine = enrich(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
let enrich_duration = stage_start.elapsed();
let stage_start = Instant::now(); let stage_start = Instant::now();
let _machine = persist(machine, &mut ctx) let _machine = persist(machine, &mut ctx)
@@ -236,18 +258,14 @@ impl IngestionPipeline {
let persist_duration = stage_start.elapsed(); let persist_duration = stage_start.elapsed();
let total_duration = pipeline_started.elapsed(); let total_duration = pipeline_started.elapsed();
let prepare_ms = Self::duration_millis(prepare_duration);
let retrieve_ms = Self::duration_millis(retrieve_duration);
let enrich_ms = Self::duration_millis(enrich_duration);
let persist_ms = Self::duration_millis(persist_duration);
info!( info!(
task_id = %ctx.task_id, task_id = %ctx.task_id,
attempt = ctx.attempt, attempt = ctx.attempt,
total_ms = Self::duration_millis(total_duration), total_ms = Self::duration_millis(total_duration),
prepare_ms, prepare_ms = Self::duration_millis(timings.prepare),
retrieve_ms, retrieve_ms = Self::duration_millis(timings.retrieve),
enrich_ms, enrich_ms = Self::duration_millis(timings.enrich),
persist_ms, persist_ms = Self::duration_millis(persist_duration),
"ingestion pipeline finished" "ingestion pipeline finished"
); );
@@ -267,16 +285,7 @@ impl IngestionPipeline {
self.services.as_ref(), self.services.as_ref(),
); );
let machine = ready(); let (_machine, _timings) = self.run_through_enrichment(&mut ctx, payload).await?;
let machine = prepare_content(machine, &mut ctx, payload)
.await
.map_err(|err| ctx.abort(err))?;
let machine = retrieve_related(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
let _machine = enrich(machine, &mut ctx)
.await
.map_err(|err| ctx.abort(err))?;
ctx.build_artifacts().await.map_err(|err| ctx.abort(err)) ctx.build_artifacts().await.map_err(|err| ctx.abort(err))
} }
@@ -0,0 +1,127 @@
//! Low-level database write mechanics for the persist stage.
//!
//! This module owns *how* ingested artifacts reach `SurrealDB` (per-item store loops,
//! the relationship transaction, and conflict retry/backoff). The persist stage in
//! [`super::stages`] owns *what* gets written and in which order.
use std::sync::Arc;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
},
},
};
use tokio::time::{sleep, Duration};
use tracing::{debug, warn};
use super::{
config::IngestionTuning,
context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk},
};
const STORE_RELATIONSHIPS: &str = r"
BEGIN TRANSACTION;
LET $relationships = $relationships;
FOR $relationship IN $relationships {
LET $in_node = type::thing('knowledge_entity', $relationship.in);
LET $out_node = type::thing('knowledge_entity', $relationship.out);
RELATE $in_node->relates_to->$out_node CONTENT {
id: type::thing('relates_to', $relationship.id),
metadata: $relationship.metadata
};
};
COMMIT TRANSACTION;
";
/// Persists chunk embeddings to the vector store.
///
/// Chunks are written serially on purpose. Concurrent/batched inserts were
/// trialed and did not reliably improve throughput; see `ingestion-pipeline/AGENTS.md`
/// for the rationale and as a candidate for future refactoring/benchmarking.
pub(super) async fn store_vector_chunks(
db: &SurrealDbClient,
task_id: &str,
chunks: &[EmbeddedTextChunk],
) -> Result<usize, AppError> {
for embedded in chunks {
TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db)
.await?;
debug!(
task_id = %task_id,
chunk_id = %embedded.chunk.id,
chunk_len = embedded.chunk.chunk.chars().count(),
"chunk persisted"
);
}
Ok(chunks.len())
}
/// Persists knowledge entities and their relationships.
///
/// Entities are stored serially (see `store_vector_chunks` and AGENTS.md for why).
/// Relationships are written via a single transaction with bounded conflict retry.
pub(super) async fn store_graph_entities(
db: &SurrealDbClient,
tuning: &IngestionTuning,
entities: Vec<EmbeddedKnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), AppError> {
for embedded in entities {
KnowledgeEntity::store_with_embedding(embedded.entity, embedded.embedding, db).await?;
}
if relationships.is_empty() {
return Ok(());
}
let relationships = Arc::new(relationships);
let mut backoff_ms = tuning.graph_initial_backoff_ms;
let last_attempt = tuning.graph_store_attempts.saturating_sub(1);
for attempt in 0..tuning.graph_store_attempts {
let result = db
.client
.query(STORE_RELATIONSHIPS)
.bind(("relationships", Arc::clone(&relationships)))
.await;
match result {
Ok(_) => return Ok(()),
Err(err) => {
if is_retryable_conflict(&err) && attempt < last_attempt {
let next_attempt = attempt.saturating_add(1);
warn!(
attempt = next_attempt,
"Transient SurrealDB conflict while storing graph data; retrying"
);
sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = backoff_ms
.saturating_mul(2)
.min(tuning.graph_max_backoff_ms);
continue;
}
return Err(AppError::from(err));
}
}
}
Err(AppError::InternalError(
"Failed to store graph entities after retries".to_string(),
))
}
fn is_retryable_conflict(error: &surrealdb::Error) -> bool {
error
.to_string()
.contains("Failed to commit transaction due to a read or write conflict")
}
+70 -12
View File
@@ -3,7 +3,6 @@ use std::{
sync::{Arc, OnceLock}, sync::{Arc, OnceLock},
}; };
use async_openai::types::{ use async_openai::types::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat, CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat,
@@ -30,7 +29,6 @@ use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content
use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk}; use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk};
use crate::utils::llm_instructions::get_ingress_analysis_schema; use crate::utils::llm_instructions::get_ingress_analysis_schema;
const EMBEDDING_QUERY_CHAR_LIMIT: usize = 12_000;
#[async_trait] #[async_trait]
pub trait PipelineServices: Send + Sync { pub trait PipelineServices: Send + Sync {
async fn prepare_text_content( async fn prepare_text_content(
@@ -71,6 +69,7 @@ pub struct DefaultPipelineServices {
reranker_pool: Option<Arc<RerankerPool>>, reranker_pool: Option<Arc<RerankerPool>>,
storage: StorageManager, storage: StorageManager,
embedding_provider: Arc<EmbeddingProvider>, embedding_provider: Arc<EmbeddingProvider>,
embedding_query_char_limit: usize,
} }
impl DefaultPipelineServices { impl DefaultPipelineServices {
@@ -81,6 +80,7 @@ impl DefaultPipelineServices {
reranker_pool: Option<Arc<RerankerPool>>, reranker_pool: Option<Arc<RerankerPool>>,
storage: StorageManager, storage: StorageManager,
embedding_provider: Arc<EmbeddingProvider>, embedding_provider: Arc<EmbeddingProvider>,
embedding_query_char_limit: usize,
) -> Self { ) -> Self {
Self { Self {
db, db,
@@ -89,6 +89,7 @@ impl DefaultPipelineServices {
reranker_pool, reranker_pool,
storage, storage,
embedding_provider, embedding_provider,
embedding_query_char_limit,
} }
} }
@@ -169,7 +170,7 @@ impl PipelineServices for DefaultPipelineServices {
&self, &self,
content: &TextContent, content: &TextContent,
) -> Result<Vec<RetrievedEntity>, AppError> { ) -> Result<Vec<RetrievedEntity>, AppError> {
let truncated_body = truncate_for_embedding(&content.text, EMBEDDING_QUERY_CHAR_LIMIT); let truncated_body = truncate_for_embedding(&content.text, self.embedding_query_char_limit);
let input_text = format!( let input_text = format!(
"content: {}\n[truncated={}], category: {}, user_context: {:?}", "content: {}\n[truncated={}], category: {}, user_context: {:?}",
truncated_body, truncated_body,
@@ -250,7 +251,7 @@ impl PipelineServices for DefaultPipelineServices {
token_range: Range<usize>, token_range: Range<usize>,
overlap_tokens: usize, overlap_tokens: usize,
) -> Result<Vec<EmbeddedTextChunk>, AppError> { ) -> Result<Vec<EmbeddedTextChunk>, AppError> {
let chunk_candidates = prepare_chunks( let chunk_candidates = split_text_into_chunks(
&content.text, &content.text,
token_range.start, token_range.start,
token_range.end, token_range.end,
@@ -263,7 +264,9 @@ impl PipelineServices for DefaultPipelineServices {
.embedding_provider .embedding_provider
.embed(&chunk_text) .embed(&chunk_text)
.await .await
.map_err(|e| AppError::InternalError(format!("FastEmbed embedding for chunk failed: {e}")))?; .map_err(|e| {
AppError::InternalError(format!("FastEmbed embedding for chunk failed: {e}"))
})?;
let chunk_struct = TextChunk::new( let chunk_struct = TextChunk::new(
content.id().to_string(), content.id().to_string(),
chunk_text, chunk_text,
@@ -278,7 +281,7 @@ impl PipelineServices for DefaultPipelineServices {
} }
} }
fn prepare_chunks( fn split_text_into_chunks(
text: &str, text: &str,
min_tokens: usize, min_tokens: usize,
max_tokens: usize, max_tokens: usize,
@@ -352,9 +355,7 @@ mod tests {
use async_openai::{config::OpenAIConfig, types::ChatCompletionRequestMessage, Client}; use async_openai::{config::OpenAIConfig, types::ChatCompletionRequestMessage, Client};
use common::{ use common::{
storage::{ storage::{
db::SurrealDbClient, db::SurrealDbClient, store::StorageManager, types::system_settings::SystemSettingsPatch,
store::StorageManager,
types::system_settings::SystemSettingsPatch,
}, },
utils::{ utils::{
config::{AppConfig, StorageKind}, config::{AppConfig, StorageKind},
@@ -364,6 +365,8 @@ mod tests {
use uuid::Uuid; use uuid::Uuid;
use super::DefaultPipelineServices; use super::DefaultPipelineServices;
use crate::pipeline::IngestionTuning;
use common::error::AppError;
fn system_prompt_from_request( fn system_prompt_from_request(
request: &async_openai::types::CreateChatCompletionRequest, request: &async_openai::types::CreateChatCompletionRequest,
@@ -380,8 +383,8 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn prepare_llm_request_uses_ingestion_prompt_from_system_settings( async fn prepare_llm_request_uses_ingestion_prompt_from_system_settings() -> anyhow::Result<()>
) -> anyhow::Result<()> { {
const SENTINEL: &str = "ingestion-prompt-sentinel-from-db"; const SENTINEL: &str = "ingestion-prompt-sentinel-from-db";
let db = Arc::new( let db = Arc::new(
@@ -402,7 +405,9 @@ mod tests {
storage: StorageKind::Memory, storage: StorageKind::Memory,
..Default::default() ..Default::default()
}; };
let storage = StorageManager::new(&config).await.context("storage manager")?; let storage = StorageManager::new(&config)
.await
.context("storage manager")?;
let openai_client = Arc::new(Client::with_config(OpenAIConfig::default())); let openai_client = Arc::new(Client::with_config(OpenAIConfig::default()));
let embedding_provider = Arc::new(EmbeddingProvider::new_hashed(384)?); let embedding_provider = Arc::new(EmbeddingProvider::new_hashed(384)?);
@@ -413,6 +418,7 @@ mod tests {
None, None,
storage, storage,
embedding_provider, embedding_provider,
IngestionTuning::default().embedding_query_char_limit,
); );
let request = services let request = services
@@ -423,4 +429,56 @@ mod tests {
assert_eq!(system_prompt_from_request(&request)?, SENTINEL); assert_eq!(system_prompt_from_request(&request)?, SENTINEL);
Ok(()) Ok(())
} }
#[test]
fn split_text_into_chunks_rejects_zero_bounds() {
assert!(matches!(
super::split_text_into_chunks("text", 0, 10, 0),
Err(AppError::Validation(_))
));
assert!(matches!(
super::split_text_into_chunks("text", 4, 0, 0),
Err(AppError::Validation(_))
));
}
#[test]
fn split_text_into_chunks_rejects_min_greater_than_max() {
assert!(matches!(
super::split_text_into_chunks("text", 10, 4, 0),
Err(AppError::Validation(_))
));
}
#[test]
fn split_text_into_chunks_rejects_overlap_at_or_above_min() {
assert!(matches!(
super::split_text_into_chunks("text", 4, 10, 4),
Err(AppError::Validation(_))
));
assert!(matches!(
super::split_text_into_chunks("text", 4, 10, 5),
Err(AppError::Validation(_))
));
}
#[test]
fn truncate_for_embedding_returns_short_text_unchanged() {
assert_eq!(super::truncate_for_embedding("hello", 10), "hello");
// Exactly at the limit is left untouched (no ellipsis appended).
assert_eq!(super::truncate_for_embedding("hello", 5), "hello");
}
#[test]
fn truncate_for_embedding_appends_ellipsis_when_over_limit() {
assert_eq!(super::truncate_for_embedding("hello world", 5), "hello…");
}
#[test]
fn truncate_for_embedding_respects_char_boundaries() {
// Multibyte characters must not be split mid-byte.
let truncated = super::truncate_for_embedding("héllo wörld", 4);
assert_eq!(truncated, "héll…");
assert_eq!(truncated.chars().count(), 5);
}
} }
@@ -1,42 +1,23 @@
use std::sync::Arc; //! State-machine stages of the ingestion pipeline.
//!
//! Each function advances the `IngestionMachine` by one transition
//! (`prepare → retrieve → enrich → persist`), mutating the shared
//! [`PipelineContext`]. Low-level database writes live in [`super::persistence`].
use common::{ use common::{
error::AppError, error::AppError,
storage::{ storage::{indexes::rebuild, types::ingestion_payload::IngestionPayload},
db::SurrealDbClient,
indexes::rebuild,
types::{
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
},
},
}; };
use state_machines::core::GuardError; use state_machines::core::GuardError;
use tokio::time::{sleep, Duration}; use tracing::{debug, instrument};
use tracing::{debug, instrument, warn};
use super::{ use super::{
context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk, PipelineArtifacts, PipelineContext}, context::{PipelineArtifacts, PipelineContext},
enrichment_result::LLMEnrichmentResult, enrichment_result::LLMEnrichmentResult,
persistence::{store_graph_entities, store_vector_chunks},
state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved}, state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved},
}; };
const STORE_RELATIONSHIPS: &str = r"
BEGIN TRANSACTION;
LET $relationships = $relationships;
FOR $relationship IN $relationships {
LET $in_node = type::thing('knowledge_entity', $relationship.in);
LET $out_node = type::thing('knowledge_entity', $relationship.out);
RELATE $in_node->relates_to->$out_node CONTENT {
id: type::thing('relates_to', $relationship.id),
metadata: $relationship.metadata
};
};
COMMIT TRANSACTION;
";
#[instrument( #[instrument(
level = "trace", level = "trace",
skip_all, skip_all,
@@ -174,23 +155,9 @@ pub async fn persist(
let entity_count = entities.len(); let entity_count = entities.len();
let relationship_count = relationships.len(); let relationship_count = relationships.len();
debug!("Were storing chunks"); let chunk_count = store_vector_chunks(ctx.db, ctx.task_id.as_str(), &chunks).await?;
let chunk_count = store_vector_chunks(
ctx.db,
ctx.task_id.as_str(),
&chunks,
&ctx.pipeline_config.tuning,
)
.await?;
debug!("We stored chunks");
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?; store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?;
debug!("Stored graph entities");
ctx.db.store_item(text_content).await?; ctx.db.store_item(text_content).await?;
debug!("stored item");
rebuild(ctx.db).await?; rebuild(ctx.db).await?;
debug!( debug!(
@@ -212,107 +179,3 @@ fn map_guard_error(event: &str, guard: &GuardError) -> AppError {
"invalid ingestion pipeline transition during {event}: {guard:?}" "invalid ingestion pipeline transition during {event}: {guard:?}"
)) ))
} }
async fn store_graph_entities(
db: &SurrealDbClient,
tuning: &super::config::IngestionTuning,
entities: Vec<EmbeddedKnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), AppError> {
// Persist entities with embeddings first.
for embedded in entities {
KnowledgeEntity::store_with_embedding(embedded.entity, embedded.embedding, db).await?;
}
if relationships.is_empty() {
return Ok(());
}
let relationships = Arc::new(relationships);
let mut backoff_ms = tuning.graph_initial_backoff_ms;
let last_attempt = tuning.graph_store_attempts.saturating_sub(1);
for attempt in 0..tuning.graph_store_attempts {
let result = db
.client
.query(STORE_RELATIONSHIPS)
.bind(("relationships", Arc::clone(&relationships)))
.await;
match result {
Ok(_) => return Ok(()),
Err(err) => {
if is_retryable_conflict(&err) && attempt < last_attempt {
let next_attempt = attempt.saturating_add(1);
warn!(
attempt = next_attempt,
"Transient SurrealDB conflict while storing graph data; retrying"
);
sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = backoff_ms
.saturating_mul(2)
.min(tuning.graph_max_backoff_ms);
continue;
}
return Err(AppError::from(err));
}
}
}
Err(AppError::InternalError(
"Failed to store graph entities after retries".to_string(),
))
}
async fn store_vector_chunks(
db: &SurrealDbClient,
task_id: &str,
chunks: &[EmbeddedTextChunk],
tuning: &super::config::IngestionTuning,
) -> Result<usize, AppError> {
let chunk_count = chunks.len();
let batch_size = tuning.chunk_insert_concurrency.max(1);
for batch in chunks.chunks(batch_size) {
store_chunk_batch(db, batch, tuning, task_id).await?;
}
Ok(chunk_count)
}
fn is_retryable_conflict(error: &surrealdb::Error) -> bool {
error
.to_string()
.contains("Failed to commit transaction due to a read or write conflict")
}
async fn store_chunk_batch(
db: &SurrealDbClient,
batch: &[EmbeddedTextChunk],
_tuning: &super::config::IngestionTuning,
task_id: &str,
) -> Result<(), AppError> {
if batch.is_empty() {
return Ok(());
}
for embedded in batch {
TextChunk::store_with_embedding(
embedded.chunk.clone(),
embedded.embedding.clone(),
db,
)
.await?;
debug!(
task_id = %task_id,
chunk_id = %embedded.chunk.id,
chunk_len = embedded.chunk.chunk.chars().count(),
"chunk persisted"
);
}
Ok(())
}
+35 -37
View File
@@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::{self, Context};
use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk}; use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk};
use anyhow::{self, Context};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{Duration as ChronoDuration, Utc}; use chrono::{Duration as ChronoDuration, Utc};
use common::{ use common::{
@@ -279,7 +279,6 @@ fn pipeline_config() -> IngestionConfig {
tuning: IngestionTuning { tuning: IngestionTuning {
chunk_min_tokens: 4, chunk_min_tokens: 4,
chunk_max_tokens: 64, chunk_max_tokens: 64,
chunk_insert_concurrency: 4,
entity_embedding_concurrency: 2, entity_embedding_concurrency: 2,
..IngestionTuning::default() ..IngestionTuning::default()
}, },
@@ -302,18 +301,33 @@ async fn reserve_task(
} }
#[tokio::test] #[tokio::test]
async fn ingestion_pipeline_happy_path_persists_entities() -> anyhow::Result<()> #[allow(clippy::duration_suboptimal_units)] // assertions mirror retry_delay's seconds-based config
{ async fn retry_delay_grows_exponentially_and_caps() -> anyhow::Result<()> {
use std::time::Duration;
let db = setup_db().await?;
let services: Arc<dyn PipelineServices> = Arc::new(MockServices::new("user-delay"));
let pipeline = IngestionPipeline::with_services(Arc::new(db), pipeline_config(), services)?;
// Defaults: base = 30s, cap exponent = 5, max = 900s.
assert_eq!(pipeline.retry_delay(0), Duration::from_secs(30));
assert_eq!(pipeline.retry_delay(1), Duration::from_secs(30));
assert_eq!(pipeline.retry_delay(2), Duration::from_secs(60));
assert_eq!(pipeline.retry_delay(3), Duration::from_secs(120));
// Beyond the cap exponent the delay clamps at the configured maximum.
assert_eq!(pipeline.retry_delay(7), Duration::from_secs(900));
Ok(())
}
#[tokio::test]
async fn ingestion_pipeline_happy_path_persists_entities() -> anyhow::Result<()> {
let db = setup_db().await?; let db = setup_db().await?;
let worker_id = "worker-happy"; let worker_id = "worker-happy";
let user_id = "user-123"; let user_id = "user-123";
let services = Arc::new(MockServices::new(user_id)); let services = Arc::new(MockServices::new(user_id));
let services_clone: Arc<dyn PipelineServices> = Arc::<MockServices>::clone(&services); let services_clone: Arc<dyn PipelineServices> = Arc::<MockServices>::clone(&services);
let pipeline = IngestionPipeline::with_services( let pipeline =
Arc::new(db.clone()), IngestionPipeline::with_services(Arc::new(db.clone()), pipeline_config(), services_clone)?;
pipeline_config(),
services_clone,
)?;
let task = reserve_task( let task = reserve_task(
&db, &db,
@@ -330,15 +344,11 @@ async fn ingestion_pipeline_happy_path_persists_entities() -> anyhow::Result<()>
pipeline.process_task(task.clone()).await?; pipeline.process_task(task.clone()).await?;
let stored_task: IngestionTask = db let stored_task: IngestionTask = db.get_item(&task.id).await?.context("task present")?;
.get_item(&task.id)
.await?
.context("task present")?;
assert_eq!(stored_task.state, TaskState::Succeeded); assert_eq!(stored_task.state, TaskState::Succeeded);
let stored_entities: Vec<KnowledgeEntity> = db let stored_entities: Vec<KnowledgeEntity> =
.get_all_stored_items::<KnowledgeEntity>() db.get_all_stored_items::<KnowledgeEntity>().await?;
.await?;
assert!(!stored_entities.is_empty(), "entities should be stored"); assert!(!stored_entities.is_empty(), "entities should be stored");
let stored_chunks: Vec<TextChunk> = db.get_all_stored_items::<TextChunk>().await?; let stored_chunks: Vec<TextChunk> = db.get_all_stored_items::<TextChunk>().await?;
@@ -356,9 +366,9 @@ async fn ingestion_pipeline_happy_path_persists_entities() -> anyhow::Result<()>
call_log.get(0..4), call_log.get(0..4),
Some(&["prepare", "retrieve", "enrich", "convert"][..]) Some(&["prepare", "retrieve", "enrich", "convert"][..])
); );
assert!( assert!(call_log
call_log.get(4..).is_some_and(|tail| tail.iter().all(|entry| *entry == "chunk")) .get(4..)
); .is_some_and(|tail| tail.iter().all(|entry| *entry == "chunk")));
Ok(()) Ok(())
} }
@@ -371,11 +381,7 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() -> anyhow::Result<()> {
let services_clone: Arc<dyn PipelineServices> = Arc::<MockServices>::clone(&services); let services_clone: Arc<dyn PipelineServices> = Arc::<MockServices>::clone(&services);
let mut config = pipeline_config(); let mut config = pipeline_config();
config.chunk_only = true; config.chunk_only = true;
let pipeline = IngestionPipeline::with_services( let pipeline = IngestionPipeline::with_services(Arc::new(db.clone()), config, services_clone)?;
Arc::new(db.clone()),
config,
services_clone,
)?;
let task = reserve_task( let task = reserve_task(
&db, &db,
@@ -392,9 +398,8 @@ async fn ingestion_pipeline_chunk_only_skips_analysis() -> anyhow::Result<()> {
pipeline.process_task(task.clone()).await?; pipeline.process_task(task.clone()).await?;
let stored_entities: Vec<KnowledgeEntity> = db let stored_entities: Vec<KnowledgeEntity> =
.get_all_stored_items::<KnowledgeEntity>() db.get_all_stored_items::<KnowledgeEntity>().await?;
.await?;
assert!( assert!(
stored_entities.is_empty(), stored_entities.is_empty(),
"chunk-only ingestion should not persist entities" "chunk-only ingestion should not persist entities"
@@ -451,10 +456,7 @@ async fn ingestion_pipeline_failure_marks_retry() -> anyhow::Result<()> {
"failure services should bubble error from pipeline" "failure services should bubble error from pipeline"
); );
let stored_task: IngestionTask = db let stored_task: IngestionTask = db.get_item(&task.id).await?.context("task present")?;
.get_item(&task.id)
.await?
.context("task present")?;
assert_eq!(stored_task.state, TaskState::Failed); assert_eq!(stored_task.state, TaskState::Failed);
assert!( assert!(
stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5), stored_task.scheduled_at > Utc::now() - ChronoDuration::seconds(5),
@@ -464,8 +466,7 @@ async fn ingestion_pipeline_failure_marks_retry() -> anyhow::Result<()> {
} }
#[tokio::test] #[tokio::test]
async fn ingestion_pipeline_validation_failure_dead_letters_task( async fn ingestion_pipeline_validation_failure_dead_letters_task() -> anyhow::Result<()> {
) -> anyhow::Result<()> {
let db = setup_db().await?; let db = setup_db().await?;
let worker_id = "worker-validation"; let worker_id = "worker-validation";
let user_id = "user-789"; let user_id = "user-789";
@@ -492,10 +493,7 @@ async fn ingestion_pipeline_validation_failure_dead_letters_task(
"validation failure should surface as error" "validation failure should surface as error"
); );
let stored_task: IngestionTask = db let stored_task: IngestionTask = db.get_item(&task.id).await?.context("task present")?;
.get_item(&task.id)
.await?
.context("task present")?;
assert_eq!(stored_task.state, TaskState::DeadLetter); assert_eq!(stored_task.state, TaskState::DeadLetter);
Ok(()) Ok(())
} }
+27
View File
@@ -0,0 +1,27 @@
use common::error::AppError;
use headless_chrome::Browser;
/// Launches a headless Chrome instance, honoring the `docker` feature flag
/// (which disables the Chrome sandbox for container environments).
///
/// This is the single place the crate spawns a browser. If the rendering backend
/// is ever swapped away from headless Chrome to something leaner, this function is
/// the seam to change; callers only depend on getting back a `Browser`.
pub(crate) fn launch_browser() -> Result<Browser, AppError> {
#[cfg(feature = "docker")]
{
let options = headless_chrome::LaunchOptionsBuilder::default()
.sandbox(false)
.build()
.map_err(|err| {
AppError::Processing(format!("Failed to build headless browser options: {err}"))
})?;
Browser::new(options)
.map_err(|err| AppError::Processing(format!("Failed to start headless browser: {err}")))
}
#[cfg(not(feature = "docker"))]
{
Browser::default()
.map_err(|err| AppError::Processing(format!("Failed to start headless browser: {err}")))
}
}
@@ -12,7 +12,7 @@ use uuid::Uuid;
use super::{ use super::{
audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image, audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image,
pdf_ingestion::extract_pdf_content, pdf::extract_pdf_content,
}; };
struct TempPathGuard { struct TempPathGuard {
@@ -187,8 +187,8 @@ mod tests {
let openai_client = Client::with_config(OpenAIConfig::default()); let openai_client = Client::with_config(OpenAIConfig::default());
let text = extract_text_from_file(&file_info, &db, &openai_client, &config, &storage) let text =
.await?; extract_text_from_file(&file_info, &db, &openai_client, &config, &storage).await?;
assert_eq!(text, String::from_utf8_lossy(contents)); assert_eq!(text, String::from_utf8_lossy(contents));
Ok(()) Ok(())
@@ -51,3 +51,54 @@ impl GraphMapper {
.ok_or_else(|| AppError::GraphMapper(format!("Key '{key}' not found in map."))) .ok_or_else(|| AppError::GraphMapper(format!("Key '{key}' not found in map.")))
} }
} }
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
use super::*;
#[test]
fn assign_then_get_returns_same_id() {
let mut mapper = GraphMapper::new();
let assigned = mapper.assign_id("entity-key");
assert_eq!(mapper.get_id("entity-key").expect("key present"), assigned);
}
#[test]
fn get_id_for_unknown_key_errors() {
let mapper = GraphMapper::new();
assert!(matches!(
mapper.get_id("missing"),
Err(AppError::GraphMapper(_))
));
}
#[test]
fn get_or_parse_id_parses_raw_uuid_without_lookup() {
let mapper = GraphMapper::new();
let raw = Uuid::new_v4();
let resolved = mapper
.get_or_parse_id(&raw.to_string())
.expect("raw uuid parses");
assert_eq!(resolved, raw);
}
#[test]
fn get_or_parse_id_falls_back_to_map_for_keys() {
let mut mapper = GraphMapper::new();
let assigned = mapper.assign_id("alias");
assert_eq!(
mapper.get_or_parse_id("alias").expect("alias mapped"),
assigned
);
}
#[test]
fn get_or_parse_id_errors_for_unknown_non_uuid_key() {
let mapper = GraphMapper::new();
assert!(matches!(
mapper.get_or_parse_id("not-a-uuid-and-not-mapped"),
Err(AppError::GraphMapper(_))
));
}
}
+2 -1
View File
@@ -1,7 +1,8 @@
pub mod audio_transcription; pub mod audio_transcription;
pub mod browser;
pub mod file_text_extraction; pub mod file_text_extraction;
pub mod graph_mapper; pub mod graph_mapper;
pub mod image_parsing; pub mod image_parsing;
pub mod llm_instructions; pub mod llm_instructions;
pub mod pdf_ingestion; pub mod pdf;
pub mod url_text_retrieval; pub mod url_text_retrieval;
+55
View File
@@ -0,0 +1,55 @@
mod render;
mod text;
mod vision;
use std::path::Path;
use common::{error::AppError, storage::db::SurrealDbClient, utils::config::PdfIngestMode};
use self::{
render::{load_page_numbers, render_pdf_pages},
text::{post_process, try_fast_path},
vision::vision_markdown,
};
/// Upper bound on the number of pages handed to the vision model in a single document.
const MAX_VISION_PAGES: usize = 50;
/// Attempts to extract PDF content, using a fast text layer first and falling back to
/// rendering the document for a vision-enabled LLM when needed.
pub async fn extract_pdf_content(
file_path: &Path,
db: &SurrealDbClient,
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
mode: &PdfIngestMode,
) -> Result<String, AppError> {
let pdf_bytes = tokio::fs::read(file_path).await?;
if let Some(candidate) = try_fast_path(pdf_bytes.clone()).await? {
return Ok(candidate);
}
if matches!(mode, PdfIngestMode::Classic) {
return Err(AppError::Processing(
"PDF text extraction failed and LLM-first mode is disabled".into(),
));
}
let page_numbers = load_page_numbers(pdf_bytes.clone()).await?;
if page_numbers.is_empty() {
return Err(AppError::Processing("PDF appears to have no pages".into()));
}
if page_numbers.len() > MAX_VISION_PAGES {
return Err(AppError::Processing(format!(
"PDF has {} pages which exceeds the configured vision processing limit of {}",
page_numbers.len(),
MAX_VISION_PAGES
)));
}
let rendered_pages = render_pdf_pages(file_path, &page_numbers).await?;
let combined_markdown = vision_markdown(rendered_pages, db, client).await?;
Ok(post_process(&combined_markdown))
}
+418
View File
@@ -0,0 +1,418 @@
//! Headless-Chrome rasterization of PDF pages into PNG screenshots.
//!
//! This is the only Chrome-dependent part of PDF ingestion. It depends on the
//! browser's internal PDF-viewer shadow DOM, so it is inherently fragile across
//! Chrome upgrades; a full-page-capture fallback guards the common failure modes.
use std::{
path::{Path, PathBuf},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use base64::{engine::general_purpose::STANDARD, Engine as _};
use headless_chrome::protocol::cdp::{Emulation, Page, DOM};
use lopdf::Document;
use serde_json::Value;
use tracing::{debug, warn};
use common::error::AppError;
use crate::utils::browser::launch_browser;
const NAVIGATION_RETRY_INTERVAL_MS: u64 = 120;
const NAVIGATION_RETRY_ATTEMPTS: usize = 10;
const MIN_PAGE_IMAGE_BYTES: usize = 1_024;
const DEFAULT_VIEWPORT_WIDTH: u32 = 1_248; // generous width to reduce horizontal clipping
const DEFAULT_VIEWPORT_HEIGHT: u32 = 1_800; // tall enough to capture full page at fit-to-width scale
const DEFAULT_DEVICE_SCALE_FACTOR: f64 = 1.0;
const CANVAS_VIEWPORT_ATTEMPTS: usize = 12;
const CANVAS_VIEWPORT_WAIT_MS: u64 = 200;
const DEBUG_IMAGE_ENV_VAR: &str = "MINNE_PDF_DEBUG_DIR";
/// Parses the PDF structure to discover the available page numbers while keeping work off
/// the async executor.
pub(super) async fn load_page_numbers(pdf_bytes: Vec<u8>) -> Result<Vec<u32>, AppError> {
let pages = tokio::task::spawn_blocking(move || -> Result<Vec<u32>, AppError> {
let document = Document::load_mem(&pdf_bytes)
.map_err(|err| AppError::Processing(format!("Failed to parse PDF: {err}")))?;
let mut page_numbers: Vec<u32> = document.get_pages().keys().copied().collect();
page_numbers.sort_unstable();
Ok(page_numbers)
})
.await??;
Ok(pages)
}
/// Uses the existing headless Chrome dependency to rasterize the requested PDF pages into PNGs.
pub(super) async fn render_pdf_pages(
file_path: &Path,
pages: &[u32],
) -> Result<Vec<Vec<u8>>, AppError> {
let file_path = file_path.to_path_buf();
let pages = pages.to_vec();
let page_numbers = pages.clone();
let captures =
tokio::task::spawn_blocking(move || render_pdf_pages_inner(&file_path, &pages)).await??;
for (page_number, png) in page_numbers.iter().zip(captures.iter()) {
if let Err(err) = maybe_dump_debug_image(*page_number, png).await {
warn!(
page = page_number,
error = %err,
"Failed to write debug screenshot to disk"
);
}
}
Ok(captures)
}
fn render_pdf_pages_inner(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>>, AppError> {
let file_url = url::Url::from_file_path(file_path)
.map_err(|()| AppError::Processing("Unable to construct PDF file URL".into()))?;
let browser = launch_browser()?;
let tab = browser
.new_tab()
.map_err(|err| AppError::Processing(format!("Failed to create Chrome tab: {err}")))?;
tab.set_default_timeout(Duration::from_secs(10));
configure_tab(&tab)?;
set_pdf_viewport(&tab)?;
let mut captures = Vec::with_capacity(pages.len());
for page in pages.iter().copied() {
let target = format!("{file_url}#page={page}&toolbar=0&statusbar=0&zoom=page-fit");
tab.navigate_to(&target)
.map_err(|err| AppError::Processing(format!("Failed to navigate to PDF page: {err}")))?
.wait_until_navigated()
.map_err(|err| AppError::Processing(format!("Navigation to PDF page failed: {err}")))?;
let mut loaded = false;
for attempt in 0..NAVIGATION_RETRY_ATTEMPTS {
if tab
.wait_for_element("embed, canvas, body")
.map(|_| ())
.is_ok()
{
loaded = true;
break;
}
if attempt < NAVIGATION_RETRY_ATTEMPTS.saturating_sub(1) {
std::thread::sleep(Duration::from_millis(NAVIGATION_RETRY_INTERVAL_MS));
}
}
if !loaded {
return Err(AppError::Processing(
"Timed out waiting for Chrome to render PDF page".into(),
));
}
wait_for_pdf_ready(&tab, page)?;
std::thread::sleep(Duration::from_millis(350));
prepare_pdf_viewer(&tab, page);
let mut viewport: Option<Page::Viewport> = None;
for attempt in 0..CANVAS_VIEWPORT_ATTEMPTS {
match canvas_viewport_for_page(&tab, page) {
Ok(Some(vp)) => {
viewport = Some(vp);
break;
}
Ok(None) => {
if attempt < CANVAS_VIEWPORT_ATTEMPTS.saturating_sub(1) {
std::thread::sleep(Duration::from_millis(CANVAS_VIEWPORT_WAIT_MS));
}
}
Err(err) => {
warn!(page, error = %err, "Failed to derive canvas viewport");
break;
}
}
}
let png = if let Some(clip) = viewport {
match tab.call_method(Page::CaptureScreenshot {
format: Some(Page::CaptureScreenshotFormatOption::Png),
quality: None,
clip: Some(clip),
from_surface: Some(true),
capture_beyond_viewport: Some(true),
optimize_for_speed: Some(false),
}) {
Ok(data) => match STANDARD.decode(data.data) {
Ok(bytes) => bytes,
Err(err) => {
warn!(error = %err, page, "Failed to decode clipped screenshot; falling back to full page capture");
capture_full_page_png(&tab)?
}
},
Err(err) => {
warn!(error = %err, page, "Clipped screenshot failed; falling back to full page capture");
capture_full_page_png(&tab)?
}
}
} else {
warn!(
page,
"Unable to determine canvas viewport; capturing full page"
);
capture_full_page_png(&tab)?
};
debug!(page, bytes = png.len(), "Captured PDF page screenshot");
if is_suspicious_image(png.len()) {
warn!(
page,
bytes = png.len(),
"Screenshot size below threshold; check rendering output"
);
}
captures.push(png);
}
Ok(captures)
}
fn configure_tab(tab: &headless_chrome::Tab) -> Result<(), AppError> {
tab.call_method(Emulation::SetDefaultBackgroundColorOverride {
color: Some(DOM::RGBA {
r: 255,
g: 255,
b: 255,
a: Some(1.0),
}),
})
.map_err(|err| {
AppError::Processing(format!("Failed to configure Chrome page background: {err}"))
})?;
Ok(())
}
fn set_pdf_viewport(tab: &headless_chrome::Tab) -> Result<(), AppError> {
tab.call_method(Emulation::SetDeviceMetricsOverride {
width: DEFAULT_VIEWPORT_WIDTH,
height: DEFAULT_VIEWPORT_HEIGHT,
device_scale_factor: DEFAULT_DEVICE_SCALE_FACTOR,
mobile: false,
scale: None,
screen_width: Some(DEFAULT_VIEWPORT_WIDTH),
screen_height: Some(DEFAULT_VIEWPORT_HEIGHT),
position_x: None,
position_y: None,
dont_set_visible_size: Some(false),
screen_orientation: None,
viewport: None,
display_feature: None,
device_posture: None,
})
.map_err(|err| AppError::Processing(format!("Failed to configure Chrome viewport: {err}")))?;
tab.call_method(Emulation::SetVisibleSize {
width: DEFAULT_VIEWPORT_WIDTH,
height: DEFAULT_VIEWPORT_HEIGHT,
})
.map_err(|err| AppError::Processing(format!("Failed to apply Chrome visible size: {err}")))?;
Ok(())
}
fn wait_for_pdf_ready(
tab: &headless_chrome::Tab,
page_number: u32,
) -> Result<headless_chrome::Element<'_>, AppError> {
let embed_selector = "embed[type='application/pdf']";
let element = tab
.wait_for_element_with_custom_timeout(embed_selector, Duration::from_secs(8))
.or_else(|_| tab.wait_for_element_with_custom_timeout("embed", Duration::from_secs(8)))
.map_err(|err| AppError::Processing(format!("Timed out waiting for PDF content: {err}")))?;
if let Err(err) = element.scroll_into_view() {
debug!("Failed to scroll PDF element into view: {err}");
}
debug!(page = page_number, "PDF viewer element located");
Ok(element)
}
fn prepare_pdf_viewer(tab: &headless_chrome::Tab, page_number: u32) {
let script = format!(
r#"(function() {{
const embed = document.querySelector('embed[type="application/pdf"]') || document.querySelector('embed');
if (!embed || !embed.shadowRoot) return false;
const viewer = embed.shadowRoot.querySelector('pdf-viewer');
if (!viewer || !viewer.shadowRoot) return false;
const app = viewer.shadowRoot.querySelector('viewer-app');
if (app && app.shadowRoot) {{
const toolbar = app.shadowRoot.querySelector('#toolbar');
if (toolbar) {{ toolbar.style.display = 'none'; }}
}}
const page = viewer.shadowRoot.querySelector('viewer-page:nth-of-type({page_number})');
if (page && page.scrollIntoView) {{
page.scrollIntoView({{ block: 'start', inline: 'center' }});
}}
const canvas = viewer.shadowRoot.querySelector('canvas[aria-label="Page {page_number}"]');
return !!canvas;
}})()"#
);
match tab.evaluate(&script, false) {
Ok(result) => {
let ready = result
.value
.as_ref()
.and_then(Value::as_bool)
.unwrap_or(false);
debug!(page = page_number, ready, "Prepared PDF viewer page");
}
Err(err) => {
debug!(page = page_number, error = %err, "Unable to run PDF viewer preparation script");
}
}
}
fn canvas_viewport_for_page(
tab: &headless_chrome::Tab,
page_number: u32,
) -> Result<Option<Page::Viewport>, AppError> {
let script = format!(
r#"(function() {{
const embed = document.querySelector('embed[type="application/pdf"]') || document.querySelector('embed');
if (!embed || !embed.shadowRoot) return null;
const viewer = embed.shadowRoot.querySelector('pdf-viewer');
if (!viewer || !viewer.shadowRoot) return null;
const canvas = viewer.shadowRoot.querySelector('canvas[aria-label="Page {page_number}"]');
if (!canvas) return null;
const rect = canvas.getBoundingClientRect();
return {{ x: rect.x, y: rect.y, width: rect.width, height: rect.height }};
}})()"#
);
let result = tab
.evaluate(&script, false)
.map_err(|err| AppError::Processing(format!("Failed to inspect PDF canvas: {err}")))?;
let Some(value) = result.value else {
return Ok(None);
};
if value.is_null() {
return Ok(None);
}
let x = value
.get("x")
.and_then(Value::as_f64)
.unwrap_or_default()
.max(0.0);
let y = value
.get("y")
.and_then(Value::as_f64)
.unwrap_or_default()
.max(0.0);
let width = value
.get("width")
.and_then(Value::as_f64)
.unwrap_or_default();
let height = value
.get("height")
.and_then(Value::as_f64)
.unwrap_or_default();
if width <= 0.0 || height <= 0.0 {
return Ok(None);
}
debug!(
page = page_number,
x, y, width, height, "Derived canvas viewport"
);
Ok(Some(Page::Viewport {
x,
y,
width,
height,
scale: 1.0,
}))
}
fn capture_full_page_png(tab: &headless_chrome::Tab) -> Result<Vec<u8>, AppError> {
let screenshot = tab
.call_method(Page::CaptureScreenshot {
format: Some(Page::CaptureScreenshotFormatOption::Png),
quality: None,
clip: None,
from_surface: Some(true),
capture_beyond_viewport: Some(true),
optimize_for_speed: Some(false),
})
.map_err(|err| {
AppError::Processing(format!("Failed to capture PDF page (fallback): {err}"))
})?;
STANDARD.decode(screenshot.data).map_err(|err| {
AppError::Processing(format!("Failed to decode PDF screenshot (fallback): {err}"))
})
}
const fn is_suspicious_image(len: usize) -> bool {
len < MIN_PAGE_IMAGE_BYTES
}
fn debug_dump_directory() -> Option<PathBuf> {
std::env::var(DEBUG_IMAGE_ENV_VAR)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.map(PathBuf::from)
}
async fn maybe_dump_debug_image(page_index: u32, bytes: &[u8]) -> Result<(), AppError> {
if let Some(dir) = debug_dump_directory() {
tokio::fs::create_dir_all(&dir).await?;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let file_path = dir.join(format!("page-{page_index:04}-{timestamp}.png"));
tokio::fs::write(&file_path, bytes).await?;
debug!(?file_path, size = bytes.len(), "Wrote PDF debug screenshot");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::{self};
#[test]
fn test_debug_dump_directory_env_var() -> anyhow::Result<()> {
std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
assert!(debug_dump_directory().is_none());
std::env::set_var(DEBUG_IMAGE_ENV_VAR, "/tmp/minne_pdf_debug");
let dir =
debug_dump_directory().ok_or_else(|| anyhow::anyhow!("expected debug directory"))?;
assert_eq!(dir, PathBuf::from("/tmp/minne_pdf_debug"));
std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
Ok(())
}
#[test]
fn test_is_suspicious_image_threshold() {
assert!(is_suspicious_image(0));
assert!(is_suspicious_image(MIN_PAGE_IMAGE_BYTES - 1));
assert!(!is_suspicious_image(MIN_PAGE_IMAGE_BYTES + 1));
}
}
+137
View File
@@ -0,0 +1,137 @@
//! Fast-path PDF text extraction and Markdown reflow heuristics.
//!
//! These are pure (non-IO, non-Chrome) helpers used before falling back to the
//! vision pipeline, plus the Markdown normalization applied to both paths.
use common::error::AppError;
const FAST_PATH_MIN_LEN: usize = 150;
const FAST_PATH_MIN_ASCII_RATIO: f64 = 0.7;
/// Runs `pdf-extract` on the PDF bytes and validates the result with simple heuristics.
/// Returns `Ok(None)` when the text layer is missing or too noisy.
pub(super) async fn try_fast_path(pdf_bytes: Vec<u8>) -> Result<Option<String>, AppError> {
let extraction = tokio::task::spawn_blocking(move || {
pdf_extract::extract_text_from_mem(&pdf_bytes).map(|s| s.trim().to_string())
})
.await?
.map_err(|err| AppError::Processing(format!("Failed to extract text from PDF: {err}")))?;
if extraction.is_empty() {
return Ok(None);
}
if !looks_good_enough(&extraction) {
return Ok(None);
}
Ok(Some(normalize_fast_text(&extraction)))
}
/// Heuristic that determines whether the fast-path text looks like well-formed prose.
#[allow(clippy::cast_precision_loss)]
fn looks_good_enough(text: &str) -> bool {
if text.len() < FAST_PATH_MIN_LEN {
return false;
}
let total_chars = text.chars().count() as f64;
if total_chars == 0.0 {
return false;
}
let ascii_chars = text.chars().filter(char::is_ascii).count() as f64;
let ascii_ratio = ascii_chars / total_chars;
if ascii_ratio < FAST_PATH_MIN_ASCII_RATIO {
return false;
}
let letters = text.chars().filter(|c| c.is_alphabetic()).count() as f64;
let letter_ratio = letters / total_chars;
letter_ratio > 0.3
}
/// Normalizes fast-path output so downstream consumers see consistent Markdown.
fn normalize_fast_text(text: &str) -> String {
reflow_markdown(text)
}
/// Cleans, trims, and reflows Markdown created by the LLM path.
pub(super) fn post_process(markdown: &str) -> String {
let cleaned = markdown.replace('\r', "");
let trimmed = cleaned.trim();
reflow_markdown(trimmed)
}
/// Joins hard-wrapped paragraph text while preserving structural Markdown lines.
fn reflow_markdown(input: &str) -> String {
let mut paragraphs = Vec::new();
let mut buffer: Vec<String> = Vec::new();
for line in input.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
buffer.clear();
}
continue;
}
if is_structural_line(trimmed) {
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
buffer.clear();
}
paragraphs.push(trimmed.to_string());
continue;
}
buffer.push(trimmed.to_string());
}
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
}
paragraphs.join("\n\n")
}
/// Detects whether a line is structural Markdown that should remain on its own.
fn is_structural_line(line: &str) -> bool {
let lowered = line.to_ascii_lowercase();
line.starts_with('#')
|| line.starts_with('-')
|| line.starts_with('*')
|| line.starts_with('>')
|| line.starts_with("```")
|| line.starts_with('~')
|| line.starts_with("| ")
|| line.starts_with("+-")
|| lowered.chars().next().is_some_and(|c| c.is_ascii_digit()) && lowered.contains('.')
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_looks_good_enough_short_text() {
assert!(!looks_good_enough("too short"));
}
#[test]
fn test_looks_good_enough_ascii_text() {
let text = "This is a reasonably long ASCII text that should pass the heuristic. \
It contains multiple sentences and a decent amount of letters to satisfy the threshold.";
assert!(looks_good_enough(text));
}
#[test]
fn test_reflow_markdown_preserves_lists() {
let input = "Item one\nItem two\n\n- Bullet\n- Another";
let output = reflow_markdown(input);
assert!(output.contains("Item one Item two"));
assert!(output.contains("- Bullet"));
}
}
+226
View File
@@ -0,0 +1,226 @@
//! Vision-LLM transcription of rendered PDF pages into Markdown.
use async_openai::types::{
ChatCompletionRequestMessageContentPartImageArgs,
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ImageDetail, ImageUrlArgs,
};
use base64::{engine::general_purpose::STANDARD, Engine as _};
use tracing::{debug, warn};
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
const PAGES_PER_VISION_CHUNK: usize = 4;
const MAX_VISION_ATTEMPTS: usize = 2;
const PDF_MARKDOWN_PROMPT: &str = "Convert these PDF pages to clean Markdown. Preserve headings, lists, tables, blockquotes, code fences, and inline formatting. Keep the original reading order, avoid commentary, and do NOT wrap the entire response in a Markdown code block.";
const PDF_MARKDOWN_PROMPT_RETRY: &str = "You must transcribe the provided PDF page images into accurate Markdown. The images are already supplied, so do not respond that you cannot view them. Extract all visible text, tables, and structure, and do NOT wrap the overall response in a Markdown code block.";
/// Sends rendered pages to the configured multimodal model in batches and stitches the
/// resulting Markdown chunks together.
pub(super) async fn vision_markdown(
rendered_pages: Vec<Vec<u8>>,
db: &SurrealDbClient,
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<String, AppError> {
let settings = SystemSettings::get_current(db).await?;
let model = settings.image_processing_model;
debug!(
pages = rendered_pages.len(),
"Preparing vision batches for PDF conversion"
);
let mut markdown_sections = Vec::with_capacity(rendered_pages.len());
for (batch_idx, chunk) in rendered_pages.chunks(PAGES_PER_VISION_CHUNK).enumerate() {
let encoded_images = encode_batch(batch_idx, chunk);
let markdown = transcribe_batch(client, &model, batch_idx, &encoded_images).await?;
markdown_sections.push(markdown);
}
Ok(markdown_sections.join("\n\n"))
}
/// Base64-encodes one batch of page images, warning on suspiciously tiny payloads.
fn encode_batch(batch_idx: usize, chunk: &[Vec<u8>]) -> Vec<String> {
let total_image_bytes: usize = chunk.iter().map(Vec::len).sum();
debug!(
batch = batch_idx,
pages = chunk.len(),
bytes = total_image_bytes,
"Encoding PDF images for vision batch"
);
chunk
.iter()
.enumerate()
.map(|(idx, png_bytes)| {
let encoded = STANDARD.encode(png_bytes);
if encoded.len() < 80 {
warn!(
batch = batch_idx,
page_index = idx,
encoded_bytes = encoded.len(),
"Encoded PDF image payload unusually small"
);
}
encoded
})
.collect()
}
/// Requests Markdown for a single batch, retrying with a stronger prompt on low-quality output.
async fn transcribe_batch(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
model: &str,
batch_idx: usize,
encoded_images: &[String],
) -> Result<String, AppError> {
let last_attempt = MAX_VISION_ATTEMPTS.saturating_sub(1);
for attempt in 0..MAX_VISION_ATTEMPTS {
let request = build_request(model, prompt_for_attempt(attempt), encoded_images)?;
let response = client.chat().create(request).await?;
let Some(choice) = response.choices.first() else {
warn!(
batch = batch_idx,
attempt, "Vision response contained zero choices"
);
continue;
};
let Some(content) = choice.message.content.as_ref() else {
warn!(
batch = batch_idx,
attempt, "Vision response missing content field"
);
continue;
};
debug!(
batch = batch_idx,
attempt,
response_chars = content.len(),
"Received Markdown response for PDF batch"
);
log_preview(batch_idx, attempt, content);
if is_low_quality_response(content) {
warn!(
batch = batch_idx,
attempt, "Vision model returned low quality response"
);
if attempt == last_attempt {
return Err(AppError::Processing(
"Vision model failed to transcribe PDF page contents".into(),
));
}
continue;
}
return Ok(content.trim().to_string());
}
Err(AppError::Processing(
"Vision model did not return usable Markdown".into(),
))
}
/// Builds the chat-completion request carrying the prompt and the batch's images.
fn build_request(
model: &str,
prompt_text: &str,
encoded_images: &[String],
) -> Result<CreateChatCompletionRequest, AppError> {
let mut content_parts = Vec::with_capacity(encoded_images.len().saturating_add(1));
content_parts.push(
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(prompt_text)
.build()?
.into(),
);
for encoded in encoded_images {
let image_url = format!("data:image/png;base64,{encoded}");
content_parts.push(
ChatCompletionRequestMessageContentPartImageArgs::default()
.image_url(
ImageUrlArgs::default()
.url(image_url)
.detail(ImageDetail::High)
.build()?,
)
.build()?
.into(),
);
}
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages([ChatCompletionRequestUserMessageArgs::default()
.content(content_parts)
.build()?
.into()])
.build()?;
Ok(request)
}
/// Logs a truncated preview of a model response at debug level.
fn log_preview(batch_idx: usize, attempt: usize, content: &str) {
let preview: String = if content.len() > 500 {
let mut snippet = content.chars().take(500).collect::<String>();
snippet.push('…');
snippet
} else {
content.to_string()
};
debug!(batch = batch_idx, attempt, preview = %preview, "Vision response content preview");
}
fn is_low_quality_response(content: &str) -> bool {
let trimmed = content.trim();
if trimmed.is_empty() {
return true;
}
let lowered = trimmed.to_ascii_lowercase();
lowered.contains("unable to") || lowered.contains("cannot")
}
const fn prompt_for_attempt(attempt: usize) -> &'static str {
if attempt == 0 {
PDF_MARKDOWN_PROMPT
} else {
PDF_MARKDOWN_PROMPT_RETRY
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_low_quality_response_detection() {
assert!(is_low_quality_response(""));
assert!(is_low_quality_response("I'm unable to help."));
assert!(is_low_quality_response("I cannot read this."));
assert!(!is_low_quality_response("# Heading\nValid content"));
}
#[test]
fn test_prompt_for_attempt_variants() {
assert_eq!(prompt_for_attempt(0), PDF_MARKDOWN_PROMPT);
assert_eq!(prompt_for_attempt(1), PDF_MARKDOWN_PROMPT_RETRY);
assert_eq!(prompt_for_attempt(5), PDF_MARKDOWN_PROMPT_RETRY);
}
#[test]
fn test_markdown_prompts_discourage_code_blocks() {
assert!(!PDF_MARKDOWN_PROMPT.contains("```"));
assert!(!PDF_MARKDOWN_PROMPT_RETRY.contains("```"));
}
}
@@ -1,801 +0,0 @@
use std::{
path::{Path, PathBuf},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use async_openai::types::{
ChatCompletionRequestMessageContentPartImageArgs,
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequestArgs, ImageDetail, ImageUrlArgs,
};
use base64::{engine::general_purpose::STANDARD, Engine as _};
use headless_chrome::{
protocol::cdp::{Emulation, Page, DOM},
Browser,
};
use lopdf::Document;
use serde_json::Value;
use tracing::{debug, warn};
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
utils::config::PdfIngestMode,
};
const FAST_PATH_MIN_LEN: usize = 150;
const FAST_PATH_MIN_ASCII_RATIO: f64 = 0.7;
const MAX_VISION_PAGES: usize = 50;
const PAGES_PER_VISION_CHUNK: usize = 4;
const MAX_VISION_ATTEMPTS: usize = 2;
const PDF_MARKDOWN_PROMPT: &str = "Convert these PDF pages to clean Markdown. Preserve headings, lists, tables, blockquotes, code fences, and inline formatting. Keep the original reading order, avoid commentary, and do NOT wrap the entire response in a Markdown code block.";
const PDF_MARKDOWN_PROMPT_RETRY: &str = "You must transcribe the provided PDF page images into accurate Markdown. The images are already supplied, so do not respond that you cannot view them. Extract all visible text, tables, and structure, and do NOT wrap the overall response in a Markdown code block.";
const NAVIGATION_RETRY_INTERVAL_MS: u64 = 120;
const NAVIGATION_RETRY_ATTEMPTS: usize = 10;
const MIN_PAGE_IMAGE_BYTES: usize = 1_024;
const DEFAULT_VIEWPORT_WIDTH: u32 = 1_248; // generous width to reduce horizontal clipping
const DEFAULT_VIEWPORT_HEIGHT: u32 = 1_800; // tall enough to capture full page at fit-to-width scale
const DEFAULT_DEVICE_SCALE_FACTOR: f64 = 1.0;
const CANVAS_VIEWPORT_ATTEMPTS: usize = 12;
const CANVAS_VIEWPORT_WAIT_MS: u64 = 200;
const DEBUG_IMAGE_ENV_VAR: &str = "MINNE_PDF_DEBUG_DIR";
/// Attempts to extract PDF content, using a fast text layer first and falling back to
/// rendering the document for a vision-enabled LLM when needed.
pub async fn extract_pdf_content(
file_path: &Path,
db: &SurrealDbClient,
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
mode: &PdfIngestMode,
) -> Result<String, AppError> {
let pdf_bytes = tokio::fs::read(file_path).await?;
if let Some(candidate) = try_fast_path(pdf_bytes.clone()).await? {
return Ok(candidate);
}
if matches!(mode, PdfIngestMode::Classic) {
return Err(AppError::Processing(
"PDF text extraction failed and LLM-first mode is disabled".into(),
));
}
let page_numbers = load_page_numbers(pdf_bytes.clone()).await?;
if page_numbers.is_empty() {
return Err(AppError::Processing("PDF appears to have no pages".into()));
}
if page_numbers.len() > MAX_VISION_PAGES {
return Err(AppError::Processing(format!(
"PDF has {} pages which exceeds the configured vision processing limit of {}",
page_numbers.len(),
MAX_VISION_PAGES
)));
}
let rendered_pages = render_pdf_pages(file_path, &page_numbers).await?;
let combined_markdown = vision_markdown(rendered_pages, db, client).await?;
Ok(post_process(&combined_markdown))
}
/// Runs `pdf-extract` on the PDF bytes and validates the result with simple heuristics.
/// Returns `Ok(None)` when the text layer is missing or too noisy.
async fn try_fast_path(pdf_bytes: Vec<u8>) -> Result<Option<String>, AppError> {
let extraction = tokio::task::spawn_blocking(move || {
pdf_extract::extract_text_from_mem(&pdf_bytes).map(|s| s.trim().to_string())
})
.await?
.map_err(|err| AppError::Processing(format!("Failed to extract text from PDF: {err}")))?;
if extraction.is_empty() {
return Ok(None);
}
if !looks_good_enough(&extraction) {
return Ok(None);
}
Ok(Some(normalize_fast_text(&extraction)))
}
/// Parses the PDF structure to discover the available page numbers while keeping work off
/// the async executor.
async fn load_page_numbers(pdf_bytes: Vec<u8>) -> Result<Vec<u32>, AppError> {
let pages = tokio::task::spawn_blocking(move || -> Result<Vec<u32>, AppError> {
let document = Document::load_mem(&pdf_bytes)
.map_err(|err| AppError::Processing(format!("Failed to parse PDF: {err}")))?;
let mut page_numbers: Vec<u32> = document.get_pages().keys().copied().collect();
page_numbers.sort_unstable();
Ok(page_numbers)
})
.await??;
Ok(pages)
}
/// Uses the existing headless Chrome dependency to rasterize the requested PDF pages into PNGs.
async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>>, AppError> {
let file_path = file_path.to_path_buf();
let pages = pages.to_vec();
let page_numbers = pages.clone();
let captures = tokio::task::spawn_blocking(move || {
render_pdf_pages_inner(&file_path, &pages)
})
.await??;
for (page_number, png) in page_numbers.iter().zip(captures.iter()) {
if let Err(err) = maybe_dump_debug_image(*page_number, png).await {
warn!(
page = page_number,
error = %err,
"Failed to write debug screenshot to disk"
);
}
}
Ok(captures)
}
fn render_pdf_pages_inner(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>>, AppError> {
let file_url = url::Url::from_file_path(file_path)
.map_err(|()| AppError::Processing("Unable to construct PDF file URL".into()))?;
let browser = create_browser()?;
let tab = browser
.new_tab()
.map_err(|err| AppError::Processing(format!("Failed to create Chrome tab: {err}")))?;
tab.set_default_timeout(Duration::from_secs(10));
configure_tab(&tab)?;
set_pdf_viewport(&tab)?;
let mut captures = Vec::with_capacity(pages.len());
for page in pages.iter().copied() {
let target = format!("{file_url}#page={page}&toolbar=0&statusbar=0&zoom=page-fit");
tab.navigate_to(&target)
.map_err(|err| AppError::Processing(format!("Failed to navigate to PDF page: {err}")))?
.wait_until_navigated()
.map_err(|err| AppError::Processing(format!("Navigation to PDF page failed: {err}")))?;
let mut loaded = false;
for attempt in 0..NAVIGATION_RETRY_ATTEMPTS {
if tab
.wait_for_element("embed, canvas, body")
.map(|_| ())
.is_ok()
{
loaded = true;
break;
}
if attempt < NAVIGATION_RETRY_ATTEMPTS.saturating_sub(1) {
std::thread::sleep(Duration::from_millis(NAVIGATION_RETRY_INTERVAL_MS));
}
}
if !loaded {
return Err(AppError::Processing(
"Timed out waiting for Chrome to render PDF page".into(),
));
}
wait_for_pdf_ready(&tab, page)?;
std::thread::sleep(Duration::from_millis(350));
prepare_pdf_viewer(&tab, page);
let mut viewport: Option<Page::Viewport> = None;
for attempt in 0..CANVAS_VIEWPORT_ATTEMPTS {
match canvas_viewport_for_page(&tab, page) {
Ok(Some(vp)) => {
viewport = Some(vp);
break;
}
Ok(None) => {
if attempt < CANVAS_VIEWPORT_ATTEMPTS.saturating_sub(1) {
std::thread::sleep(Duration::from_millis(CANVAS_VIEWPORT_WAIT_MS));
}
}
Err(err) => {
warn!(page, error = %err, "Failed to derive canvas viewport");
break;
}
}
}
let png = if let Some(clip) = viewport {
match tab.call_method(Page::CaptureScreenshot {
format: Some(Page::CaptureScreenshotFormatOption::Png),
quality: None,
clip: Some(clip),
from_surface: Some(true),
capture_beyond_viewport: Some(true),
optimize_for_speed: Some(false),
}) {
Ok(data) => match STANDARD.decode(data.data) {
Ok(bytes) => bytes,
Err(err) => {
warn!(error = %err, page, "Failed to decode clipped screenshot; falling back to full page capture");
capture_full_page_png(&tab)?
}
},
Err(err) => {
warn!(error = %err, page, "Clipped screenshot failed; falling back to full page capture");
capture_full_page_png(&tab)?
}
}
} else {
warn!(
page,
"Unable to determine canvas viewport; capturing full page"
);
capture_full_page_png(&tab)?
};
debug!(
page,
bytes = png.len(),
"Captured PDF page screenshot"
);
if is_suspicious_image(png.len()) {
warn!(
page,
bytes = png.len(),
"Screenshot size below threshold; check rendering output"
);
}
captures.push(png);
}
Ok(captures)
}
/// Launches a headless Chrome instance that respects the existing feature flags.
fn create_browser() -> Result<Browser, AppError> {
#[cfg(feature = "docker")]
{
let options = headless_chrome::LaunchOptionsBuilder::default()
.sandbox(false)
.build()
.map_err(|err| AppError::Processing(format!("Failed to launch Chrome: {err}")))?;
Browser::new(options)
.map_err(|err| AppError::Processing(format!("Failed to start Chrome: {err}")))
}
#[cfg(not(feature = "docker"))]
{
Browser::default()
.map_err(|err| AppError::Processing(format!("Failed to start Chrome: {err}")))
}
}
/// Sends one or more rendered pages to the configured multimodal model and stitches the resulting Markdown chunks together.
#[allow(clippy::too_many_lines)]
async fn vision_markdown(
rendered_pages: Vec<Vec<u8>>,
db: &SurrealDbClient,
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<String, AppError> {
let settings = SystemSettings::get_current(db).await?;
let prompt = PDF_MARKDOWN_PROMPT;
debug!(
pages = rendered_pages.len(),
"Preparing vision batches for PDF conversion"
);
let mut markdown_sections = Vec::with_capacity(rendered_pages.len());
for (batch_idx, chunk) in rendered_pages.chunks(PAGES_PER_VISION_CHUNK).enumerate() {
let total_image_bytes: usize = chunk.iter().map(std::vec::Vec::len).sum();
debug!(
batch = batch_idx,
pages = chunk.len(),
bytes = total_image_bytes,
"Encoding PDF images for vision batch"
);
let encoded_images: Vec<String> = chunk
.iter()
.enumerate()
.map(|(idx, png_bytes)| {
let encoded = STANDARD.encode(png_bytes);
if encoded.len() < 80 {
warn!(
batch = batch_idx,
page_index = idx,
encoded_bytes = encoded.len(),
"Encoded PDF image payload unusually small"
);
}
encoded
})
.collect();
let mut batch_markdown: Option<String> = None;
let last_attempt = MAX_VISION_ATTEMPTS.saturating_sub(1);
for attempt in 0..MAX_VISION_ATTEMPTS {
let prompt_text = prompt_for_attempt(attempt, prompt);
let mut content_parts = Vec::with_capacity(encoded_images.len().saturating_add(1));
content_parts.push(
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(prompt_text)
.build()?
.into(),
);
for encoded in &encoded_images {
let image_url = format!("data:image/png;base64,{encoded}");
content_parts.push(
ChatCompletionRequestMessageContentPartImageArgs::default()
.image_url(
ImageUrlArgs::default()
.url(image_url)
.detail(ImageDetail::High)
.build()?,
)
.build()?
.into(),
);
}
let request = CreateChatCompletionRequestArgs::default()
.model(settings.image_processing_model.clone())
.messages([ChatCompletionRequestUserMessageArgs::default()
.content(content_parts)
.build()?
.into()])
.build()?;
let response = client.chat().create(request).await?;
let Some(choice) = response.choices.first() else {
warn!(
batch = batch_idx,
attempt, "Vision response contained zero choices"
);
continue;
};
let Some(content) = choice.message.content.as_ref() else {
warn!(
batch = batch_idx,
attempt, "Vision response missing content field"
);
continue;
};
debug!(
batch = batch_idx,
attempt,
response_chars = content.len(),
"Received Markdown response for PDF batch"
);
let preview: String = if content.len() > 500 {
let mut snippet = content.chars().take(500).collect::<String>();
snippet.push('…');
snippet
} else {
content.clone()
};
debug!(batch = batch_idx, attempt, preview = %preview, "Vision response content preview");
if is_low_quality_response(content) {
warn!(
batch = batch_idx,
attempt, "Vision model returned low quality response"
);
if attempt == last_attempt {
return Err(AppError::Processing(
"Vision model failed to transcribe PDF page contents".into(),
));
}
continue;
}
batch_markdown = Some(content.trim().to_string());
break;
}
if let Some(markdown) = batch_markdown {
markdown_sections.push(markdown);
} else {
return Err(AppError::Processing(
"Vision model did not return usable Markdown".into(),
));
}
}
Ok(markdown_sections.join("\n\n"))
}
/// Heuristic that determines whether the fast-path text looks like well-formed prose.
#[allow(clippy::cast_precision_loss)]
fn looks_good_enough(text: &str) -> bool {
if text.len() < FAST_PATH_MIN_LEN {
return false;
}
let total_chars = text.chars().count() as f64;
if total_chars == 0.0 {
return false;
}
let ascii_chars = text.chars().filter(char::is_ascii).count() as f64;
let ascii_ratio = ascii_chars / total_chars;
if ascii_ratio < FAST_PATH_MIN_ASCII_RATIO {
return false;
}
let letters = text.chars().filter(|c| c.is_alphabetic()).count() as f64;
let letter_ratio = letters / total_chars;
letter_ratio > 0.3
}
/// Normalizes fast-path output so downstream consumers see consistent Markdown.
fn normalize_fast_text(text: &str) -> String {
reflow_markdown(text)
}
/// Cleans, trims, and reflows Markdown created by the LLM path.
fn post_process(markdown: &str) -> String {
let cleaned = markdown.replace('\r', "");
let trimmed = cleaned.trim();
reflow_markdown(trimmed)
}
/// Joins hard-wrapped paragraph text while preserving structural Markdown lines.
fn reflow_markdown(input: &str) -> String {
let mut paragraphs = Vec::new();
let mut buffer: Vec<String> = Vec::new();
for line in input.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
buffer.clear();
}
continue;
}
if is_structural_line(trimmed) {
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
buffer.clear();
}
paragraphs.push(trimmed.to_string());
continue;
}
buffer.push(trimmed.to_string());
}
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
}
paragraphs.join("\n\n")
}
/// Detects whether a line is structural Markdown that should remain on its own.
fn is_structural_line(line: &str) -> bool {
let lowered = line.to_ascii_lowercase();
line.starts_with('#')
|| line.starts_with('-')
|| line.starts_with('*')
|| line.starts_with('>')
|| line.starts_with("```")
|| line.starts_with('~')
|| line.starts_with("| ")
|| line.starts_with("+-")
|| lowered.chars().next().is_some_and(|c| c.is_ascii_digit()) && lowered.contains('.')
}
fn debug_dump_directory() -> Option<PathBuf> {
std::env::var(DEBUG_IMAGE_ENV_VAR)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.map(PathBuf::from)
}
fn configure_tab(tab: &headless_chrome::Tab) -> Result<(), AppError> {
tab.call_method(Emulation::SetDefaultBackgroundColorOverride {
color: Some(DOM::RGBA {
r: 255,
g: 255,
b: 255,
a: Some(1.0),
}),
})
.map_err(|err| {
AppError::Processing(format!("Failed to configure Chrome page background: {err}"))
})?;
Ok(())
}
fn set_pdf_viewport(tab: &headless_chrome::Tab) -> Result<(), AppError> {
tab.call_method(Emulation::SetDeviceMetricsOverride {
width: DEFAULT_VIEWPORT_WIDTH,
height: DEFAULT_VIEWPORT_HEIGHT,
device_scale_factor: DEFAULT_DEVICE_SCALE_FACTOR,
mobile: false,
scale: None,
screen_width: Some(DEFAULT_VIEWPORT_WIDTH),
screen_height: Some(DEFAULT_VIEWPORT_HEIGHT),
position_x: None,
position_y: None,
dont_set_visible_size: Some(false),
screen_orientation: None,
viewport: None,
display_feature: None,
device_posture: None,
})
.map_err(|err| AppError::Processing(format!("Failed to configure Chrome viewport: {err}")))?;
tab.call_method(Emulation::SetVisibleSize {
width: DEFAULT_VIEWPORT_WIDTH,
height: DEFAULT_VIEWPORT_HEIGHT,
})
.map_err(|err| AppError::Processing(format!("Failed to apply Chrome visible size: {err}")))?;
Ok(())
}
fn wait_for_pdf_ready(
tab: &headless_chrome::Tab,
page_number: u32,
) -> Result<headless_chrome::Element<'_>, AppError> {
let embed_selector = "embed[type='application/pdf']";
let element = tab
.wait_for_element_with_custom_timeout(embed_selector, Duration::from_secs(8))
.or_else(|_| tab.wait_for_element_with_custom_timeout("embed", Duration::from_secs(8)))
.map_err(|err| AppError::Processing(format!("Timed out waiting for PDF content: {err}")))?;
if let Err(err) = element.scroll_into_view() {
debug!("Failed to scroll PDF element into view: {err}");
}
debug!(page = page_number, "PDF viewer element located");
Ok(element)
}
fn prepare_pdf_viewer(tab: &headless_chrome::Tab, page_number: u32) {
let script = format!(
r#"(function() {{
const embed = document.querySelector('embed[type="application/pdf"]') || document.querySelector('embed');
if (!embed || !embed.shadowRoot) return false;
const viewer = embed.shadowRoot.querySelector('pdf-viewer');
if (!viewer || !viewer.shadowRoot) return false;
const app = viewer.shadowRoot.querySelector('viewer-app');
if (app && app.shadowRoot) {{
const toolbar = app.shadowRoot.querySelector('#toolbar');
if (toolbar) {{ toolbar.style.display = 'none'; }}
}}
const page = viewer.shadowRoot.querySelector('viewer-page:nth-of-type({page_number})');
if (page && page.scrollIntoView) {{
page.scrollIntoView({{ block: 'start', inline: 'center' }});
}}
const canvas = viewer.shadowRoot.querySelector('canvas[aria-label="Page {page_number}"]');
return !!canvas;
}})()"#
);
match tab.evaluate(&script, false) {
Ok(result) => {
let ready = result
.value
.as_ref()
.and_then(Value::as_bool)
.unwrap_or(false);
debug!(page = page_number, ready, "Prepared PDF viewer page");
}
Err(err) => {
debug!(page = page_number, error = %err, "Unable to run PDF viewer preparation script");
}
}
}
fn canvas_viewport_for_page(
tab: &headless_chrome::Tab,
page_number: u32,
) -> Result<Option<Page::Viewport>, AppError> {
let script = format!(
r#"(function() {{
const embed = document.querySelector('embed[type="application/pdf"]') || document.querySelector('embed');
if (!embed || !embed.shadowRoot) return null;
const viewer = embed.shadowRoot.querySelector('pdf-viewer');
if (!viewer || !viewer.shadowRoot) return null;
const canvas = viewer.shadowRoot.querySelector('canvas[aria-label="Page {page_number}"]');
if (!canvas) return null;
const rect = canvas.getBoundingClientRect();
return {{ x: rect.x, y: rect.y, width: rect.width, height: rect.height }};
}})()"#
);
let result = tab
.evaluate(&script, false)
.map_err(|err| AppError::Processing(format!("Failed to inspect PDF canvas: {err}")))?;
let Some(value) = result.value else {
return Ok(None);
};
if value.is_null() {
return Ok(None);
}
let x = value
.get("x")
.and_then(Value::as_f64)
.unwrap_or_default()
.max(0.0);
let y = value
.get("y")
.and_then(Value::as_f64)
.unwrap_or_default()
.max(0.0);
let width = value
.get("width")
.and_then(Value::as_f64)
.unwrap_or_default();
let height = value
.get("height")
.and_then(Value::as_f64)
.unwrap_or_default();
if width <= 0.0 || height <= 0.0 {
return Ok(None);
}
debug!(
page = page_number,
x, y, width, height, "Derived canvas viewport"
);
Ok(Some(Page::Viewport {
x,
y,
width,
height,
scale: 1.0,
}))
}
fn capture_full_page_png(tab: &headless_chrome::Tab) -> Result<Vec<u8>, AppError> {
let screenshot = tab
.call_method(Page::CaptureScreenshot {
format: Some(Page::CaptureScreenshotFormatOption::Png),
quality: None,
clip: None,
from_surface: Some(true),
capture_beyond_viewport: Some(true),
optimize_for_speed: Some(false),
})
.map_err(|err| {
AppError::Processing(format!("Failed to capture PDF page (fallback): {err}"))
})?;
STANDARD.decode(screenshot.data).map_err(|err| {
AppError::Processing(format!("Failed to decode PDF screenshot (fallback): {err}"))
})
}
const fn is_suspicious_image(len: usize) -> bool {
len < MIN_PAGE_IMAGE_BYTES
}
async fn maybe_dump_debug_image(page_index: u32, bytes: &[u8]) -> Result<(), AppError> {
if let Some(dir) = debug_dump_directory() {
tokio::fs::create_dir_all(&dir).await?;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let file_path = dir.join(format!("page-{page_index:04}-{timestamp}.png"));
tokio::fs::write(&file_path, bytes).await?;
debug!(?file_path, size = bytes.len(), "Wrote PDF debug screenshot");
}
Ok(())
}
fn is_low_quality_response(content: &str) -> bool {
let trimmed = content.trim();
if trimmed.is_empty() {
return true;
}
let lowered = trimmed.to_ascii_lowercase();
lowered.contains("unable to") || lowered.contains("cannot")
}
const fn prompt_for_attempt(attempt: usize, base_prompt: &str) -> &str {
if attempt == 0 {
base_prompt
} else {
PDF_MARKDOWN_PROMPT_RETRY
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::{self};
#[test]
fn test_looks_good_enough_short_text() {
assert!(!looks_good_enough("too short"));
}
#[test]
fn test_looks_good_enough_ascii_text() {
let text = "This is a reasonably long ASCII text that should pass the heuristic. \
It contains multiple sentences and a decent amount of letters to satisfy the threshold.";
assert!(looks_good_enough(text));
}
#[test]
fn test_reflow_markdown_preserves_lists() {
let input = "Item one\nItem two\n\n- Bullet\n- Another";
let output = reflow_markdown(input);
assert!(output.contains("Item one Item two"));
assert!(output.contains("- Bullet"));
}
#[test]
fn test_debug_dump_directory_env_var() -> anyhow::Result<()> {
std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
assert!(debug_dump_directory().is_none());
std::env::set_var(DEBUG_IMAGE_ENV_VAR, "/tmp/minne_pdf_debug");
let dir = debug_dump_directory().ok_or_else(|| anyhow::anyhow!("expected debug directory"))?;
assert_eq!(dir, PathBuf::from("/tmp/minne_pdf_debug"));
std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
Ok(())
}
#[test]
fn test_is_suspicious_image_threshold() {
assert!(is_suspicious_image(0));
assert!(is_suspicious_image(MIN_PAGE_IMAGE_BYTES - 1));
assert!(!is_suspicious_image(MIN_PAGE_IMAGE_BYTES + 1));
}
#[test]
fn test_is_low_quality_response_detection() {
assert!(is_low_quality_response(""));
assert!(is_low_quality_response("I'm unable to help."));
assert!(is_low_quality_response("I cannot read this."));
assert!(!is_low_quality_response("# Heading\nValid content"));
}
#[test]
fn test_prompt_for_attempt_variants() {
assert_eq!(
prompt_for_attempt(0, PDF_MARKDOWN_PROMPT),
PDF_MARKDOWN_PROMPT
);
assert_eq!(
prompt_for_attempt(1, PDF_MARKDOWN_PROMPT),
PDF_MARKDOWN_PROMPT_RETRY
);
assert_eq!(
prompt_for_attempt(5, PDF_MARKDOWN_PROMPT),
PDF_MARKDOWN_PROMPT_RETRY
);
}
#[test]
fn test_markdown_prompts_discourage_code_blocks() {
assert!(!PDF_MARKDOWN_PROMPT.contains("```"));
assert!(!PDF_MARKDOWN_PROMPT_RETRY.contains("```"));
}
}
@@ -6,7 +6,6 @@ use common::{
storage::{db::SurrealDbClient, store::StorageManager, types::file_info::FileInfo}, storage::{db::SurrealDbClient, store::StorageManager, types::file_info::FileInfo},
}; };
use dom_smoothie::{Article, Readability, TextMode}; use dom_smoothie::{Article, Readability, TextMode};
use headless_chrome::Browser;
use std::{ use std::{
io::{Seek, SeekFrom, Write}, io::{Seek, SeekFrom, Write},
net::IpAddr, net::IpAddr,
@@ -23,22 +22,7 @@ pub async fn extract_text_from_url(
info!("Fetching URL: {}", url); info!("Fetching URL: {}", url);
let now = Instant::now(); let now = Instant::now();
let browser = { let browser = crate::utils::browser::launch_browser()?;
#[cfg(feature = "docker")]
{
let options = headless_chrome::LaunchOptionsBuilder::default()
.sandbox(false)
.build()
.map_err(|e| AppError::InternalError(e.to_string()))?;
Browser::new(options)
.map_err(|e| AppError::InternalError(e.to_string()))?
}
#[cfg(not(feature = "docker"))]
{
Browser::default()
.map_err(|e| AppError::InternalError(e.to_string()))?
}
};
let tab = browser let tab = browser
.new_tab() .new_tab()