retrieval simplfied

This commit is contained in:
Per Stark
2025-12-09 20:35:42 +01:00
parent 192e6480e0
commit 8121e04125
55 changed files with 469 additions and 1208 deletions
+1
View File
@@ -5,6 +5,7 @@ use tokio::task::JoinError;
use crate::storage::types::file_info::FileError; use crate::storage::types::file_info::FileError;
// Core internal errors // Core internal errors
#[allow(clippy::module_name_repetitions)]
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum AppError { pub enum AppError {
#[error("Database error: {0}")] #[error("Database error: {0}")]
+2
View File
@@ -1,3 +1,5 @@
#![allow(clippy::doc_markdown)]
//! Shared utilities and storage helpers for the workspace crates.
pub mod error; pub mod error;
pub mod storage; pub mod storage;
pub mod utils; pub mod utils;
+2
View File
@@ -13,12 +13,14 @@ use surrealdb::{
use surrealdb_migrations::MigrationRunner; use surrealdb_migrations::MigrationRunner;
use tracing::debug; use tracing::debug;
/// Embedded SurrealDB migration directory packaged with the crate.
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/"); static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
#[derive(Clone)] #[derive(Clone)]
pub struct SurrealDbClient { pub struct SurrealDbClient {
pub client: Surreal<Any>, pub client: Surreal<Any>,
} }
#[allow(clippy::module_name_repetitions)]
pub trait ProvidesDb { pub trait ProvidesDb {
fn db(&self) -> &Arc<SurrealDbClient>; fn db(&self) -> &Arc<SurrealDbClient>;
} }
+28 -5
View File
@@ -1,3 +1,13 @@
#![allow(
clippy::missing_docs_in_private_items,
clippy::module_name_repetitions,
clippy::items_after_statements,
clippy::arithmetic_side_effects,
clippy::cast_precision_loss,
clippy::redundant_closure_for_method_calls,
clippy::single_match_else,
clippy::uninlined_format_args
)]
use std::time::Duration; use std::time::Duration;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
@@ -234,12 +244,25 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
analyzer = FTS_ANALYZER_NAME analyzer = FTS_ANALYZER_NAME
); );
db.client let res = db
.client
.query(fallback_query) .query(fallback_query)
.await .await
.context("creating fallback FTS analyzer")? .context("creating fallback FTS analyzer")?;
.check()
.context("failed to create fallback FTS analyzer")?; if let Err(err) = res.check() {
warn!(
error = %err,
"Fallback analyzer creation failed; FTS will run without snowball/ascii analyzer ({})",
FTS_ANALYZER_NAME
);
return Err(err).context("failed to create fallback FTS analyzer");
}
warn!(
"Snowball analyzer unavailable; using fallback analyzer ({}) with lowercase+ascii only",
FTS_ANALYZER_NAME
);
Ok(()) Ok(())
} }
@@ -466,7 +489,7 @@ async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
let rows: Vec<CountRow> = response let rows: Vec<CountRow> = response
.take(0) .take(0)
.context("failed to deserialize count() response")?; .context("failed to deserialize count() response")?;
Ok(rows.first().map(|r| r.count).unwrap_or(0)) Ok(rows.first().map_or(0, |r| r.count))
} }
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> { async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
+1 -1
View File
@@ -183,7 +183,7 @@ impl StorageManager {
while current.starts_with(base) && current.as_path() != base.as_path() { while current.starts_with(base) && current.as_path() != base.as_path() {
match tokio::fs::remove_dir(&current).await { match tokio::fs::remove_dir(&current).await {
Ok(_) => {} Ok(()) => {}
Err(err) => match err.kind() { Err(err) => match err.kind() {
ErrorKind::NotFound => {} ErrorKind::NotFound => {}
ErrorKind::DirectoryNotEmpty => break, ErrorKind::DirectoryNotEmpty => break,
+2 -1
View File
@@ -71,6 +71,7 @@ impl Analytics {
// We need to use a direct query for COUNT aggregation // We need to use a direct query for COUNT aggregation
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct CountResult { struct CountResult {
/// Total user count.
count: i64, count: i64,
} }
@@ -81,7 +82,7 @@ impl Analytics {
.await? .await?
.take(0)?; .take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0)) Ok(result.map_or(0, |r| r.count))
} }
} }
+23 -17
View File
@@ -3,12 +3,10 @@ use bytes;
use mime_guess::from_path; use mime_guess::from_path;
use object_store::Error as ObjectStoreError; use object_store::Error as ObjectStoreError;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::{ use std::{io::{BufReader, Read}, path::Path};
io::{BufReader, Read},
path::Path,
};
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use thiserror::Error; use thiserror::Error;
use tokio::task;
use tracing::info; use tracing::info;
use uuid::Uuid; use uuid::Uuid;
@@ -71,21 +69,29 @@ impl FileInfo {
/// ///
/// # Returns /// # Returns
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error. /// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
#[allow(clippy::indexing_slicing)]
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> { async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
let mut reader = BufReader::new(file.as_file()); let mut file_clone = file.as_file().try_clone()?;
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192]; // 8KB buffer
loop { let digest = task::spawn_blocking(move || -> Result<_, std::io::Error> {
let n = reader.read(&mut buffer)?; let mut reader = BufReader::new(&mut file_clone);
if n == 0 { let mut hasher = Sha256::new();
break; let mut buffer = [0u8; 8192]; // 8KB buffer
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
hasher.update(&buffer[..n]);
} }
hasher.update(&buffer[..n]);
}
let digest = hasher.finalize(); Ok::<_, std::io::Error>(hasher.finalize())
Ok(format!("{:x}", digest)) })
.await
.map_err(std::io::Error::other)??;
Ok(format!("{digest:x}"))
} }
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal. /// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
@@ -103,7 +109,7 @@ impl FileInfo {
} }
}) })
.collect(); .collect();
format!("{}{}", sanitized_name, ext) format!("{sanitized_name}{ext}")
} else { } else {
// No extension // No extension
file_name file_name
@@ -292,7 +298,7 @@ impl FileInfo {
storage: &StorageManager, storage: &StorageManager,
) -> Result<String, FileError> { ) -> Result<String, FileError> {
// Logical object location relative to the store root // Logical object location relative to the store root
let location = format!("{}/{}/{}", user_id, uuid, file_name); let location = format!("{user_id}/{uuid}/{file_name}");
info!("Persisting to object location: {}", location); info!("Persisting to object location: {}", location);
let bytes = tokio::fs::read(file.path()).await?; let bytes = tokio::fs::read(file.path()).await?;
@@ -1,3 +1,9 @@
#![allow(
clippy::result_large_err,
clippy::needless_pass_by_value,
clippy::implicit_clone,
clippy::semicolon_if_nothing_returned
)]
use crate::{error::AppError, storage::types::file_info::FileInfo}; use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::info; use tracing::info;
@@ -38,6 +44,7 @@ impl IngestionPayload {
/// # Returns /// # Returns
/// * `Result<Vec<IngestionPayload>, AppError>` - On success, returns a vector of ingress objects /// * `Result<Vec<IngestionPayload>, AppError>` - On success, returns a vector of ingress objects
/// (one per file/content type). On failure, returns an `AppError`. /// (one per file/content type). On failure, returns an `AppError`.
#[allow(clippy::similar_names)]
pub fn create_ingestion_payload( pub fn create_ingestion_payload(
content: Option<String>, content: Option<String>,
context: String, context: String,
@@ -1,3 +1,12 @@
#![allow(
clippy::cast_possible_wrap,
clippy::items_after_statements,
clippy::arithmetic_side_effects,
clippy::cast_sign_loss,
clippy::missing_docs_in_private_items,
clippy::trivially_copy_pass_by_ref,
clippy::expect_used
)]
use std::time::Duration; use std::time::Duration;
use chrono::Duration as ChronoDuration; use chrono::Duration as ChronoDuration;
@@ -1,3 +1,14 @@
#![allow(
clippy::missing_docs_in_private_items,
clippy::module_name_repetitions,
clippy::match_same_arms,
clippy::format_push_string,
clippy::uninlined_format_args,
clippy::explicit_iter_loop,
clippy::items_after_statements,
clippy::get_first,
clippy::redundant_closure_for_method_calls
)]
use std::collections::HashMap; use std::collections::HashMap;
use crate::{ use crate::{
@@ -72,7 +72,7 @@ impl KnowledgeEntityEmbedding {
return Ok(HashMap::new()); return Ok(HashMap::new());
} }
let ids_list: Vec<RecordId> = entity_ids.iter().cloned().collect(); let ids_list: Vec<RecordId> = entity_ids.to_vec();
let query = format!( let query = format!(
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids", "SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
@@ -110,6 +110,7 @@ impl KnowledgeEntityEmbedding {
} }
/// Delete embeddings by source_id (via joining to knowledge_entity table) /// Delete embeddings by source_id (via joining to knowledge_entity table)
#[allow(clippy::items_after_statements)]
pub async fn delete_by_source_id( pub async fn delete_by_source_id(
source_id: &str, source_id: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
@@ -121,6 +122,7 @@ impl KnowledgeEntityEmbedding {
.bind(("source_id", source_id.to_owned())) .bind(("source_id", source_id.to_owned()))
.await .await
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)] #[derive(Deserialize)]
struct IdRow { struct IdRow {
id: RecordId, id: RecordId,
@@ -65,8 +65,7 @@ impl KnowledgeRelationship {
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let query = format!( let query = format!(
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'", "DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'"
source_id
); );
db_client.query(query).await?; db_client.query(query).await?;
@@ -81,15 +80,14 @@ impl KnowledgeRelationship {
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let mut authorized_result = db_client let mut authorized_result = db_client
.query(format!( .query(format!(
"SELECT * FROM relates_to WHERE id = relates_to:`{}` AND metadata.user_id = '{}'", "SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'"
id, user_id
)) ))
.await?; .await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default(); let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
if authorized.is_empty() { if authorized.is_empty() {
let mut exists_result = db_client let mut exists_result = db_client
.query(format!("SELECT * FROM relates_to:`{}`", id)) .query(format!("SELECT * FROM relates_to:`{id}`"))
.await?; .await?;
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?; let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
@@ -98,11 +96,11 @@ impl KnowledgeRelationship {
"Not authorized to delete relationship".into(), "Not authorized to delete relationship".into(),
)) ))
} else { } else {
Err(AppError::NotFound(format!("Relationship {} not found", id))) Err(AppError::NotFound(format!("Relationship {id} not found")))
} }
} else { } else {
db_client db_client
.query(format!("DELETE relates_to:`{}`", id)) .query(format!("DELETE relates_to:`{id}`"))
.await?; .await?;
Ok(()) Ok(())
} }
+2 -1
View File
@@ -1,3 +1,4 @@
#![allow(clippy::module_name_repetitions)]
use uuid::Uuid; use uuid::Uuid;
use crate::stored_object; use crate::stored_object;
@@ -56,7 +57,7 @@ impl fmt::Display for Message {
pub fn format_history(history: &[Message]) -> String { pub fn format_history(history: &[Message]) -> String {
history history
.iter() .iter()
.map(|msg| format!("{}", msg)) .map(|msg| format!("{msg}"))
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("\n") .join("\n")
} }
+6 -2
View File
@@ -1,3 +1,4 @@
#![allow(clippy::unsafe_derive_deserialize)]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub mod analytics; pub mod analytics;
pub mod conversation; pub mod conversation;
@@ -23,7 +24,7 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
#[macro_export] #[macro_export]
macro_rules! stored_object { macro_rules! stored_object {
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => { ($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use surrealdb::sql::Thing; use surrealdb::sql::Thing;
use $crate::storage::types::StoredObject; use $crate::storage::types::StoredObject;
@@ -87,6 +88,7 @@ macro_rules! stored_object {
} }
#[allow(dead_code)] #[allow(dead_code)]
#[allow(clippy::ref_option)]
fn serialize_option_datetime<S>( fn serialize_option_datetime<S>(
date: &Option<DateTime<Utc>>, date: &Option<DateTime<Utc>>,
serializer: S, serializer: S,
@@ -102,6 +104,7 @@ macro_rules! stored_object {
} }
#[allow(dead_code)] #[allow(dead_code)]
#[allow(clippy::ref_option)]
fn deserialize_option_datetime<'de, D>( fn deserialize_option_datetime<'de, D>(
deserializer: D, deserializer: D,
) -> Result<Option<DateTime<Utc>>, D::Error> ) -> Result<Option<DateTime<Utc>>, D::Error>
@@ -113,6 +116,7 @@ macro_rules! stored_object {
} }
$(#[$struct_attr])*
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct $name { pub struct $name {
#[serde(deserialize_with = "deserialize_flexible_id")] #[serde(deserialize_with = "deserialize_flexible_id")]
@@ -121,7 +125,7 @@ macro_rules! stored_object {
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)] #[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
$( $(#[$attr])* pub $field: $ty),* $( $(#[$field_attr])* pub $field: $ty),*
} }
impl StoredObject for $name { impl StoredObject for $name {
+17 -7
View File
@@ -1,4 +1,6 @@
#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)]
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Write;
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
@@ -18,6 +20,7 @@ stored_object!(TextChunk, "text_chunk", {
}); });
/// Search result including hydrated chunk. /// Search result including hydrated chunk.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)] #[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct TextChunkSearchResult { pub struct TextChunkSearchResult {
pub chunk: TextChunk, pub chunk: TextChunk,
@@ -98,6 +101,7 @@ impl TextChunk {
db: &SurrealDbClient, db: &SurrealDbClient,
user_id: &str, user_id: &str,
) -> Result<Vec<TextChunkSearchResult>, AppError> { ) -> Result<Vec<TextChunkSearchResult>, AppError> {
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)] #[derive(Deserialize)]
struct Row { struct Row {
chunk_id: TextChunk, chunk_id: TextChunk,
@@ -160,6 +164,8 @@ impl TextChunk {
score: f32, score: f32,
} }
let limit = i64::try_from(take).unwrap_or(i64::MAX);
let sql = format!( let sql = format!(
r#" r#"
SELECT SELECT
@@ -183,7 +189,7 @@ impl TextChunk {
.query(&sql) .query(&sql)
.bind(("terms", terms.to_owned())) .bind(("terms", terms.to_owned()))
.bind(("user_id", user_id.to_owned())) .bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64)) .bind(("limit", limit))
.await .await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?; .map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
@@ -245,7 +251,7 @@ impl TextChunk {
// Generate all new embeddings in memory // Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new(); let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all chunks..."); info!("Generating new embeddings for all chunks...");
for chunk in all_chunks.iter() { for chunk in &all_chunks {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3); let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || { let embedding = Retry::spawn(retry_strategy, || {
@@ -283,12 +289,13 @@ impl TextChunk {
"[{}]", "[{}]",
embedding embedding
.iter() .iter()
.map(|f| f.to_string()) .map(ToString::to_string)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(",") .join(",")
); );
// Use the chunk id as the embedding record id to keep a 1:1 mapping // Use the chunk id as the embedding record id to keep a 1:1 mapping
transaction_query.push_str(&format!( write!(
&mut transaction_query,
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \ "UPSERT type::thing('text_chunk_embedding', '{id}') SET \
chunk_id = type::thing('text_chunk', '{id}'), \ chunk_id = type::thing('text_chunk', '{id}'), \
source_id = '{source_id}', \ source_id = '{source_id}', \
@@ -300,13 +307,16 @@ impl TextChunk {
embedding = embedding_str, embedding = embedding_str,
user_id = user_id, user_id = user_id,
source_id = source_id source_id = source_id
)); )
.map_err(|e| AppError::InternalError(e.to_string()))?;
} }
transaction_query.push_str(&format!( write!(
&mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions new_dimensions
)); )
.map_err(|e| AppError::InternalError(e.to_string()))?;
transaction_query.push_str("COMMIT TRANSACTION;"); transaction_query.push_str("COMMIT TRANSACTION;");
@@ -110,6 +110,11 @@ impl TextChunkEmbedding {
source_id: &str, source_id: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids_query = format!( let ids_query = format!(
"SELECT id FROM {} WHERE source_id = $source_id", "SELECT id FROM {} WHERE source_id = $source_id",
TextChunk::table_name() TextChunk::table_name()
@@ -120,10 +125,6 @@ impl TextChunkEmbedding {
.bind(("source_id", source_id.to_owned())) .bind(("source_id", source_id.to_owned()))
.await .await
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?; let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
if ids.is_empty() { if ids.is_empty() {
+1
View File
@@ -5,6 +5,7 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use super::file_info::FileInfo; use super::file_info::FileInfo;
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
pub struct TextContentSearchResult { pub struct TextContentSearchResult {
#[serde(deserialize_with = "deserialize_flexible_id")] #[serde(deserialize_with = "deserialize_flexible_id")]
+26 -14
View File
@@ -1,4 +1,5 @@
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use anyhow::anyhow;
use async_trait::async_trait; use async_trait::async_trait;
use axum_session_auth::Authentication; use axum_session_auth::Authentication;
use chrono_tz::Tz; use chrono_tz::Tz;
@@ -17,12 +18,16 @@ use super::{
use chrono::Duration; use chrono::Duration;
use futures::try_join; use futures::try_join;
/// Result row for returning user category.
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct CategoryResponse { pub struct CategoryResponse {
/// Category name tied to the user.
category: String, category: String,
} }
stored_object!(User, "user", { stored_object!(
#[allow(clippy::unsafe_derive_deserialize)]
User, "user", {
email: String, email: String,
password: String, password: String,
anonymous: bool, anonymous: bool,
@@ -35,11 +40,11 @@ stored_object!(User, "user", {
#[async_trait] #[async_trait]
impl Authentication<User, String, Surreal<Any>> for User { impl Authentication<User, String, Surreal<Any>> for User {
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> { async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
let db = db.unwrap(); let db = db.ok_or_else(|| anyhow!("Database handle missing"))?;
Ok(db Ok(db
.select((Self::table_name(), userid.as_str())) .select((Self::table_name(), userid.as_str()))
.await? .await?
.unwrap()) .ok_or_else(|| anyhow!("User {userid} not found"))?)
} }
fn is_authenticated(&self) -> bool { fn is_authenticated(&self) -> bool {
@@ -55,14 +60,14 @@ impl Authentication<User, String, Surreal<Any>> for User {
} }
} }
/// Ensures a timezone string parses, defaulting to UTC when invalid.
fn validate_timezone(input: &str) -> String { fn validate_timezone(input: &str) -> String {
match input.parse::<Tz>() { if input.parse::<Tz>().is_ok() {
Ok(_) => input.to_owned(), return input.to_owned();
Err(_) => {
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
}
} }
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
@@ -77,12 +82,15 @@ pub struct DashboardStats {
pub new_text_chunks_week: i64, pub new_text_chunks_week: i64,
} }
/// Helper for aggregating `SurrealDB` count responses.
#[derive(Deserialize)] #[derive(Deserialize)]
struct CountResult { struct CountResult {
/// Row count returned by the query.
count: i64, count: i64,
} }
impl User { impl User {
/// Counts all objects of a given type belonging to the user.
async fn count_total<T: crate::storage::types::StoredObject>( async fn count_total<T: crate::storage::types::StoredObject>(
db: &SurrealDbClient, db: &SurrealDbClient,
user_id: &str, user_id: &str,
@@ -94,9 +102,10 @@ impl User {
.bind(("user_id", user_id.to_string())) .bind(("user_id", user_id.to_string()))
.await? .await?
.take(0)?; .take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0)) Ok(result.map_or(0, |r| r.count))
} }
/// Counts objects of a given type created after a specific timestamp.
async fn count_since<T: crate::storage::types::StoredObject>( async fn count_since<T: crate::storage::types::StoredObject>(
db: &SurrealDbClient, db: &SurrealDbClient,
user_id: &str, user_id: &str,
@@ -112,14 +121,16 @@ impl User {
.bind(("since", surrealdb::Datetime::from(since))) .bind(("since", surrealdb::Datetime::from(since)))
.await? .await?
.take(0)?; .take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0)) Ok(result.map_or(0, |r| r.count))
} }
pub async fn get_dashboard_stats( pub async fn get_dashboard_stats(
user_id: &str, user_id: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<DashboardStats, AppError> { ) -> Result<DashboardStats, AppError> {
let since = chrono::Utc::now() - Duration::days(7); let since = chrono::Utc::now()
.checked_sub_signed(Duration::days(7))
.unwrap_or_else(chrono::Utc::now);
let ( let (
total_documents, total_documents,
@@ -261,7 +272,7 @@ impl User {
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> { pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
// Generate a secure random API key // Generate a secure random API key
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", "")); let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace('-', ""));
// Update the user record with the new API key // Update the user record with the new API key
let user: Option<Self> = db let user: Option<Self> = db
@@ -341,6 +352,7 @@ impl User {
) -> Result<Vec<String>, AppError> { ) -> Result<Vec<String>, AppError> {
#[derive(Deserialize)] #[derive(Deserialize)]
struct EntityTypeResponse { struct EntityTypeResponse {
/// Raw entity type value from the database.
entity_type: String, entity_type: String,
} }
@@ -358,7 +370,7 @@ impl User {
.into_iter() .into_iter()
.map(|item| { .map(|item| {
let normalized = KnowledgeEntityType::from(item.entity_type); let normalized = KnowledgeEntityType::from(item.entity_type);
format!("{:?}", normalized) format!("{normalized:?}")
}) })
.collect(); .collect();
+9
View File
@@ -9,6 +9,7 @@ pub enum StorageKind {
Memory, Memory,
} }
/// Default storage backend when none is configured.
fn default_storage_kind() -> StorageKind { fn default_storage_kind() -> StorageKind {
StorageKind::Local StorageKind::Local
} }
@@ -23,10 +24,13 @@ pub enum PdfIngestMode {
LlmFirst, LlmFirst,
} }
/// Default PDF ingestion mode when unset.
fn default_pdf_ingest_mode() -> PdfIngestMode { fn default_pdf_ingest_mode() -> PdfIngestMode {
PdfIngestMode::LlmFirst PdfIngestMode::LlmFirst
} }
/// Application configuration loaded from files and environment variables.
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Deserialize, Debug)] #[derive(Clone, Deserialize, Debug)]
pub struct AppConfig { pub struct AppConfig {
pub openai_api_key: String, pub openai_api_key: String,
@@ -58,14 +62,17 @@ pub struct AppConfig {
pub retrieval_strategy: Option<String>, pub retrieval_strategy: Option<String>,
} }
/// Default data directory for persisted assets.
fn default_data_dir() -> String { fn default_data_dir() -> String {
"./data".to_string() "./data".to_string()
} }
/// Default base URL used for OpenAI-compatible APIs.
fn default_base_url() -> String { fn default_base_url() -> String {
"https://api.openai.com/v1".to_string() "https://api.openai.com/v1".to_string()
} }
/// Whether reranking is enabled by default.
fn default_reranking_enabled() -> bool { fn default_reranking_enabled() -> bool {
false false
} }
@@ -124,6 +131,8 @@ impl Default for AppConfig {
} }
} }
/// Loads the application configuration from the environment and optional config file.
#[allow(clippy::module_name_repetitions)]
pub fn get_config() -> Result<AppConfig, ConfigError> { pub fn get_config() -> Result<AppConfig, ConfigError> {
ensure_ort_path(); ensure_ort_path();
+40 -24
View File
@@ -16,19 +16,16 @@ use crate::{
storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
}; };
#[derive(Debug, Clone, Copy, PartialEq, Eq)] /// Supported embedding backends.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EmbeddingBackend { pub enum EmbeddingBackend {
OpenAI, OpenAI,
#[default]
FastEmbed, FastEmbed,
Hashed, Hashed,
} }
impl Default for EmbeddingBackend {
fn default() -> Self {
Self::FastEmbed
}
}
impl std::str::FromStr for EmbeddingBackend { impl std::str::FromStr for EmbeddingBackend {
type Err = anyhow::Error; type Err = anyhow::Error;
@@ -44,24 +41,38 @@ impl std::str::FromStr for EmbeddingBackend {
} }
} }
/// Wrapper around the chosen embedding backend.
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)] #[derive(Clone)]
pub struct EmbeddingProvider { pub struct EmbeddingProvider {
/// Concrete backend implementation.
inner: EmbeddingInner, inner: EmbeddingInner,
} }
/// Concrete embedding implementations.
#[derive(Clone)] #[derive(Clone)]
enum EmbeddingInner { enum EmbeddingInner {
/// Uses an `OpenAI`-compatible API.
OpenAI { OpenAI {
/// Client used to issue embedding requests.
client: Arc<Client<async_openai::config::OpenAIConfig>>, client: Arc<Client<async_openai::config::OpenAIConfig>>,
/// Model identifier for the API.
model: String, model: String,
/// Expected output dimensions.
dimensions: u32, dimensions: u32,
}, },
/// Generates deterministic hashed embeddings without external calls.
Hashed { Hashed {
/// Output vector length.
dimension: usize, dimension: usize,
}, },
/// Uses `FastEmbed` running locally.
FastEmbed { FastEmbed {
/// Shared `FastEmbed` model.
model: Arc<Mutex<TextEmbedding>>, model: Arc<Mutex<TextEmbedding>>,
/// Model metadata used for info logging.
model_name: EmbeddingModel, model_name: EmbeddingModel,
/// Output vector length.
dimension: usize, dimension: usize,
}, },
} }
@@ -77,8 +88,9 @@ impl EmbeddingProvider {
pub fn dimension(&self) -> usize { pub fn dimension(&self) -> usize {
match &self.inner { match &self.inner {
EmbeddingInner::Hashed { dimension } => *dimension, EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => {
EmbeddingInner::FastEmbed { dimension, .. } => *dimension, *dimension
}
EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize, EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize,
} }
} }
@@ -172,12 +184,12 @@ impl EmbeddingProvider {
} }
} }
pub async fn new_openai( pub fn new_openai(
client: Arc<Client<async_openai::config::OpenAIConfig>>, client: Arc<Client<async_openai::config::OpenAIConfig>>,
model: String, model: String,
dimensions: u32, dimensions: u32,
) -> Result<Self> { ) -> Result<Self> {
Ok(EmbeddingProvider { Ok(Self {
inner: EmbeddingInner::OpenAI { inner: EmbeddingInner::OpenAI {
client, client,
model, model,
@@ -226,6 +238,7 @@ impl EmbeddingProvider {
} }
// Helper functions for hashed embeddings // Helper functions for hashed embeddings
/// Generates a hashed embedding vector without external dependencies.
fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> { fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
let dim = dimension.max(1); let dim = dimension.max(1);
let mut vector = vec![0.0f32; dim]; let mut vector = vec![0.0f32; dim];
@@ -233,15 +246,11 @@ fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
return vector; return vector;
} }
let mut token_count = 0f32;
for token in tokens(text) { for token in tokens(text) {
token_count += 1.0;
let idx = bucket(&token, dim); let idx = bucket(&token, dim);
vector[idx] += 1.0; if let Some(slot) = vector.get_mut(idx) {
} *slot += 1.0;
}
if token_count == 0.0 {
return vector;
} }
let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt(); let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
@@ -254,16 +263,22 @@ fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
vector vector
} }
/// Tokenizes the text into alphanumeric lowercase tokens.
fn tokens(text: &str) -> impl Iterator<Item = String> + '_ { fn tokens(text: &str) -> impl Iterator<Item = String> + '_ {
text.split(|c: char| !c.is_ascii_alphanumeric()) text.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|token| !token.is_empty()) .filter(|token| !token.is_empty())
.map(|token| token.to_ascii_lowercase()) .map(str::to_ascii_lowercase)
} }
/// Buckets a token into the hashed embedding vector.
#[allow(clippy::arithmetic_side_effects)]
fn bucket(token: &str, dimension: usize) -> usize { fn bucket(token: &str, dimension: usize) -> usize {
let safe_dimension = dimension.max(1);
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
token.hash(&mut hasher); token.hash(&mut hasher);
(hasher.finish() as usize) % dimension usize::try_from(hasher.finish())
.unwrap_or_default()
% safe_dimension
} }
// Backward compatibility function // Backward compatibility function
@@ -274,15 +289,15 @@ pub async fn generate_embedding_with_provider(
provider.embed(input).await.map_err(AppError::from) provider.embed(input).await.map_err(AppError::from)
} }
/// Generates an embedding vector for the given input text using OpenAI's embedding model. /// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
/// ///
/// This function takes a text input and converts it into a numerical vector representation (embedding) /// This function takes a text input and converts it into a numerical vector representation (embedding)
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity /// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity
/// comparisons, vector search, and other natural language processing tasks. /// comparisons, vector search, and other natural language processing tasks.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `client`: The OpenAI client instance used to make API requests. /// * `client`: The `OpenAI` client instance used to make API requests.
/// * `input`: The text string to generate embeddings for. /// * `input`: The text string to generate embeddings for.
/// ///
/// # Returns /// # Returns
@@ -294,9 +309,10 @@ pub async fn generate_embedding_with_provider(
/// # Errors /// # Errors
/// ///
/// This function can return a `AppError` in the following cases: /// This function can return a `AppError` in the following cases:
/// * If the OpenAI API request fails /// * If the `OpenAI` API request fails
/// * If the request building fails /// * If the request building fails
/// * If no embedding data is received in the response /// * If no embedding data is received in the response
#[allow(clippy::module_name_repetitions)]
pub async fn generate_embedding( pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>, client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str, input: &str,
+1
View File
@@ -4,6 +4,7 @@ pub use minijinja_contrib;
pub use minijinja_embed; pub use minijinja_embed;
use std::sync::Arc; use std::sync::Arc;
#[allow(clippy::module_name_repetitions)]
pub trait ProvidesTemplateEngine { pub trait ProvidesTemplateEngine {
fn template_engine(&self) -> &Arc<TemplateEngine>; fn template_engine(&self) -> &Arc<TemplateEngine>;
} }
+5 -14
View File
@@ -28,19 +28,14 @@ fn default_ingestion_cache_dir() -> PathBuf {
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025; pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
#[value(rename_all = "lowercase")] #[value(rename_all = "lowercase")]
pub enum EmbeddingBackend { pub enum EmbeddingBackend {
Hashed, Hashed,
#[default]
FastEmbed, FastEmbed,
} }
impl Default for EmbeddingBackend {
fn default() -> Self {
Self::FastEmbed
}
}
impl std::fmt::Display for EmbeddingBackend { impl std::fmt::Display for EmbeddingBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
@@ -109,7 +104,7 @@ pub struct RetrievalSettings {
pub require_verified_chunks: bool, pub require_verified_chunks: bool,
/// Select the retrieval pipeline strategy /// Select the retrieval pipeline strategy
#[arg(long, default_value_t = RetrievalStrategy::Initial)] #[arg(long, default_value_t = RetrievalStrategy::Default)]
pub strategy: RetrievalStrategy, pub strategy: RetrievalStrategy,
} }
@@ -130,7 +125,7 @@ impl Default for RetrievalSettings {
chunk_rrf_use_vector: None, chunk_rrf_use_vector: None,
chunk_rrf_use_fts: None, chunk_rrf_use_fts: None,
require_verified_chunks: true, require_verified_chunks: true,
strategy: RetrievalStrategy::Initial, strategy: RetrievalStrategy::Default,
} }
} }
} }
@@ -378,11 +373,7 @@ impl Config {
self.summary_sample = self.sample.max(1); self.summary_sample = self.sample.max(1);
// Handle retrieval settings // Handle retrieval settings
if self.llm_mode { self.retrieval.require_verified_chunks = !self.llm_mode;
self.retrieval.require_verified_chunks = false;
} else {
self.retrieval.require_verified_chunks = true;
}
if self.dataset == DatasetKind::Beir { if self.dataset == DatasetKind::Beir {
self.negative_multiplier = 9.0; self.negative_multiplier = 9.0;
+6 -6
View File
@@ -14,13 +14,13 @@ pub use store::{
}; };
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig { pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
let mut tuning = ingestion_pipeline::IngestionTuning::default();
tuning.chunk_min_tokens = config.ingest.ingest_chunk_min_tokens;
tuning.chunk_max_tokens = config.ingest.ingest_chunk_max_tokens;
tuning.chunk_overlap_tokens = config.ingest.ingest_chunk_overlap_tokens;
ingestion_pipeline::IngestionConfig { ingestion_pipeline::IngestionConfig {
tuning, tuning: ingestion_pipeline::IngestionTuning {
chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
..Default::default()
},
chunk_only: config.ingest.ingest_chunks_only, chunk_only: config.ingest.ingest_chunks_only,
} }
} }
+12 -10
View File
@@ -106,6 +106,7 @@ struct IngestionStats {
negative_ingested: usize, negative_ingested: usize,
} }
#[allow(clippy::too_many_arguments)]
pub async fn ensure_corpus( pub async fn ensure_corpus(
dataset: &ConvertedDataset, dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>, slice: &ResolvedSlice<'_>,
@@ -337,11 +338,9 @@ pub async fn ensure_corpus(
}); });
} }
for record in &mut records { for entry in records.iter_mut().flatten() {
if let Some(ref mut entry) = record { if entry.dirty {
if entry.dirty { store.persist(&entry.shard)?;
store.persist(&entry.shard)?;
}
} }
} }
@@ -403,6 +402,7 @@ pub async fn ensure_corpus(
Ok(handle) Ok(handle)
} }
#[allow(clippy::too_many_arguments)]
async fn ingest_paragraph_batch( async fn ingest_paragraph_batch(
dataset: &ConvertedDataset, dataset: &ConvertedDataset,
targets: &[IngestRequest<'_>], targets: &[IngestRequest<'_>],
@@ -430,8 +430,10 @@ async fn ingest_paragraph_batch(
.await .await
.context("applying migrations for ingestion")?; .context("applying migrations for ingestion")?;
let mut app_config = AppConfig::default(); let app_config = AppConfig {
app_config.storage = StorageKind::Memory; storage: StorageKind::Memory,
..Default::default()
};
let backend: DynStore = Arc::new(InMemory::new()); let backend: DynStore = Arc::new(InMemory::new());
let storage = StorageManager::with_backend(backend, StorageKind::Memory); let storage = StorageManager::with_backend(backend, StorageKind::Memory);
@@ -444,8 +446,7 @@ async fn ingest_paragraph_batch(
storage, storage,
embedding.clone(), embedding.clone(),
pipeline_config, pipeline_config,
) )?;
.await?;
let pipeline = Arc::new(pipeline); let pipeline = Arc::new(pipeline);
let mut shards = Vec::with_capacity(targets.len()); let mut shards = Vec::with_capacity(targets.len());
@@ -454,7 +455,7 @@ async fn ingest_paragraph_batch(
info!( info!(
batch = batch_index, batch = batch_index,
batch_size = batch.len(), batch_size = batch.len(),
total_batches = (targets.len() + batch_size - 1) / batch_size, total_batches = targets.len().div_ceil(batch_size),
"Ingesting paragraph batch" "Ingesting paragraph batch"
); );
let model_clone = embedding_model.clone(); let model_clone = embedding_model.clone();
@@ -486,6 +487,7 @@ async fn ingest_paragraph_batch(
Ok(shards) Ok(shards)
} }
#[allow(clippy::too_many_arguments)]
async fn ingest_single_paragraph( async fn ingest_single_paragraph(
pipeline: Arc<IngestionPipeline>, pipeline: Arc<IngestionPipeline>,
request: IngestRequest<'_>, request: IngestRequest<'_>,
+4 -9
View File
@@ -481,6 +481,7 @@ impl ParagraphShardStore {
} }
impl ParagraphShard { impl ParagraphShard {
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
paragraph: &ConvertedParagraph, paragraph: &ConvertedParagraph,
shard_path: String, shard_path: String,
@@ -674,10 +675,8 @@ async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
let slice = &batches[start..group_end]; let slice = &batches[start..group_end];
let mut query = db.client.query("BEGIN TRANSACTION;"); let mut query = db.client.query("BEGIN TRANSACTION;");
let mut bind_index = 0usize; for (bind_index, batch) in slice.iter().enumerate() {
for batch in slice {
let name = format!("{prefix}{bind_index}"); let name = format!("{prefix}{bind_index}");
bind_index += 1;
query = query query = query
.query(format!("{} ${};", statement.as_ref(), name)) .query(format!("{} ${};", statement.as_ref(), name))
.bind((name, batch.items.clone())); .bind((name, batch.items.clone()));
@@ -702,7 +701,7 @@ async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> { pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?; let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
let result = (|| async { let result = async {
execute_batched_inserts( execute_batched_inserts(
db, db,
format!("INSERT INTO {}", TextContent::table_name()), format!("INSERT INTO {}", TextContent::table_name()),
@@ -752,7 +751,7 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife
.await?; .await?;
Ok(()) Ok(())
})() }
.await; .await;
if result.is_err() { if result.is_err() {
@@ -778,7 +777,6 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::db_helpers::change_embedding_length_in_hnsw_indexes;
use chrono::Utc; use chrono::Utc;
use common::storage::types::knowledge_entity::KnowledgeEntityType; use common::storage::types::knowledge_entity::KnowledgeEntityType;
use uuid::Uuid; use uuid::Uuid;
@@ -905,9 +903,6 @@ mod tests {
db.apply_migrations() db.apply_migrations()
.await .await
.expect("apply migrations for memory db"); .expect("apply migrations for memory db");
change_embedding_length_in_hnsw_indexes(&db, 3)
.await
.expect("set embedding index dimension for test");
let manifest = build_manifest(); let manifest = build_manifest();
seed_manifest_into_db(&db, &manifest) seed_manifest_into_db(&db, &manifest)
+2 -7
View File
@@ -245,8 +245,9 @@ fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
catalog.dataset(kind.id()) catalog.dataset(kind.id())
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum DatasetKind { pub enum DatasetKind {
#[default]
SquadV2, SquadV2,
NaturalQuestions, NaturalQuestions,
Beir, Beir,
@@ -368,12 +369,6 @@ impl std::fmt::Display for DatasetKind {
} }
} }
impl Default for DatasetKind {
fn default() -> Self {
Self::SquadV2
}
}
impl FromStr for DatasetKind { impl FromStr for DatasetKind {
type Err = anyhow::Error; type Err = anyhow::Error;
+8 -7
View File
@@ -36,13 +36,14 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s
Ok(()) Ok(())
} }
// Test helper to force index dimension change // // Test helper to force index dimension change
pub async fn change_embedding_length_in_hnsw_indexes( // #[allow(dead_code)]
db: &SurrealDbClient, // pub async fn change_embedding_length_in_hnsw_indexes(
dimension: usize, // db: &SurrealDbClient,
) -> Result<()> { // dimension: usize,
recreate_indexes(db, dimension).await // ) -> Result<()> {
} // recreate_indexes(db, dimension).await
// }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
+2 -1
View File
@@ -86,6 +86,7 @@ pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
} }
/// Determine if we can reuse an existing namespace based on cached state. /// Determine if we can reuse an existing namespace based on cached state.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn can_reuse_namespace( pub(crate) async fn can_reuse_namespace(
db: &SurrealDbClient, db: &SurrealDbClient,
descriptor: &snapshot::Descriptor, descriptor: &snapshot::Descriptor,
@@ -213,7 +214,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
timezone: "UTC".to_string(), timezone: "UTC".to_string(),
}; };
if let Some(existing) = db.get_item::<User>(&user.get_id()).await? { if let Some(existing) = db.get_item::<User>(user.get_id()).await? {
return Ok(existing); return Ok(existing);
} }
+1 -1
View File
@@ -154,7 +154,7 @@ impl<'a> EvaluationContext<'a> {
} }
pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) { pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) {
let elapsed = duration.as_millis() as u128; let elapsed = duration.as_millis();
match stage { match stage {
EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed, EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed,
EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed, EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed,
+1 -3
View File
@@ -21,9 +21,7 @@ pub async fn run_evaluation(
let machine = stages::prepare_namespace(machine, &mut ctx).await?; let machine = stages::prepare_namespace(machine, &mut ctx).await?;
let machine = stages::run_queries(machine, &mut ctx).await?; let machine = stages::run_queries(machine, &mut ctx).await?;
let machine = stages::summarize(machine, &mut ctx).await?; let machine = stages::summarize(machine, &mut ctx).await?;
let machine = stages::finalize(machine, &mut ctx).await?; let _ = stages::finalize(machine, &mut ctx).await?;
drop(machine);
Ok(ctx.into_summary()) Ok(ctx.into_summary())
} }
@@ -113,7 +113,7 @@ pub(crate) async fn prepare_corpus(
.metadata .metadata
.ingestion_fingerprint .ingestion_fingerprint
.clone(); .clone();
let ingestion_duration_ms = ingestion_timer.elapsed().as_millis() as u128; let ingestion_duration_ms = ingestion_timer.elapsed().as_millis();
info!( info!(
cache = %corpus_handle.path.display(), cache = %corpus_handle.path.display(),
reused_ingestion = corpus_handle.reused_ingestion, reused_ingestion = corpus_handle.reused_ingestion,
@@ -119,7 +119,7 @@ pub(crate) async fn prepare_namespace(
corpus::seed_manifest_into_db(ctx.db(), &manifest_for_seed) corpus::seed_manifest_into_db(ctx.db(), &manifest_for_seed)
.await .await
.context("seeding ingestion corpus from manifest")?; .context("seeding ingestion corpus from manifest")?;
namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128); namespace_seed_ms = Some(seed_start.elapsed().as_millis());
// Recreate indexes AFTER data is loaded (correct bulk loading pattern) // Recreate indexes AFTER data is loaded (correct bulk loading pattern)
if indexes_disabled { if indexes_disabled {
@@ -50,8 +50,10 @@ pub(crate) async fn run_queries(
None None
}; };
let mut retrieval_config = RetrievalConfig::default(); let mut retrieval_config = RetrievalConfig {
retrieval_config.strategy = config.retrieval.strategy; strategy: config.retrieval.strategy,
..Default::default()
};
retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top; retrieval_config.tuning.rerank_keep_top = config.retrieval.rerank_keep_top;
if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top { if retrieval_config.tuning.fallback_min_results < config.retrieval.rerank_keep_top {
retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top; retrieval_config.tuning.fallback_min_results = config.retrieval.rerank_keep_top;
@@ -213,7 +215,7 @@ pub(crate) async fn run_queries(
.with_context(|| format!("running pipeline for question {}", question_id))?; .with_context(|| format!("running pipeline for question {}", question_id))?;
(outcome.results, None, outcome.stage_timings) (outcome.results, None, outcome.stage_timings)
}; };
let query_latency = query_start.elapsed().as_millis() as u128; let query_latency = query_start.elapsed().as_millis();
let candidates = adapt_strategy_output(result_output); let candidates = adapt_strategy_output(result_output);
let mut retrieved = Vec::new(); let mut retrieved = Vec::new();
+2 -2
View File
@@ -436,8 +436,8 @@ pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a
select_window(resolved, 0, None) select_window(resolved, 0, None)
} }
fn load_explicit_slice<'a>( fn load_explicit_slice(
dataset: &'a ConvertedDataset, dataset: &ConvertedDataset,
index: &DatasetIndex, index: &DatasetIndex,
config: &SliceConfig<'_>, config: &SliceConfig<'_>,
slice_arg: &str, slice_arg: &str,
+1 -1
View File
@@ -46,7 +46,7 @@ impl HtmlState {
.retrieval_strategy .retrieval_strategy
.as_deref() .as_deref()
.and_then(|value| value.parse().ok()) .and_then(|value| value.parse().ok())
.unwrap_or(RetrievalStrategy::Initial) .unwrap_or(RetrievalStrategy::Default)
} }
} }
impl ProvidesDb for HtmlState { impl ProvidesDb for HtmlState {
@@ -15,7 +15,10 @@ use futures::{
use json_stream_parser::JsonStreamParser; use json_stream_parser::JsonStreamParser;
use minijinja::Value; use minijinja::Value;
use retrieval_pipeline::{ use retrieval_pipeline::{
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat}, answer_retrieval::{
chunks_to_chat_context, create_chat_request, create_user_message_with_history,
LLMResponseFormat,
},
retrieved_entities_to_json, retrieved_entities_to_json,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -126,7 +129,7 @@ pub async fn get_response_stream(
let strategy = state.retrieval_strategy(); let strategy = state.retrieval_strategy();
let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy); let config = retrieval_pipeline::RetrievalConfig::for_chat(strategy);
let entities = match retrieval_pipeline::retrieve_entities( let retrieval_result = match retrieval_pipeline::retrieve_entities(
&state.db, &state.db,
&state.openai_client, &state.openai_client,
&user_message.content, &user_message.content,
@@ -136,19 +139,21 @@ pub async fn get_response_stream(
) )
.await .await
{ {
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => entities, Ok(result) => result,
Ok(retrieval_pipeline::StrategyOutput::Chunks(_chunks)) => {
return Sse::new(create_error_stream("Chat retrieval currently only supports Entity-based strategies (Initial). Revised strategy returns Chunks which are not yet supported by this handler."));
}
Err(_e) => { Err(_e) => {
return Sse::new(create_error_stream("Failed to retrieve knowledge entities")); return Sse::new(create_error_stream("Failed to retrieve knowledge"));
} }
}; };
// 3. Create the OpenAI request // 3. Create the OpenAI request with appropriate context format
let entities_json = retrieved_entities_to_json(&entities); let context_json = match retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(&entities)
}
};
let formatted_user_message = let formatted_user_message =
create_user_message_with_history(&entities_json, &history, &user_message.content); create_user_message_with_history(&context_json, &history, &user_message.content);
let settings = match SystemSettings::get_current(&state.db).await { let settings = match SystemSettings::get_current(&state.db).await {
Ok(s) => s, Ok(s) => s,
Err(_) => { Err(_) => {
+5
View File
@@ -1,3 +1,8 @@
#![allow(
clippy::missing_docs_in_private_items,
clippy::result_large_err
)]
pub mod pipeline; pub mod pipeline;
pub mod utils; pub mod utils;
+1 -10
View File
@@ -31,17 +31,8 @@ impl Default for IngestionTuning {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, Default)]
pub struct IngestionConfig { pub struct IngestionConfig {
pub tuning: IngestionTuning, pub tuning: IngestionTuning,
pub chunk_only: bool, pub chunk_only: bool,
} }
impl Default for IngestionConfig {
fn default() -> Self {
Self {
tuning: IngestionTuning::default(),
chunk_only: false,
}
}
}
@@ -52,7 +52,7 @@ impl LLMEnrichmentResult {
entity_concurrency: usize, entity_concurrency: usize,
embedding_provider: Option<&EmbeddingProvider>, embedding_provider: Option<&EmbeddingProvider>,
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> { ) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
let mapper = Arc::new(self.create_mapper()?); let mapper = Arc::new(self.create_mapper());
let entities = self let entities = self
.process_entities( .process_entities(
@@ -66,21 +66,22 @@ impl LLMEnrichmentResult {
) )
.await?; .await?;
let relationships = self.process_relationships(source_id, user_id, Arc::clone(&mapper))?; let relationships = self.process_relationships(source_id, user_id, mapper.as_ref())?;
Ok((entities, relationships)) Ok((entities, relationships))
} }
fn create_mapper(&self) -> Result<GraphMapper, AppError> { fn create_mapper(&self) -> GraphMapper {
let mut mapper = GraphMapper::new(); let mut mapper = GraphMapper::new();
for entity in &self.knowledge_entities { for entity in &self.knowledge_entities {
mapper.assign_id(&entity.key); mapper.assign_id(&entity.key);
} }
Ok(mapper) mapper
} }
#[allow(clippy::too_many_arguments)]
async fn process_entities( async fn process_entities(
&self, &self,
source_id: &str, source_id: &str,
@@ -91,7 +92,7 @@ impl LLMEnrichmentResult {
entity_concurrency: usize, entity_concurrency: usize,
embedding_provider: Option<&EmbeddingProvider>, embedding_provider: Option<&EmbeddingProvider>,
) -> Result<Vec<EmbeddedKnowledgeEntity>, AppError> { ) -> Result<Vec<EmbeddedKnowledgeEntity>, AppError> {
stream::iter(self.knowledge_entities.iter().cloned().map(|entity| { stream::iter(self.knowledge_entities.clone().into_iter().map(|entity| {
let mapper = Arc::clone(&mapper); let mapper = Arc::clone(&mapper);
let openai_client = openai_client.clone(); let openai_client = openai_client.clone();
let source_id = source_id.to_string(); let source_id = source_id.to_string();
@@ -120,7 +121,7 @@ impl LLMEnrichmentResult {
&self, &self,
source_id: &str, source_id: &str,
user_id: &str, user_id: &str,
mapper: Arc<GraphMapper>, mapper: &GraphMapper,
) -> Result<Vec<KnowledgeRelationship>, AppError> { ) -> Result<Vec<KnowledgeRelationship>, AppError> {
self.relationships self.relationships
.iter() .iter()
@@ -170,9 +171,9 @@ async fn create_single_entity(
id: assigned_id, id: assigned_id,
created_at: now, created_at: now,
updated_at: now, updated_at: now,
name: llm_entity.name.to_string(), name: llm_entity.name.clone(),
description: llm_entity.description.to_string(), description: llm_entity.description.clone(),
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.to_string()), entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
source_id: source_id.to_string(), source_id: source_id.to_string(),
metadata: None, metadata: None,
user_id: user_id.into(), user_id: user_id.into(),
+19 -12
View File
@@ -8,6 +8,7 @@ mod state;
pub use config::{IngestionConfig, IngestionTuning}; pub use config::{IngestionConfig, IngestionTuning};
pub use enrichment_result::{LLMEnrichmentResult, LLMKnowledgeEntity, LLMRelationship}; pub use enrichment_result::{LLMEnrichmentResult, LLMKnowledgeEntity, LLMRelationship};
#[allow(clippy::module_name_repetitions)]
pub use services::{DefaultPipelineServices, PipelineServices}; pub use services::{DefaultPipelineServices, PipelineServices};
use std::{ use std::{
@@ -37,6 +38,7 @@ use self::{
state::ready, state::ready,
}; };
#[allow(clippy::module_name_repetitions)]
pub struct IngestionPipeline { pub struct IngestionPipeline {
db: Arc<SurrealDbClient>, db: Arc<SurrealDbClient>,
pipeline_config: IngestionConfig, pipeline_config: IngestionConfig,
@@ -44,7 +46,7 @@ pub struct IngestionPipeline {
} }
impl IngestionPipeline { impl IngestionPipeline {
pub async fn new( pub fn new(
db: Arc<SurrealDbClient>, db: Arc<SurrealDbClient>,
openai_client: Arc<Client<async_openai::config::OpenAIConfig>>, openai_client: Arc<Client<async_openai::config::OpenAIConfig>>,
config: AppConfig, config: AppConfig,
@@ -61,10 +63,9 @@ impl IngestionPipeline {
embedding_provider, embedding_provider,
IngestionConfig::default(), IngestionConfig::default(),
) )
.await
} }
pub async fn new_with_config( pub fn new_with_config(
db: Arc<SurrealDbClient>, db: Arc<SurrealDbClient>,
openai_client: Arc<Client<async_openai::config::OpenAIConfig>>, openai_client: Arc<Client<async_openai::config::OpenAIConfig>>,
config: AppConfig, config: AppConfig,
@@ -74,9 +75,9 @@ impl IngestionPipeline {
pipeline_config: IngestionConfig, pipeline_config: IngestionConfig,
) -> Result<Self, AppError> { ) -> Result<Self, AppError> {
let services = DefaultPipelineServices::new( let services = DefaultPipelineServices::new(
db.clone(), Arc::clone(&db),
openai_client.clone(), openai_client,
config.clone(), config,
reranker_pool, reranker_pool,
storage, storage,
embedding_provider, embedding_provider,
@@ -181,11 +182,17 @@ impl IngestionPipeline {
.saturating_sub(1) .saturating_sub(1)
.min(tuning.retry_backoff_cap_exponent); .min(tuning.retry_backoff_cap_exponent);
let multiplier = 2_u64.pow(capped_attempt); let multiplier = 2_u64.pow(capped_attempt);
let delay = tuning.retry_base_delay_secs * multiplier; let delay = tuning
.retry_base_delay_secs
.saturating_mul(multiplier);
Duration::from_secs(delay.min(tuning.retry_max_delay_secs)) Duration::from_secs(delay.min(tuning.retry_max_delay_secs))
} }
fn duration_millis(duration: Duration) -> u64 {
u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
}
#[tracing::instrument( #[tracing::instrument(
skip_all, skip_all,
fields(task_id = %task.id, attempt = task.attempts, user_id = %task.user_id) fields(task_id = %task.id, attempt = task.attempts, user_id = %task.user_id)
@@ -231,14 +238,14 @@ impl IngestionPipeline {
let persist_duration = stage_start.elapsed(); let persist_duration = stage_start.elapsed();
let total_duration = pipeline_started.elapsed(); let total_duration = pipeline_started.elapsed();
let prepare_ms = prepare_duration.as_millis() as u64; let prepare_ms = Self::duration_millis(prepare_duration);
let retrieve_ms = retrieve_duration.as_millis() as u64; let retrieve_ms = Self::duration_millis(retrieve_duration);
let enrich_ms = enrich_duration.as_millis() as u64; let enrich_ms = Self::duration_millis(enrich_duration);
let persist_ms = persist_duration.as_millis() as u64; let persist_ms = Self::duration_millis(persist_duration);
info!( info!(
task_id = %ctx.task_id, task_id = %ctx.task_id,
attempt = ctx.attempt, attempt = ctx.attempt,
total_ms = total_duration.as_millis() as u64, total_ms = Self::duration_millis(total_duration),
prepare_ms, prepare_ms,
retrieve_ms, retrieve_ms,
enrich_ms, enrich_ms,
+3 -3
View File
@@ -228,7 +228,7 @@ impl PipelineServices for DefaultPipelineServices {
) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> { ) -> Result<(Vec<EmbeddedKnowledgeEntity>, Vec<KnowledgeRelationship>), AppError> {
analysis analysis
.to_database_entities( .to_database_entities(
&content.get_id(), content.get_id(),
&content.user_id, &content.user_id,
&self.openai_client, &self.openai_client,
&self.db, &self.db,
@@ -327,13 +327,13 @@ fn truncate_for_embedding(text: &str, max_chars: usize) -> String {
return text.to_string(); return text.to_string();
} }
let mut truncated = String::with_capacity(max_chars + 3); let mut truncated = String::with_capacity(max_chars.saturating_add(3));
for (idx, ch) in text.chars().enumerate() { for (idx, ch) in text.chars().enumerate() {
if idx >= max_chars { if idx >= max_chars {
break; break;
} }
truncated.push(ch); truncated.push(ch);
} }
truncated.push_str(""); truncated.push('…');
truncated truncated
} }
+32 -29
View File
@@ -20,6 +20,22 @@ use super::{
state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved}, state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved},
}; };
const STORE_RELATIONSHIPS: &str = r"
BEGIN TRANSACTION;
LET $relationships = $relationships;
FOR $relationship IN $relationships {
LET $in_node = type::thing('knowledge_entity', $relationship.in);
LET $out_node = type::thing('knowledge_entity', $relationship.out);
RELATE $in_node->relates_to->$out_node CONTENT {
id: type::thing('relates_to', $relationship.id),
metadata: $relationship.metadata
};
};
COMMIT TRANSACTION;
";
#[instrument( #[instrument(
level = "trace", level = "trace",
skip_all, skip_all,
@@ -40,8 +56,7 @@ pub async fn prepare_content(
let context_len = text_content let context_len = text_content
.context .context
.as_ref() .as_ref()
.map(|c| c.chars().count()) .map_or(0, |c| c.chars().count());
.unwrap_or(0);
tracing::info!( tracing::info!(
task_id = %ctx.task_id, task_id = %ctx.task_id,
@@ -65,7 +80,7 @@ pub async fn prepare_content(
machine machine
.prepare() .prepare()
.map_err(|(_, guard)| map_guard_error("prepare", guard)) .map_err(|(_, guard)| map_guard_error("prepare", &guard))
} }
#[instrument( #[instrument(
@@ -80,7 +95,7 @@ pub async fn retrieve_related(
if ctx.pipeline_config.chunk_only { if ctx.pipeline_config.chunk_only {
return machine return machine
.retrieve() .retrieve()
.map_err(|(_, guard)| map_guard_error("retrieve", guard)); .map_err(|(_, guard)| map_guard_error("retrieve", &guard));
} }
let content = ctx.text_content()?; let content = ctx.text_content()?;
@@ -97,7 +112,7 @@ pub async fn retrieve_related(
machine machine
.retrieve() .retrieve()
.map_err(|(_, guard)| map_guard_error("retrieve", guard)) .map_err(|(_, guard)| map_guard_error("retrieve", &guard))
} }
#[instrument( #[instrument(
@@ -116,7 +131,7 @@ pub async fn enrich(
}); });
return machine return machine
.enrich() .enrich()
.map_err(|(_, guard)| map_guard_error("enrich", guard)); .map_err(|(_, guard)| map_guard_error("enrich", &guard));
} }
let content = ctx.text_content()?; let content = ctx.text_content()?;
@@ -137,7 +152,7 @@ pub async fn enrich(
machine machine
.enrich() .enrich()
.map_err(|(_, guard)| map_guard_error("enrich", guard)) .map_err(|(_, guard)| map_guard_error("enrich", &guard))
} }
#[instrument( #[instrument(
@@ -182,10 +197,10 @@ pub async fn persist(
machine machine
.persist() .persist()
.map_err(|(_, guard)| map_guard_error("persist", guard)) .map_err(|(_, guard)| map_guard_error("persist", &guard))
} }
fn map_guard_error(event: &str, guard: GuardError) -> AppError { fn map_guard_error(event: &str, guard: &GuardError) -> AppError {
AppError::InternalError(format!( AppError::InternalError(format!(
"invalid ingestion pipeline transition during {event}: {guard:?}" "invalid ingestion pipeline transition during {event}: {guard:?}"
)) ))
@@ -206,43 +221,31 @@ async fn store_graph_entities(
return Ok(()); return Ok(());
} }
const STORE_RELATIONSHIPS: &str = r"
BEGIN TRANSACTION;
LET $relationships = $relationships;
FOR $relationship IN $relationships {
LET $in_node = type::thing('knowledge_entity', $relationship.in);
LET $out_node = type::thing('knowledge_entity', $relationship.out);
RELATE $in_node->relates_to->$out_node CONTENT {
id: type::thing('relates_to', $relationship.id),
metadata: $relationship.metadata
};
};
COMMIT TRANSACTION;
";
let relationships = Arc::new(relationships); let relationships = Arc::new(relationships);
let mut backoff_ms = tuning.graph_initial_backoff_ms; let mut backoff_ms = tuning.graph_initial_backoff_ms;
let last_attempt = tuning.graph_store_attempts.saturating_sub(1);
for attempt in 0..tuning.graph_store_attempts { for attempt in 0..tuning.graph_store_attempts {
let result = db let result = db
.client .client
.query(STORE_RELATIONSHIPS) .query(STORE_RELATIONSHIPS)
.bind(("relationships", relationships.clone())) .bind(("relationships", Arc::clone(&relationships)))
.await; .await;
match result { match result {
Ok(_) => return Ok(()), Ok(_) => return Ok(()),
Err(err) => { Err(err) => {
if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts { if is_retryable_conflict(&err) && attempt < last_attempt {
let next_attempt = attempt.saturating_add(1);
warn!( warn!(
attempt = attempt + 1, attempt = next_attempt,
"Transient SurrealDB conflict while storing graph data; retrying" "Transient SurrealDB conflict while storing graph data; retrying"
); );
sleep(Duration::from_millis(backoff_ms)).await; sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms); backoff_ms = backoff_ms
.saturating_mul(2)
.min(tuning.graph_max_backoff_ms);
continue; continue;
} }
@@ -65,7 +65,7 @@ fn infer_extension(file_info: &FileInfo) -> Option<String> {
Path::new(&file_info.path) Path::new(&file_info.path)
.extension() .extension()
.and_then(|ext| ext.to_str()) .and_then(|ext| ext.to_str())
.map(|ext| ext.to_string()) .map(std::string::ToString::to_string)
} }
pub async fn extract_text_from_file( pub async fn extract_text_from_file(
@@ -116,6 +116,7 @@ async fn load_page_numbers(pdf_bytes: Vec<u8>) -> Result<Vec<u32>, AppError> {
} }
/// Uses the existing headless Chrome dependency to rasterize the requested PDF pages into PNGs. /// Uses the existing headless Chrome dependency to rasterize the requested PDF pages into PNGs.
#[allow(clippy::too_many_lines)]
async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>>, AppError> { async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>>, AppError> {
let file_url = url::Url::from_file_path(file_path) let file_url = url::Url::from_file_path(file_path)
.map_err(|()| AppError::Processing("Unable to construct PDF file URL".into()))?; .map_err(|()| AppError::Processing("Unable to construct PDF file URL".into()))?;
@@ -148,7 +149,7 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>
loaded = true; loaded = true;
break; break;
} }
if attempt + 1 < NAVIGATION_RETRY_ATTEMPTS { if attempt < NAVIGATION_RETRY_ATTEMPTS.saturating_sub(1) {
sleep(Duration::from_millis(NAVIGATION_RETRY_INTERVAL_MS)).await; sleep(Duration::from_millis(NAVIGATION_RETRY_INTERVAL_MS)).await;
} }
} }
@@ -172,7 +173,7 @@ async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>
break; break;
} }
Ok(None) => { Ok(None) => {
if attempt + 1 < CANVAS_VIEWPORT_ATTEMPTS { if attempt < CANVAS_VIEWPORT_ATTEMPTS.saturating_sub(1) {
tokio::time::sleep(Duration::from_millis(CANVAS_VIEWPORT_WAIT_MS)).await; tokio::time::sleep(Duration::from_millis(CANVAS_VIEWPORT_WAIT_MS)).await;
} }
} }
@@ -260,6 +261,7 @@ fn create_browser() -> Result<Browser, AppError> {
} }
/// Sends one or more rendered pages to the configured multimodal model and stitches the resulting Markdown chunks together. /// Sends one or more rendered pages to the configured multimodal model and stitches the resulting Markdown chunks together.
#[allow(clippy::too_many_lines)]
async fn vision_markdown( async fn vision_markdown(
rendered_pages: Vec<Vec<u8>>, rendered_pages: Vec<Vec<u8>>,
db: &SurrealDbClient, db: &SurrealDbClient,
@@ -303,10 +305,11 @@ async fn vision_markdown(
let mut batch_markdown: Option<String> = None; let mut batch_markdown: Option<String> = None;
let last_attempt = MAX_VISION_ATTEMPTS.saturating_sub(1);
for attempt in 0..MAX_VISION_ATTEMPTS { for attempt in 0..MAX_VISION_ATTEMPTS {
let prompt_text = prompt_for_attempt(attempt, prompt); let prompt_text = prompt_for_attempt(attempt, prompt);
let mut content_parts = Vec::with_capacity(encoded_images.len() + 1); let mut content_parts = Vec::with_capacity(encoded_images.len().saturating_add(1));
content_parts.push( content_parts.push(
ChatCompletionRequestMessageContentPartTextArgs::default() ChatCompletionRequestMessageContentPartTextArgs::default()
.text(prompt_text) .text(prompt_text)
@@ -375,7 +378,7 @@ async fn vision_markdown(
batch = batch_idx, batch = batch_idx,
attempt, "Vision model returned low quality response" attempt, "Vision model returned low quality response"
); );
if attempt + 1 == MAX_VISION_ATTEMPTS { if attempt == last_attempt {
return Err(AppError::Processing( return Err(AppError::Processing(
"Vision model failed to transcribe PDF page contents".into(), "Vision model failed to transcribe PDF page contents".into(),
)); ));
@@ -400,6 +403,7 @@ async fn vision_markdown(
} }
/// Heuristic that determines whether the fast-path text looks like well-formed prose. /// Heuristic that determines whether the fast-path text looks like well-formed prose.
#[allow(clippy::cast_precision_loss)]
fn looks_good_enough(text: &str) -> bool { fn looks_good_enough(text: &str) -> bool {
if text.len() < FAST_PATH_MIN_LEN { if text.len() < FAST_PATH_MIN_LEN {
return false; return false;
@@ -50,7 +50,7 @@ pub async fn extract_text_from_url(
)?; )?;
let mut tmp_file = NamedTempFile::new()?; let mut tmp_file = NamedTempFile::new()?;
let temp_path_str = format!("{:?}", tmp_file.path()); let temp_path_str = tmp_file.path().display().to_string();
tmp_file.write_all(&screenshot)?; tmp_file.write_all(&screenshot)?;
tmp_file.as_file().sync_all()?; tmp_file.as_file().sync_all()?;
@@ -108,14 +108,11 @@ fn ensure_ingestion_url_allowed(url: &url::Url) -> Result<String, AppError> {
} }
} }
let host = match url.host_str() { let Some(host) = url.host_str() else {
Some(host) => host, warn!(%url, "Rejected ingestion URL missing host");
None => { return Err(AppError::Validation(
warn!(%url, "Rejected ingestion URL missing host"); "URL is missing a host component".to_string(),
return Err(AppError::Validation( ));
"URL is missing a host component".to_string(),
));
}
}; };
if host.eq_ignore_ascii_case("localhost") { if host.eq_ignore_ascii_case("localhost") {
-1
View File
@@ -138,7 +138,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
storage.clone(), storage.clone(),
embedding_provider, embedding_provider,
) )
.await
.unwrap(), .unwrap(),
); );
+1 -1
View File
@@ -53,7 +53,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
storage, storage,
embedding_provider, embedding_provider,
) )
.await?, ?,
); );
run_worker_loop(db, ingestion_pipeline).await run_worker_loop(db, ingestion_pipeline).await
@@ -51,6 +51,24 @@ pub fn create_user_message(entities_json: &Value, query: &str) -> String {
) )
} }
/// Convert chunk-based retrieval results to JSON format for LLM context
pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}
serde_json::json!(chunks
.iter()
.map(|chunk| {
serde_json::json!({
"content": chunk.chunk.chunk,
"source_id": chunk.chunk.source_id,
"score": round_score(chunk.score),
})
})
.collect::<Vec<_>>())
}
pub fn create_user_message_with_history( pub fn create_user_message_with_history(
entities_json: &Value, entities_json: &Value,
history: &[Message], history: &[Message],
-268
View File
@@ -1,268 +0,0 @@
use std::collections::HashMap;
use serde::Deserialize;
use tracing::debug;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
};
use crate::scoring::Scored;
use common::storage::types::file_info::deserialize_flexible_id;
use surrealdb::sql::Thing;
#[derive(Debug, Deserialize)]
struct FtsScoreRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
fts_score: Option<f32>,
}
/// Executes a full-text search query against SurrealDB and returns scored results.
///
/// The function expects FTS indexes to exist for the provided table. Currently supports
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
pub async fn find_items_by_fts<T>(
take: usize,
query: &str,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let (filter_clause, score_clause) = match table {
"knowledge_entity" => (
"(name @0@ $terms OR description @1@ $terms)",
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
),
"text_chunk" => (
"(chunk @0@ $terms)",
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
),
_ => {
return Err(AppError::Validation(format!(
"FTS not configured for table '{table}'"
)))
}
};
let sql = format!(
"SELECT id, {score_clause} AS fts_score \
FROM {table} \
WHERE {filter_clause} \
AND user_id = $user_id \
ORDER BY fts_score DESC \
LIMIT $limit",
table = table,
filter_clause = filter_clause,
score_clause = score_clause
);
debug!(
table = table,
limit = take,
"Executing FTS query with filter clause: {}",
filter_clause
);
let mut response = db_client
.query(sql)
.bind(("terms", query.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
if score_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut results = Vec::with_capacity(score_rows.len());
for row in score_rows {
if let Some(item) = item_map.remove(&row.id) {
let score = row.fts_score.unwrap_or_default();
results.push(Scored::new(item).with_fts_score(score));
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::indexes::ensure_runtime_indexes;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
text_chunk::TextChunk,
StoredObject,
};
use uuid::Uuid;
#[tokio::test]
async fn fts_preserves_single_field_score_for_name() {
let namespace = "fts_test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
ensure_runtime_indexes(&db, 1536)
.await
.expect("failed to build runtime indexes");
let user_id = "user_fts";
let entity = KnowledgeEntity::new(
"source_a".into(),
"Rustacean handbook".into(),
"completely unrelated description".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"rustacean",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the name matched"
);
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_description() {
let namespace = "fts_test_ns_desc";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
ensure_runtime_indexes(&db, 1536)
.await
.expect("failed to build runtime indexes");
let user_id = "user_fts_desc";
let entity = KnowledgeEntity::new(
"source_b".into(),
"neutral name".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"async",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the description matched"
);
}
#[tokio::test]
async fn fts_preserves_scores_for_text_chunks() {
let namespace = "fts_test_ns_chunks";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
ensure_runtime_indexes(&db, 1536)
.await
.expect("failed to build runtime indexes");
let user_id = "user_fts_chunk";
let chunk = TextChunk::new(
"source_chunk".into(),
"GraphQL documentation reference".into(),
user_id.into(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.0; 1536], &db)
.await
.expect("failed to insert chunk");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results =
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when chunk field matched"
);
}
}
+7 -190
View File
@@ -10,54 +10,17 @@ use common::storage::{
}, },
}; };
/// Retrieves database entries that match a specific source identifier. /// Find entities related to the given entity via graph relationships.
/// ///
/// This function queries the database for all records in a specified table that have /// Queries the `relates_to` edge table for all relationships involving the entity,
/// a matching `source_id` field. It's commonly used to find related entities or /// then fetches and returns the neighboring entities.
/// track the origin of database entries.
/// ///
/// # Arguments /// # Arguments
/// /// * `db` - Database client
/// * `source_id` - The identifier to search for in the database /// * `entity_id` - ID of the entity to find neighbors for
/// * `table_name` - The name of the table to search in /// * `user_id` - User ID for access control
/// * `db_client` - The `SurrealDB` client instance for database operations /// * `limit` - Maximum number of neighbors to return
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
///
/// # Returns
///
/// Returns a `Result` containing either:
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
/// * `Err(Error)` - An error if the database query fails
///
/// # Errors
///
/// This function will return a `Error` if:
/// * The database query fails to execute
/// * The results cannot be deserialized into type `T`
pub async fn find_entities_by_source_ids<T>(
source_ids: Vec<String>,
table_name: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<T>, Error>
where
T: for<'de> serde::Deserialize<'de>,
{
let query =
"SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id";
db.query(query)
.bind(("table", table_name.to_owned()))
.bind(("source_ids", source_ids))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)
}
/// Find entities by their relationship to the id
pub async fn find_entities_by_relationship_by_id( pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient, db: &SurrealDbClient,
entity_id: &str, entity_id: &str,
@@ -153,154 +116,8 @@ mod tests {
use super::*; use super::*;
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use common::storage::types::knowledge_relationship::KnowledgeRelationship; use common::storage::types::knowledge_relationship::KnowledgeRelationship;
use common::storage::types::StoredObject;
use uuid::Uuid; use uuid::Uuid;
#[tokio::test]
async fn test_find_entities_by_source_ids() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
// Create some test entities with different source_ids
let source_id1 = "source123".to_string();
let source_id2 = "source456".to_string();
let source_id3 = "source789".to_string();
let entity_type = KnowledgeEntityType::Document;
let user_id = "user123".to_string();
// Entity with source_id1
let entity1 = KnowledgeEntity::new(
source_id1.clone(),
"Entity 1".to_string(),
"Description 1".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
// Entity with source_id2
let entity2 = KnowledgeEntity::new(
source_id2.clone(),
"Entity 2".to_string(),
"Description 2".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
// Another entity with source_id1
let entity3 = KnowledgeEntity::new(
source_id1.clone(),
"Entity 3".to_string(),
"Description 3".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
// Entity with source_id3
let entity4 = KnowledgeEntity::new(
source_id3.clone(),
"Entity 4".to_string(),
"Description 4".to_string(),
entity_type.clone(),
None,
user_id.clone(),
);
// Store all entities
db.store_item(entity1.clone())
.await
.expect("Failed to store entity 1");
db.store_item(entity2.clone())
.await
.expect("Failed to store entity 2");
db.store_item(entity3.clone())
.await
.expect("Failed to store entity 3");
db.store_item(entity4.clone())
.await
.expect("Failed to store entity 4");
// Test finding entities by multiple source_ids
let source_ids = vec![source_id1.clone(), source_id2.clone()];
let found_entities: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db)
.await
.expect("Failed to find entities by source_ids");
// Should find 3 entities (2 with source_id1, 1 with source_id2)
assert_eq!(
found_entities.len(),
3,
"Should find 3 entities with the specified source_ids"
);
// Check that entities with source_id1 and source_id2 are found
let found_source_ids: Vec<String> =
found_entities.iter().map(|e| e.source_id.clone()).collect();
assert!(
found_source_ids.contains(&source_id1),
"Should find entities with source_id1"
);
assert!(
found_source_ids.contains(&source_id2),
"Should find entities with source_id2"
);
assert!(
!found_source_ids.contains(&source_id3),
"Should not find entities with source_id3"
);
// Test finding entities by a single source_id
let single_source_id = vec![source_id1.clone()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
single_source_id,
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
.expect("Failed to find entities by single source_id");
// Should find 2 entities with source_id1
assert_eq!(
found_entities.len(),
2,
"Should find 2 entities with source_id1"
);
// Check that all found entities have source_id1
for entity in found_entities {
assert_eq!(
entity.source_id, source_id1,
"All found entities should have source_id1"
);
}
// Test finding entities with non-existent source_id
let non_existent_source_id = vec!["non_existent_source".to_string()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
non_existent_source_id,
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
.expect("Failed to find entities by non-existent source_id");
// Should find 0 entities
assert_eq!(
found_entities.len(),
0,
"Should find 0 entities with non-existent source_id"
);
}
#[tokio::test] #[tokio::test]
async fn test_find_entities_by_relationship_by_id() { async fn test_find_entities_by_relationship_by_id() {
+29 -119
View File
@@ -1,6 +1,6 @@
pub mod answer_retrieval; pub mod answer_retrieval;
pub mod answer_retrieval_helper; pub mod answer_retrieval_helper;
pub mod fts;
pub mod graph; pub mod graph;
pub mod pipeline; pub mod pipeline;
pub mod reranking; pub mod reranking;
@@ -70,11 +70,7 @@ mod tests {
use super::*; use super::*;
use async_openai::Client; use async_openai::Client;
use common::storage::indexes::ensure_runtime_indexes; use common::storage::indexes::ensure_runtime_indexes;
use common::storage::types::{ use common::storage::types::text_chunk::TextChunk;
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
};
use pipeline::{RetrievalConfig, RetrievalStrategy}; use pipeline::{RetrievalConfig, RetrievalStrategy};
use uuid::Uuid; use uuid::Uuid;
@@ -82,14 +78,6 @@ mod tests {
vec![0.9, 0.1, 0.0] vec![0.9, 0.1, 0.0]
} }
fn entity_embedding_high() -> Vec<f32> {
vec![0.8, 0.2, 0.0]
}
fn entity_embedding_low() -> Vec<f32> {
vec![0.1, 0.9, 0.0]
}
fn chunk_embedding_primary() -> Vec<f32> { fn chunk_embedding_primary() -> Vec<f32> {
vec![0.85, 0.15, 0.0] vec![0.85, 0.15, 0.0]
} }
@@ -113,41 +101,19 @@ mod tests {
.await .await
.expect("failed to build runtime indexes"); .expect("failed to build runtime indexes");
db.query(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding;
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 3;
COMMIT TRANSACTION;",
)
.await
.expect("Failed to configure indices");
db db
} }
#[tokio::test] #[tokio::test]
async fn test_retrieve_entities_with_embedding_basic_flow() { async fn test_default_strategy_retrieves_chunks() {
let db = setup_test_db().await; let db = setup_test_db().await;
let user_id = "test_user"; let user_id = "test_user";
let entity = KnowledgeEntity::new(
"source_1".into(),
"Rust async guide".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
let chunk = TextChunk::new( let chunk = TextChunk::new(
entity.source_id.clone(), "source_1".into(),
"Tokio uses cooperative scheduling for fairness.".into(), "Tokio uses cooperative scheduling for fairness.".into(),
user_id.into(), user_id.into(),
); );
KnowledgeEntity::store_with_embedding(entity.clone(), entity_embedding_high(), &db)
.await
.expect("Failed to store entity");
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db) TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
.await .await
.expect("Failed to store chunk"); .expect("Failed to store chunk");
@@ -164,64 +130,32 @@ mod tests {
None, None,
) )
.await .await
.expect("Hybrid retrieval failed"); .expect("Default strategy retrieval failed");
let entities = match results { let chunks = match results {
StrategyOutput::Entities(items) => items, StrategyOutput::Chunks(items) => items,
other => panic!("expected entity results, got {:?}", other), other => panic!("expected chunk results, got {:?}", other),
}; };
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
assert!( assert!(
!entities.is_empty(), chunks[0].chunk.chunk.contains("Tokio"),
"Expected at least one retrieval result" "Expected chunk about Tokio"
);
let top = &entities[0];
assert!(
top.entity.name.contains("Rust"),
"Expected Rust entity to be ranked first"
);
assert!(
!top.chunks.is_empty(),
"Expected Rust entity to include supporting chunks"
); );
} }
#[tokio::test] #[tokio::test]
async fn test_graph_relationship_enriches_results() { async fn test_default_strategy_returns_chunks_from_multiple_sources() {
let db = setup_test_db().await; let db = setup_test_db().await;
let user_id = "graph_user"; let user_id = "multi_source_user";
let primary = KnowledgeEntity::new(
"primary_source".into(),
"Async Rust patterns".into(),
"Explores async runtimes and scheduling strategies.".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
let neighbor = KnowledgeEntity::new(
"neighbor_source".into(),
"Tokio Scheduler Deep Dive".into(),
"Details on Tokio's cooperative scheduler.".into(),
KnowledgeEntityType::Document,
None,
user_id.into(),
);
KnowledgeEntity::store_with_embedding(primary.clone(), entity_embedding_high(), &db)
.await
.expect("Failed to store primary entity");
KnowledgeEntity::store_with_embedding(neighbor.clone(), entity_embedding_low(), &db)
.await
.expect("Failed to store neighbor entity");
let primary_chunk = TextChunk::new( let primary_chunk = TextChunk::new(
primary.source_id.clone(), "primary_source".into(),
"Rust async tasks use Tokio's cooperative scheduler.".into(), "Rust async tasks use Tokio's cooperative scheduler.".into(),
user_id.into(), user_id.into(),
); );
let neighbor_chunk = TextChunk::new( let secondary_chunk = TextChunk::new(
neighbor.source_id.clone(), "secondary_source".into(),
"Tokio's scheduler manages task fairness across executors.".into(), "Tokio's scheduler manages task fairness across executors.".into(),
user_id.into(), user_id.into(),
); );
@@ -229,23 +163,11 @@ mod tests {
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db) TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db)
.await .await
.expect("Failed to store primary chunk"); .expect("Failed to store primary chunk");
TextChunk::store_with_embedding(neighbor_chunk, chunk_embedding_secondary(), &db) TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db)
.await .await
.expect("Failed to store neighbor chunk"); .expect("Failed to store secondary chunk");
let openai_client = Client::new(); let openai_client = Client::new();
let relationship = KnowledgeRelationship::new(
primary.id.clone(),
neighbor.id.clone(),
user_id.into(),
"relationship_source".into(),
"references".into(),
);
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
let results = pipeline::run_pipeline_with_embedding( let results = pipeline::run_pipeline_with_embedding(
&db, &db,
&openai_client, &openai_client,
@@ -257,35 +179,23 @@ mod tests {
None, None,
) )
.await .await
.expect("Hybrid retrieval failed"); .expect("Default strategy retrieval failed");
let entities = match results { let chunks = match results {
StrategyOutput::Entities(items) => items, StrategyOutput::Chunks(items) => items,
other => panic!("expected entity results, got {:?}", other), other => panic!("expected chunk results, got {:?}", other),
}; };
let mut neighbor_entry = None; assert!(chunks.len() >= 2, "Expected chunks from multiple sources");
for entity in &entities {
if entity.entity.id == neighbor.id {
neighbor_entry = Some(entity.clone());
}
}
println!("{:?}", entities);
let neighbor_entry =
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
assert!( assert!(
neighbor_entry.score > 0.2, chunks.iter().any(|c| c.chunk.source_id == "primary_source"),
"Graph-enriched entity should have a meaningful fused score" "Should include primary source chunk"
); );
assert!( assert!(
neighbor_entry chunks
.chunks
.iter() .iter()
.all(|chunk| chunk.chunk.source_id == neighbor.source_id), .any(|c| c.chunk.source_id == "secondary_source"),
"Neighbor entity should surface its own supporting chunks" "Should include secondary source chunk"
); );
} }
@@ -311,7 +221,7 @@ mod tests {
.await .await
.expect("Failed to store chunk two"); .expect("Failed to store chunk two");
let config = RetrievalConfig::with_strategy(RetrievalStrategy::Revised); let config = RetrievalConfig::with_strategy(RetrievalStrategy::Default);
let openai_client = Client::new(); let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding( let results = pipeline::run_pipeline_with_embedding(
&db, &db,
+17 -8
View File
@@ -6,15 +6,17 @@ use crate::scoring::FusionWeights;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, clap::ValueEnum)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy { pub enum RetrievalStrategy {
Initial, /// Primary hybrid chunk retrieval for search/chat (formerly Revised)
Revised, Default,
/// Entity retrieval for suggesting relationships when creating manual entities
RelationshipSuggestion, RelationshipSuggestion,
/// Entity retrieval for context during content ingestion
Ingestion, Ingestion,
} }
impl Default for RetrievalStrategy { impl Default for RetrievalStrategy {
fn default() -> Self { fn default() -> Self {
Self::Initial Self::Default
} }
} }
@@ -23,8 +25,16 @@ impl std::str::FromStr for RetrievalStrategy {
fn from_str(value: &str) -> Result<Self, Self::Err> { fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() { match value.to_ascii_lowercase().as_str() {
"initial" => Ok(Self::Initial), "default" => Ok(Self::Default),
"revised" => Ok(Self::Revised), // Backward compatibility: treat "initial" and "revised" as "default"
"initial" | "revised" => {
tracing::warn!(
"Retrieval strategy '{}' is deprecated. Use 'default' instead. \
The 'initial' strategy has been removed in favor of the simpler hybrid chunk retrieval.",
value
);
Ok(Self::Default)
}
"relationship_suggestion" => Ok(Self::RelationshipSuggestion), "relationship_suggestion" => Ok(Self::RelationshipSuggestion),
"ingestion" => Ok(Self::Ingestion), "ingestion" => Ok(Self::Ingestion),
other => Err(format!("unknown retrieval strategy '{other}'")), other => Err(format!("unknown retrieval strategy '{other}'")),
@@ -35,8 +45,7 @@ impl std::str::FromStr for RetrievalStrategy {
impl fmt::Display for RetrievalStrategy { impl fmt::Display for RetrievalStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let label = match self { let label = match self {
RetrievalStrategy::Initial => "initial", RetrievalStrategy::Default => "default",
RetrievalStrategy::Revised => "revised",
RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion", RetrievalStrategy::RelationshipSuggestion => "relationship_suggestion",
RetrievalStrategy::Ingestion => "ingestion", RetrievalStrategy::Ingestion => "ingestion",
}; };
@@ -136,7 +145,7 @@ pub struct RetrievalConfig {
impl RetrievalConfig { impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self { pub fn new(tuning: RetrievalTuning) -> Self {
Self { Self {
strategy: RetrievalStrategy::Initial, strategy: RetrievalStrategy::Default,
tuning, tuning,
} }
} }
+9 -87
View File
@@ -17,9 +17,7 @@ use std::time::{Duration, Instant};
use tracing::info; use tracing::info;
use stages::PipelineContext; use stages::PipelineContext;
use strategies::{ use strategies::{DefaultStrategyDriver, IngestionDriver, RelationshipSuggestionDriver};
IngestionDriver, InitialStrategyDriver, RelationshipSuggestionDriver, RevisedStrategyDriver,
};
// Export StrategyOutput publicly from this module // Export StrategyOutput publicly from this module
// (it's defined in lib.rs but we re-export it here) // (it's defined in lib.rs but we re-export it here)
@@ -132,25 +130,8 @@ pub async fn run_pipeline(
); );
match config.strategy { match config.strategy {
RetrievalStrategy::Initial => { RetrievalStrategy::Default => {
let driver = InitialStrategyDriver::new(); let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Revised => {
let driver = RevisedStrategyDriver::new();
let run = execute_strategy( let run = execute_strategy(
driver, driver,
db_client, db_client,
@@ -214,25 +195,8 @@ pub async fn run_pipeline_with_embedding(
reranker: Option<RerankerLease>, reranker: Option<RerankerLease>,
) -> Result<StrategyOutput, AppError> { ) -> Result<StrategyOutput, AppError> {
match config.strategy { match config.strategy {
RetrievalStrategy::Initial => { RetrievalStrategy::Default => {
let driver = InitialStrategyDriver::new(); let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(StrategyOutput::Entities(run.results))
}
RetrievalStrategy::Revised => {
let driver = RevisedStrategyDriver::new();
let run = execute_strategy( let run = execute_strategy(
driver, driver,
db_client, db_client,
@@ -301,29 +265,8 @@ pub async fn run_pipeline_with_embedding_with_metrics(
reranker: Option<RerankerLease>, reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> { ) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
match config.strategy { match config.strategy {
RetrievalStrategy::Initial => { RetrievalStrategy::Default => {
let driver = InitialStrategyDriver::new(); let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
false,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
RetrievalStrategy::Revised => {
let driver = RevisedStrategyDriver::new();
let run = execute_strategy( let run = execute_strategy(
driver, driver,
db_client, db_client,
@@ -361,29 +304,8 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
reranker: Option<RerankerLease>, reranker: Option<RerankerLease>,
) -> Result<PipelineRunOutput<StrategyOutput>, AppError> { ) -> Result<PipelineRunOutput<StrategyOutput>, AppError> {
match config.strategy { match config.strategy {
RetrievalStrategy::Initial => { RetrievalStrategy::Default => {
let driver = InitialStrategyDriver::new(); let driver = DefaultStrategyDriver::new();
let run = execute_strategy(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
config,
reranker,
true,
)
.await?;
Ok(PipelineRunOutput {
results: StrategyOutput::Entities(run.results),
diagnostics: run.diagnostics,
stage_timings: run.stage_timings,
})
}
RetrievalStrategy::Revised => {
let driver = RevisedStrategyDriver::new();
let run = execute_strategy( let run = execute_strategy(
driver, driver,
db_client, db_client,
+29 -264
View File
@@ -12,13 +12,13 @@ use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt}; use futures::{stream::FuturesUnordered, StreamExt};
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
collections::{HashMap, HashSet}, collections::HashMap,
}; };
use tracing::{debug, instrument, warn}; use tracing::{debug, instrument, warn};
use crate::{ use crate::{
fts::find_items_by_fts,
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids}, graph::find_entities_by_relationship_by_id,
reranking::RerankerLease, reranking::RerankerLease,
scoring::{ scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, reciprocal_rank_fusion, clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, reciprocal_rank_fusion,
@@ -45,7 +45,6 @@ pub struct PipelineContext<'a> {
pub config: RetrievalConfig, pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>, pub query_embedding: Option<Vec<f32>>,
pub entity_candidates: HashMap<String, Scored<KnowledgeEntity>>, pub entity_candidates: HashMap<String, Scored<KnowledgeEntity>>,
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
pub filtered_entities: Vec<Scored<KnowledgeEntity>>, pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>, pub chunk_values: Vec<Scored<TextChunk>>,
pub revised_chunk_values: Vec<Scored<TextChunk>>, pub revised_chunk_values: Vec<Scored<TextChunk>>,
@@ -75,7 +74,6 @@ impl<'a> PipelineContext<'a> {
config, config,
query_embedding: None, query_embedding: None,
entity_candidates: HashMap::new(), entity_candidates: HashMap::new(),
chunk_candidates: HashMap::new(),
filtered_entities: Vec::new(), filtered_entities: Vec::new(),
chunk_values: Vec::new(), chunk_values: Vec::new(),
revised_chunk_values: Vec::new(), revised_chunk_values: Vec::new(),
@@ -209,20 +207,6 @@ impl PipelineStage for GraphExpansionStage {
} }
} }
#[derive(Debug, Clone, Copy)]
pub struct ChunkAttachStage;
#[async_trait]
impl PipelineStage for ChunkAttachStage {
fn kind(&self) -> StageKind {
StageKind::ChunkAttach
}
async fn execute(&self, ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
attach_chunks(ctx).await
}
}
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct RerankStage; pub struct RerankStage;
@@ -324,75 +308,68 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App
let weights = FusionWeights::default(); let weights = FusionWeights::default();
let (vector_entity_results, vector_chunk_results, mut fts_entities, mut fts_chunks) = tokio::try_join!( let (vector_entity_results, fts_entity_results) = tokio::try_join!(
KnowledgeEntity::vector_search( KnowledgeEntity::vector_search(
tuning.entity_vector_take, tuning.entity_vector_take,
embedding.clone(),
ctx.db_client,
&ctx.user_id,
),
TextChunk::vector_search(
tuning.chunk_vector_take,
embedding, embedding,
ctx.db_client, ctx.db_client,
&ctx.user_id, &ctx.user_id,
), ),
find_items_by_fts( KnowledgeEntity::search(
tuning.entity_fts_take,
&ctx.input_text,
ctx.db_client, ctx.db_client,
"knowledge_entity", &ctx.input_text,
&ctx.user_id, &ctx.user_id,
), tuning.entity_fts_take,
find_items_by_fts( )
tuning.chunk_fts_take,
&ctx.input_text,
ctx.db_client,
"text_chunk",
&ctx.user_id
),
)?; )?;
#[allow(clippy::useless_conversion)]
let vector_entities: Vec<Scored<KnowledgeEntity>> = vector_entity_results let vector_entities: Vec<Scored<KnowledgeEntity>> = vector_entity_results
.into_iter() .into_iter()
.map(|row| Scored::new(row.entity).with_vector_score(row.score)) .map(|row| Scored::new(row.entity).with_vector_score(row.score))
.collect(); .collect();
let vector_chunks: Vec<Scored<TextChunk>> = vector_chunk_results
let mut fts_entities: Vec<Scored<KnowledgeEntity>> = fts_entity_results
.into_iter() .into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score)) .map(|res| {
let entity = KnowledgeEntity {
id: res.id,
created_at: res.created_at,
updated_at: res.updated_at,
source_id: res.source_id,
name: res.name,
description: res.description,
entity_type: res.entity_type,
metadata: res.metadata,
user_id: res.user_id,
};
Scored::new(entity).with_fts_score(res.score)
})
.collect(); .collect();
debug!( debug!(
vector_entities = vector_entities.len(), vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
fts_entities = fts_entities.len(), fts_entities = fts_entities.len(),
fts_chunks = fts_chunks.len(),
"Hybrid retrieval initial candidate counts" "Hybrid retrieval initial candidate counts"
); );
if ctx.diagnostics_enabled() { if ctx.diagnostics_enabled() {
ctx.record_collect_candidates(CollectCandidatesStats { ctx.record_collect_candidates(CollectCandidatesStats {
vector_entity_candidates: vector_entities.len(), vector_entity_candidates: vector_entities.len(),
vector_chunk_candidates: vector_chunks.len(), vector_chunk_candidates: 0,
fts_entity_candidates: fts_entities.len(), fts_entity_candidates: fts_entities.len(),
fts_chunk_candidates: fts_chunks.len(), fts_chunk_candidates: 0,
vector_chunk_scores: sample_scores(&vector_chunks, |chunk| { vector_chunk_scores: Vec::new(),
chunk.scores.vector.unwrap_or(0.0) fts_chunk_scores: Vec::new(),
}),
fts_chunk_scores: sample_scores(&fts_chunks, |chunk| chunk.scores.fts.unwrap_or(0.0)),
}); });
} }
normalize_fts_scores(&mut fts_entities); normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
merge_scored_by_id(&mut ctx.entity_candidates, vector_entities); merge_scored_by_id(&mut ctx.entity_candidates, vector_entities);
merge_scored_by_id(&mut ctx.entity_candidates, fts_entities); merge_scored_by_id(&mut ctx.entity_candidates, fts_entities);
merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks);
merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks);
apply_fusion(&mut ctx.entity_candidates, weights); apply_fusion(&mut ctx.entity_candidates, weights);
apply_fusion(&mut ctx.chunk_candidates, weights);
Ok(()) Ok(())
} }
@@ -467,82 +444,6 @@ pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError>
Ok(()) Ok(())
} }
#[instrument(level = "trace", skip_all)]
pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Attaching chunks to surviving entities");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
let chunk_candidates_before = ctx.chunk_candidates.len();
let chunk_sources_considered = chunk_by_source.len();
backfill_entities_from_chunks(
&mut ctx.entity_candidates,
&chunk_by_source,
ctx.db_client,
&ctx.user_id,
weights,
)
.await?;
boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights);
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
ctx.entity_candidates.values().cloned().collect();
sort_by_fused_desc(&mut entity_results);
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
.iter()
.filter(|candidate| candidate.fused >= tuning.score_threshold)
.cloned()
.collect();
if filtered_entities.len() < tuning.fallback_min_results {
filtered_entities = entity_results
.into_iter()
.take(tuning.fallback_min_results)
.collect();
}
ctx.filtered_entities = filtered_entities;
let mut chunk_results: Vec<Scored<TextChunk>> =
ctx.chunk_candidates.values().cloned().collect();
sort_by_fused_desc(&mut chunk_results);
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
for chunk in chunk_results {
chunk_by_id.insert(chunk.item.id.clone(), chunk);
}
enrich_chunks_from_entities(
&mut chunk_by_id,
&ctx.filtered_entities,
ctx.db_client,
&ctx.user_id,
weights,
)
.await?;
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
if ctx.diagnostics_enabled() {
ctx.record_chunk_enrichment(ChunkEnrichmentStats {
filtered_entity_count: ctx.filtered_entities.len(),
fallback_min_results: tuning.fallback_min_results,
chunk_sources_considered,
chunk_candidates_before_enrichment: chunk_candidates_before,
chunk_candidates_after_enrichment: chunk_values.len(),
top_chunk_scores: sample_scores(&chunk_values, |chunk| chunk.fused),
});
}
ctx.chunk_values = chunk_values;
Ok(())
}
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { pub async fn rerank(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
@@ -960,142 +861,6 @@ where
} }
} }
fn group_chunks_by_source(
chunks: &HashMap<String, Scored<TextChunk>>,
) -> HashMap<String, Vec<Scored<TextChunk>>> {
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.values() {
by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk.clone());
}
by_source
}
async fn backfill_entities_from_chunks(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
let mut missing_sources = Vec::new();
for source_id in chunk_by_source.keys() {
if !entity_candidates
.values()
.any(|entity| entity.item.source_id == *source_id)
{
missing_sources.push(source_id.clone());
}
}
if missing_sources.is_empty() {
return Ok(());
}
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
missing_sources.clone(),
"knowledge_entity",
user_id,
db_client,
)
.await
.unwrap_or_default();
if related_entities.is_empty() {
warn!("expected related entities for missing chunk sources, but none were found");
}
for entity in related_entities {
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
entity_candidates.insert(entity.id.clone(), scored);
}
}
Ok(())
}
fn boost_entities_with_chunks(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
weights: FusionWeights,
) {
for entity in entity_candidates.values_mut() {
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
if best_chunk_score > 0.0 {
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
entity.scores.vector = Some(boosted);
let fused = fuse_scores(&entity.scores, weights);
entity.update_fused(fused);
}
}
}
}
async fn enrich_chunks_from_entities(
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
entities: &[Scored<KnowledgeEntity>],
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
source_ids.insert(entity.item.source_id.clone());
}
if source_ids.is_empty() {
return Ok(());
}
let chunks = find_entities_by_source_ids::<TextChunk>(
source_ids.into_iter().collect(),
"text_chunk",
user_id,
db_client,
)
.await?;
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
for entity in entities {
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
}
for chunk in chunks {
let entry = chunk_candidates
.entry(chunk.id.clone())
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
let entity_score = entity_score_lookup
.get(&chunk.source_id)
.copied()
.unwrap_or(0.0);
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
}
Ok(())
}
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> { fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
if ctx.filtered_entities.is_empty() { if ctx.filtered_entities.is_empty() {
return Vec::new(); return Vec::new();
+7 -33
View File
@@ -1,50 +1,24 @@
use super::{ use super::{
stages::{ stages::{
AssembleEntitiesStage, ChunkAssembleStage, ChunkAttachStage, ChunkRerankStage, AssembleEntitiesStage, ChunkAssembleStage, ChunkRerankStage, ChunkVectorStage,
ChunkVectorStage, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, CollectCandidatesStage, EmbedStage, GraphExpansionStage, PipelineContext, RerankStage,
RerankStage,
}, },
BoxedStage, StrategyDriver, BoxedStage, StrategyDriver,
}; };
use crate::{RetrievedChunk, RetrievedEntity}; use crate::{RetrievedChunk, RetrievedEntity};
use common::error::AppError; use common::error::AppError;
pub struct InitialStrategyDriver;
impl InitialStrategyDriver {
pub struct DefaultStrategyDriver;
impl DefaultStrategyDriver {
pub fn new() -> Self { pub fn new() -> Self {
Self Self
} }
} }
impl StrategyDriver for InitialStrategyDriver { impl StrategyDriver for DefaultStrategyDriver {
type Output = Vec<RetrievedEntity>;
fn stages(&self) -> Vec<BoxedStage> {
vec![
Box::new(EmbedStage),
Box::new(CollectCandidatesStage),
Box::new(GraphExpansionStage),
Box::new(ChunkAttachStage),
Box::new(RerankStage),
Box::new(AssembleEntitiesStage),
]
}
fn finalize(&self, ctx: &mut PipelineContext<'_>) -> Result<Self::Output, AppError> {
Ok(ctx.take_entity_results())
}
}
pub struct RevisedStrategyDriver;
impl RevisedStrategyDriver {
pub fn new() -> Self {
Self
}
}
impl StrategyDriver for RevisedStrategyDriver {
type Output = Vec<RetrievedChunk>; type Output = Vec<RetrievedChunk>;
fn stages(&self) -> Vec<BoxedStage> { fn stages(&self) -> Vec<BoxedStage> {