fix: arc-share retrieved chunks, centralize entity embeddings, and trim hot-path clones.

This commit is contained in:
Per Stark
2026-06-06 23:05:53 +02:00
parent 676fdbc132
commit 4559ee0aa8
41 changed files with 368 additions and 289 deletions
+4 -3
View File
@@ -92,9 +92,10 @@ mod tests {
#[test] #[test]
fn extract_api_key_rejects_invalid_header_values() { fn extract_api_key_rejects_invalid_header_values() {
let mut request = request_with_headers(&[]); let mut request = request_with_headers(&[]);
request request.headers_mut().insert(
.headers_mut() "X-API-Key",
.insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header")); HeaderValue::from_bytes(&[0xFF]).expect("invalid header"),
);
assert_eq!(extract_api_key(&request), None); assert_eq!(extract_api_key(&request), None);
} }
} }
+3 -11
View File
@@ -9,11 +9,7 @@ use axum::{
Router, Router,
}; };
use common::{ use common::{
storage::{ storage::{db::SurrealDbClient, store::StorageManager, types::user::User},
db::SurrealDbClient,
store::StorageManager,
types::user::User,
},
utils::config::{AppConfig, StorageKind}, utils::config::{AppConfig, StorageKind},
}; };
use tower::ServiceExt; use tower::ServiceExt;
@@ -34,9 +30,7 @@ async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
storage: StorageKind::Memory, storage: StorageKind::Memory,
..Default::default() ..Default::default()
}; };
let storage = StorageManager::new(&config) let storage = StorageManager::new(&config).await.expect("storage manager");
.await
.expect("storage manager");
let state = ApiState { let state = ApiState {
db: Arc::clone(&db), db: Arc::clone(&db),
@@ -147,9 +141,7 @@ async fn authenticated_user_can_list_categories() {
.await .await
.expect("test user"); .expect("test user");
let api_key = User::set_api_key(&user.id, &db) let api_key = User::set_api_key(&user.id, &db).await.expect("api key");
.await
.expect("api key");
let response = app let response = app
.clone() .clone()
+52 -7
View File
@@ -3,9 +3,12 @@ use std::collections::HashMap;
use std::fmt::Write; use std::fmt::Write;
use crate::{ use crate::{
error::AppError, storage::db::SurrealDbClient, storage::indexes::hnsw_index_overwrite_sql, error::AppError,
storage::db::SurrealDbClient,
storage::indexes::hnsw_index_overwrite_sql,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
storage::types::system_settings::SystemSettings, stored_object, storage::types::system_settings::SystemSettings,
stored_object,
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE}, utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
}; };
use tracing::{error, info}; use tracing::{error, info};
@@ -25,6 +28,17 @@ impl KnowledgeEntityType {
pub fn variants() -> &'static [&'static str] { pub fn variants() -> &'static [&'static str] {
&["Idea", "Project", "Document", "Page", "TextSnippet"] &["Idea", "Project", "Document", "Page", "TextSnippet"]
} }
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Idea => "Idea",
Self::Project => "Project",
Self::Document => "Document",
Self::Page => "Page",
Self::TextSnippet => "TextSnippet",
}
}
} }
impl From<String> for KnowledgeEntityType { impl From<String> for KnowledgeEntityType {
@@ -80,6 +94,27 @@ impl KnowledgeEntity {
} }
} }
/// Canonical text fed to the embedding provider for a knowledge entity.
#[must_use]
pub fn embedding_input_text(
name: &str,
description: &str,
entity_type: KnowledgeEntityType,
) -> String {
let mut out = String::with_capacity(
name.len()
.saturating_add(description.len())
.saturating_add(entity_type.as_str().len())
.saturating_add(32),
);
let _ = write!(
out,
"name: {name}, description: {description}, type: {}",
entity_type.as_str()
);
out
}
/// Full-text search over knowledge entities using the BM25 FTS index. /// Full-text search over knowledge entities using the BM25 FTS index.
pub async fn fts_search( pub async fn fts_search(
take: usize, take: usize,
@@ -314,8 +349,7 @@ impl KnowledgeEntity {
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
embedding_provider: &EmbeddingProvider, embedding_provider: &EmbeddingProvider,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let embedding_input = let embedding_input = Self::embedding_input_text(name, description, *entity_type);
format!("name: {name}, description: {description}, type: {entity_type:?}",);
let embedding = embedding_provider.embed(&embedding_input).await?; let embedding = embedding_provider.embed(&embedding_input).await?;
let entity: KnowledgeEntity = db_client let entity: KnowledgeEntity = db_client
@@ -402,9 +436,10 @@ impl KnowledgeEntity {
let inputs: Vec<String> = batch let inputs: Vec<String> = batch
.iter() .iter()
.map(|entity| { .map(|entity| {
format!( Self::embedding_input_text(
"name: {}, description: {}, type: {:?}", &entity.name,
entity.name, entity.description, entity.entity_type &entity.description,
entity.entity_type,
) )
}) })
.collect(); .collect();
@@ -523,6 +558,16 @@ mod tests {
use anyhow::{self, Context}; use anyhow::{self, Context};
use uuid::Uuid; use uuid::Uuid;
#[test]
fn embedding_input_text_uses_canonical_type_label() {
let text = KnowledgeEntity::embedding_input_text(
"Alpha",
"Beta",
KnowledgeEntityType::TextSnippet,
);
assert_eq!(text, "name: Alpha, description: Beta, type: TextSnippet");
}
async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> { async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> {
let snowball_sql = r#" let snowball_sql = r#"
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english); DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
@@ -122,8 +122,7 @@ impl KnowledgeRelationship {
.bind(("user_id", user_id.to_owned())) .bind(("user_id", user_id.to_owned()))
.await .await
.map_err(AppError::from)?; .map_err(AppError::from)?;
let deleted: Vec<KnowledgeRelationship> = let deleted: Vec<KnowledgeRelationship> = delete_result.take(0).map_err(AppError::from)?;
delete_result.take(0).map_err(AppError::from)?;
if !deleted.is_empty() { if !deleted.is_empty() {
return Ok(()); return Ok(());
@@ -567,8 +566,8 @@ mod tests {
shared_source.to_string(), shared_source.to_string(),
"references".to_string(), "references".to_string(),
); );
let rel_a_id = rel_a.id.clone(); let owner_relationship_id = rel_a.id.clone();
let rel_b_id = rel_b.id.clone(); let other_relationship_id = rel_b.id.clone();
rel_a.store_relationship(&db).await?; rel_a.store_relationship(&db).await?;
rel_b.store_relationship(&db).await?; rel_b.store_relationship(&db).await?;
@@ -576,8 +575,12 @@ mod tests {
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db) KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
.await?; .await?;
assert!(get_relationship_by_id(&rel_a_id, &db).await.is_none()); assert!(get_relationship_by_id(&owner_relationship_id, &db)
assert!(get_relationship_by_id(&rel_b_id, &db).await.is_some()); .await
.is_none());
assert!(get_relationship_by_id(&other_relationship_id, &db)
.await
.is_some());
Ok(()) Ok(())
} }
+5 -1
View File
@@ -299,7 +299,11 @@ impl TextChunk {
} }
processed = processed.saturating_add(batch.len()); processed = processed.saturating_add(batch.len());
info!(progress = processed, total = total_chunks, "Re-embedding progress"); info!(
progress = processed,
total = total_chunks,
"Re-embedding progress"
);
} }
info!("Successfully generated all new embeddings."); info!("Successfully generated all new embeddings.");
+1 -2
View File
@@ -140,8 +140,7 @@ impl TextContent {
.await .await
.map_err(AppError::from)?; .map_err(AppError::from)?;
let existing: Option<surrealdb::sql::Thing> = let existing: Option<surrealdb::sql::Thing> = response.take(0).map_err(AppError::from)?;
response.take(0).map_err(AppError::from)?;
Ok(existing.is_some()) Ok(existing.is_some())
} }
+21 -18
View File
@@ -38,7 +38,7 @@ enum EmbeddingInner {
/// Client used to issue embedding requests. /// Client used to issue embedding requests.
client: Arc<Client<async_openai::config::OpenAIConfig>>, client: Arc<Client<async_openai::config::OpenAIConfig>>,
/// Model identifier for the API. /// Model identifier for the API.
model: String, model: Arc<str>,
/// Expected output dimensions. /// Expected output dimensions.
dimensions: u32, dimensions: u32,
}, },
@@ -272,8 +272,9 @@ struct FastEmbedLease {
} }
impl FastEmbedLease { impl FastEmbedLease {
async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> { async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let engine = Arc::clone(&self.engine); let engine = Arc::clone(&self.engine);
let texts = texts.to_vec();
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> { tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?; let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?;
guard.embed(texts, None).map_err(EmbeddingError::fastembed) guard.embed(texts, None).map_err(EmbeddingError::fastembed)
@@ -293,7 +294,7 @@ impl Drop for FastEmbedLease {
async fn run_fastembed( async fn run_fastembed(
pool: &Arc<FastEmbedPool>, pool: &Arc<FastEmbedPool>,
texts: Vec<String>, texts: &[String],
) -> Result<Vec<Vec<f32>>, EmbeddingError> { ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let lease = pool.checkout().await?; let lease = pool.checkout().await?;
lease.embed(texts).await lease.embed(texts).await
@@ -323,7 +324,7 @@ impl EmbeddingProvider {
pub fn model_code(&self) -> Option<String> { pub fn model_code(&self) -> Option<String> {
match &self.inner { match &self.inner {
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()), EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
EmbeddingInner::OpenAI { model, .. } => Some(model.clone()), EmbeddingInner::OpenAI { model, .. } => Some(model.as_ref().to_owned()),
EmbeddingInner::Hashed { .. } => None, EmbeddingInner::Hashed { .. } => None,
} }
} }
@@ -338,7 +339,8 @@ impl EmbeddingProvider {
match &self.inner { match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { pool, .. } => { EmbeddingInner::FastEmbed { pool, .. } => {
let embeddings = run_fastembed(pool, vec![text.to_owned()]).await?; let text = text.to_owned();
let embeddings = run_fastembed(pool, std::slice::from_ref(&text)).await?;
embeddings.into_iter().next().ok_or(EmbeddingError::NoData) embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
} }
EmbeddingInner::OpenAI { EmbeddingInner::OpenAI {
@@ -347,7 +349,7 @@ impl EmbeddingProvider {
dimensions, dimensions,
} => { } => {
let request = CreateEmbeddingRequestArgs::default() let request = CreateEmbeddingRequestArgs::default()
.model(model.clone()) .model(model.as_ref())
.input([text]) .input([text])
.dimensions(*dimensions) .dimensions(*dimensions)
.build()?; .build()?;
@@ -382,7 +384,7 @@ impl EmbeddingProvider {
if texts.is_empty() { if texts.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
run_fastembed(pool, texts.to_vec()).await run_fastembed(pool, texts).await
} }
EmbeddingInner::OpenAI { EmbeddingInner::OpenAI {
client, client,
@@ -394,7 +396,7 @@ impl EmbeddingProvider {
} }
let request = CreateEmbeddingRequestArgs::default() let request = CreateEmbeddingRequestArgs::default()
.model(model.clone()) .model(model.as_ref())
.input(texts.to_vec()) .input(texts.to_vec())
.dimensions(*dimensions) .dimensions(*dimensions)
.build()?; .build()?;
@@ -417,13 +419,13 @@ impl EmbeddingProvider {
/// Currently infallible; reserved for future validation. /// Currently infallible; reserved for future validation.
pub fn new_openai( pub fn new_openai(
client: Arc<Client<async_openai::config::OpenAIConfig>>, client: Arc<Client<async_openai::config::OpenAIConfig>>,
model: String, model: impl AsRef<str>,
dimensions: u32, dimensions: u32,
) -> Result<Self, EmbeddingError> { ) -> Result<Self, EmbeddingError> {
Ok(Self { Ok(Self {
inner: EmbeddingInner::OpenAI { inner: EmbeddingInner::OpenAI {
client, client,
model, model: Arc::from(model.as_ref()),
dimensions, dimensions,
}, },
}) })
@@ -520,7 +522,7 @@ impl EmbeddingProvider {
"openai embedding backend requires an openai client".into(), "openai embedding backend requires an openai client".into(),
) )
})?; })?;
Self::new_openai(client, settings.embedding_model.clone(), dimensions) Self::new_openai(client, settings.embedding_model.as_str(), dimensions)
} }
EmbeddingBackend::FastEmbed => { EmbeddingBackend::FastEmbed => {
let pool_size = config let pool_size = config
@@ -586,11 +588,12 @@ mod tests {
#![allow(clippy::expect_used)] #![allow(clippy::expect_used)]
use super::{ use super::{
align_fastembed_system_settings, fastembed_model_dimension, list_fastembed_embedding_models, align_fastembed_system_settings, fastembed_model_dimension,
resolve_fastembed_model_code, DEFAULT_FASTEMBED_MODEL_CODE, EmbeddingError, list_fastembed_embedding_models, resolve_fastembed_model_code, EmbeddingError,
DEFAULT_FASTEMBED_MODEL_CODE,
}; };
use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError};
use crate::storage::types::system_settings::SystemSettings; use crate::storage::types::system_settings::SystemSettings;
use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError};
use serde_json::json; use serde_json::json;
#[test] #[test]
@@ -656,16 +659,16 @@ mod tests {
fastembed_model: Some("Xenova/bge-base-en-v1.5".into()), fastembed_model: Some("Xenova/bge-base-en-v1.5".into()),
..AppConfig::default() ..AppConfig::default()
}; };
let resolved = resolve_fastembed_model_code(&config, "text-embedding-3-small") let resolved =
.expect("config model"); resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("config model");
assert_eq!(resolved, "Xenova/bge-base-en-v1.5"); assert_eq!(resolved, "Xenova/bge-base-en-v1.5");
} }
#[test] #[test]
fn resolve_fastembed_model_falls_back_from_openai_default() { fn resolve_fastembed_model_falls_back_from_openai_default() {
let config = AppConfig::default(); let config = AppConfig::default();
let resolved = resolve_fastembed_model_code(&config, "text-embedding-3-small") let resolved =
.expect("default model"); resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("default model");
assert_eq!(resolved, DEFAULT_FASTEMBED_MODEL_CODE); assert_eq!(resolved, DEFAULT_FASTEMBED_MODEL_CODE);
} }
@@ -42,14 +42,12 @@ pub(crate) async fn prepare_db(
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed) // Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
let embedding_provider = match config.embedding_backend { let embedding_provider = match config.embedding_backend {
crate::args::EmbeddingBackend::FastEmbed => { crate::args::EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed(
EmbeddingProvider::new_fastembed( config.embedding_model.clone(),
config.embedding_model.clone(), default_embedding_pool_size(),
default_embedding_pool_size(), )
) .await
.await .context("creating FastEmbed provider")?,
.context("creating FastEmbed provider")?
}
crate::args::EmbeddingBackend::Hashed => { crate::args::EmbeddingBackend::Hashed => {
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")? EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
} }
+4 -4
View File
@@ -196,7 +196,7 @@ pub struct EvaluationCandidate {
} }
impl EvaluationCandidate { impl EvaluationCandidate {
fn from_entity(entity: RetrievedEntity) -> Self { fn from_entity(entity: &RetrievedEntity) -> Self {
let entity_category = Some(format!("{:?}", entity.entity.entity_type)); let entity_category = Some(format!("{:?}", entity.entity.entity_type));
Self { Self {
entity_id: entity.entity.id().to_string(), entity_id: entity.entity.id().to_string(),
@@ -223,9 +223,9 @@ impl EvaluationCandidate {
} }
} }
fn candidates_from_entities(entities: Vec<RetrievedEntity>) -> Vec<EvaluationCandidate> { fn candidates_from_entities(entities: &[RetrievedEntity]) -> Vec<EvaluationCandidate> {
entities entities
.into_iter() .iter()
.map(EvaluationCandidate::from_entity) .map(EvaluationCandidate::from_entity)
.collect() .collect()
} }
@@ -241,7 +241,7 @@ pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec<EvaluationCandidat
match output { match output {
RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks), RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks),
RetrievalOutput::WithEntities { chunks, entities } => { RetrievalOutput::WithEntities { chunks, entities } => {
let mut candidates = candidates_from_entities(entities); let mut candidates = candidates_from_entities(&entities);
candidates.extend(candidates_from_chunks(chunks)); candidates.extend(candidates_from_chunks(chunks));
candidates.sort_by(|a, b| b.score.total_cmp(&a.score)); candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
candidates candidates
+3 -1
View File
@@ -142,7 +142,9 @@ impl HtmlState {
return; return;
} }
let overflow = cache.len().saturating_sub(CONVERSATION_ARCHIVE_CACHE_MAX_USERS); let overflow = cache
.len()
.saturating_sub(CONVERSATION_ARCHIVE_CACHE_MAX_USERS);
let mut by_expiry: Vec<(String, Instant)> = cache let mut by_expiry: Vec<(String, Instant)> = cache
.iter() .iter()
.map(|(user_id, entry)| (user_id.clone(), entry.expires_at)) .map(|(user_id, entry)| (user_id.clone(), entry.expires_at))
@@ -183,9 +183,7 @@ fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap)
} }
} }
fn context_to_map( fn context_to_map(value: &Value) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
value: &Value,
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
match value.kind() { match value.kind() {
minijinja::value::ValueKind::Map => { minijinja::value::ValueKind::Map => {
let mut map = HashMap::new(); let mut map = HashMap::new();
+2 -2
View File
@@ -8,8 +8,8 @@ use surrealdb::{engine::any::Any, Surreal};
use crate::{ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::{
analytics_middleware::analytics_middleware, auth_middleware::require_auth, analytics_middleware::analytics_middleware, auth_middleware::require_auth, compression,
compression, response_middleware::with_template_response, response_middleware::with_template_response,
}, },
}; };
+1 -3
View File
@@ -164,9 +164,7 @@ pub async fn update_theme(
)) ))
} }
pub async fn show_change_password( pub async fn show_change_password(RequireUser(_user): RequireUser) -> TemplateResult {
RequireUser(_user): RequireUser,
) -> TemplateResult {
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"auth/change_password_form.html", "auth/change_password_form.html",
(), (),
+22 -27
View File
@@ -18,8 +18,8 @@ use common::{
utils::{ utils::{
config::AppConfig, config::AppConfig,
embedding::{ embedding::{
fastembed_model_dimension, is_valid_fastembed_model_code, list_fastembed_embedding_models, fastembed_model_dimension, is_valid_fastembed_model_code,
EmbeddingBackend, FastEmbedModelOption, list_fastembed_embedding_models, EmbeddingBackend, FastEmbedModelOption,
}, },
}, },
}; };
@@ -52,7 +52,6 @@ pub enum AdminSection {
Models, Models,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct AdminPanelQuery { pub struct AdminPanelQuery {
section: Option<String>, section: Option<String>,
@@ -101,8 +100,9 @@ pub async fn show_admin_panel(
(None, None, false) (None, None, false)
}; };
let effective_backend = let effective_backend = effective_embedding_backend(&settings, &state.config)
effective_embedding_backend(&settings, &state.config).as_str().to_string(); .as_str()
.to_string();
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"admin/base.html", "admin/base.html",
@@ -187,7 +187,9 @@ struct EmbeddingSettingsPlan {
} }
fn effective_embedding_backend(settings: &SystemSettings, config: &AppConfig) -> EmbeddingBackend { fn effective_embedding_backend(settings: &SystemSettings, config: &AppConfig) -> EmbeddingBackend {
settings.embedding_backend.unwrap_or(config.embedding_backend) settings
.embedding_backend
.unwrap_or(config.embedding_backend)
} }
fn is_fastembed_admin_context(settings: &SystemSettings, config: &AppConfig) -> bool { fn is_fastembed_admin_context(settings: &SystemSettings, config: &AppConfig) -> bool {
@@ -241,11 +243,10 @@ fn plan_embedding_settings_update(
))); )));
} }
let embedding_dimensions = fastembed_model_dimension(&embedding_model) let embedding_dimensions =
.map_err(AppError::from)?; fastembed_model_dimension(&embedding_model).map_err(AppError::from)?;
let reembedding_needed = embedding_dimensions != current.embedding_dimensions; let reembedding_needed = embedding_dimensions != current.embedding_dimensions;
let restart_needed = let restart_needed = embedding_model != current.embedding_model || reembedding_needed;
embedding_model != current.embedding_model || reembedding_needed;
Ok(EmbeddingSettingsPlan { Ok(EmbeddingSettingsPlan {
embedding_model, embedding_model,
@@ -274,8 +275,7 @@ pub async fn update_model_settings(
Form(input): Form<ModelSettingsInput>, Form(input): Form<ModelSettingsInput>,
) -> TemplateResult { ) -> TemplateResult {
let current_settings = SystemSettings::get_current(&state.db).await?; let current_settings = SystemSettings::get_current(&state.db).await?;
let embedding_plan = let embedding_plan = plan_embedding_settings_update(&current_settings, &input, &state.config)?;
plan_embedding_settings_update(&current_settings, &input, &state.config)?;
let new_settings = SystemSettingsPatch { let new_settings = SystemSettingsPatch {
query_model: Some(input.query_model), query_model: Some(input.query_model),
@@ -309,10 +309,11 @@ pub async fn update_model_settings(
.await .await
.map_err(|_e| AppError::InternalError("Failed to get models".to_string()))?; .map_err(|_e| AppError::InternalError("Failed to get models".to_string()))?;
let effective_backend = let effective_backend = effective_embedding_backend(&new_settings, &state.config)
effective_embedding_backend(&new_settings, &state.config).as_str().to_string(); .as_str()
let show_fastembed_models = .to_string();
is_fastembed_admin_context(&new_settings, &state.config).then(list_fastembed_embedding_models); let show_fastembed_models = is_fastembed_admin_context(&new_settings, &state.config)
.then(list_fastembed_embedding_models);
Ok(TemplateResponse::new_partial( Ok(TemplateResponse::new_partial(
"admin/sections/models.html", "admin/sections/models.html",
@@ -368,8 +369,8 @@ mod tests {
embedding_model: Some("Xenova/bge-base-en-v1.5".into()), embedding_model: Some("Xenova/bge-base-en-v1.5".into()),
embedding_dimensions: None, embedding_dimensions: None,
}; };
let plan = plan_embedding_settings_update(&current, &input, &AppConfig::default()) let plan =
.expect("plan"); plan_embedding_settings_update(&current, &input, &AppConfig::default()).expect("plan");
assert_eq!(plan.embedding_model, "Xenova/bge-base-en-v1.5"); assert_eq!(plan.embedding_model, "Xenova/bge-base-en-v1.5");
assert_eq!(plan.embedding_dimensions, 768); assert_eq!(plan.embedding_dimensions, 768);
assert!(plan.reembedding_needed); assert!(plan.reembedding_needed);
@@ -407,9 +408,7 @@ pub struct SystemPromptEditData {
default_query_prompt: String, default_query_prompt: String,
} }
pub async fn show_edit_system_prompt( pub async fn show_edit_system_prompt(State(state): State<HtmlState>) -> TemplateResult {
State(state): State<HtmlState>,
) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -457,9 +456,7 @@ pub struct IngestionPromptEditData {
default_ingestion_prompt: String, default_ingestion_prompt: String,
} }
pub async fn show_edit_ingestion_prompt( pub async fn show_edit_ingestion_prompt(State(state): State<HtmlState>) -> TemplateResult {
State(state): State<HtmlState>,
) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -502,9 +499,7 @@ pub struct ImagePromptEditData {
default_image_prompt: String, default_image_prompt: String,
} }
pub async fn show_edit_image_prompt( pub async fn show_edit_image_prompt(State(state): State<HtmlState>) -> TemplateResult {
State(state): State<HtmlState>,
) -> TemplateResult {
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
+4 -1
View File
@@ -2,7 +2,10 @@ use axum::{extract::State, Form};
use axum_htmx::HxBoosted; use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use common::{error::AppError, storage::types::user::{Theme, User}}; use common::{
error::AppError,
storage::types::user::{Theme, User},
};
use crate::{ use crate::{
html_state::HtmlState, html_state::HtmlState,
+2 -2
View File
@@ -18,8 +18,8 @@ use crate::{
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{ response_middleware::{
template_as_response, template_with_headers, TemplateResponse, TemplateResult, template_as_response, template_with_headers, ResponseResult, TemplateResponse,
ResponseResult, TemplateResult,
}, },
}, },
}; };
@@ -22,8 +22,8 @@ use retrieval_pipeline::answer_retrieval::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::from_str; use serde_json::from_str;
use tokio::sync::Mutex;
use tokio::sync::mpsc::channel; use tokio::sync::mpsc::channel;
use tokio::sync::Mutex;
use tracing::{debug, error, info}; use tracing::{debug, error, info};
use common::storage::{ use common::storage::{
@@ -36,10 +36,7 @@ use common::storage::{
}, },
}; };
use crate::{ use crate::{html_state::HtmlState, middlewares::auth_middleware::RequireUser};
html_state::HtmlState,
middlewares::auth_middleware::RequireUser,
};
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references}; use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
+3 -7
View File
@@ -3,15 +3,11 @@ mod message_response_stream;
mod reference_validation; mod reference_validation;
mod references; mod references;
use axum::{ use axum::{extract::FromRef, routing::get, Router};
extract::FromRef,
routing::get,
Router,
};
pub use chat_handlers::{ pub use chat_handlers::{
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title, delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
reload_sidebar, show_conversation_editing_title, reload_sidebar, show_chat_base as show_base, show_conversation_editing_title,
show_chat_base as show_base, show_existing_chat as show_existing, show_existing_chat as show_existing,
}; };
use message_response_stream::get_response_stream; use message_response_stream::get_response_stream;
use references::show_reference_tooltip; use references::show_reference_tooltip;
+1 -3
View File
@@ -1,8 +1,6 @@
#![allow(clippy::missing_docs_in_private_items)] #![allow(clippy::missing_docs_in_private_items)]
use axum::{ use axum::extract::{Path, State};
extract::{Path, State},
};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use chrono_tz::Tz; use chrono_tz::Tz;
use serde::Serialize; use serde::Serialize;
+3 -2
View File
@@ -13,7 +13,7 @@ use crate::{
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{ response_middleware::{
template_as_response, TemplateResponse, TemplateResult, ResponseResult, template_as_response, ResponseResult, TemplateResponse, TemplateResult,
}, },
}, },
utils::text_content_preview::truncate_text_contents, utils::text_content_preview::truncate_text_contents,
@@ -84,7 +84,8 @@ pub async fn delete_text_content(
// Delete the text content and any related data // Delete the text content and any related data
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?; TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?; KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db).await?; KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db)
.await?;
state state
.db .db
.delete_item::<TextContent>(&text_content.id) .delete_item::<TextContent>(&text_content.id)
+2 -5
View File
@@ -63,9 +63,7 @@ pub async fn show_ingest_form(
)) ))
} }
pub async fn hide_ingest_form( pub async fn hide_ingest_form(RequireUser(_user): RequireUser) -> TemplateResult {
RequireUser(_user): RequireUser,
) -> TemplateResult {
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"ingestion/add_content_button.html", "ingestion/add_content_button.html",
(), (),
@@ -148,8 +146,7 @@ pub async fn process_ingest_form(
user.id.clone(), user.id.clone(),
)?; )?;
let tasks = let tasks = IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"dashboard/current_task.html", "dashboard/current_task.html",
+12 -20
View File
@@ -37,7 +37,7 @@ use crate::{
middlewares::{ middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{ response_middleware::{
template_with_headers, TemplateResponse, TemplateResult, ResponseResult, template_with_headers, ResponseResult, TemplateResponse, TemplateResult,
}, },
}, },
utils::pagination::{paginate_items, paginate_slice, Pagination}, utils::pagination::{paginate_items, paginate_slice, Pagination},
@@ -185,8 +185,7 @@ pub async fn create_knowledge_entity(
let description = form.description.trim().to_string(); let description = form.description.trim().to_string();
let entity_type = KnowledgeEntityType::from(form.entity_type.trim().to_string()); let entity_type = KnowledgeEntityType::from(form.entity_type.trim().to_string());
let embedding_input = let embedding_input = KnowledgeEntity::embedding_input_text(&name, &description, entity_type);
format!("name: {name}, description: {description}, type: {entity_type:?}");
let embedding = state let embedding = state
.embedding_provider .embedding_provider
.embed(&embedding_input) .embed(&embedding_input)
@@ -290,10 +289,12 @@ pub async fn suggest_knowledge_relationships(
if !query_parts.is_empty() { if !query_parts.is_empty() {
let name = form.name.as_deref().unwrap_or("").trim(); let name = form.name.as_deref().unwrap_or("").trim();
let description = form.description.as_deref().unwrap_or("").trim(); let description = form.description.as_deref().unwrap_or("").trim();
let entity_type = form.entity_type.as_deref().map_or( let entity_type = form
KnowledgeEntityType::Document, .entity_type
|value| KnowledgeEntityType::from(value.to_string()), .as_deref()
); .map_or(KnowledgeEntityType::Document, |value| {
KnowledgeEntityType::from(value.to_string())
});
let suggested = suggest_related_entities( let suggested = suggest_related_entities(
&state.db, &state.db,
@@ -374,10 +375,8 @@ async fn suggest_related_entities(
draft: DraftEntityQuery<'_>, draft: DraftEntityQuery<'_>,
entity_lookup: &HashMap<String, KnowledgeEntity>, entity_lookup: &HashMap<String, KnowledgeEntity>,
) -> Result<HashMap<String, f32>, AppError> { ) -> Result<HashMap<String, f32>, AppError> {
let embedding_input = format!( let embedding_input =
"name: {}, description: {}, type: {:?}", KnowledgeEntity::embedding_input_text(draft.name, draft.description, draft.entity_type);
draft.name, draft.description, draft.entity_type
);
let embedding = embedding_provider.embed(&embedding_input).await?; let embedding = embedding_provider.embed(&embedding_input).await?;
let take = MAX_RELATIONSHIP_SUGGESTIONS * 2; let take = MAX_RELATIONSHIP_SUGGESTIONS * 2;
@@ -484,11 +483,7 @@ fn build_relationship_options(
fn build_relationship_rows( fn build_relationship_rows(
relationships: Vec<KnowledgeRelationship>, relationships: Vec<KnowledgeRelationship>,
) -> ( ) -> (Vec<RelationshipTableRow>, Vec<String>, String) {
Vec<RelationshipTableRow>,
Vec<String>,
String,
) {
let relationship_type_options = collect_relationship_type_options(&relationships); let relationship_type_options = collect_relationship_type_options(&relationships);
let mut frequency: HashMap<String, usize> = HashMap::new(); let mut frequency: HashMap<String, usize> = HashMap::new();
let relationships = relationships let relationships = relationships
@@ -509,10 +504,7 @@ fn build_relationship_rows(
let default_relationship_type = frequency let default_relationship_type = frequency
.into_iter() .into_iter()
.max_by_key(|(_, count)| *count) .max_by_key(|(_, count)| *count)
.map_or_else( .map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label);
|| DEFAULT_RELATIONSHIP_TYPE.to_string(),
|(label, _)| label,
);
( (
relationships, relationships,
@@ -12,7 +12,7 @@ use crate::html_state::HtmlState;
use crate::middlewares::{ use crate::middlewares::{
auth_middleware::RequireUser, auth_middleware::RequireUser,
response_middleware::{ response_middleware::{
template_with_headers, TemplateResponse, TemplateResult, ResponseResult, template_with_headers, ResponseResult, TemplateResponse, TemplateResult,
}, },
}; };
use common::storage::types::{ use common::storage::types::{
+11 -14
View File
@@ -1,11 +1,11 @@
use std::collections::HashSet; use std::collections::HashSet;
use axum::{ use axum::extract::{Query, State};
extract::{Query, State},
};
use axum_htmx::{HxBoosted, HxRequest}; use axum_htmx::{HxBoosted, HxRequest};
use common::storage::types::{text_content::TextContent, user::User}; use common::storage::types::{text_content::TextContent, user::User};
use retrieval_pipeline::{retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity}; use retrieval_pipeline::{
retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity,
};
use serde::{de, Deserialize, Deserializer, Serialize}; use serde::{de, Deserialize, Deserializer, Serialize};
use std::{fmt, str::FromStr}; use std::{fmt, str::FromStr};
@@ -48,9 +48,7 @@ impl<'de> Deserialize<'de> for SearchView {
Some("chunks") => SearchView::Chunks, Some("chunks") => SearchView::Chunks,
Some("entities") => SearchView::Entities, Some("entities") => SearchView::Entities,
Some(other) => { Some(other) => {
return Err(de::Error::custom(format!( return Err(de::Error::custom(format!("invalid search view: {other}")));
"invalid search view: {other}"
)));
} }
}) })
} }
@@ -121,13 +119,12 @@ pub async fn search_result_handler(
HxBoosted(is_boosted): HxBoosted, HxBoosted(is_boosted): HxBoosted,
) -> TemplateResult { ) -> TemplateResult {
let view = params.view; let view = params.view;
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) = let (search_results_for_template, final_query_param_for_template) =
params.query if let Some(actual_query) = params.query {
{ perform_search(&state, &user, actual_query, view).await?
perform_search(&state, &user, actual_query, view).await? } else {
} else { (Vec::<SearchResultForTemplate>::new(), String::new())
(Vec::<SearchResultForTemplate>::new(), String::new()) };
};
let data = AnswerData { let data = AnswerData {
search_result: search_results_for_template, search_result: search_results_for_template,
+1 -3
View File
@@ -2,9 +2,7 @@ mod handlers;
use axum::{extract::FromRef, routing::get, Router}; use axum::{extract::FromRef, routing::get, Router};
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
pub use handlers::{ pub use handlers::{search_result_handler as result_handler, SearchParams as SearchQueryParams};
search_result_handler as result_handler, SearchParams as SearchQueryParams,
};
use crate::html_state::HtmlState; use crate::html_state::HtmlState;
+15 -3
View File
@@ -31,8 +31,16 @@ impl Pagination {
} else { } else {
0 0
}; };
let start_index = if page_len == 0 { 0 } else { offset.saturating_add(1) }; let start_index = if page_len == 0 {
let end_index = if page_len == 0 { 0 } else { offset.saturating_add(page_len) }; 0
} else {
offset.saturating_add(1)
};
let end_index = if page_len == 0 {
0
} else {
offset.saturating_add(page_len)
};
Self { Self {
current_page, current_page,
@@ -109,7 +117,11 @@ pub fn paginate_items<T>(
let total_pages = if total_items == 0 { let total_pages = if total_items == 0 {
0 0
} else { } else {
total_items.saturating_sub(1).checked_div(per_page).unwrap_or(0).saturating_add(1) total_items
.saturating_sub(1)
.checked_div(per_page)
.unwrap_or(0)
.saturating_add(1)
}; };
let mut current_page = requested_page.unwrap_or(1); let mut current_page = requested_page.unwrap_or(1);
+5 -16
View File
@@ -9,11 +9,7 @@ use axum::{
Router, Router,
}; };
use common::{ use common::{
storage::{ storage::{db::SurrealDbClient, store::StorageManager, types::user::User},
db::SurrealDbClient,
store::StorageManager,
types::user::User,
},
utils::{ utils::{
config::{AppConfig, StorageKind}, config::{AppConfig, StorageKind},
embedding::EmbeddingProvider, embedding::EmbeddingProvider,
@@ -37,24 +33,17 @@ async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
.await .await
.expect("migrations should apply"); .expect("migrations should apply");
let session_store = Arc::new( let session_store = Arc::new(db.create_session_store().await.expect("session store"));
db.create_session_store()
.await
.expect("session store"),
);
let config = AppConfig { let config = AppConfig {
storage: StorageKind::Memory, storage: StorageKind::Memory,
..Default::default() ..Default::default()
}; };
let storage = StorageManager::new(&config) let storage = StorageManager::new(&config).await.expect("storage manager");
.await
.expect("storage manager");
let embedding_provider = Arc::new( let embedding_provider =
EmbeddingProvider::new_hashed(8).expect("embedding provider"), Arc::new(EmbeddingProvider::new_hashed(8).expect("embedding provider"));
);
let state = HtmlState::new_with_resources(StateResources { let state = HtmlState::new_with_resources(StateResources {
db: Arc::clone(&db), db: Arc::clone(&db),
@@ -6,11 +6,9 @@ use serde::{Deserialize, Serialize};
use common::{ use common::{
error::AppError, error::AppError,
storage::{ storage::types::{
types::{ knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, knowledge_relationship::KnowledgeRelationship,
knowledge_relationship::KnowledgeRelationship,
},
}, },
utils::embedding::EmbeddingProvider, utils::embedding::EmbeddingProvider,
}; };
@@ -83,25 +81,32 @@ impl LLMEnrichmentResult {
entity_concurrency: usize, entity_concurrency: usize,
embedding_provider: &EmbeddingProvider, embedding_provider: &EmbeddingProvider,
) -> Result<Vec<EmbeddedKnowledgeEntity>, AppError> { ) -> Result<Vec<EmbeddedKnowledgeEntity>, AppError> {
stream::iter(self.knowledge_entities.clone().into_iter().map(|entity| { let tasks: Vec<_> = self
let mapper = Arc::clone(&mapper); .knowledge_entities
let source_id = source_id.to_string(); .iter()
let user_id = user_id.to_string(); .map(|entity| {
let llm_entity = entity.clone();
let mapper = Arc::clone(&mapper);
let source_id = source_id.to_string();
let user_id = user_id.to_string();
async move { async move {
create_single_entity( create_single_entity(
&entity, llm_entity,
&source_id, &source_id,
&user_id, &user_id,
mapper, mapper,
embedding_provider, embedding_provider,
) )
.await .await
} }
})) })
.buffer_unordered(entity_concurrency.max(1)) .collect();
.try_collect()
.await stream::iter(tasks)
.buffer_unordered(entity_concurrency.max(1))
.try_collect()
.await
} }
fn process_relationships( fn process_relationships(
@@ -129,7 +134,7 @@ impl LLMEnrichmentResult {
} }
async fn create_single_entity( async fn create_single_entity(
llm_entity: &LLMKnowledgeEntity, llm_entity: LLMKnowledgeEntity,
source_id: &str, source_id: &str,
user_id: &str, user_id: &str,
mapper: Arc<GraphMapper>, mapper: Arc<GraphMapper>,
@@ -137,9 +142,11 @@ async fn create_single_entity(
) -> Result<EmbeddedKnowledgeEntity, AppError> { ) -> Result<EmbeddedKnowledgeEntity, AppError> {
let assigned_id = mapper.get_id(&llm_entity.key)?.to_string(); let assigned_id = mapper.get_id(&llm_entity.key)?.to_string();
let embedding_input = format!( let entity_type = KnowledgeEntityType::from(llm_entity.entity_type);
"name: {}, description: {}, type: {}", let embedding_input = KnowledgeEntity::embedding_input_text(
llm_entity.name, llm_entity.description, llm_entity.entity_type &llm_entity.name,
&llm_entity.description,
entity_type,
); );
let embedding = embedding_provider.embed(&embedding_input).await?; let embedding = embedding_provider.embed(&embedding_input).await?;
@@ -149,9 +156,9 @@ async fn create_single_entity(
id: assigned_id, id: assigned_id,
created_at: now, created_at: now,
updated_at: now, updated_at: now,
name: llm_entity.name.clone(), name: llm_entity.name,
description: llm_entity.description.clone(), description: llm_entity.description,
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()), entity_type,
source_id: source_id.to_string(), source_id: source_id.to_string(),
metadata: None, metadata: None,
user_id: user_id.into(), user_id: user_id.into(),
@@ -48,20 +48,20 @@ const STORE_RELATIONSHIPS: &str = r"
pub(super) async fn store_vector_chunks( pub(super) async fn store_vector_chunks(
db: &SurrealDbClient, db: &SurrealDbClient,
task_id: &str, task_id: &str,
chunks: &[EmbeddedTextChunk], chunks: Vec<EmbeddedTextChunk>,
) -> Result<usize, AppError> { ) -> Result<usize, AppError> {
let chunk_count = chunks.len();
for embedded in chunks { for embedded in chunks {
TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db)
.await?;
debug!( debug!(
task_id = %task_id, task_id = %task_id,
chunk_id = %embedded.chunk.id, chunk_id = %embedded.chunk.id,
chunk_len = embedded.chunk.chunk.chars().count(), chunk_len = embedded.chunk.chunk.chars().count(),
"chunk persisted" "chunk persisted"
); );
TextChunk::store_with_embedding(embedded.chunk, embedded.embedding, db).await?;
} }
Ok(chunks.len()) Ok(chunk_count)
} }
/// Persists knowledge entities and their relationships. /// Persists knowledge entities and their relationships.
+1 -1
View File
@@ -155,7 +155,7 @@ pub async fn persist(
let entity_count = entities.len(); let entity_count = entities.len();
let relationship_count = relationships.len(); let relationship_count = relationships.len();
let chunk_count = store_vector_chunks(ctx.db, ctx.task_id.as_str(), &chunks).await?; let chunk_count = store_vector_chunks(ctx.db, ctx.task_id.as_str(), chunks).await?;
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?; store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?;
ctx.db.store_item(text_content).await?; ctx.db.store_item(text_content).await?;
rebuild(ctx.db).await?; rebuild(ctx.db).await?;
+1 -1
View File
@@ -92,7 +92,7 @@ impl MockServices {
entity: retrieved_entity, entity: retrieved_entity,
score: 0.8, score: 0.8,
chunks: std::sync::Arc::new(vec![RetrievedChunk { chunks: std::sync::Arc::new(vec![RetrievedChunk {
chunk: retrieved_chunk, chunk: std::sync::Arc::new(retrieved_chunk),
score: 0.7, score: 0.7,
}]), }]),
}], }],
@@ -74,10 +74,7 @@ pub async fn extract_text_from_file(
config: &AppConfig, config: &AppConfig,
storage: &StorageManager, storage: &StorageManager,
) -> Result<String, AppError> { ) -> Result<String, AppError> {
let file_bytes = storage let file_bytes = storage.get(&file_info.path).await.map_err(AppError::from)?;
.get(&file_info.path)
.await
.map_err(AppError::from)?;
let local_path = resolve_existing_local_path(storage, &file_info.path).await; let local_path = resolve_existing_local_path(storage, &file_info.path).await;
match file_info.mime_type.as_str() { match file_info.mime_type.as_str() {
+1 -4
View File
@@ -8,10 +8,7 @@ use std::sync::Arc;
use anyhow::Context; use anyhow::Context;
use async_openai::Client; use async_openai::Client;
use common::{ use common::{
storage::{ storage::{db::SurrealDbClient, store::StorageManager},
db::SurrealDbClient,
store::StorageManager,
},
utils::{ utils::{
config::{get_config, AppConfig}, config::{get_config, AppConfig},
embedding::{align_fastembed_system_settings, EmbeddingProvider}, embedding::{align_fastembed_system_settings, EmbeddingProvider},
+15 -13
View File
@@ -67,7 +67,8 @@ pub async fn prepare_embedding_runtime(
let index_dim = if mismatch { let index_dim = if mismatch {
match role { match role {
EmbeddingRuntimeRole::Maintainer => { EmbeddingRuntimeRole::Maintainer => {
reconcile_embeddings(&services.db, &services.embedding_provider, target_dim).await?; reconcile_embeddings(&services.db, &services.embedding_provider, target_dim)
.await?;
target_dim target_dim
} }
EmbeddingRuntimeRole::ReadOnly => { EmbeddingRuntimeRole::ReadOnly => {
@@ -238,9 +239,7 @@ mod tests {
stored_dim: usize, stored_dim: usize,
target_dim: usize, target_dim: usize,
) -> (super::SharedServices, std::path::PathBuf) { ) -> (super::SharedServices, std::path::PathBuf) {
let (mut services, data_dir) = init_smoke_services() let (mut services, data_dir) = init_smoke_services().await.expect("smoke services");
.await
.expect("smoke services");
ensure_runtime(&services.db, stored_dim) ensure_runtime(&services.db, stored_dim)
.await .await
@@ -254,9 +253,8 @@ mod tests {
.await .await
.expect("update settings"); .expect("update settings");
services.embedding_provider = Arc::new( services.embedding_provider =
EmbeddingProvider::new_hashed(target_dim).expect("hashed provider for test"), Arc::new(EmbeddingProvider::new_hashed(target_dim).expect("hashed provider for test"));
);
(services, data_dir) (services, data_dir)
} }
@@ -270,7 +268,9 @@ mod tests {
.expect("maintainer startup"); .expect("maintainer startup");
assert_eq!( assert_eq!(
embedding_index_dimension(&services.db).await.expect("index dim"), embedding_index_dimension(&services.db)
.await
.expect("index dim"),
Some(5), Some(5),
"maintainer should rebuild the index to the provider dimension" "maintainer should rebuild the index to the provider dimension"
); );
@@ -287,7 +287,9 @@ mod tests {
.expect("read-only startup"); .expect("read-only startup");
assert_eq!( assert_eq!(
embedding_index_dimension(&services.db).await.expect("index dim"), embedding_index_dimension(&services.db)
.await
.expect("index dim"),
Some(3), Some(3),
"read-only server must not overwrite the index before a maintainer re-embeds" "read-only server must not overwrite the index before a maintainer re-embeds"
); );
@@ -297,9 +299,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn maintainer_reembeds_chunks_when_index_dimension_differs() { async fn maintainer_reembeds_chunks_when_index_dimension_differs() {
let (mut services, data_dir) = init_smoke_services() let (mut services, data_dir) = init_smoke_services().await.expect("smoke services");
.await
.expect("smoke services");
let mut settings = SystemSettings::get_current(&services.db) let mut settings = SystemSettings::get_current(&services.db)
.await .await
@@ -339,7 +339,9 @@ mod tests {
.expect("maintainer startup with data"); .expect("maintainer startup with data");
assert_eq!( assert_eq!(
embedding_index_dimension(&services.db).await.expect("index dim"), embedding_index_dimension(&services.db)
.await
.expect("index dim"),
Some(5) Some(5)
); );
+5 -1
View File
@@ -95,7 +95,11 @@ mod tests {
use common::storage::types::{system_settings::SystemSettings, user::User}; use common::storage::types::{system_settings::SystemSettings, user::User};
use tower::ServiceExt; use tower::ServiceExt;
async fn build_test_app() -> (Router, Arc<common::storage::db::SurrealDbClient>, std::path::PathBuf) { async fn build_test_app() -> (
Router,
Arc<common::storage::db::SurrealDbClient>,
std::path::PathBuf,
) {
let (services, data_dir) = init_smoke_services() let (services, data_dir) = init_smoke_services()
.await .await
.expect("failed to init services"); .expect("failed to init services");
+2 -3
View File
@@ -68,9 +68,8 @@ mod tests {
let db = Arc::clone(&services.db); let db = Arc::clone(&services.db);
let pipeline = Arc::new(pipeline); let pipeline = Arc::new(pipeline);
let worker = tokio::spawn(async move { let worker =
ingestion_pipeline::run_worker_loop(db, pipeline).await tokio::spawn(async move { ingestion_pipeline::run_worker_loop(db, pipeline).await });
});
tokio::time::sleep(Duration::from_millis(250)).await; tokio::time::sleep(Duration::from_millis(250)).await;
assert!( assert!(
+6 -2
View File
@@ -7,6 +7,8 @@ pub mod scoring;
use std::sync::Arc; use std::sync::Arc;
pub use scoring::RetrievalCandidate;
use common::{ use common::{
error::AppError, error::AppError,
storage::{ storage::{
@@ -45,7 +47,7 @@ pub(crate) fn round_score(value: f32) -> f64 {
// Captures a supporting chunk plus its fused retrieval score for downstream prompts. // Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RetrievedChunk { pub struct RetrievedChunk {
pub chunk: TextChunk, pub chunk: Arc<TextChunk>,
pub score: f32, pub score: f32,
} }
@@ -159,7 +161,9 @@ mod tests {
assert!(!chunks.is_empty(), "Expected at least one retrieval result"); assert!(!chunks.is_empty(), "Expected at least one retrieval result");
assert!( assert!(
chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")), chunks
.first()
.is_some_and(|c| c.chunk.chunk.contains("Tokio")),
"Expected chunk about Tokio" "Expected chunk about Tokio"
); );
Ok(()) Ok(())
+2 -2
View File
@@ -11,7 +11,7 @@ use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
use super::{ use super::{
config::RetrievalConfig, config::RetrievalConfig,
diagnostics::{AssembleStats, Diagnostics, SearchStats}, diagnostics::{AssembleStats, Diagnostics, SearchStats},
StageKind, StageTimings, RetrievalParams, RetrievalParams, StageKind, StageTimings,
}; };
/// Mutable working state threaded through every retrieval stage. /// Mutable working state threaded through every retrieval stage.
@@ -22,7 +22,7 @@ pub(crate) struct PipelineContext<'a> {
pub user_id: String, pub user_id: String,
pub config: RetrievalConfig, pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>, pub query_embedding: Option<Vec<f32>>,
pub chunk_values: Vec<Scored<TextChunk>>, pub chunk_values: Vec<Scored<std::sync::Arc<TextChunk>>>,
pub reranker: Option<RerankerLease>, pub reranker: Option<RerankerLease>,
pub diagnostics: Option<Diagnostics>, pub diagnostics: Option<Diagnostics>,
pub entity_results: Vec<RetrievedEntity>, pub entity_results: Vec<RetrievedEntity>,
+57 -27
View File
@@ -131,14 +131,14 @@ pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
let vector_candidates = vector_rows.len(); let vector_candidates = vector_rows.len();
let fts_candidates = fts_rows.len(); let fts_candidates = fts_rows.len();
let vector_scored: Vec<Scored<TextChunk>> = vector_rows let vector_scored: Vec<Scored<Arc<TextChunk>>> = vector_rows
.into_iter() .into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score)) .map(|row| Scored::new(Arc::new(row.chunk)).with_vector_score(row.score))
.collect(); .collect();
let fts_scored: Vec<Scored<TextChunk>> = fts_rows let fts_scored: Vec<Scored<Arc<TextChunk>>> = fts_rows
.into_iter() .into_iter()
.map(|row| Scored::new(row.chunk).with_fts_score(row.score)) .map(|row| Scored::new(Arc::new(row.chunk)).with_fts_score(row.score))
.collect(); .collect();
let mut fts_weight = tuning.chunk_rrf_fts_weight; let mut fts_weight = tuning.chunk_rrf_fts_weight;
@@ -222,40 +222,63 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
/// and the contributing chunks are attached. /// and the contributing chunks are attached.
#[instrument(level = "trace", skip_all)] #[instrument(level = "trace", skip_all)]
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
if ctx.chunk_values.is_empty() { let chunk_values = std::mem::take(&mut ctx.chunk_values);
if chunk_values.is_empty() {
return Ok(()); return Ok(());
} }
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1); let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
struct IndexedChunk {
idx: usize,
score: f32,
}
let mut source_order: Vec<String> = Vec::new(); let mut source_order: Vec<String> = Vec::new();
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new(); let mut chunks_by_source: HashMap<String, Vec<IndexedChunk>> = HashMap::new();
let mut best_score: HashMap<String, f32> = HashMap::new(); let mut best_score: HashMap<String, f32> = HashMap::new();
for scored in &ctx.chunk_values { for (idx, scored) in chunk_values.iter().enumerate() {
let source_id = &scored.item.source_id; if let Some(attached) = chunks_by_source.get_mut(&scored.item.source_id) {
let is_new_source = !chunks_by_source.contains_key(source_id); if attached.len() < max_chunks {
if is_new_source { attached.push(IndexedChunk {
source_order.push(source_id.clone()); idx,
score: scored.fused,
});
}
} else {
let source_id = scored.item.source_id.clone();
best_score.insert(source_id.clone(), scored.fused); best_score.insert(source_id.clone(), scored.fused);
} source_order.push(source_id.clone());
chunks_by_source.insert(
let attached = chunks_by_source source_id,
.entry(source_id.clone()) vec![IndexedChunk {
.or_default(); idx,
if attached.len() < max_chunks { score: scored.fused,
attached.push(RetrievedChunk { }],
chunk: scored.item.clone(), );
score: scored.fused,
});
} }
} }
let chunks_by_source: HashMap<String, Arc<Vec<RetrievedChunk>>> = chunks_by_source let chunks_by_source: HashMap<String, Arc<Vec<RetrievedChunk>>> = chunks_by_source
.into_iter() .into_iter()
.map(|(source, chunks)| (source, Arc::new(chunks))) .map(|(source, indices)| {
let chunks = indices
.into_iter()
.filter_map(|indexed| {
let scored = chunk_values.get(indexed.idx)?;
Some(RetrievedChunk {
chunk: Arc::clone(&scored.item),
score: indexed.score,
})
})
.collect();
(source, Arc::new(chunks))
})
.collect(); .collect();
ctx.chunk_values = chunk_values;
let entities = let entities =
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?; KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
@@ -336,10 +359,17 @@ fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
where where
F: FnMut(&Scored<T>) -> f32, F: FnMut(&Scored<T>) -> f32,
{ {
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect() items
.iter()
.take(SCORE_SAMPLE_LIMIT)
.map(extractor)
.collect()
} }
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> { fn build_chunk_rerank_documents(
chunks: &[Scored<Arc<TextChunk>>],
max_chunks: usize,
) -> Vec<String> {
let take = chunks.len().min(max_chunks); let take = chunks.len().min(max_chunks);
let mut documents = Vec::with_capacity(take); let mut documents = Vec::with_capacity(take);
let mut buffer = String::with_capacity(512); let mut buffer = String::with_capacity(512);
@@ -363,7 +393,7 @@ fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize)
} }
fn apply_chunk_rerank_results( fn apply_chunk_rerank_results(
chunks: &mut Vec<Scored<TextChunk>>, chunks: &mut Vec<Scored<Arc<TextChunk>>>,
tuning: &RetrievalTuning, tuning: &RetrievalTuning,
results: Vec<RerankResult>, results: Vec<RerankResult>,
) { ) {
@@ -371,7 +401,7 @@ fn apply_chunk_rerank_results(
return; return;
} }
let mut remaining: Vec<Option<Scored<TextChunk>>> = let mut remaining: Vec<Option<Scored<Arc<TextChunk>>>> =
std::mem::take(chunks).into_iter().map(Some).collect(); std::mem::take(chunks).into_iter().map(Some).collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect(); let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
@@ -384,7 +414,7 @@ fn apply_chunk_rerank_results(
clamp_unit(tuning.rerank_blend_weight) clamp_unit(tuning.rerank_blend_weight)
}; };
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len()); let mut reranked: Vec<Scored<Arc<TextChunk>>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) { for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) { if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() { if let Some(mut candidate) = slot.take() {
+6 -11
View File
@@ -29,8 +29,7 @@ impl RerankerPool {
/// Build the pool at startup. /// Build the pool at startup.
/// `pool_size` controls max parallel reranks. /// `pool_size` controls max parallel reranks.
pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> { pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
let init_options = let init_options = RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
Self::new_with_options(pool_size, &init_options) Self::new_with_options(pool_size, &init_options)
} }
@@ -44,8 +43,7 @@ impl RerankerPool {
))); )));
} }
fs::create_dir_all(&init_options.cache_dir) fs::create_dir_all(&init_options.cache_dir).map_err(|e| Box::new(AppError::from(e)))?;
.map_err(|e| Box::new(AppError::from(e)))?;
let mut engines = Vec::with_capacity(pool_size); let mut engines = Vec::with_capacity(pool_size);
for x in 0..pool_size { for x in 0..pool_size {
@@ -77,10 +75,7 @@ impl RerankerPool {
/// This returns a lease that can perform `rerank()`. /// This returns a lease that can perform `rerank()`.
pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> { pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
// Acquire a permit. This enforces backpressure. // Acquire a permit. This enforces backpressure.
let permit = Arc::clone(&self.semaphore) let permit = Arc::clone(&self.semaphore).acquire_owned().await.ok()?;
.acquire_owned()
.await
.ok()?;
// Pick an engine. // Pick an engine.
// This is naive: just pick based on a simple modulo counter. // This is naive: just pick based on a simple modulo counter.
@@ -165,9 +160,9 @@ impl RerankerLease {
let engine = Arc::clone(&self.engine); let engine = Arc::clone(&self.engine);
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
let mut guard = engine.lock().map_err(|_| { let mut guard = engine
AppError::InternalError("reranker engine mutex poisoned".into()) .lock()
})?; .map_err(|_| AppError::InternalError("reranker engine mutex poisoned".into()))?;
guard guard
.rerank(query, documents, false, None) .rerank(query, documents, false, None)
.map_err(|e| AppError::InternalError(e.to_string())) .map_err(|e| AppError::InternalError(e.to_string()))
+34 -8
View File
@@ -1,9 +1,35 @@
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
collections::{hash_map::Entry, HashMap}, collections::{hash_map::Entry, HashMap},
sync::Arc,
}; };
use common::storage::types::StoredObject; use common::storage::types::{
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject,
};
/// Identifier access for retrieval fusion and sorting.
pub trait RetrievalCandidate {
fn candidate_id(&self) -> &str;
}
impl RetrievalCandidate for TextChunk {
fn candidate_id(&self) -> &str {
self.id()
}
}
impl RetrievalCandidate for Arc<TextChunk> {
fn candidate_id(&self) -> &str {
self.as_ref().id()
}
}
impl RetrievalCandidate for KnowledgeEntity {
fn candidate_id(&self) -> &str {
self.id()
}
}
/// Holds optional subscores gathered from the vector and full-text retrieval signals. /// Holds optional subscores gathered from the vector and full-text retrieval signals.
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
@@ -102,13 +128,13 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>]) pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where where
T: StoredObject, T: RetrievalCandidate,
{ {
items.sort_by(|a, b| { items.sort_by(|a, b| {
b.fused b.fused
.partial_cmp(&a.fused) .partial_cmp(&a.fused)
.unwrap_or(Ordering::Equal) .unwrap_or(Ordering::Equal)
.then_with(|| a.item.id().cmp(b.item.id())) .then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
}); });
} }
@@ -122,7 +148,7 @@ pub fn reciprocal_rank_fusion<T>(
config: RrfConfig, config: RrfConfig,
) -> Vec<Scored<T>> ) -> Vec<Scored<T>>
where where
T: StoredObject, T: RetrievalCandidate,
{ {
let mut merged: HashMap<String, Scored<T>> = HashMap::new(); let mut merged: HashMap<String, Scored<T>> = HashMap::new();
let k = if config.k <= 0.0 { 60.0 } else { config.k }; let k = if config.k <= 0.0 { 60.0 } else { config.k };
@@ -144,11 +170,11 @@ where
b_score b_score
.partial_cmp(&a_score) .partial_cmp(&a_score)
.unwrap_or(Ordering::Equal) .unwrap_or(Ordering::Equal)
.then_with(|| a.item.id().cmp(b.item.id())) .then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
}); });
for (rank, candidate) in vector_ranked.into_iter().enumerate() { for (rank, candidate) in vector_ranked.into_iter().enumerate() {
let id = candidate.item.id().to_owned(); let id = candidate.item.candidate_id().to_owned();
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
let contribution = vector_weight / (k + rank_f32 + 1.0); let contribution = vector_weight / (k + rank_f32 + 1.0);
@@ -183,11 +209,11 @@ where
b_score b_score
.partial_cmp(&a_score) .partial_cmp(&a_score)
.unwrap_or(Ordering::Equal) .unwrap_or(Ordering::Equal)
.then_with(|| a.item.id().cmp(b.item.id())) .then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
}); });
for (rank, candidate) in fts_ranked.into_iter().enumerate() { for (rank, candidate) in fts_ranked.into_iter().enumerate() {
let id = candidate.item.id().to_owned(); let id = candidate.item.candidate_id().to_owned();
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
let contribution = fts_weight / (k + rank_f32 + 1.0); let contribution = fts_weight / (k + rank_f32 + 1.0);