chore: harden common storage bootstrap and slim embedded db assets

Unify embedding config, build providers from system settings, and fail
startup when index builds error or time out. Move Surreal assets under
common/db so embeds exclude crate source, and read storage via streams.
This commit is contained in:
Per Stark
2026-05-29 12:26:26 +02:00
parent 93d11b66eb
commit e3bb2935d0
62 changed files with 672 additions and 443 deletions
Generated
+1
View File
@@ -3810,6 +3810,7 @@ dependencies = [
"api-router", "api-router",
"async-openai", "async-openai",
"axum", "axum",
"chrono",
"common", "common",
"futures", "futures",
"html-router", "html-router",
+2 -2
View File
@@ -13,8 +13,8 @@ use surrealdb::{
use surrealdb_migrations::MigrationRunner; use surrealdb_migrations::MigrationRunner;
use tracing::debug; use tracing::debug;
/// Embedded SurrealDB migration directory packaged with the crate. /// Embedded SurrealDB project root (`migrations/`, `schemas/`, `.surrealdb`).
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/"); static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/db");
#[derive(Clone)] #[derive(Clone)]
pub struct SurrealDbClient { pub struct SurrealDbClient {
+162 -70
View File
@@ -9,6 +9,7 @@ use tracing::{debug, info, warn};
use crate::{error::AppError, storage::db::SurrealDbClient}; use crate::{error::AppError, storage::db::SurrealDbClient};
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50); const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
const INDEX_BUILD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer"; const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
/// HNSW index options used by runtime index creation (includes CONCURRENTLY). /// HNSW index options used by runtime index creation (includes CONCURRENTLY).
@@ -296,14 +297,10 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
return Ok("unknown".to_string()); return Ok("unknown".to_string());
}; };
let building = info.get("building"); let parsed: IndexInfoForIndex = serde_json::from_value(info)
let status = building .context("deserializing INFO FOR INDEX response")?;
.and_then(|b| b.get("status"))
.and_then(|s| s.as_str())
.unwrap_or("ready")
.to_string();
Ok(status) Ok(parsed.building_status())
} }
async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> { async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
@@ -531,8 +528,21 @@ async fn poll_index_build_status(
poll_every: Duration, poll_every: Duration,
) -> Result<()> { ) -> Result<()> {
let started_at = std::time::Instant::now(); let started_at = std::time::Instant::now();
let mut last_snapshot: Option<IndexBuildSnapshot> = None;
loop { loop {
if started_at.elapsed() >= INDEX_BUILD_TIMEOUT {
return Err(anyhow::anyhow!(
"index build timed out after {:?} for {index_name} on {table} (last status: {})",
INDEX_BUILD_TIMEOUT,
last_snapshot
.as_ref()
.map(|snapshot| snapshot.status.as_str())
.unwrap_or("unknown")
))
.with_context(|| format!("index {index_name} on table {table} did not become ready"));
}
tokio::time::sleep(poll_every).await; tokio::time::sleep(poll_every).await;
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};"); let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
@@ -546,14 +556,13 @@ async fn poll_index_build_status(
.context("failed to deserialize INFO FOR INDEX result")?; .context("failed to deserialize INFO FOR INDEX result")?;
let Some(snapshot) = parse_index_build_info(info, total_rows) else { let Some(snapshot) = parse_index_build_info(info, total_rows) else {
warn!( return Err(anyhow::anyhow!(
index = %index_name, "INFO FOR INDEX returned no data for {index_name} on {table}"
table = %table, ));
"INFO FOR INDEX returned no data; assuming index definition might be missing"
);
break;
}; };
last_snapshot = Some(snapshot.clone());
if let Some(pct) = snapshot.progress_pct { if let Some(pct) = snapshot.progress_pct {
debug!( debug!(
index = %index_name, index = %index_name,
@@ -589,25 +598,87 @@ async fn poll_index_build_status(
total = snapshot.total_rows, total = snapshot.total_rows,
"Index is ready" "Index is ready"
); );
break; return Ok(());
} }
if snapshot.status.eq_ignore_ascii_case("error") { if snapshot.status.eq_ignore_ascii_case("error") {
warn!( return Err(anyhow::anyhow!(
index = %index_name, "index build failed for {index_name} on {table}: status=error, processed={}, total={:?}",
table = %table, snapshot.processed,
status = snapshot.status, snapshot.total_rows
"Index build reported error status; stopping polling" ));
); }
break; }
}
/// `building` block from SurrealDB `INFO FOR INDEX` (concurrent index builds).
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
struct IndexBuildingProgress {
#[serde(default)]
initial: u64,
#[serde(default)]
pending: u64,
#[serde(default)]
updated: u64,
#[serde(default)]
status: String,
}
/// Top-level `INFO FOR INDEX` payload shape (SurrealDB v2.x).
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Default)]
struct IndexInfoForIndex {
#[serde(default)]
building: Option<IndexBuildingProgress>,
}
impl IndexInfoForIndex {
fn building_status(&self) -> String {
match &self.building {
None => "ready".to_string(),
Some(progress) if progress.status.is_empty() => "ready".to_string(),
Some(progress) => progress.status.clone(),
} }
} }
Ok(()) fn into_build_snapshot(self, total_rows: Option<u64>) -> IndexBuildSnapshot {
let (initial, pending, updated, status) = match self.building {
None => (0, 0, 0, "ready".to_string()),
Some(progress) => {
let status = if progress.status.is_empty() {
"ready".to_string()
} else {
progress.status
};
(progress.initial, progress.pending, progress.updated, status)
}
};
let processed = initial.saturating_add(updated);
let progress_pct = total_rows.map(|total| {
if total == 0 {
0.0
} else {
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX))
/ f64::from(u32::try_from(total).unwrap_or(1)))
.min(1.0))
* 100.0
}
});
IndexBuildSnapshot {
status,
initial,
pending,
updated,
processed,
total_rows,
progress_pct,
}
}
} }
/// Snapshot of an index build progress as reported by SurrealDB's `INFO FOR INDEX`. /// Snapshot of an index build progress as reported by SurrealDB's `INFO FOR INDEX`.
#[derive(Debug, PartialEq)] #[derive(Debug, Clone, PartialEq)]
struct IndexBuildSnapshot { struct IndexBuildSnapshot {
/// Current build status string (e.g., `"indexing"`, `"ready"`, `"error"`). /// Current build status string (e.g., `"indexing"`, `"ready"`, `"error"`).
status: String, status: String,
@@ -636,53 +707,8 @@ fn parse_index_build_info(
total_rows: Option<u64>, total_rows: Option<u64>,
) -> Option<IndexBuildSnapshot> { ) -> Option<IndexBuildSnapshot> {
let info = info?; let info = info?;
let building = info.get("building"); let parsed: IndexInfoForIndex = serde_json::from_value(info).ok()?;
Some(parsed.into_build_snapshot(total_rows))
let status = building
.and_then(|b| b.get("status"))
.and_then(|s| s.as_str())
// If there's no `building` block at all, treat as "ready" (index not building anymore)
.unwrap_or("ready")
.to_string();
let initial = building
.and_then(|b| b.get("initial"))
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
let pending = building
.and_then(|b| b.get("pending"))
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
let updated = building
.and_then(|b| b.get("updated"))
.and_then(serde_json::Value::as_u64)
.unwrap_or(0);
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
let processed = initial.saturating_add(updated);
let progress_pct = total_rows.map(|total| {
if total == 0 {
0.0
} else {
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX))
/ f64::from(u32::try_from(total).unwrap_or(1)))
.min(1.0))
* 100.0
}
});
Some(IndexBuildSnapshot {
status,
initial,
pending,
updated,
processed,
total_rows,
progress_pct,
})
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@@ -786,6 +812,72 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn index_info_for_index_deserializes_ready_status_shape() -> anyhow::Result<()> {
let info = json!({
"building": {
"status": "ready"
}
});
let parsed: IndexInfoForIndex =
serde_json::from_value(info).context("deserialize ready shape")?;
assert_eq!(parsed.building_status(), "ready");
let snapshot = parse_index_build_info(
Some(json!({
"building": { "status": "ready" }
})),
None,
)
.context("snapshot")?;
assert!(snapshot.is_ready());
assert_eq!(snapshot.initial, 0);
Ok(())
}
#[test]
fn index_info_for_index_deserializes_indexing_shape_from_surreal_docs() -> anyhow::Result<()> {
let info = json!({
"building": {
"initial": 8143,
"pending": 19,
"status": "indexing",
"updated": 80
}
});
let parsed: IndexInfoForIndex =
serde_json::from_value(info.clone()).context("deserialize indexing shape")?;
assert_eq!(parsed.building_status(), "indexing");
let snapshot = parse_index_build_info(Some(info), None).context("snapshot")?;
assert_eq!(snapshot.status, "indexing");
assert_eq!(snapshot.initial, 8143);
assert_eq!(snapshot.pending, 19);
assert_eq!(snapshot.updated, 80);
assert_eq!(snapshot.processed, 8223);
assert!(!snapshot.is_ready());
Ok(())
}
#[test]
fn parse_index_build_info_reports_error_status() -> anyhow::Result<()> {
let info = json!({
"building": {
"initial": 100,
"pending": 5,
"status": "error",
"updated": 10
}
});
let snapshot = parse_index_build_info(Some(info), Some(200)).context("snapshot")?;
assert_eq!(snapshot.status, "error");
assert!(!snapshot.is_ready());
Ok(())
}
#[test] #[test]
fn extract_dimension_parses_value() { fn extract_dimension_parses_value() {
let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;"; let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;";
+26 -13
View File
@@ -2,7 +2,7 @@ use std::io::ErrorKind;
use std::path::{Component, Path, PathBuf}; use std::path::{Component, Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, Result as AnyResult}; use anyhow::{anyhow, Context, Result as AnyResult};
use bytes::Bytes; use bytes::Bytes;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
@@ -107,15 +107,18 @@ impl StorageManager {
/// Retrieve bytes from the specified location. /// Retrieve bytes from the specified location.
/// ///
/// Returns the full contents buffered in memory. /// Reads via [`Self::get_stream`] and buffers the full object in memory.
/// ///
/// # Errors /// # Errors
/// ///
/// Returns `Err` if the location does not exist or the underlying backend fails. /// Returns `Err` if the location does not exist or the underlying backend fails.
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> { pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
let path = ObjPath::from(location); let mut stream = self.get_stream(location).await?;
let result = self.store.get(&path).await?; let mut collected = Vec::new();
result.bytes().await while let Some(chunk) = stream.next().await {
collected.extend_from_slice(&chunk?);
}
Ok(Bytes::from(collected))
} }
/// Get a streaming handle for large objects. /// Get a streaming handle for large objects.
@@ -252,7 +255,10 @@ async fn create_storage_backend(
) -> object_store::Result<(DynStorage, Option<PathBuf>)> { ) -> object_store::Result<(DynStorage, Option<PathBuf>)> {
match cfg.storage { match cfg.storage {
StorageKind::Local => { StorageKind::Local => {
let base = resolve_base_dir(cfg); let base = resolve_base_dir(cfg).map_err(|err| object_store::Error::Generic {
store: "LocalFileSystem",
source: err.into(),
})?;
if !base.exists() { if !base.exists() {
tokio::fs::create_dir_all(&base).await.map_err(|e| { tokio::fs::create_dir_all(&base).await.map_err(|e| {
object_store::Error::Generic { object_store::Error::Generic {
@@ -576,15 +582,22 @@ pub mod testing {
/// Resolve the absolute base directory used for local storage from config. /// Resolve the absolute base directory used for local storage from config.
/// ///
/// If `data_dir` is relative, it is resolved against the current working directory. /// If `data_dir` is relative, it is resolved against the process current working directory.
#[must_use] ///
pub fn resolve_base_dir(cfg: &AppConfig) -> PathBuf { /// # Errors
///
/// Returns `Err` when `data_dir` is relative and the current working directory cannot be read.
pub fn resolve_base_dir(cfg: &AppConfig) -> AnyResult<PathBuf> {
if cfg.data_dir.starts_with('/') { if cfg.data_dir.starts_with('/') {
PathBuf::from(&cfg.data_dir) Ok(PathBuf::from(&cfg.data_dir))
} else { } else {
std::env::current_dir() let cwd = std::env::current_dir().with_context(|| {
.unwrap_or_else(|_| PathBuf::from(".")) format!(
.join(&cfg.data_dir) "failed to resolve relative data_dir '{}' against the current working directory",
cfg.data_dir
)
})?;
Ok(cwd.join(&cfg.data_dir))
} }
} }
+1 -1
View File
@@ -1,4 +1,4 @@
use crate::utils::embedding::EmbeddingBackend; use crate::utils::config::EmbeddingBackend;
use crate::utils::serde_helpers::deserialize_flexible_id; use crate::utils::serde_helpers::deserialize_flexible_id;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
+49 -8
View File
@@ -1,9 +1,18 @@
use config::{Config, ConfigError, Environment, File}; use config::{Config, ConfigError, Environment, File};
use serde::Deserialize; use serde::{Deserialize, Serialize};
use std::env; use std::{env, sync::Once, str::FromStr};
use thiserror::Error;
/// Error returned when parsing an embedding backend name.
#[derive(Debug, Error, PartialEq, Eq)]
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
pub struct ParseEmbeddingBackendError {
/// The unrecognized input string.
pub input: String,
}
/// Selects the embedding backend for vector generation. /// Selects the embedding backend for vector generation.
#[derive(Clone, Copy, Deserialize, Debug, Default, PartialEq)] #[derive(Clone, Copy, Deserialize, Serialize, Debug, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum EmbeddingBackend { pub enum EmbeddingBackend {
/// Use OpenAI-compatible API for embeddings. /// Use OpenAI-compatible API for embeddings.
@@ -15,6 +24,32 @@ pub enum EmbeddingBackend {
Hashed, Hashed,
} }
impl EmbeddingBackend {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::OpenAI => "openai",
Self::FastEmbed => "fastembed",
Self::Hashed => "hashed",
}
}
}
impl FromStr for EmbeddingBackend {
type Err = ParseEmbeddingBackendError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"openai" => Ok(Self::OpenAI),
"hashed" => Ok(Self::Hashed),
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
other => Err(ParseEmbeddingBackendError {
input: other.to_string(),
}),
}
}
}
#[derive(Clone, Copy, Deserialize, Debug, PartialEq)] #[derive(Clone, Copy, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum StorageKind { pub enum StorageKind {
@@ -133,11 +168,17 @@ fn default_ingest_max_category_bytes() -> usize {
128 128
} }
static ORT_PATH_INIT: Once = Once::new();
/// Sets `ORT_DYLIB_PATH` once per process when a bundled ONNX runtime library is found.
pub fn ensure_ort_path() { pub fn ensure_ort_path() {
if env::var_os("ORT_DYLIB_PATH").is_some() { ORT_PATH_INIT.call_once(|| {
return; if env::var_os("ORT_DYLIB_PATH").is_some() {
} return;
if let Ok(mut exe) = env::current_exe() { }
let Ok(mut exe) = env::current_exe() else {
return;
};
exe.pop(); exe.pop();
if cfg!(target_os = "windows") { if cfg!(target_os = "windows") {
@@ -160,7 +201,7 @@ pub fn ensure_ort_path() {
if p.exists() { if p.exists() {
env::set_var("ORT_DYLIB_PATH", p); env::set_var("ORT_DYLIB_PATH", p);
} }
} });
} }
impl Default for AppConfig { impl Default for AppConfig {
+25 -60
View File
@@ -8,59 +8,15 @@ use std::{
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_openai::{types::CreateEmbeddingRequestArgs, Client}; use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::debug; use tracing::debug;
use crate::{ use crate::{
error::AppError, error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
utils::config::AppConfig,
}; };
/// Error returned when parsing an embedding backend name. pub use crate::utils::config::{EmbeddingBackend, ParseEmbeddingBackendError};
#[derive(Debug, Error, PartialEq, Eq)]
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
pub struct ParseEmbeddingBackendError {
/// The unrecognized input string.
pub input: String,
}
/// Supported embedding backends.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingBackend {
#[default]
OpenAI,
FastEmbed,
Hashed,
}
impl EmbeddingBackend {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::OpenAI => "openai",
Self::FastEmbed => "fastembed",
Self::Hashed => "hashed",
}
}
}
impl std::str::FromStr for EmbeddingBackend {
type Err = ParseEmbeddingBackendError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"openai" => Ok(Self::OpenAI),
"hashed" => Ok(Self::Hashed),
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
other => Err(ParseEmbeddingBackendError {
input: other.to_string(),
}),
}
}
}
/// Wrapper around the chosen embedding backend. /// Wrapper around the chosen embedding backend.
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
@@ -281,30 +237,31 @@ impl EmbeddingProvider {
}) })
} }
/// Creates an embedding provider based on application configuration. /// Creates an embedding provider from persisted settings and bootstrap config.
/// ///
/// Dispatches to the appropriate constructor based on `config.embedding_backend`: /// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
/// - `OpenAI`: Requires a valid OpenAI client /// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
/// - `FastEmbed`: Uses local embedding model /// persists the resolved backend to the database.
/// - `Hashed`: Uses deterministic hashed embeddings (for testing) pub async fn from_system_settings(
pub async fn from_config( settings: &SystemSettings,
config: &crate::utils::config::AppConfig, config: &AppConfig,
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>, openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
) -> Result<Self> { ) -> Result<Self> {
use crate::utils::config::EmbeddingBackend; let dimensions = settings.embedding_dimensions;
match config.embedding_backend { match config.embedding_backend {
EmbeddingBackend::OpenAI => { EmbeddingBackend::OpenAI => {
let client = openai_client let client = openai_client
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?; .ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
// Use defaults that match SystemSettings initial values Self::new_openai(client, settings.embedding_model.clone(), dimensions)
Self::new_openai(client, "text-embedding-3-small".to_string(), 1536)
} }
EmbeddingBackend::FastEmbed => { EmbeddingBackend::FastEmbed => {
// Use nomic-embed-text-v1.5 as the default FastEmbed model Self::new_fastembed(Some(settings.embedding_model.clone())).await
Self::new_fastembed(Some("nomic-ai/nomic-embed-text-v1.5".to_string())).await }
EmbeddingBackend::Hashed => {
let dimension = usize::try_from(dimensions)
.map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?;
Self::new_hashed(dimension)
} }
EmbeddingBackend::Hashed => Self::new_hashed(384),
} }
} }
} }
@@ -460,6 +417,14 @@ mod tests {
use crate::storage::types::system_settings::SystemSettings; use crate::storage::types::system_settings::SystemSettings;
use serde_json::json; use serde_json::json;
#[test]
fn embedding_backend_defaults_to_fastembed() {
assert_eq!(
EmbeddingBackend::default(),
EmbeddingBackend::FastEmbed
);
}
#[test] #[test]
fn embedding_backend_as_str_matches_serde_names() { fn embedding_backend_as_str_matches_serde_names() {
assert_eq!(EmbeddingBackend::OpenAI.as_str(), "openai"); assert_eq!(EmbeddingBackend::OpenAI.as_str(), "openai");
+22
View File
@@ -29,6 +29,14 @@ pub fn validate_ingest_input(
category: &str, category: &str,
file_count: usize, file_count: usize,
) -> Result<(), IngestValidationError> { ) -> Result<(), IngestValidationError> {
let text_field_bytes = content.map(str::len).unwrap_or(0) + ctx.len() + category.len();
if text_field_bytes > config.ingest_max_body_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"request text fields exceed maximum allowed body size of {} bytes",
config.ingest_max_body_bytes
)));
}
if file_count > config.ingest_max_files { if file_count > config.ingest_max_files {
return Err(IngestValidationError::BadRequest(format!( return Err(IngestValidationError::BadRequest(format!(
"too many files: maximum allowed is {}", "too many files: maximum allowed is {}",
@@ -127,4 +135,18 @@ mod tests {
assert!(result.is_ok()); assert!(result.is_ok());
} }
#[test]
fn validate_ingest_input_rejects_oversized_text_fields() {
let config = AppConfig {
ingest_max_body_bytes: 10,
..Default::default()
};
let result = validate_ingest_input(&config, Some("123456"), "ctx", "cat", 0);
assert!(matches!(
result,
Err(IngestValidationError::PayloadTooLarge(_))
));
}
} }
+1 -2
View File
@@ -27,8 +27,7 @@
filter = let filter = let
extraPaths = [ extraPaths = [
(toString ./Cargo.lock) (toString ./Cargo.lock)
(toString ./common/migrations) (toString ./common/db)
(toString ./common/schemas)
(toString ./html-router/templates) (toString ./html-router/templates)
(toString ./html-router/assets) (toString ./html-router/assets)
]; ];
+1
View File
@@ -30,6 +30,7 @@ retrieval-pipeline = { path = "../retrieval-pipeline" }
[dev-dependencies] [dev-dependencies]
tower = "0.5" tower = "0.5"
uuid = { workspace = true } uuid = { workspace = true }
chrono = { workspace = true }
common = { path = "../common", features = ["test-utils"] } common = { path = "../common", features = ["test-utils"] }
[[bin]] [[bin]]
-73
View File
@@ -1,73 +0,0 @@
use std::sync::Arc;
use async_openai::Client;
use common::{
storage::{
db::SurrealDbClient,
store::StorageManager,
},
utils::{
config::{get_config, AppConfig},
embedding::EmbeddingProvider,
},
};
use retrieval_pipeline::reranking::RerankerPool;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
pub struct SharedServices {
pub db: Arc<SurrealDbClient>,
pub openai_client: Arc<Client<async_openai::config::OpenAIConfig>>,
pub embedding_provider: Arc<EmbeddingProvider>,
pub storage: StorageManager,
pub reranker_pool: Option<Arc<RerankerPool>>,
pub config: AppConfig,
}
pub async fn init() -> anyhow::Result<SharedServices> {
tracing_subscriber::registry()
.with(fmt::layer().with_writer(std::io::stderr))
.with(EnvFilter::from_default_env())
.try_init()
.ok();
let config = get_config()?;
init_with_config(config).await
}
pub(crate) async fn init_with_config(config: AppConfig) -> anyhow::Result<SharedServices> {
let db = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await?,
);
db.apply_migrations().await?;
let openai_client = Arc::new(Client::with_config(
async_openai::config::OpenAIConfig::new()
.with_api_key(&config.openai_api_key)
.with_api_base(&config.openai_base_url),
));
let embedding_provider = Arc::new(
EmbeddingProvider::from_config(&config, Some(Arc::clone(&openai_client))).await?,
);
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
let storage = StorageManager::new(&config).await?;
Ok(SharedServices {
db,
openai_client,
embedding_provider,
storage,
reranker_pool,
config,
})
}
+136
View File
@@ -0,0 +1,136 @@
mod startup;
pub mod wiring;
pub use startup::prepare_embedding_runtime;
use std::sync::Arc;
use anyhow::Context;
use async_openai::Client;
use common::{
storage::{
db::SurrealDbClient,
store::StorageManager,
types::system_settings::SystemSettings,
},
utils::{
config::{get_config, AppConfig},
embedding::EmbeddingProvider,
},
};
use retrieval_pipeline::reranking::RerankerPool;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
pub struct SharedServices {
pub db: Arc<SurrealDbClient>,
pub openai_client: Arc<Client<async_openai::config::OpenAIConfig>>,
pub embedding_provider: Arc<EmbeddingProvider>,
pub storage: StorageManager,
pub reranker_pool: Option<Arc<RerankerPool>>,
pub config: AppConfig,
}
pub async fn init() -> anyhow::Result<SharedServices> {
tracing_subscriber::registry()
.with(fmt::layer().with_writer(std::io::stderr))
.with(EnvFilter::from_default_env())
.try_init()
.ok();
let config = get_config()?;
init_with_config(config).await
}
pub(crate) async fn init_with_config(config: AppConfig) -> anyhow::Result<SharedServices> {
let db = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
&config.surrealdb_namespace,
&config.surrealdb_database,
)
.await
.context("connect to surrealdb")?,
);
db.apply_migrations()
.await
.context("apply database migrations")?;
let settings = SystemSettings::get_current(&db)
.await
.context("load system settings")?;
let openai_client = Arc::new(Client::with_config(
async_openai::config::OpenAIConfig::new()
.with_api_key(&config.openai_api_key)
.with_api_base(&config.openai_base_url),
));
let embedding_provider = Arc::new(
EmbeddingProvider::from_system_settings(
&settings,
&config,
Some(Arc::clone(&openai_client)),
)
.await
.context("initialize embedding provider")?,
);
let reranker_pool = RerankerPool::maybe_from_config(&config)?;
let storage = StorageManager::new(&config)
.await
.context("initialize storage manager")?;
Ok(SharedServices {
db,
openai_client,
embedding_provider,
storage,
reranker_pool,
config,
})
}
#[cfg(test)]
pub(crate) mod tests {
use std::path::Path;
use anyhow::Context;
use common::utils::config::{AppConfig, EmbeddingBackend, PdfIngestMode, StorageKind};
use uuid::Uuid;
pub fn smoke_test_config(namespace: &str, database: &str, data_dir: &Path) -> AppConfig {
AppConfig {
openai_api_key: "test-key".into(),
surrealdb_address: "mem://".into(),
surrealdb_username: "root".into(),
surrealdb_password: "root".into(),
surrealdb_namespace: namespace.into(),
surrealdb_database: database.into(),
data_dir: data_dir.to_string_lossy().into_owned(),
http_port: 0,
openai_base_url: "https://example.com".into(),
storage: StorageKind::Local,
pdf_ingest_mode: PdfIngestMode::LlmFirst,
embedding_backend: EmbeddingBackend::Hashed,
..Default::default()
}
}
pub async fn init_smoke_services() -> anyhow::Result<(super::SharedServices, std::path::PathBuf)>
{
let namespace = "test_ns";
let database = format!("test_db_{}", Uuid::new_v4());
let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4()));
tokio::fs::create_dir_all(&data_dir)
.await
.context("create temp data directory")?;
let config = smoke_test_config(namespace, &database, &data_dir);
let services = super::init_with_config(config).await?;
Ok((services, data_dir))
}
}
+66
View File
@@ -0,0 +1,66 @@
use anyhow::Context;
use common::{
storage::{
db::SurrealDbClient,
indexes::ensure_runtime,
types::{
knowledge_entity::KnowledgeEntity, system_settings::SystemSettings,
text_chunk::TextChunk,
},
},
utils::embedding::EmbeddingProvider,
};
use tracing::{info, warn};
use super::SharedServices;
/// Syncs embedding settings, re-embeds stored vectors when dimensions change, and
/// ensures runtime indexes match the active embedding dimension.
pub async fn prepare_embedding_runtime(services: &SharedServices) -> anyhow::Result<SystemSettings> {
let (settings, dimensions_changed) =
SystemSettings::sync_from_embedding_provider(&services.db, &services.embedding_provider)
.await
.context("sync system settings from embedding provider")?;
if dimensions_changed {
re_embed_all(
&services.db,
&services.embedding_provider,
settings.embedding_dimensions,
)
.await?;
}
ensure_runtime(
&services.db,
settings.embedding_dimensions as usize,
)
.await
.context("ensure runtime indexes")?;
Ok(settings)
}
async fn re_embed_all(
db: &SurrealDbClient,
embedding_provider: &EmbeddingProvider,
embedding_dimensions: u32,
) -> anyhow::Result<()> {
warn!(
embedding_dimensions,
"Embedding configuration changed; re-embedding existing data"
);
info!("Re-embedding TextChunks");
TextChunk::update_all_embeddings_with_provider(db, embedding_provider)
.await
.context("re-embed text chunks after embedding dimension change")?;
info!("Re-embedding KnowledgeEntities");
KnowledgeEntity::update_all_embeddings_with_provider(db, embedding_provider)
.await
.context("re-embed knowledge entities after embedding dimension change")?;
info!("Re-embedding complete");
Ok(())
}
+54
View File
@@ -0,0 +1,54 @@
use std::sync::Arc;
use anyhow::Context;
use api_router::{api_routes_v1, api_state::ApiState};
use axum::{extract::FromRef, Router};
use html_router::{
html_routes,
html_state::{HtmlState, StateResources},
};
use super::SharedServices;
/// Builds the Minne API and HTML route subtrees without fixing the outer Axum state
/// type. SaaS consumers can merge additional routers and attach their own `AppState`
/// as long as it implements `FromRef` for `ApiState` and `HtmlState`.
pub fn minne_routes<S>(api_state: &ApiState, html_state: &HtmlState) -> Router<S>
where
S: Clone + Send + Sync + 'static,
ApiState: FromRef<S>,
HtmlState: FromRef<S>,
{
Router::new()
.nest("/api/v1", api_routes_v1(api_state))
.merge(html_routes(html_state))
}
pub fn build_api_state(services: &SharedServices) -> ApiState {
ApiState {
db: Arc::clone(&services.db),
config: services.config.clone(),
storage: services.storage.clone(),
}
}
pub async fn build_html_state(services: &SharedServices) -> anyhow::Result<HtmlState> {
let session_store = Arc::new(
services
.db
.create_session_store()
.await
.context("create session store")?,
);
Ok(HtmlState::new_with_resources(StateResources {
db: Arc::clone(&services.db),
openai_client: Arc::clone(&services.openai_client),
session_store,
storage: services.storage.clone(),
config: services.config.clone(),
reranker_pool: services.reranker_pool.clone(),
embedding_provider: Arc::clone(&services.embedding_provider),
template_engine: None,
}))
}
+55 -178
View File
@@ -2,50 +2,17 @@ mod bootstrap;
use std::sync::Arc; use std::sync::Arc;
use api_router::{api_routes_v1, api_state::ApiState}; use axum::extract::FromRef;
use axum::{extract::FromRef, Router}; use bootstrap::{
use common::{ init, prepare_embedding_runtime,
storage::{ wiring::{build_api_state, build_html_state, minne_routes},
indexes::ensure_runtime,
types::{
knowledge_entity::KnowledgeEntity, system_settings::SystemSettings,
text_chunk::TextChunk,
},
},
};
use html_router::{
html_routes,
html_state::{HtmlState, StateResources},
}; };
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
use tracing::{error, info, warn}; use tracing::info;
use tokio::task::LocalSet;
fn spawn_server_thread(
listener: tokio::net::TcpListener,
app: Router,
) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || {
let rt = match tokio::runtime::Runtime::new() {
Ok(rt) => rt,
Err(e) => {
error!("Failed to create server runtime: {e}");
return;
}
};
rt.block_on(async {
if let Err(e) = axum::serve(listener, app).await {
error!("Server error: {}", e);
}
});
})
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let services = bootstrap::init().await?; let services = init().await?;
let session_store = Arc::new(services.db.create_session_store().await?);
info!( info!(
embedding_backend = ?services.config.embedding_backend, embedding_backend = ?services.config.embedding_backend,
@@ -53,64 +20,16 @@ async fn main() -> anyhow::Result<()> {
"Embedding provider initialized" "Embedding provider initialized"
); );
let (settings, dimensions_changed) = prepare_embedding_runtime(&services).await?;
SystemSettings::sync_from_embedding_provider(&services.db, &services.embedding_provider)
.await?;
if dimensions_changed { let html_state = build_html_state(&services).await?;
warn!( let api_state = build_api_state(&services);
new_dimensions = settings.embedding_dimensions,
"Embedding configuration changed; re-embedding existing data"
);
info!("Re-embedding TextChunks"); let app = minne_routes(&api_state, &html_state).with_state(AppState {
if let Err(e) = api_state,
TextChunk::update_all_embeddings_with_provider(&services.db, &services.embedding_provider) html_state,
.await
{
error!(
"Failed to re-embed TextChunks: {}. Search results may be stale.",
e
);
}
info!("Re-embedding KnowledgeEntities");
if let Err(e) =
KnowledgeEntity::update_all_embeddings_with_provider(&services.db, &services.embedding_provider)
.await
{
error!(
"Failed to re-embed KnowledgeEntities: {}. Search results may be stale.",
e
);
}
info!("Re-embedding complete.");
}
ensure_runtime(&services.db, settings.embedding_dimensions as usize).await?;
let html_state = HtmlState::new_with_resources(StateResources {
db: Arc::clone(&services.db),
openai_client: Arc::clone(&services.openai_client),
session_store,
storage: services.storage.clone(),
config: services.config.clone(),
reranker_pool: services.reranker_pool.clone(),
embedding_provider: Arc::clone(&services.embedding_provider),
template_engine: None,
}); });
let api_state = ApiState::new(&services.config, services.storage.clone()).await?;
let app = Router::new()
.nest("/api/v1", api_routes_v1(&api_state))
.merge(html_routes(&html_state))
.with_state(AppState {
api_state,
html_state,
});
info!( info!(
"Starting server listening on 0.0.0.0:{}", "Starting server listening on 0.0.0.0:{}",
services.config.http_port services.config.http_port
@@ -118,28 +37,32 @@ async fn main() -> anyhow::Result<()> {
let serve_address = format!("0.0.0.0:{}", services.config.http_port); let serve_address = format!("0.0.0.0:{}", services.config.http_port);
let listener = tokio::net::TcpListener::bind(serve_address).await?; let listener = tokio::net::TcpListener::bind(serve_address).await?;
let server_handle = spawn_server_thread(listener, app); let worker_db = Arc::clone(&services.db);
let worker_openai = Arc::clone(&services.openai_client);
let worker_embedding = Arc::clone(&services.embedding_provider);
let worker_config = services.config.clone();
let worker_reranker = services.reranker_pool.clone();
let worker_storage = services.storage.clone();
let ingestion_pipeline = Arc::new(IngestionPipeline::new( let server = tokio::spawn(async move { axum::serve(listener, app).await });
Arc::clone(&services.db), let worker = tokio::spawn(async move {
Arc::clone(&services.openai_client),
services.config.clone(),
services.reranker_pool.clone(),
services.storage,
Arc::clone(&services.embedding_provider),
)?);
let local = LocalSet::new();
local.spawn_local(async move {
info!("Starting worker process"); info!("Starting worker process");
if let Err(e) = run_worker_loop(services.db, ingestion_pipeline).await {
error!("Worker error: {}", e);
}
});
local.await;
if let Err(e) = server_handle.join() { let ingestion_pipeline = Arc::new(IngestionPipeline::new(
error!("Server thread panicked: {:?}", e); Arc::clone(&worker_db),
worker_openai,
worker_config,
worker_reranker,
worker_storage,
worker_embedding,
)?);
run_worker_loop(worker_db, ingestion_pipeline).await
});
tokio::select! {
result = server => result??,
result = worker => result??,
} }
Ok(()) Ok(())
@@ -147,8 +70,8 @@ async fn main() -> anyhow::Result<()> {
#[derive(Clone, FromRef)] #[derive(Clone, FromRef)]
struct AppState { struct AppState {
api_state: ApiState, api_state: api_router::api_state::ApiState,
html_state: HtmlState, html_state: html_router::html_state::HtmlState,
} }
#[cfg(test)] #[cfg(test)]
@@ -160,79 +83,33 @@ mod tests {
response::Response, response::Response,
Router, Router,
}; };
use common::storage::{ use bootstrap::{
db::SurrealDbClient, prepare_embedding_runtime,
store::StorageManager, tests::init_smoke_services,
types::{system_settings::SystemSettings, user::User}, wiring::{build_api_state, build_html_state, minne_routes},
}; };
use common::utils::config::{AppConfig, EmbeddingBackend, PdfIngestMode, StorageKind}; use common::storage::types::{system_settings::SystemSettings, user::User};
use std::{path::Path, sync::Arc};
use tower::ServiceExt; use tower::ServiceExt;
use uuid::Uuid;
fn smoke_test_config(namespace: &str, database: &str, data_dir: &Path) -> AppConfig { async fn build_test_app() -> (Router, Arc<common::storage::db::SurrealDbClient>, std::path::PathBuf) {
AppConfig { let (services, data_dir) = init_smoke_services()
openai_api_key: "test-key".into(),
surrealdb_address: "mem://".into(),
surrealdb_username: "root".into(),
surrealdb_password: "root".into(),
surrealdb_namespace: namespace.into(),
surrealdb_database: database.into(),
data_dir: data_dir.to_string_lossy().into_owned(),
http_port: 0,
openai_base_url: "https://example.com".into(),
storage: StorageKind::Local,
pdf_ingest_mode: PdfIngestMode::LlmFirst,
embedding_backend: EmbeddingBackend::Hashed,
..Default::default()
}
}
async fn build_test_app() -> (Router, Arc<SurrealDbClient>, std::path::PathBuf) {
let namespace = "test_ns";
let database = format!("test_db_{}", Uuid::new_v4());
let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4()));
tokio::fs::create_dir_all(&data_dir).await
.expect("failed to create temp data directory");
let config = smoke_test_config(namespace, &database, &data_dir);
let services = crate::bootstrap::init_with_config(config.clone())
.await .await
.expect("failed to init services"); .expect("failed to init services");
let session_store = Arc::new( prepare_embedding_runtime(&services)
services .await
.db .expect("failed to prepare embedding runtime");
.create_session_store()
.await
.expect("failed to create session store"),
);
let html_state = HtmlState::new_with_resources(StateResources { let html_state = build_html_state(&services)
db: Arc::clone(&services.db), .await
openai_client: Arc::clone(&services.openai_client), .expect("failed to build html state");
session_store, let api_state = build_api_state(&services);
storage: services.storage.clone(),
config: services.config.clone(), let app = minne_routes(&api_state, &html_state).with_state(AppState {
reranker_pool: services.reranker_pool.clone(), api_state,
embedding_provider: Arc::clone(&services.embedding_provider), html_state,
template_engine: None,
}); });
let api_state = ApiState {
db: Arc::clone(&services.db),
config: services.config.clone(),
storage: services.storage,
};
let app = Router::new()
.nest("/api/v1", api_routes_v1(&api_state))
.merge(html_routes(&html_state))
.with_state(AppState {
api_state,
html_state,
});
(app, services.db, data_dir) (app, services.db, data_dir)
} }
+13 -35
View File
@@ -1,47 +1,25 @@
mod bootstrap; mod bootstrap;
use std::sync::Arc; use axum::extract::FromRef;
use bootstrap::{
use api_router::{api_routes_v1, api_state::ApiState}; init, prepare_embedding_runtime,
use axum::{extract::FromRef, Router}; wiring::{build_api_state, build_html_state, minne_routes},
use common::storage::types::system_settings::SystemSettings;
use html_router::{
html_routes,
html_state::{HtmlState, StateResources},
}; };
use tracing::info; use tracing::info;
#[tokio::main(flavor = "multi_thread", worker_threads = 2)] #[tokio::main(flavor = "multi_thread", worker_threads = 2)]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let services = bootstrap::init().await?; let services = init().await?;
prepare_embedding_runtime(&services).await?;
let session_store = Arc::new(services.db.create_session_store().await?); let html_state = build_html_state(&services).await?;
let api_state = build_api_state(&services);
let (_settings, _dimensions_changed) = let app = minne_routes(&api_state, &html_state).with_state(AppState {
SystemSettings::sync_from_embedding_provider(&services.db, &services.embedding_provider) api_state,
.await?; html_state,
let html_state = HtmlState::new_with_resources(StateResources {
db: Arc::clone(&services.db),
openai_client: Arc::clone(&services.openai_client),
session_store,
storage: services.storage.clone(),
config: services.config.clone(),
reranker_pool: services.reranker_pool.clone(),
embedding_provider: Arc::clone(&services.embedding_provider),
template_engine: None,
}); });
let api_state = ApiState::new(&services.config, services.storage).await?;
let app = Router::new()
.nest("/api/v1", api_routes_v1(&api_state))
.merge(html_routes(&html_state))
.with_state(AppState {
api_state,
html_state,
});
info!( info!(
"Starting server listening on 0.0.0.0:{}", "Starting server listening on 0.0.0.0:{}",
services.config.http_port services.config.http_port
@@ -55,6 +33,6 @@ async fn main() -> anyhow::Result<()> {
#[derive(Clone, FromRef)] #[derive(Clone, FromRef)]
struct AppState { struct AppState {
api_state: ApiState, api_state: api_router::api_state::ApiState,
html_state: HtmlState, html_state: html_router::html_state::HtmlState,
} }
+58 -1
View File
@@ -2,12 +2,14 @@ mod bootstrap;
use std::sync::Arc; use std::sync::Arc;
use bootstrap::{init, prepare_embedding_runtime};
use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop};
use tracing::info; use tracing::info;
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let services = bootstrap::init().await?; let services = init().await?;
prepare_embedding_runtime(&services).await?;
info!( info!(
embedding_backend = ?services.config.embedding_backend, embedding_backend = ?services.config.embedding_backend,
@@ -25,3 +27,58 @@ async fn main() -> anyhow::Result<()> {
run_worker_loop(services.db, ingestion_pipeline).await run_worker_loop(services.db, ingestion_pipeline).await
} }
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use chrono::Utc;
use common::storage::types::ingestion_task::{IngestionTask, DEFAULT_LEASE_SECS};
use ingestion_pipeline::pipeline::IngestionPipeline;
use crate::bootstrap::tests::init_smoke_services;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn worker_smoke_initializes_and_claims_idle() -> anyhow::Result<()> {
let (services, data_dir) = init_smoke_services().await?;
let pipeline = IngestionPipeline::new(
Arc::clone(&services.db),
Arc::clone(&services.openai_client),
services.config.clone(),
services.reranker_pool.clone(),
services.storage,
Arc::clone(&services.embedding_provider),
)?;
let worker_id = "worker-smoke";
let claimed = IngestionTask::claim_next_ready(
&services.db,
worker_id,
Utc::now(),
Duration::from_secs(DEFAULT_LEASE_SECS as u64),
)
.await?;
assert!(
claimed.is_none(),
"worker smoke test should find no pending tasks"
);
let db = Arc::clone(&services.db);
let pipeline = Arc::new(pipeline);
let worker = tokio::spawn(async move {
ingestion_pipeline::run_worker_loop(db, pipeline).await
});
tokio::time::sleep(Duration::from_millis(250)).await;
assert!(
!worker.is_finished(),
"worker loop should keep running while idle"
);
worker.abort();
tokio::fs::remove_dir_all(&data_dir).await.ok();
Ok(())
}
}