mirror of
https://github.com/perstarkse/minne.git
synced 2026-07-04 20:11:42 +02:00
fix: arc-share retrieved chunks, centralize entity embeddings, and trim hot-path clones.
This commit is contained in:
@@ -92,9 +92,10 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn extract_api_key_rejects_invalid_header_values() {
|
fn extract_api_key_rejects_invalid_header_values() {
|
||||||
let mut request = request_with_headers(&[]);
|
let mut request = request_with_headers(&[]);
|
||||||
request
|
request.headers_mut().insert(
|
||||||
.headers_mut()
|
"X-API-Key",
|
||||||
.insert("X-API-Key", HeaderValue::from_bytes(&[0xFF]).expect("invalid header"));
|
HeaderValue::from_bytes(&[0xFF]).expect("invalid header"),
|
||||||
|
);
|
||||||
assert_eq!(extract_api_key(&request), None);
|
assert_eq!(extract_api_key(&request), None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,11 +9,7 @@ use axum::{
|
|||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use common::{
|
use common::{
|
||||||
storage::{
|
storage::{db::SurrealDbClient, store::StorageManager, types::user::User},
|
||||||
db::SurrealDbClient,
|
|
||||||
store::StorageManager,
|
|
||||||
types::user::User,
|
|
||||||
},
|
|
||||||
utils::config::{AppConfig, StorageKind},
|
utils::config::{AppConfig, StorageKind},
|
||||||
};
|
};
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
@@ -34,9 +30,7 @@ async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
|
|||||||
storage: StorageKind::Memory,
|
storage: StorageKind::Memory,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let storage = StorageManager::new(&config)
|
let storage = StorageManager::new(&config).await.expect("storage manager");
|
||||||
.await
|
|
||||||
.expect("storage manager");
|
|
||||||
|
|
||||||
let state = ApiState {
|
let state = ApiState {
|
||||||
db: Arc::clone(&db),
|
db: Arc::clone(&db),
|
||||||
@@ -147,9 +141,7 @@ async fn authenticated_user_can_list_categories() {
|
|||||||
.await
|
.await
|
||||||
.expect("test user");
|
.expect("test user");
|
||||||
|
|
||||||
let api_key = User::set_api_key(&user.id, &db)
|
let api_key = User::set_api_key(&user.id, &db).await.expect("api key");
|
||||||
.await
|
|
||||||
.expect("api key");
|
|
||||||
|
|
||||||
let response = app
|
let response = app
|
||||||
.clone()
|
.clone()
|
||||||
|
|||||||
@@ -3,9 +3,12 @@ use std::collections::HashMap;
|
|||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::AppError, storage::db::SurrealDbClient, storage::indexes::hnsw_index_overwrite_sql,
|
error::AppError,
|
||||||
|
storage::db::SurrealDbClient,
|
||||||
|
storage::indexes::hnsw_index_overwrite_sql,
|
||||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||||
storage::types::system_settings::SystemSettings, stored_object,
|
storage::types::system_settings::SystemSettings,
|
||||||
|
stored_object,
|
||||||
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
|
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
|
||||||
};
|
};
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
@@ -25,6 +28,17 @@ impl KnowledgeEntityType {
|
|||||||
pub fn variants() -> &'static [&'static str] {
|
pub fn variants() -> &'static [&'static str] {
|
||||||
&["Idea", "Project", "Document", "Page", "TextSnippet"]
|
&["Idea", "Project", "Document", "Page", "TextSnippet"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub const fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Idea => "Idea",
|
||||||
|
Self::Project => "Project",
|
||||||
|
Self::Document => "Document",
|
||||||
|
Self::Page => "Page",
|
||||||
|
Self::TextSnippet => "TextSnippet",
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<String> for KnowledgeEntityType {
|
impl From<String> for KnowledgeEntityType {
|
||||||
@@ -80,6 +94,27 @@ impl KnowledgeEntity {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Canonical text fed to the embedding provider for a knowledge entity.
|
||||||
|
#[must_use]
|
||||||
|
pub fn embedding_input_text(
|
||||||
|
name: &str,
|
||||||
|
description: &str,
|
||||||
|
entity_type: KnowledgeEntityType,
|
||||||
|
) -> String {
|
||||||
|
let mut out = String::with_capacity(
|
||||||
|
name.len()
|
||||||
|
.saturating_add(description.len())
|
||||||
|
.saturating_add(entity_type.as_str().len())
|
||||||
|
.saturating_add(32),
|
||||||
|
);
|
||||||
|
let _ = write!(
|
||||||
|
out,
|
||||||
|
"name: {name}, description: {description}, type: {}",
|
||||||
|
entity_type.as_str()
|
||||||
|
);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
/// Full-text search over knowledge entities using the BM25 FTS index.
|
/// Full-text search over knowledge entities using the BM25 FTS index.
|
||||||
pub async fn fts_search(
|
pub async fn fts_search(
|
||||||
take: usize,
|
take: usize,
|
||||||
@@ -314,8 +349,7 @@ impl KnowledgeEntity {
|
|||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
embedding_provider: &EmbeddingProvider,
|
embedding_provider: &EmbeddingProvider,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let embedding_input =
|
let embedding_input = Self::embedding_input_text(name, description, *entity_type);
|
||||||
format!("name: {name}, description: {description}, type: {entity_type:?}",);
|
|
||||||
let embedding = embedding_provider.embed(&embedding_input).await?;
|
let embedding = embedding_provider.embed(&embedding_input).await?;
|
||||||
|
|
||||||
let entity: KnowledgeEntity = db_client
|
let entity: KnowledgeEntity = db_client
|
||||||
@@ -402,9 +436,10 @@ impl KnowledgeEntity {
|
|||||||
let inputs: Vec<String> = batch
|
let inputs: Vec<String> = batch
|
||||||
.iter()
|
.iter()
|
||||||
.map(|entity| {
|
.map(|entity| {
|
||||||
format!(
|
Self::embedding_input_text(
|
||||||
"name: {}, description: {}, type: {:?}",
|
&entity.name,
|
||||||
entity.name, entity.description, entity.entity_type
|
&entity.description,
|
||||||
|
entity.entity_type,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@@ -523,6 +558,16 @@ mod tests {
|
|||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embedding_input_text_uses_canonical_type_label() {
|
||||||
|
let text = KnowledgeEntity::embedding_input_text(
|
||||||
|
"Alpha",
|
||||||
|
"Beta",
|
||||||
|
KnowledgeEntityType::TextSnippet,
|
||||||
|
);
|
||||||
|
assert_eq!(text, "name: Alpha, description: Beta, type: TextSnippet");
|
||||||
|
}
|
||||||
|
|
||||||
async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> {
|
async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||||
let snowball_sql = r#"
|
let snowball_sql = r#"
|
||||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||||
|
|||||||
@@ -122,8 +122,7 @@ impl KnowledgeRelationship {
|
|||||||
.bind(("user_id", user_id.to_owned()))
|
.bind(("user_id", user_id.to_owned()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?;
|
.map_err(AppError::from)?;
|
||||||
let deleted: Vec<KnowledgeRelationship> =
|
let deleted: Vec<KnowledgeRelationship> = delete_result.take(0).map_err(AppError::from)?;
|
||||||
delete_result.take(0).map_err(AppError::from)?;
|
|
||||||
|
|
||||||
if !deleted.is_empty() {
|
if !deleted.is_empty() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
@@ -567,8 +566,8 @@ mod tests {
|
|||||||
shared_source.to_string(),
|
shared_source.to_string(),
|
||||||
"references".to_string(),
|
"references".to_string(),
|
||||||
);
|
);
|
||||||
let rel_a_id = rel_a.id.clone();
|
let owner_relationship_id = rel_a.id.clone();
|
||||||
let rel_b_id = rel_b.id.clone();
|
let other_relationship_id = rel_b.id.clone();
|
||||||
|
|
||||||
rel_a.store_relationship(&db).await?;
|
rel_a.store_relationship(&db).await?;
|
||||||
rel_b.store_relationship(&db).await?;
|
rel_b.store_relationship(&db).await?;
|
||||||
@@ -576,8 +575,12 @@ mod tests {
|
|||||||
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
assert!(get_relationship_by_id(&rel_a_id, &db).await.is_none());
|
assert!(get_relationship_by_id(&owner_relationship_id, &db)
|
||||||
assert!(get_relationship_by_id(&rel_b_id, &db).await.is_some());
|
.await
|
||||||
|
.is_none());
|
||||||
|
assert!(get_relationship_by_id(&other_relationship_id, &db)
|
||||||
|
.await
|
||||||
|
.is_some());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -299,7 +299,11 @@ impl TextChunk {
|
|||||||
}
|
}
|
||||||
|
|
||||||
processed = processed.saturating_add(batch.len());
|
processed = processed.saturating_add(batch.len());
|
||||||
info!(progress = processed, total = total_chunks, "Re-embedding progress");
|
info!(
|
||||||
|
progress = processed,
|
||||||
|
total = total_chunks,
|
||||||
|
"Re-embedding progress"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
info!("Successfully generated all new embeddings.");
|
info!("Successfully generated all new embeddings.");
|
||||||
|
|
||||||
|
|||||||
@@ -140,8 +140,7 @@ impl TextContent {
|
|||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?;
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
let existing: Option<surrealdb::sql::Thing> =
|
let existing: Option<surrealdb::sql::Thing> = response.take(0).map_err(AppError::from)?;
|
||||||
response.take(0).map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(existing.is_some())
|
Ok(existing.is_some())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ enum EmbeddingInner {
|
|||||||
/// Client used to issue embedding requests.
|
/// Client used to issue embedding requests.
|
||||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||||
/// Model identifier for the API.
|
/// Model identifier for the API.
|
||||||
model: String,
|
model: Arc<str>,
|
||||||
/// Expected output dimensions.
|
/// Expected output dimensions.
|
||||||
dimensions: u32,
|
dimensions: u32,
|
||||||
},
|
},
|
||||||
@@ -272,8 +272,9 @@ struct FastEmbedLease {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FastEmbedLease {
|
impl FastEmbedLease {
|
||||||
async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
let engine = Arc::clone(&self.engine);
|
let engine = Arc::clone(&self.engine);
|
||||||
|
let texts = texts.to_vec();
|
||||||
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?;
|
let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?;
|
||||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||||
@@ -293,7 +294,7 @@ impl Drop for FastEmbedLease {
|
|||||||
|
|
||||||
async fn run_fastembed(
|
async fn run_fastembed(
|
||||||
pool: &Arc<FastEmbedPool>,
|
pool: &Arc<FastEmbedPool>,
|
||||||
texts: Vec<String>,
|
texts: &[String],
|
||||||
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
let lease = pool.checkout().await?;
|
let lease = pool.checkout().await?;
|
||||||
lease.embed(texts).await
|
lease.embed(texts).await
|
||||||
@@ -323,7 +324,7 @@ impl EmbeddingProvider {
|
|||||||
pub fn model_code(&self) -> Option<String> {
|
pub fn model_code(&self) -> Option<String> {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
|
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
|
||||||
EmbeddingInner::OpenAI { model, .. } => Some(model.clone()),
|
EmbeddingInner::OpenAI { model, .. } => Some(model.as_ref().to_owned()),
|
||||||
EmbeddingInner::Hashed { .. } => None,
|
EmbeddingInner::Hashed { .. } => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -338,7 +339,8 @@ impl EmbeddingProvider {
|
|||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||||
EmbeddingInner::FastEmbed { pool, .. } => {
|
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||||
let embeddings = run_fastembed(pool, vec![text.to_owned()]).await?;
|
let text = text.to_owned();
|
||||||
|
let embeddings = run_fastembed(pool, std::slice::from_ref(&text)).await?;
|
||||||
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
||||||
}
|
}
|
||||||
EmbeddingInner::OpenAI {
|
EmbeddingInner::OpenAI {
|
||||||
@@ -347,7 +349,7 @@ impl EmbeddingProvider {
|
|||||||
dimensions,
|
dimensions,
|
||||||
} => {
|
} => {
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
let request = CreateEmbeddingRequestArgs::default()
|
||||||
.model(model.clone())
|
.model(model.as_ref())
|
||||||
.input([text])
|
.input([text])
|
||||||
.dimensions(*dimensions)
|
.dimensions(*dimensions)
|
||||||
.build()?;
|
.build()?;
|
||||||
@@ -382,7 +384,7 @@ impl EmbeddingProvider {
|
|||||||
if texts.is_empty() {
|
if texts.is_empty() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
run_fastembed(pool, texts.to_vec()).await
|
run_fastembed(pool, texts).await
|
||||||
}
|
}
|
||||||
EmbeddingInner::OpenAI {
|
EmbeddingInner::OpenAI {
|
||||||
client,
|
client,
|
||||||
@@ -394,7 +396,7 @@ impl EmbeddingProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
let request = CreateEmbeddingRequestArgs::default()
|
||||||
.model(model.clone())
|
.model(model.as_ref())
|
||||||
.input(texts.to_vec())
|
.input(texts.to_vec())
|
||||||
.dimensions(*dimensions)
|
.dimensions(*dimensions)
|
||||||
.build()?;
|
.build()?;
|
||||||
@@ -417,13 +419,13 @@ impl EmbeddingProvider {
|
|||||||
/// Currently infallible; reserved for future validation.
|
/// Currently infallible; reserved for future validation.
|
||||||
pub fn new_openai(
|
pub fn new_openai(
|
||||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||||
model: String,
|
model: impl AsRef<str>,
|
||||||
dimensions: u32,
|
dimensions: u32,
|
||||||
) -> Result<Self, EmbeddingError> {
|
) -> Result<Self, EmbeddingError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: EmbeddingInner::OpenAI {
|
inner: EmbeddingInner::OpenAI {
|
||||||
client,
|
client,
|
||||||
model,
|
model: Arc::from(model.as_ref()),
|
||||||
dimensions,
|
dimensions,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -520,7 +522,7 @@ impl EmbeddingProvider {
|
|||||||
"openai embedding backend requires an openai client".into(),
|
"openai embedding backend requires an openai client".into(),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
Self::new_openai(client, settings.embedding_model.as_str(), dimensions)
|
||||||
}
|
}
|
||||||
EmbeddingBackend::FastEmbed => {
|
EmbeddingBackend::FastEmbed => {
|
||||||
let pool_size = config
|
let pool_size = config
|
||||||
@@ -586,11 +588,12 @@ mod tests {
|
|||||||
#![allow(clippy::expect_used)]
|
#![allow(clippy::expect_used)]
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
align_fastembed_system_settings, fastembed_model_dimension, list_fastembed_embedding_models,
|
align_fastembed_system_settings, fastembed_model_dimension,
|
||||||
resolve_fastembed_model_code, DEFAULT_FASTEMBED_MODEL_CODE, EmbeddingError,
|
list_fastembed_embedding_models, resolve_fastembed_model_code, EmbeddingError,
|
||||||
|
DEFAULT_FASTEMBED_MODEL_CODE,
|
||||||
};
|
};
|
||||||
use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError};
|
|
||||||
use crate::storage::types::system_settings::SystemSettings;
|
use crate::storage::types::system_settings::SystemSettings;
|
||||||
|
use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -656,16 +659,16 @@ mod tests {
|
|||||||
fastembed_model: Some("Xenova/bge-base-en-v1.5".into()),
|
fastembed_model: Some("Xenova/bge-base-en-v1.5".into()),
|
||||||
..AppConfig::default()
|
..AppConfig::default()
|
||||||
};
|
};
|
||||||
let resolved = resolve_fastembed_model_code(&config, "text-embedding-3-small")
|
let resolved =
|
||||||
.expect("config model");
|
resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("config model");
|
||||||
assert_eq!(resolved, "Xenova/bge-base-en-v1.5");
|
assert_eq!(resolved, "Xenova/bge-base-en-v1.5");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_fastembed_model_falls_back_from_openai_default() {
|
fn resolve_fastembed_model_falls_back_from_openai_default() {
|
||||||
let config = AppConfig::default();
|
let config = AppConfig::default();
|
||||||
let resolved = resolve_fastembed_model_code(&config, "text-embedding-3-small")
|
let resolved =
|
||||||
.expect("default model");
|
resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("default model");
|
||||||
assert_eq!(resolved, DEFAULT_FASTEMBED_MODEL_CODE);
|
assert_eq!(resolved, DEFAULT_FASTEMBED_MODEL_CODE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,14 +42,12 @@ pub(crate) async fn prepare_db(
|
|||||||
|
|
||||||
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
|
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
|
||||||
let embedding_provider = match config.embedding_backend {
|
let embedding_provider = match config.embedding_backend {
|
||||||
crate::args::EmbeddingBackend::FastEmbed => {
|
crate::args::EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed(
|
||||||
EmbeddingProvider::new_fastembed(
|
config.embedding_model.clone(),
|
||||||
config.embedding_model.clone(),
|
default_embedding_pool_size(),
|
||||||
default_embedding_pool_size(),
|
)
|
||||||
)
|
.await
|
||||||
.await
|
.context("creating FastEmbed provider")?,
|
||||||
.context("creating FastEmbed provider")?
|
|
||||||
}
|
|
||||||
crate::args::EmbeddingBackend::Hashed => {
|
crate::args::EmbeddingBackend::Hashed => {
|
||||||
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
|
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ pub struct EvaluationCandidate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EvaluationCandidate {
|
impl EvaluationCandidate {
|
||||||
fn from_entity(entity: RetrievedEntity) -> Self {
|
fn from_entity(entity: &RetrievedEntity) -> Self {
|
||||||
let entity_category = Some(format!("{:?}", entity.entity.entity_type));
|
let entity_category = Some(format!("{:?}", entity.entity.entity_type));
|
||||||
Self {
|
Self {
|
||||||
entity_id: entity.entity.id().to_string(),
|
entity_id: entity.entity.id().to_string(),
|
||||||
@@ -223,9 +223,9 @@ impl EvaluationCandidate {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn candidates_from_entities(entities: Vec<RetrievedEntity>) -> Vec<EvaluationCandidate> {
|
fn candidates_from_entities(entities: &[RetrievedEntity]) -> Vec<EvaluationCandidate> {
|
||||||
entities
|
entities
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(EvaluationCandidate::from_entity)
|
.map(EvaluationCandidate::from_entity)
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
@@ -241,7 +241,7 @@ pub fn adapt_retrieval_output(output: RetrievalOutput) -> Vec<EvaluationCandidat
|
|||||||
match output {
|
match output {
|
||||||
RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks),
|
RetrievalOutput::Chunks(chunks) => candidates_from_chunks(chunks),
|
||||||
RetrievalOutput::WithEntities { chunks, entities } => {
|
RetrievalOutput::WithEntities { chunks, entities } => {
|
||||||
let mut candidates = candidates_from_entities(entities);
|
let mut candidates = candidates_from_entities(&entities);
|
||||||
candidates.extend(candidates_from_chunks(chunks));
|
candidates.extend(candidates_from_chunks(chunks));
|
||||||
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
|
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||||
candidates
|
candidates
|
||||||
|
|||||||
@@ -142,7 +142,9 @@ impl HtmlState {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let overflow = cache.len().saturating_sub(CONVERSATION_ARCHIVE_CACHE_MAX_USERS);
|
let overflow = cache
|
||||||
|
.len()
|
||||||
|
.saturating_sub(CONVERSATION_ARCHIVE_CACHE_MAX_USERS);
|
||||||
let mut by_expiry: Vec<(String, Instant)> = cache
|
let mut by_expiry: Vec<(String, Instant)> = cache
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(user_id, entry)| (user_id.clone(), entry.expires_at))
|
.map(|(user_id, entry)| (user_id.clone(), entry.expires_at))
|
||||||
|
|||||||
@@ -183,9 +183,7 @@ fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn context_to_map(
|
fn context_to_map(value: &Value) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
|
||||||
value: &Value,
|
|
||||||
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
|
|
||||||
match value.kind() {
|
match value.kind() {
|
||||||
minijinja::value::ValueKind::Map => {
|
minijinja::value::ValueKind::Map => {
|
||||||
let mut map = HashMap::new();
|
let mut map = HashMap::new();
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ use surrealdb::{engine::any::Any, Surreal};
|
|||||||
use crate::{
|
use crate::{
|
||||||
html_state::HtmlState,
|
html_state::HtmlState,
|
||||||
middlewares::{
|
middlewares::{
|
||||||
analytics_middleware::analytics_middleware, auth_middleware::require_auth,
|
analytics_middleware::analytics_middleware, auth_middleware::require_auth, compression,
|
||||||
compression, response_middleware::with_template_response,
|
response_middleware::with_template_response,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -164,9 +164,7 @@ pub async fn update_theme(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn show_change_password(
|
pub async fn show_change_password(RequireUser(_user): RequireUser) -> TemplateResult {
|
||||||
RequireUser(_user): RequireUser,
|
|
||||||
) -> TemplateResult {
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
"auth/change_password_form.html",
|
"auth/change_password_form.html",
|
||||||
(),
|
(),
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ use common::{
|
|||||||
utils::{
|
utils::{
|
||||||
config::AppConfig,
|
config::AppConfig,
|
||||||
embedding::{
|
embedding::{
|
||||||
fastembed_model_dimension, is_valid_fastembed_model_code, list_fastembed_embedding_models,
|
fastembed_model_dimension, is_valid_fastembed_model_code,
|
||||||
EmbeddingBackend, FastEmbedModelOption,
|
list_fastembed_embedding_models, EmbeddingBackend, FastEmbedModelOption,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@@ -52,7 +52,6 @@ pub enum AdminSection {
|
|||||||
Models,
|
Models,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct AdminPanelQuery {
|
pub struct AdminPanelQuery {
|
||||||
section: Option<String>,
|
section: Option<String>,
|
||||||
@@ -101,8 +100,9 @@ pub async fn show_admin_panel(
|
|||||||
(None, None, false)
|
(None, None, false)
|
||||||
};
|
};
|
||||||
|
|
||||||
let effective_backend =
|
let effective_backend = effective_embedding_backend(&settings, &state.config)
|
||||||
effective_embedding_backend(&settings, &state.config).as_str().to_string();
|
.as_str()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
"admin/base.html",
|
"admin/base.html",
|
||||||
@@ -187,7 +187,9 @@ struct EmbeddingSettingsPlan {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn effective_embedding_backend(settings: &SystemSettings, config: &AppConfig) -> EmbeddingBackend {
|
fn effective_embedding_backend(settings: &SystemSettings, config: &AppConfig) -> EmbeddingBackend {
|
||||||
settings.embedding_backend.unwrap_or(config.embedding_backend)
|
settings
|
||||||
|
.embedding_backend
|
||||||
|
.unwrap_or(config.embedding_backend)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_fastembed_admin_context(settings: &SystemSettings, config: &AppConfig) -> bool {
|
fn is_fastembed_admin_context(settings: &SystemSettings, config: &AppConfig) -> bool {
|
||||||
@@ -241,11 +243,10 @@ fn plan_embedding_settings_update(
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let embedding_dimensions = fastembed_model_dimension(&embedding_model)
|
let embedding_dimensions =
|
||||||
.map_err(AppError::from)?;
|
fastembed_model_dimension(&embedding_model).map_err(AppError::from)?;
|
||||||
let reembedding_needed = embedding_dimensions != current.embedding_dimensions;
|
let reembedding_needed = embedding_dimensions != current.embedding_dimensions;
|
||||||
let restart_needed =
|
let restart_needed = embedding_model != current.embedding_model || reembedding_needed;
|
||||||
embedding_model != current.embedding_model || reembedding_needed;
|
|
||||||
|
|
||||||
Ok(EmbeddingSettingsPlan {
|
Ok(EmbeddingSettingsPlan {
|
||||||
embedding_model,
|
embedding_model,
|
||||||
@@ -274,8 +275,7 @@ pub async fn update_model_settings(
|
|||||||
Form(input): Form<ModelSettingsInput>,
|
Form(input): Form<ModelSettingsInput>,
|
||||||
) -> TemplateResult {
|
) -> TemplateResult {
|
||||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||||
let embedding_plan =
|
let embedding_plan = plan_embedding_settings_update(¤t_settings, &input, &state.config)?;
|
||||||
plan_embedding_settings_update(¤t_settings, &input, &state.config)?;
|
|
||||||
|
|
||||||
let new_settings = SystemSettingsPatch {
|
let new_settings = SystemSettingsPatch {
|
||||||
query_model: Some(input.query_model),
|
query_model: Some(input.query_model),
|
||||||
@@ -309,10 +309,11 @@ pub async fn update_model_settings(
|
|||||||
.await
|
.await
|
||||||
.map_err(|_e| AppError::InternalError("Failed to get models".to_string()))?;
|
.map_err(|_e| AppError::InternalError("Failed to get models".to_string()))?;
|
||||||
|
|
||||||
let effective_backend =
|
let effective_backend = effective_embedding_backend(&new_settings, &state.config)
|
||||||
effective_embedding_backend(&new_settings, &state.config).as_str().to_string();
|
.as_str()
|
||||||
let show_fastembed_models =
|
.to_string();
|
||||||
is_fastembed_admin_context(&new_settings, &state.config).then(list_fastembed_embedding_models);
|
let show_fastembed_models = is_fastembed_admin_context(&new_settings, &state.config)
|
||||||
|
.then(list_fastembed_embedding_models);
|
||||||
|
|
||||||
Ok(TemplateResponse::new_partial(
|
Ok(TemplateResponse::new_partial(
|
||||||
"admin/sections/models.html",
|
"admin/sections/models.html",
|
||||||
@@ -368,8 +369,8 @@ mod tests {
|
|||||||
embedding_model: Some("Xenova/bge-base-en-v1.5".into()),
|
embedding_model: Some("Xenova/bge-base-en-v1.5".into()),
|
||||||
embedding_dimensions: None,
|
embedding_dimensions: None,
|
||||||
};
|
};
|
||||||
let plan = plan_embedding_settings_update(¤t, &input, &AppConfig::default())
|
let plan =
|
||||||
.expect("plan");
|
plan_embedding_settings_update(¤t, &input, &AppConfig::default()).expect("plan");
|
||||||
assert_eq!(plan.embedding_model, "Xenova/bge-base-en-v1.5");
|
assert_eq!(plan.embedding_model, "Xenova/bge-base-en-v1.5");
|
||||||
assert_eq!(plan.embedding_dimensions, 768);
|
assert_eq!(plan.embedding_dimensions, 768);
|
||||||
assert!(plan.reembedding_needed);
|
assert!(plan.reembedding_needed);
|
||||||
@@ -407,9 +408,7 @@ pub struct SystemPromptEditData {
|
|||||||
default_query_prompt: String,
|
default_query_prompt: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn show_edit_system_prompt(
|
pub async fn show_edit_system_prompt(State(state): State<HtmlState>) -> TemplateResult {
|
||||||
State(state): State<HtmlState>,
|
|
||||||
) -> TemplateResult {
|
|
||||||
let settings = SystemSettings::get_current(&state.db).await?;
|
let settings = SystemSettings::get_current(&state.db).await?;
|
||||||
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
@@ -457,9 +456,7 @@ pub struct IngestionPromptEditData {
|
|||||||
default_ingestion_prompt: String,
|
default_ingestion_prompt: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn show_edit_ingestion_prompt(
|
pub async fn show_edit_ingestion_prompt(State(state): State<HtmlState>) -> TemplateResult {
|
||||||
State(state): State<HtmlState>,
|
|
||||||
) -> TemplateResult {
|
|
||||||
let settings = SystemSettings::get_current(&state.db).await?;
|
let settings = SystemSettings::get_current(&state.db).await?;
|
||||||
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
@@ -502,9 +499,7 @@ pub struct ImagePromptEditData {
|
|||||||
default_image_prompt: String,
|
default_image_prompt: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn show_edit_image_prompt(
|
pub async fn show_edit_image_prompt(State(state): State<HtmlState>) -> TemplateResult {
|
||||||
State(state): State<HtmlState>,
|
|
||||||
) -> TemplateResult {
|
|
||||||
let settings = SystemSettings::get_current(&state.db).await?;
|
let settings = SystemSettings::get_current(&state.db).await?;
|
||||||
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ use axum::{extract::State, Form};
|
|||||||
use axum_htmx::HxBoosted;
|
use axum_htmx::HxBoosted;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use common::{error::AppError, storage::types::user::{Theme, User}};
|
use common::{
|
||||||
|
error::AppError,
|
||||||
|
storage::types::user::{Theme, User},
|
||||||
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
html_state::HtmlState,
|
html_state::HtmlState,
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ use crate::{
|
|||||||
middlewares::{
|
middlewares::{
|
||||||
auth_middleware::RequireUser,
|
auth_middleware::RequireUser,
|
||||||
response_middleware::{
|
response_middleware::{
|
||||||
template_as_response, template_with_headers, TemplateResponse, TemplateResult,
|
template_as_response, template_with_headers, ResponseResult, TemplateResponse,
|
||||||
ResponseResult,
|
TemplateResult,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ use retrieval_pipeline::answer_retrieval::{
|
|||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use tokio::sync::mpsc::channel;
|
use tokio::sync::mpsc::channel;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
use common::storage::{
|
use common::storage::{
|
||||||
@@ -36,10 +36,7 @@ use common::storage::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{html_state::HtmlState, middlewares::auth_middleware::RequireUser};
|
||||||
html_state::HtmlState,
|
|
||||||
middlewares::auth_middleware::RequireUser,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
|
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,11 @@ mod message_response_stream;
|
|||||||
mod reference_validation;
|
mod reference_validation;
|
||||||
mod references;
|
mod references;
|
||||||
|
|
||||||
use axum::{
|
use axum::{extract::FromRef, routing::get, Router};
|
||||||
extract::FromRef,
|
|
||||||
routing::get,
|
|
||||||
Router,
|
|
||||||
};
|
|
||||||
pub use chat_handlers::{
|
pub use chat_handlers::{
|
||||||
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
|
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
|
||||||
reload_sidebar, show_conversation_editing_title,
|
reload_sidebar, show_chat_base as show_base, show_conversation_editing_title,
|
||||||
show_chat_base as show_base, show_existing_chat as show_existing,
|
show_existing_chat as show_existing,
|
||||||
};
|
};
|
||||||
use message_response_stream::get_response_stream;
|
use message_response_stream::get_response_stream;
|
||||||
use references::show_reference_tooltip;
|
use references::show_reference_tooltip;
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
#![allow(clippy::missing_docs_in_private_items)]
|
#![allow(clippy::missing_docs_in_private_items)]
|
||||||
|
|
||||||
use axum::{
|
use axum::extract::{Path, State};
|
||||||
extract::{Path, State},
|
|
||||||
};
|
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use chrono_tz::Tz;
|
use chrono_tz::Tz;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
middlewares::{
|
middlewares::{
|
||||||
auth_middleware::RequireUser,
|
auth_middleware::RequireUser,
|
||||||
response_middleware::{
|
response_middleware::{
|
||||||
template_as_response, TemplateResponse, TemplateResult, ResponseResult,
|
template_as_response, ResponseResult, TemplateResponse, TemplateResult,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
utils::text_content_preview::truncate_text_contents,
|
utils::text_content_preview::truncate_text_contents,
|
||||||
@@ -84,7 +84,8 @@ pub async fn delete_text_content(
|
|||||||
// Delete the text content and any related data
|
// Delete the text content and any related data
|
||||||
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
|
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
|
||||||
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
|
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
|
||||||
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db).await?;
|
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db)
|
||||||
|
.await?;
|
||||||
state
|
state
|
||||||
.db
|
.db
|
||||||
.delete_item::<TextContent>(&text_content.id)
|
.delete_item::<TextContent>(&text_content.id)
|
||||||
|
|||||||
@@ -63,9 +63,7 @@ pub async fn show_ingest_form(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn hide_ingest_form(
|
pub async fn hide_ingest_form(RequireUser(_user): RequireUser) -> TemplateResult {
|
||||||
RequireUser(_user): RequireUser,
|
|
||||||
) -> TemplateResult {
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
"ingestion/add_content_button.html",
|
"ingestion/add_content_button.html",
|
||||||
(),
|
(),
|
||||||
@@ -148,8 +146,7 @@ pub async fn process_ingest_form(
|
|||||||
user.id.clone(),
|
user.id.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let tasks =
|
let tasks = IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
|
||||||
IngestionTask::create_all_and_add_to_db(payloads, &user.id, &state.db).await?;
|
|
||||||
|
|
||||||
Ok(TemplateResponse::new_template(
|
Ok(TemplateResponse::new_template(
|
||||||
"dashboard/current_task.html",
|
"dashboard/current_task.html",
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ use crate::{
|
|||||||
middlewares::{
|
middlewares::{
|
||||||
auth_middleware::RequireUser,
|
auth_middleware::RequireUser,
|
||||||
response_middleware::{
|
response_middleware::{
|
||||||
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
|
template_with_headers, ResponseResult, TemplateResponse, TemplateResult,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
utils::pagination::{paginate_items, paginate_slice, Pagination},
|
utils::pagination::{paginate_items, paginate_slice, Pagination},
|
||||||
@@ -185,8 +185,7 @@ pub async fn create_knowledge_entity(
|
|||||||
let description = form.description.trim().to_string();
|
let description = form.description.trim().to_string();
|
||||||
let entity_type = KnowledgeEntityType::from(form.entity_type.trim().to_string());
|
let entity_type = KnowledgeEntityType::from(form.entity_type.trim().to_string());
|
||||||
|
|
||||||
let embedding_input =
|
let embedding_input = KnowledgeEntity::embedding_input_text(&name, &description, entity_type);
|
||||||
format!("name: {name}, description: {description}, type: {entity_type:?}");
|
|
||||||
let embedding = state
|
let embedding = state
|
||||||
.embedding_provider
|
.embedding_provider
|
||||||
.embed(&embedding_input)
|
.embed(&embedding_input)
|
||||||
@@ -290,10 +289,12 @@ pub async fn suggest_knowledge_relationships(
|
|||||||
if !query_parts.is_empty() {
|
if !query_parts.is_empty() {
|
||||||
let name = form.name.as_deref().unwrap_or("").trim();
|
let name = form.name.as_deref().unwrap_or("").trim();
|
||||||
let description = form.description.as_deref().unwrap_or("").trim();
|
let description = form.description.as_deref().unwrap_or("").trim();
|
||||||
let entity_type = form.entity_type.as_deref().map_or(
|
let entity_type = form
|
||||||
KnowledgeEntityType::Document,
|
.entity_type
|
||||||
|value| KnowledgeEntityType::from(value.to_string()),
|
.as_deref()
|
||||||
);
|
.map_or(KnowledgeEntityType::Document, |value| {
|
||||||
|
KnowledgeEntityType::from(value.to_string())
|
||||||
|
});
|
||||||
|
|
||||||
let suggested = suggest_related_entities(
|
let suggested = suggest_related_entities(
|
||||||
&state.db,
|
&state.db,
|
||||||
@@ -374,10 +375,8 @@ async fn suggest_related_entities(
|
|||||||
draft: DraftEntityQuery<'_>,
|
draft: DraftEntityQuery<'_>,
|
||||||
entity_lookup: &HashMap<String, KnowledgeEntity>,
|
entity_lookup: &HashMap<String, KnowledgeEntity>,
|
||||||
) -> Result<HashMap<String, f32>, AppError> {
|
) -> Result<HashMap<String, f32>, AppError> {
|
||||||
let embedding_input = format!(
|
let embedding_input =
|
||||||
"name: {}, description: {}, type: {:?}",
|
KnowledgeEntity::embedding_input_text(draft.name, draft.description, draft.entity_type);
|
||||||
draft.name, draft.description, draft.entity_type
|
|
||||||
);
|
|
||||||
let embedding = embedding_provider.embed(&embedding_input).await?;
|
let embedding = embedding_provider.embed(&embedding_input).await?;
|
||||||
|
|
||||||
let take = MAX_RELATIONSHIP_SUGGESTIONS * 2;
|
let take = MAX_RELATIONSHIP_SUGGESTIONS * 2;
|
||||||
@@ -484,11 +483,7 @@ fn build_relationship_options(
|
|||||||
|
|
||||||
fn build_relationship_rows(
|
fn build_relationship_rows(
|
||||||
relationships: Vec<KnowledgeRelationship>,
|
relationships: Vec<KnowledgeRelationship>,
|
||||||
) -> (
|
) -> (Vec<RelationshipTableRow>, Vec<String>, String) {
|
||||||
Vec<RelationshipTableRow>,
|
|
||||||
Vec<String>,
|
|
||||||
String,
|
|
||||||
) {
|
|
||||||
let relationship_type_options = collect_relationship_type_options(&relationships);
|
let relationship_type_options = collect_relationship_type_options(&relationships);
|
||||||
let mut frequency: HashMap<String, usize> = HashMap::new();
|
let mut frequency: HashMap<String, usize> = HashMap::new();
|
||||||
let relationships = relationships
|
let relationships = relationships
|
||||||
@@ -509,10 +504,7 @@ fn build_relationship_rows(
|
|||||||
let default_relationship_type = frequency
|
let default_relationship_type = frequency
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.max_by_key(|(_, count)| *count)
|
.max_by_key(|(_, count)| *count)
|
||||||
.map_or_else(
|
.map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label);
|
||||||
|| DEFAULT_RELATIONSHIP_TYPE.to_string(),
|
|
||||||
|(label, _)| label,
|
|
||||||
);
|
|
||||||
|
|
||||||
(
|
(
|
||||||
relationships,
|
relationships,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use crate::html_state::HtmlState;
|
|||||||
use crate::middlewares::{
|
use crate::middlewares::{
|
||||||
auth_middleware::RequireUser,
|
auth_middleware::RequireUser,
|
||||||
response_middleware::{
|
response_middleware::{
|
||||||
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
|
template_with_headers, ResponseResult, TemplateResponse, TemplateResult,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use common::storage::types::{
|
use common::storage::types::{
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use axum::{
|
use axum::extract::{Query, State};
|
||||||
extract::{Query, State},
|
|
||||||
};
|
|
||||||
use axum_htmx::{HxBoosted, HxRequest};
|
use axum_htmx::{HxBoosted, HxRequest};
|
||||||
use common::storage::types::{text_content::TextContent, user::User};
|
use common::storage::types::{text_content::TextContent, user::User};
|
||||||
use retrieval_pipeline::{retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity};
|
use retrieval_pipeline::{
|
||||||
|
retrieve, RetrievalConfig, RetrievalOutput, RetrievedChunk, RetrievedEntity,
|
||||||
|
};
|
||||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||||
use std::{fmt, str::FromStr};
|
use std::{fmt, str::FromStr};
|
||||||
|
|
||||||
@@ -48,9 +48,7 @@ impl<'de> Deserialize<'de> for SearchView {
|
|||||||
Some("chunks") => SearchView::Chunks,
|
Some("chunks") => SearchView::Chunks,
|
||||||
Some("entities") => SearchView::Entities,
|
Some("entities") => SearchView::Entities,
|
||||||
Some(other) => {
|
Some(other) => {
|
||||||
return Err(de::Error::custom(format!(
|
return Err(de::Error::custom(format!("invalid search view: {other}")));
|
||||||
"invalid search view: {other}"
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -121,13 +119,12 @@ pub async fn search_result_handler(
|
|||||||
HxBoosted(is_boosted): HxBoosted,
|
HxBoosted(is_boosted): HxBoosted,
|
||||||
) -> TemplateResult {
|
) -> TemplateResult {
|
||||||
let view = params.view;
|
let view = params.view;
|
||||||
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
|
let (search_results_for_template, final_query_param_for_template) =
|
||||||
params.query
|
if let Some(actual_query) = params.query {
|
||||||
{
|
perform_search(&state, &user, actual_query, view).await?
|
||||||
perform_search(&state, &user, actual_query, view).await?
|
} else {
|
||||||
} else {
|
(Vec::<SearchResultForTemplate>::new(), String::new())
|
||||||
(Vec::<SearchResultForTemplate>::new(), String::new())
|
};
|
||||||
};
|
|
||||||
|
|
||||||
let data = AnswerData {
|
let data = AnswerData {
|
||||||
search_result: search_results_for_template,
|
search_result: search_results_for_template,
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ mod handlers;
|
|||||||
|
|
||||||
use axum::{extract::FromRef, routing::get, Router};
|
use axum::{extract::FromRef, routing::get, Router};
|
||||||
#[allow(clippy::module_name_repetitions)]
|
#[allow(clippy::module_name_repetitions)]
|
||||||
pub use handlers::{
|
pub use handlers::{search_result_handler as result_handler, SearchParams as SearchQueryParams};
|
||||||
search_result_handler as result_handler, SearchParams as SearchQueryParams,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::html_state::HtmlState;
|
use crate::html_state::HtmlState;
|
||||||
|
|
||||||
|
|||||||
@@ -31,8 +31,16 @@ impl Pagination {
|
|||||||
} else {
|
} else {
|
||||||
0
|
0
|
||||||
};
|
};
|
||||||
let start_index = if page_len == 0 { 0 } else { offset.saturating_add(1) };
|
let start_index = if page_len == 0 {
|
||||||
let end_index = if page_len == 0 { 0 } else { offset.saturating_add(page_len) };
|
0
|
||||||
|
} else {
|
||||||
|
offset.saturating_add(1)
|
||||||
|
};
|
||||||
|
let end_index = if page_len == 0 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
offset.saturating_add(page_len)
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
current_page,
|
current_page,
|
||||||
@@ -109,7 +117,11 @@ pub fn paginate_items<T>(
|
|||||||
let total_pages = if total_items == 0 {
|
let total_pages = if total_items == 0 {
|
||||||
0
|
0
|
||||||
} else {
|
} else {
|
||||||
total_items.saturating_sub(1).checked_div(per_page).unwrap_or(0).saturating_add(1)
|
total_items
|
||||||
|
.saturating_sub(1)
|
||||||
|
.checked_div(per_page)
|
||||||
|
.unwrap_or(0)
|
||||||
|
.saturating_add(1)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut current_page = requested_page.unwrap_or(1);
|
let mut current_page = requested_page.unwrap_or(1);
|
||||||
|
|||||||
@@ -9,11 +9,7 @@ use axum::{
|
|||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use common::{
|
use common::{
|
||||||
storage::{
|
storage::{db::SurrealDbClient, store::StorageManager, types::user::User},
|
||||||
db::SurrealDbClient,
|
|
||||||
store::StorageManager,
|
|
||||||
types::user::User,
|
|
||||||
},
|
|
||||||
utils::{
|
utils::{
|
||||||
config::{AppConfig, StorageKind},
|
config::{AppConfig, StorageKind},
|
||||||
embedding::EmbeddingProvider,
|
embedding::EmbeddingProvider,
|
||||||
@@ -37,24 +33,17 @@ async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
|
|||||||
.await
|
.await
|
||||||
.expect("migrations should apply");
|
.expect("migrations should apply");
|
||||||
|
|
||||||
let session_store = Arc::new(
|
let session_store = Arc::new(db.create_session_store().await.expect("session store"));
|
||||||
db.create_session_store()
|
|
||||||
.await
|
|
||||||
.expect("session store"),
|
|
||||||
);
|
|
||||||
|
|
||||||
let config = AppConfig {
|
let config = AppConfig {
|
||||||
storage: StorageKind::Memory,
|
storage: StorageKind::Memory,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let storage = StorageManager::new(&config)
|
let storage = StorageManager::new(&config).await.expect("storage manager");
|
||||||
.await
|
|
||||||
.expect("storage manager");
|
|
||||||
|
|
||||||
let embedding_provider = Arc::new(
|
let embedding_provider =
|
||||||
EmbeddingProvider::new_hashed(8).expect("embedding provider"),
|
Arc::new(EmbeddingProvider::new_hashed(8).expect("embedding provider"));
|
||||||
);
|
|
||||||
|
|
||||||
let state = HtmlState::new_with_resources(StateResources {
|
let state = HtmlState::new_with_resources(StateResources {
|
||||||
db: Arc::clone(&db),
|
db: Arc::clone(&db),
|
||||||
|
|||||||
@@ -6,11 +6,9 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::{
|
storage::types::{
|
||||||
types::{
|
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
knowledge_relationship::KnowledgeRelationship,
|
||||||
knowledge_relationship::KnowledgeRelationship,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
utils::embedding::EmbeddingProvider,
|
utils::embedding::EmbeddingProvider,
|
||||||
};
|
};
|
||||||
@@ -83,25 +81,32 @@ impl LLMEnrichmentResult {
|
|||||||
entity_concurrency: usize,
|
entity_concurrency: usize,
|
||||||
embedding_provider: &EmbeddingProvider,
|
embedding_provider: &EmbeddingProvider,
|
||||||
) -> Result<Vec<EmbeddedKnowledgeEntity>, AppError> {
|
) -> Result<Vec<EmbeddedKnowledgeEntity>, AppError> {
|
||||||
stream::iter(self.knowledge_entities.clone().into_iter().map(|entity| {
|
let tasks: Vec<_> = self
|
||||||
let mapper = Arc::clone(&mapper);
|
.knowledge_entities
|
||||||
let source_id = source_id.to_string();
|
.iter()
|
||||||
let user_id = user_id.to_string();
|
.map(|entity| {
|
||||||
|
let llm_entity = entity.clone();
|
||||||
|
let mapper = Arc::clone(&mapper);
|
||||||
|
let source_id = source_id.to_string();
|
||||||
|
let user_id = user_id.to_string();
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
create_single_entity(
|
create_single_entity(
|
||||||
&entity,
|
llm_entity,
|
||||||
&source_id,
|
&source_id,
|
||||||
&user_id,
|
&user_id,
|
||||||
mapper,
|
mapper,
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
}))
|
})
|
||||||
.buffer_unordered(entity_concurrency.max(1))
|
.collect();
|
||||||
.try_collect()
|
|
||||||
.await
|
stream::iter(tasks)
|
||||||
|
.buffer_unordered(entity_concurrency.max(1))
|
||||||
|
.try_collect()
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process_relationships(
|
fn process_relationships(
|
||||||
@@ -129,7 +134,7 @@ impl LLMEnrichmentResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn create_single_entity(
|
async fn create_single_entity(
|
||||||
llm_entity: &LLMKnowledgeEntity,
|
llm_entity: LLMKnowledgeEntity,
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
mapper: Arc<GraphMapper>,
|
mapper: Arc<GraphMapper>,
|
||||||
@@ -137,9 +142,11 @@ async fn create_single_entity(
|
|||||||
) -> Result<EmbeddedKnowledgeEntity, AppError> {
|
) -> Result<EmbeddedKnowledgeEntity, AppError> {
|
||||||
let assigned_id = mapper.get_id(&llm_entity.key)?.to_string();
|
let assigned_id = mapper.get_id(&llm_entity.key)?.to_string();
|
||||||
|
|
||||||
let embedding_input = format!(
|
let entity_type = KnowledgeEntityType::from(llm_entity.entity_type);
|
||||||
"name: {}, description: {}, type: {}",
|
let embedding_input = KnowledgeEntity::embedding_input_text(
|
||||||
llm_entity.name, llm_entity.description, llm_entity.entity_type
|
&llm_entity.name,
|
||||||
|
&llm_entity.description,
|
||||||
|
entity_type,
|
||||||
);
|
);
|
||||||
|
|
||||||
let embedding = embedding_provider.embed(&embedding_input).await?;
|
let embedding = embedding_provider.embed(&embedding_input).await?;
|
||||||
@@ -149,9 +156,9 @@ async fn create_single_entity(
|
|||||||
id: assigned_id,
|
id: assigned_id,
|
||||||
created_at: now,
|
created_at: now,
|
||||||
updated_at: now,
|
updated_at: now,
|
||||||
name: llm_entity.name.clone(),
|
name: llm_entity.name,
|
||||||
description: llm_entity.description.clone(),
|
description: llm_entity.description,
|
||||||
entity_type: KnowledgeEntityType::from(llm_entity.entity_type.clone()),
|
entity_type,
|
||||||
source_id: source_id.to_string(),
|
source_id: source_id.to_string(),
|
||||||
metadata: None,
|
metadata: None,
|
||||||
user_id: user_id.into(),
|
user_id: user_id.into(),
|
||||||
|
|||||||
@@ -48,20 +48,20 @@ const STORE_RELATIONSHIPS: &str = r"
|
|||||||
pub(super) async fn store_vector_chunks(
|
pub(super) async fn store_vector_chunks(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
task_id: &str,
|
task_id: &str,
|
||||||
chunks: &[EmbeddedTextChunk],
|
chunks: Vec<EmbeddedTextChunk>,
|
||||||
) -> Result<usize, AppError> {
|
) -> Result<usize, AppError> {
|
||||||
|
let chunk_count = chunks.len();
|
||||||
for embedded in chunks {
|
for embedded in chunks {
|
||||||
TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db)
|
|
||||||
.await?;
|
|
||||||
debug!(
|
debug!(
|
||||||
task_id = %task_id,
|
task_id = %task_id,
|
||||||
chunk_id = %embedded.chunk.id,
|
chunk_id = %embedded.chunk.id,
|
||||||
chunk_len = embedded.chunk.chunk.chars().count(),
|
chunk_len = embedded.chunk.chunk.chars().count(),
|
||||||
"chunk persisted"
|
"chunk persisted"
|
||||||
);
|
);
|
||||||
|
TextChunk::store_with_embedding(embedded.chunk, embedded.embedding, db).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(chunks.len())
|
Ok(chunk_count)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Persists knowledge entities and their relationships.
|
/// Persists knowledge entities and their relationships.
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ pub async fn persist(
|
|||||||
let entity_count = entities.len();
|
let entity_count = entities.len();
|
||||||
let relationship_count = relationships.len();
|
let relationship_count = relationships.len();
|
||||||
|
|
||||||
let chunk_count = store_vector_chunks(ctx.db, ctx.task_id.as_str(), &chunks).await?;
|
let chunk_count = store_vector_chunks(ctx.db, ctx.task_id.as_str(), chunks).await?;
|
||||||
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?;
|
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?;
|
||||||
ctx.db.store_item(text_content).await?;
|
ctx.db.store_item(text_content).await?;
|
||||||
rebuild(ctx.db).await?;
|
rebuild(ctx.db).await?;
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ impl MockServices {
|
|||||||
entity: retrieved_entity,
|
entity: retrieved_entity,
|
||||||
score: 0.8,
|
score: 0.8,
|
||||||
chunks: std::sync::Arc::new(vec![RetrievedChunk {
|
chunks: std::sync::Arc::new(vec![RetrievedChunk {
|
||||||
chunk: retrieved_chunk,
|
chunk: std::sync::Arc::new(retrieved_chunk),
|
||||||
score: 0.7,
|
score: 0.7,
|
||||||
}]),
|
}]),
|
||||||
}],
|
}],
|
||||||
|
|||||||
@@ -74,10 +74,7 @@ pub async fn extract_text_from_file(
|
|||||||
config: &AppConfig,
|
config: &AppConfig,
|
||||||
storage: &StorageManager,
|
storage: &StorageManager,
|
||||||
) -> Result<String, AppError> {
|
) -> Result<String, AppError> {
|
||||||
let file_bytes = storage
|
let file_bytes = storage.get(&file_info.path).await.map_err(AppError::from)?;
|
||||||
.get(&file_info.path)
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
let local_path = resolve_existing_local_path(storage, &file_info.path).await;
|
let local_path = resolve_existing_local_path(storage, &file_info.path).await;
|
||||||
|
|
||||||
match file_info.mime_type.as_str() {
|
match file_info.mime_type.as_str() {
|
||||||
|
|||||||
@@ -8,10 +8,7 @@ use std::sync::Arc;
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use common::{
|
use common::{
|
||||||
storage::{
|
storage::{db::SurrealDbClient, store::StorageManager},
|
||||||
db::SurrealDbClient,
|
|
||||||
store::StorageManager,
|
|
||||||
},
|
|
||||||
utils::{
|
utils::{
|
||||||
config::{get_config, AppConfig},
|
config::{get_config, AppConfig},
|
||||||
embedding::{align_fastembed_system_settings, EmbeddingProvider},
|
embedding::{align_fastembed_system_settings, EmbeddingProvider},
|
||||||
|
|||||||
@@ -67,7 +67,8 @@ pub async fn prepare_embedding_runtime(
|
|||||||
let index_dim = if mismatch {
|
let index_dim = if mismatch {
|
||||||
match role {
|
match role {
|
||||||
EmbeddingRuntimeRole::Maintainer => {
|
EmbeddingRuntimeRole::Maintainer => {
|
||||||
reconcile_embeddings(&services.db, &services.embedding_provider, target_dim).await?;
|
reconcile_embeddings(&services.db, &services.embedding_provider, target_dim)
|
||||||
|
.await?;
|
||||||
target_dim
|
target_dim
|
||||||
}
|
}
|
||||||
EmbeddingRuntimeRole::ReadOnly => {
|
EmbeddingRuntimeRole::ReadOnly => {
|
||||||
@@ -238,9 +239,7 @@ mod tests {
|
|||||||
stored_dim: usize,
|
stored_dim: usize,
|
||||||
target_dim: usize,
|
target_dim: usize,
|
||||||
) -> (super::SharedServices, std::path::PathBuf) {
|
) -> (super::SharedServices, std::path::PathBuf) {
|
||||||
let (mut services, data_dir) = init_smoke_services()
|
let (mut services, data_dir) = init_smoke_services().await.expect("smoke services");
|
||||||
.await
|
|
||||||
.expect("smoke services");
|
|
||||||
|
|
||||||
ensure_runtime(&services.db, stored_dim)
|
ensure_runtime(&services.db, stored_dim)
|
||||||
.await
|
.await
|
||||||
@@ -254,9 +253,8 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("update settings");
|
.expect("update settings");
|
||||||
|
|
||||||
services.embedding_provider = Arc::new(
|
services.embedding_provider =
|
||||||
EmbeddingProvider::new_hashed(target_dim).expect("hashed provider for test"),
|
Arc::new(EmbeddingProvider::new_hashed(target_dim).expect("hashed provider for test"));
|
||||||
);
|
|
||||||
|
|
||||||
(services, data_dir)
|
(services, data_dir)
|
||||||
}
|
}
|
||||||
@@ -270,7 +268,9 @@ mod tests {
|
|||||||
.expect("maintainer startup");
|
.expect("maintainer startup");
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
embedding_index_dimension(&services.db).await.expect("index dim"),
|
embedding_index_dimension(&services.db)
|
||||||
|
.await
|
||||||
|
.expect("index dim"),
|
||||||
Some(5),
|
Some(5),
|
||||||
"maintainer should rebuild the index to the provider dimension"
|
"maintainer should rebuild the index to the provider dimension"
|
||||||
);
|
);
|
||||||
@@ -287,7 +287,9 @@ mod tests {
|
|||||||
.expect("read-only startup");
|
.expect("read-only startup");
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
embedding_index_dimension(&services.db).await.expect("index dim"),
|
embedding_index_dimension(&services.db)
|
||||||
|
.await
|
||||||
|
.expect("index dim"),
|
||||||
Some(3),
|
Some(3),
|
||||||
"read-only server must not overwrite the index before a maintainer re-embeds"
|
"read-only server must not overwrite the index before a maintainer re-embeds"
|
||||||
);
|
);
|
||||||
@@ -297,9 +299,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn maintainer_reembeds_chunks_when_index_dimension_differs() {
|
async fn maintainer_reembeds_chunks_when_index_dimension_differs() {
|
||||||
let (mut services, data_dir) = init_smoke_services()
|
let (mut services, data_dir) = init_smoke_services().await.expect("smoke services");
|
||||||
.await
|
|
||||||
.expect("smoke services");
|
|
||||||
|
|
||||||
let mut settings = SystemSettings::get_current(&services.db)
|
let mut settings = SystemSettings::get_current(&services.db)
|
||||||
.await
|
.await
|
||||||
@@ -339,7 +339,9 @@ mod tests {
|
|||||||
.expect("maintainer startup with data");
|
.expect("maintainer startup with data");
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
embedding_index_dimension(&services.db).await.expect("index dim"),
|
embedding_index_dimension(&services.db)
|
||||||
|
.await
|
||||||
|
.expect("index dim"),
|
||||||
Some(5)
|
Some(5)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
+5
-1
@@ -95,7 +95,11 @@ mod tests {
|
|||||||
use common::storage::types::{system_settings::SystemSettings, user::User};
|
use common::storage::types::{system_settings::SystemSettings, user::User};
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
|
|
||||||
async fn build_test_app() -> (Router, Arc<common::storage::db::SurrealDbClient>, std::path::PathBuf) {
|
async fn build_test_app() -> (
|
||||||
|
Router,
|
||||||
|
Arc<common::storage::db::SurrealDbClient>,
|
||||||
|
std::path::PathBuf,
|
||||||
|
) {
|
||||||
let (services, data_dir) = init_smoke_services()
|
let (services, data_dir) = init_smoke_services()
|
||||||
.await
|
.await
|
||||||
.expect("failed to init services");
|
.expect("failed to init services");
|
||||||
|
|||||||
+2
-3
@@ -68,9 +68,8 @@ mod tests {
|
|||||||
|
|
||||||
let db = Arc::clone(&services.db);
|
let db = Arc::clone(&services.db);
|
||||||
let pipeline = Arc::new(pipeline);
|
let pipeline = Arc::new(pipeline);
|
||||||
let worker = tokio::spawn(async move {
|
let worker =
|
||||||
ingestion_pipeline::run_worker_loop(db, pipeline).await
|
tokio::spawn(async move { ingestion_pipeline::run_worker_loop(db, pipeline).await });
|
||||||
});
|
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||||
assert!(
|
assert!(
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ pub mod scoring;
|
|||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub use scoring::RetrievalCandidate;
|
||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::{
|
storage::{
|
||||||
@@ -45,7 +47,7 @@ pub(crate) fn round_score(value: f32) -> f64 {
|
|||||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RetrievedChunk {
|
pub struct RetrievedChunk {
|
||||||
pub chunk: TextChunk,
|
pub chunk: Arc<TextChunk>,
|
||||||
pub score: f32,
|
pub score: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,7 +161,9 @@ mod tests {
|
|||||||
|
|
||||||
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
assert!(!chunks.is_empty(), "Expected at least one retrieval result");
|
||||||
assert!(
|
assert!(
|
||||||
chunks.first().is_some_and(|c| c.chunk.chunk.contains("Tokio")),
|
chunks
|
||||||
|
.first()
|
||||||
|
.is_some_and(|c| c.chunk.chunk.contains("Tokio")),
|
||||||
"Expected chunk about Tokio"
|
"Expected chunk about Tokio"
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity};
|
|||||||
use super::{
|
use super::{
|
||||||
config::RetrievalConfig,
|
config::RetrievalConfig,
|
||||||
diagnostics::{AssembleStats, Diagnostics, SearchStats},
|
diagnostics::{AssembleStats, Diagnostics, SearchStats},
|
||||||
StageKind, StageTimings, RetrievalParams,
|
RetrievalParams, StageKind, StageTimings,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Mutable working state threaded through every retrieval stage.
|
/// Mutable working state threaded through every retrieval stage.
|
||||||
@@ -22,7 +22,7 @@ pub(crate) struct PipelineContext<'a> {
|
|||||||
pub user_id: String,
|
pub user_id: String,
|
||||||
pub config: RetrievalConfig,
|
pub config: RetrievalConfig,
|
||||||
pub query_embedding: Option<Vec<f32>>,
|
pub query_embedding: Option<Vec<f32>>,
|
||||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
pub chunk_values: Vec<Scored<std::sync::Arc<TextChunk>>>,
|
||||||
pub reranker: Option<RerankerLease>,
|
pub reranker: Option<RerankerLease>,
|
||||||
pub diagnostics: Option<Diagnostics>,
|
pub diagnostics: Option<Diagnostics>,
|
||||||
pub entity_results: Vec<RetrievedEntity>,
|
pub entity_results: Vec<RetrievedEntity>,
|
||||||
|
|||||||
@@ -131,14 +131,14 @@ pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
|||||||
let vector_candidates = vector_rows.len();
|
let vector_candidates = vector_rows.len();
|
||||||
let fts_candidates = fts_rows.len();
|
let fts_candidates = fts_rows.len();
|
||||||
|
|
||||||
let vector_scored: Vec<Scored<TextChunk>> = vector_rows
|
let vector_scored: Vec<Scored<Arc<TextChunk>>> = vector_rows
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
.map(|row| Scored::new(Arc::new(row.chunk)).with_vector_score(row.score))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let fts_scored: Vec<Scored<TextChunk>> = fts_rows
|
let fts_scored: Vec<Scored<Arc<TextChunk>>> = fts_rows
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|row| Scored::new(row.chunk).with_fts_score(row.score))
|
.map(|row| Scored::new(Arc::new(row.chunk)).with_fts_score(row.score))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut fts_weight = tuning.chunk_rrf_fts_weight;
|
let mut fts_weight = tuning.chunk_rrf_fts_weight;
|
||||||
@@ -222,40 +222,63 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
|||||||
/// and the contributing chunks are attached.
|
/// and the contributing chunks are attached.
|
||||||
#[instrument(level = "trace", skip_all)]
|
#[instrument(level = "trace", skip_all)]
|
||||||
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
if ctx.chunk_values.is_empty() {
|
let chunk_values = std::mem::take(&mut ctx.chunk_values);
|
||||||
|
if chunk_values.is_empty() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
|
let max_chunks = ctx.config.tuning.max_chunks_per_entity.max(1);
|
||||||
|
|
||||||
|
struct IndexedChunk {
|
||||||
|
idx: usize,
|
||||||
|
score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
let mut source_order: Vec<String> = Vec::new();
|
let mut source_order: Vec<String> = Vec::new();
|
||||||
let mut chunks_by_source: HashMap<String, Vec<RetrievedChunk>> = HashMap::new();
|
let mut chunks_by_source: HashMap<String, Vec<IndexedChunk>> = HashMap::new();
|
||||||
let mut best_score: HashMap<String, f32> = HashMap::new();
|
let mut best_score: HashMap<String, f32> = HashMap::new();
|
||||||
|
|
||||||
for scored in &ctx.chunk_values {
|
for (idx, scored) in chunk_values.iter().enumerate() {
|
||||||
let source_id = &scored.item.source_id;
|
if let Some(attached) = chunks_by_source.get_mut(&scored.item.source_id) {
|
||||||
let is_new_source = !chunks_by_source.contains_key(source_id);
|
if attached.len() < max_chunks {
|
||||||
if is_new_source {
|
attached.push(IndexedChunk {
|
||||||
source_order.push(source_id.clone());
|
idx,
|
||||||
|
score: scored.fused,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let source_id = scored.item.source_id.clone();
|
||||||
best_score.insert(source_id.clone(), scored.fused);
|
best_score.insert(source_id.clone(), scored.fused);
|
||||||
}
|
source_order.push(source_id.clone());
|
||||||
|
chunks_by_source.insert(
|
||||||
let attached = chunks_by_source
|
source_id,
|
||||||
.entry(source_id.clone())
|
vec![IndexedChunk {
|
||||||
.or_default();
|
idx,
|
||||||
if attached.len() < max_chunks {
|
score: scored.fused,
|
||||||
attached.push(RetrievedChunk {
|
}],
|
||||||
chunk: scored.item.clone(),
|
);
|
||||||
score: scored.fused,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let chunks_by_source: HashMap<String, Arc<Vec<RetrievedChunk>>> = chunks_by_source
|
let chunks_by_source: HashMap<String, Arc<Vec<RetrievedChunk>>> = chunks_by_source
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(source, chunks)| (source, Arc::new(chunks)))
|
.map(|(source, indices)| {
|
||||||
|
let chunks = indices
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|indexed| {
|
||||||
|
let scored = chunk_values.get(indexed.idx)?;
|
||||||
|
Some(RetrievedChunk {
|
||||||
|
chunk: Arc::clone(&scored.item),
|
||||||
|
score: indexed.score,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
(source, Arc::new(chunks))
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
ctx.chunk_values = chunk_values;
|
||||||
|
|
||||||
let entities =
|
let entities =
|
||||||
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
|
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
|
||||||
|
|
||||||
@@ -336,10 +359,17 @@ fn sample_scores<T, F>(items: &[Scored<T>], extractor: F) -> Vec<f32>
|
|||||||
where
|
where
|
||||||
F: FnMut(&Scored<T>) -> f32,
|
F: FnMut(&Scored<T>) -> f32,
|
||||||
{
|
{
|
||||||
items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect()
|
items
|
||||||
|
.iter()
|
||||||
|
.take(SCORE_SAMPLE_LIMIT)
|
||||||
|
.map(extractor)
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
|
fn build_chunk_rerank_documents(
|
||||||
|
chunks: &[Scored<Arc<TextChunk>>],
|
||||||
|
max_chunks: usize,
|
||||||
|
) -> Vec<String> {
|
||||||
let take = chunks.len().min(max_chunks);
|
let take = chunks.len().min(max_chunks);
|
||||||
let mut documents = Vec::with_capacity(take);
|
let mut documents = Vec::with_capacity(take);
|
||||||
let mut buffer = String::with_capacity(512);
|
let mut buffer = String::with_capacity(512);
|
||||||
@@ -363,7 +393,7 @@ fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize)
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn apply_chunk_rerank_results(
|
fn apply_chunk_rerank_results(
|
||||||
chunks: &mut Vec<Scored<TextChunk>>,
|
chunks: &mut Vec<Scored<Arc<TextChunk>>>,
|
||||||
tuning: &RetrievalTuning,
|
tuning: &RetrievalTuning,
|
||||||
results: Vec<RerankResult>,
|
results: Vec<RerankResult>,
|
||||||
) {
|
) {
|
||||||
@@ -371,7 +401,7 @@ fn apply_chunk_rerank_results(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut remaining: Vec<Option<Scored<TextChunk>>> =
|
let mut remaining: Vec<Option<Scored<Arc<TextChunk>>>> =
|
||||||
std::mem::take(chunks).into_iter().map(Some).collect();
|
std::mem::take(chunks).into_iter().map(Some).collect();
|
||||||
|
|
||||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||||
@@ -384,7 +414,7 @@ fn apply_chunk_rerank_results(
|
|||||||
clamp_unit(tuning.rerank_blend_weight)
|
clamp_unit(tuning.rerank_blend_weight)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
|
let mut reranked: Vec<Scored<Arc<TextChunk>>> = Vec::with_capacity(remaining.len());
|
||||||
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||||
if let Some(slot) = remaining.get_mut(result.index) {
|
if let Some(slot) = remaining.get_mut(result.index) {
|
||||||
if let Some(mut candidate) = slot.take() {
|
if let Some(mut candidate) = slot.take() {
|
||||||
|
|||||||
@@ -29,8 +29,7 @@ impl RerankerPool {
|
|||||||
/// Build the pool at startup.
|
/// Build the pool at startup.
|
||||||
/// `pool_size` controls max parallel reranks.
|
/// `pool_size` controls max parallel reranks.
|
||||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
|
pub fn new(pool_size: usize) -> Result<Arc<Self>, Box<AppError>> {
|
||||||
let init_options =
|
let init_options = RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
|
||||||
RerankInitOptions::new(fastembed::RerankerModel::JINARerankerV1TurboEn);
|
|
||||||
Self::new_with_options(pool_size, &init_options)
|
Self::new_with_options(pool_size, &init_options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,8 +43,7 @@ impl RerankerPool {
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
fs::create_dir_all(&init_options.cache_dir)
|
fs::create_dir_all(&init_options.cache_dir).map_err(|e| Box::new(AppError::from(e)))?;
|
||||||
.map_err(|e| Box::new(AppError::from(e)))?;
|
|
||||||
|
|
||||||
let mut engines = Vec::with_capacity(pool_size);
|
let mut engines = Vec::with_capacity(pool_size);
|
||||||
for x in 0..pool_size {
|
for x in 0..pool_size {
|
||||||
@@ -77,10 +75,7 @@ impl RerankerPool {
|
|||||||
/// This returns a lease that can perform `rerank()`.
|
/// This returns a lease that can perform `rerank()`.
|
||||||
pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
|
pub async fn checkout(self: &Arc<Self>) -> Option<RerankerLease> {
|
||||||
// Acquire a permit. This enforces backpressure.
|
// Acquire a permit. This enforces backpressure.
|
||||||
let permit = Arc::clone(&self.semaphore)
|
let permit = Arc::clone(&self.semaphore).acquire_owned().await.ok()?;
|
||||||
.acquire_owned()
|
|
||||||
.await
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
// Pick an engine.
|
// Pick an engine.
|
||||||
// This is naive: just pick based on a simple modulo counter.
|
// This is naive: just pick based on a simple modulo counter.
|
||||||
@@ -165,9 +160,9 @@ impl RerankerLease {
|
|||||||
let engine = Arc::clone(&self.engine);
|
let engine = Arc::clone(&self.engine);
|
||||||
|
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
let mut guard = engine.lock().map_err(|_| {
|
let mut guard = engine
|
||||||
AppError::InternalError("reranker engine mutex poisoned".into())
|
.lock()
|
||||||
})?;
|
.map_err(|_| AppError::InternalError("reranker engine mutex poisoned".into()))?;
|
||||||
guard
|
guard
|
||||||
.rerank(query, documents, false, None)
|
.rerank(query, documents, false, None)
|
||||||
.map_err(|e| AppError::InternalError(e.to_string()))
|
.map_err(|e| AppError::InternalError(e.to_string()))
|
||||||
|
|||||||
@@ -1,9 +1,35 @@
|
|||||||
use std::{
|
use std::{
|
||||||
cmp::Ordering,
|
cmp::Ordering,
|
||||||
collections::{hash_map::Entry, HashMap},
|
collections::{hash_map::Entry, HashMap},
|
||||||
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
use common::storage::types::StoredObject;
|
use common::storage::types::{
|
||||||
|
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Identifier access for retrieval fusion and sorting.
|
||||||
|
pub trait RetrievalCandidate {
|
||||||
|
fn candidate_id(&self) -> &str;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RetrievalCandidate for TextChunk {
|
||||||
|
fn candidate_id(&self) -> &str {
|
||||||
|
self.id()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RetrievalCandidate for Arc<TextChunk> {
|
||||||
|
fn candidate_id(&self) -> &str {
|
||||||
|
self.as_ref().id()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RetrievalCandidate for KnowledgeEntity {
|
||||||
|
fn candidate_id(&self) -> &str {
|
||||||
|
self.id()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
|
/// Holds optional subscores gathered from the vector and full-text retrieval signals.
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
@@ -102,13 +128,13 @@ pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
|||||||
|
|
||||||
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
||||||
where
|
where
|
||||||
T: StoredObject,
|
T: RetrievalCandidate,
|
||||||
{
|
{
|
||||||
items.sort_by(|a, b| {
|
items.sort_by(|a, b| {
|
||||||
b.fused
|
b.fused
|
||||||
.partial_cmp(&a.fused)
|
.partial_cmp(&a.fused)
|
||||||
.unwrap_or(Ordering::Equal)
|
.unwrap_or(Ordering::Equal)
|
||||||
.then_with(|| a.item.id().cmp(b.item.id()))
|
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,7 +148,7 @@ pub fn reciprocal_rank_fusion<T>(
|
|||||||
config: RrfConfig,
|
config: RrfConfig,
|
||||||
) -> Vec<Scored<T>>
|
) -> Vec<Scored<T>>
|
||||||
where
|
where
|
||||||
T: StoredObject,
|
T: RetrievalCandidate,
|
||||||
{
|
{
|
||||||
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
|
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
|
||||||
let k = if config.k <= 0.0 { 60.0 } else { config.k };
|
let k = if config.k <= 0.0 { 60.0 } else { config.k };
|
||||||
@@ -144,11 +170,11 @@ where
|
|||||||
b_score
|
b_score
|
||||||
.partial_cmp(&a_score)
|
.partial_cmp(&a_score)
|
||||||
.unwrap_or(Ordering::Equal)
|
.unwrap_or(Ordering::Equal)
|
||||||
.then_with(|| a.item.id().cmp(b.item.id()))
|
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
|
||||||
});
|
});
|
||||||
|
|
||||||
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
|
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
|
||||||
let id = candidate.item.id().to_owned();
|
let id = candidate.item.candidate_id().to_owned();
|
||||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||||
let contribution = vector_weight / (k + rank_f32 + 1.0);
|
let contribution = vector_weight / (k + rank_f32 + 1.0);
|
||||||
|
|
||||||
@@ -183,11 +209,11 @@ where
|
|||||||
b_score
|
b_score
|
||||||
.partial_cmp(&a_score)
|
.partial_cmp(&a_score)
|
||||||
.unwrap_or(Ordering::Equal)
|
.unwrap_or(Ordering::Equal)
|
||||||
.then_with(|| a.item.id().cmp(b.item.id()))
|
.then_with(|| a.item.candidate_id().cmp(b.item.candidate_id()))
|
||||||
});
|
});
|
||||||
|
|
||||||
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
|
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
|
||||||
let id = candidate.item.id().to_owned();
|
let id = candidate.item.candidate_id().to_owned();
|
||||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||||
let contribution = fts_weight / (k + rank_f32 + 1.0);
|
let contribution = fts_weight / (k + rank_f32 + 1.0);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user