mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-30 03:10:45 +02:00
chore: harden common storage bootstrap and slim embedded db assets
Unify embedding config, build providers from system settings, and fail startup when index builds error or time out. Move Surreal assets under common/db so embeds exclude crate source, and read storage via streams.
This commit is contained in:
@@ -13,8 +13,8 @@ use surrealdb::{
|
||||
use surrealdb_migrations::MigrationRunner;
|
||||
use tracing::debug;
|
||||
|
||||
/// Embedded SurrealDB migration directory packaged with the crate.
|
||||
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
|
||||
/// Embedded SurrealDB project root (`migrations/`, `schemas/`, `.surrealdb`).
|
||||
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/db");
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SurrealDbClient {
|
||||
|
||||
+162
-70
@@ -9,6 +9,7 @@ use tracing::{debug, info, warn};
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
|
||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||
const INDEX_BUILD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
||||
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
|
||||
|
||||
/// HNSW index options used by runtime index creation (includes CONCURRENTLY).
|
||||
@@ -296,14 +297,10 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
|
||||
return Ok("unknown".to_string());
|
||||
};
|
||||
|
||||
let building = info.get("building");
|
||||
let status = building
|
||||
.and_then(|b| b.get("status"))
|
||||
.and_then(|s| s.as_str())
|
||||
.unwrap_or("ready")
|
||||
.to_string();
|
||||
let parsed: IndexInfoForIndex = serde_json::from_value(info)
|
||||
.context("deserializing INFO FOR INDEX response")?;
|
||||
|
||||
Ok(status)
|
||||
Ok(parsed.building_status())
|
||||
}
|
||||
|
||||
async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
@@ -531,8 +528,21 @@ async fn poll_index_build_status(
|
||||
poll_every: Duration,
|
||||
) -> Result<()> {
|
||||
let started_at = std::time::Instant::now();
|
||||
let mut last_snapshot: Option<IndexBuildSnapshot> = None;
|
||||
|
||||
loop {
|
||||
if started_at.elapsed() >= INDEX_BUILD_TIMEOUT {
|
||||
return Err(anyhow::anyhow!(
|
||||
"index build timed out after {:?} for {index_name} on {table} (last status: {})",
|
||||
INDEX_BUILD_TIMEOUT,
|
||||
last_snapshot
|
||||
.as_ref()
|
||||
.map(|snapshot| snapshot.status.as_str())
|
||||
.unwrap_or("unknown")
|
||||
))
|
||||
.with_context(|| format!("index {index_name} on table {table} did not become ready"));
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_every).await;
|
||||
|
||||
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
|
||||
@@ -546,14 +556,13 @@ async fn poll_index_build_status(
|
||||
.context("failed to deserialize INFO FOR INDEX result")?;
|
||||
|
||||
let Some(snapshot) = parse_index_build_info(info, total_rows) else {
|
||||
warn!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
"INFO FOR INDEX returned no data; assuming index definition might be missing"
|
||||
);
|
||||
break;
|
||||
return Err(anyhow::anyhow!(
|
||||
"INFO FOR INDEX returned no data for {index_name} on {table}"
|
||||
));
|
||||
};
|
||||
|
||||
last_snapshot = Some(snapshot.clone());
|
||||
|
||||
if let Some(pct) = snapshot.progress_pct {
|
||||
debug!(
|
||||
index = %index_name,
|
||||
@@ -589,25 +598,87 @@ async fn poll_index_build_status(
|
||||
total = snapshot.total_rows,
|
||||
"Index is ready"
|
||||
);
|
||||
break;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if snapshot.status.eq_ignore_ascii_case("error") {
|
||||
warn!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
"Index build reported error status; stopping polling"
|
||||
);
|
||||
break;
|
||||
return Err(anyhow::anyhow!(
|
||||
"index build failed for {index_name} on {table}: status=error, processed={}, total={:?}",
|
||||
snapshot.processed,
|
||||
snapshot.total_rows
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `building` block from SurrealDB `INFO FOR INDEX` (concurrent index builds).
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
|
||||
struct IndexBuildingProgress {
|
||||
#[serde(default)]
|
||||
initial: u64,
|
||||
#[serde(default)]
|
||||
pending: u64,
|
||||
#[serde(default)]
|
||||
updated: u64,
|
||||
#[serde(default)]
|
||||
status: String,
|
||||
}
|
||||
|
||||
/// Top-level `INFO FOR INDEX` payload shape (SurrealDB v2.x).
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Default)]
|
||||
struct IndexInfoForIndex {
|
||||
#[serde(default)]
|
||||
building: Option<IndexBuildingProgress>,
|
||||
}
|
||||
|
||||
impl IndexInfoForIndex {
|
||||
fn building_status(&self) -> String {
|
||||
match &self.building {
|
||||
None => "ready".to_string(),
|
||||
Some(progress) if progress.status.is_empty() => "ready".to_string(),
|
||||
Some(progress) => progress.status.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
fn into_build_snapshot(self, total_rows: Option<u64>) -> IndexBuildSnapshot {
|
||||
let (initial, pending, updated, status) = match self.building {
|
||||
None => (0, 0, 0, "ready".to_string()),
|
||||
Some(progress) => {
|
||||
let status = if progress.status.is_empty() {
|
||||
"ready".to_string()
|
||||
} else {
|
||||
progress.status
|
||||
};
|
||||
(progress.initial, progress.pending, progress.updated, status)
|
||||
}
|
||||
};
|
||||
|
||||
let processed = initial.saturating_add(updated);
|
||||
let progress_pct = total_rows.map(|total| {
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX))
|
||||
/ f64::from(u32::try_from(total).unwrap_or(1)))
|
||||
.min(1.0))
|
||||
* 100.0
|
||||
}
|
||||
});
|
||||
|
||||
IndexBuildSnapshot {
|
||||
status,
|
||||
initial,
|
||||
pending,
|
||||
updated,
|
||||
processed,
|
||||
total_rows,
|
||||
progress_pct,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Snapshot of an index build progress as reported by SurrealDB's `INFO FOR INDEX`.
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
struct IndexBuildSnapshot {
|
||||
/// Current build status string (e.g., `"indexing"`, `"ready"`, `"error"`).
|
||||
status: String,
|
||||
@@ -636,53 +707,8 @@ fn parse_index_build_info(
|
||||
total_rows: Option<u64>,
|
||||
) -> Option<IndexBuildSnapshot> {
|
||||
let info = info?;
|
||||
let building = info.get("building");
|
||||
|
||||
let status = building
|
||||
.and_then(|b| b.get("status"))
|
||||
.and_then(|s| s.as_str())
|
||||
// If there's no `building` block at all, treat as "ready" (index not building anymore)
|
||||
.unwrap_or("ready")
|
||||
.to_string();
|
||||
|
||||
let initial = building
|
||||
.and_then(|b| b.get("initial"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
let pending = building
|
||||
.and_then(|b| b.get("pending"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
let updated = building
|
||||
.and_then(|b| b.get("updated"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
|
||||
let processed = initial.saturating_add(updated);
|
||||
|
||||
let progress_pct = total_rows.map(|total| {
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX))
|
||||
/ f64::from(u32::try_from(total).unwrap_or(1)))
|
||||
.min(1.0))
|
||||
* 100.0
|
||||
}
|
||||
});
|
||||
|
||||
Some(IndexBuildSnapshot {
|
||||
status,
|
||||
initial,
|
||||
pending,
|
||||
updated,
|
||||
processed,
|
||||
total_rows,
|
||||
progress_pct,
|
||||
})
|
||||
let parsed: IndexInfoForIndex = serde_json::from_value(info).ok()?;
|
||||
Some(parsed.into_build_snapshot(total_rows))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -786,6 +812,72 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_info_for_index_deserializes_ready_status_shape() -> anyhow::Result<()> {
|
||||
let info = json!({
|
||||
"building": {
|
||||
"status": "ready"
|
||||
}
|
||||
});
|
||||
|
||||
let parsed: IndexInfoForIndex =
|
||||
serde_json::from_value(info).context("deserialize ready shape")?;
|
||||
assert_eq!(parsed.building_status(), "ready");
|
||||
|
||||
let snapshot = parse_index_build_info(
|
||||
Some(json!({
|
||||
"building": { "status": "ready" }
|
||||
})),
|
||||
None,
|
||||
)
|
||||
.context("snapshot")?;
|
||||
assert!(snapshot.is_ready());
|
||||
assert_eq!(snapshot.initial, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_info_for_index_deserializes_indexing_shape_from_surreal_docs() -> anyhow::Result<()> {
|
||||
let info = json!({
|
||||
"building": {
|
||||
"initial": 8143,
|
||||
"pending": 19,
|
||||
"status": "indexing",
|
||||
"updated": 80
|
||||
}
|
||||
});
|
||||
|
||||
let parsed: IndexInfoForIndex =
|
||||
serde_json::from_value(info.clone()).context("deserialize indexing shape")?;
|
||||
assert_eq!(parsed.building_status(), "indexing");
|
||||
|
||||
let snapshot = parse_index_build_info(Some(info), None).context("snapshot")?;
|
||||
assert_eq!(snapshot.status, "indexing");
|
||||
assert_eq!(snapshot.initial, 8143);
|
||||
assert_eq!(snapshot.pending, 19);
|
||||
assert_eq!(snapshot.updated, 80);
|
||||
assert_eq!(snapshot.processed, 8223);
|
||||
assert!(!snapshot.is_ready());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_reports_error_status() -> anyhow::Result<()> {
|
||||
let info = json!({
|
||||
"building": {
|
||||
"initial": 100,
|
||||
"pending": 5,
|
||||
"status": "error",
|
||||
"updated": 10
|
||||
}
|
||||
});
|
||||
|
||||
let snapshot = parse_index_build_info(Some(info), Some(200)).context("snapshot")?;
|
||||
assert_eq!(snapshot.status, "error");
|
||||
assert!(!snapshot.is_ready());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_dimension_parses_value() {
|
||||
let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;";
|
||||
|
||||
+26
-13
@@ -2,7 +2,7 @@ use std::io::ErrorKind;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result as AnyResult};
|
||||
use anyhow::{anyhow, Context, Result as AnyResult};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
@@ -107,15 +107,18 @@ impl StorageManager {
|
||||
|
||||
/// Retrieve bytes from the specified location.
|
||||
///
|
||||
/// Returns the full contents buffered in memory.
|
||||
/// Reads via [`Self::get_stream`] and buffers the full object in memory.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the location does not exist or the underlying backend fails.
|
||||
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
|
||||
let path = ObjPath::from(location);
|
||||
let result = self.store.get(&path).await?;
|
||||
result.bytes().await
|
||||
let mut stream = self.get_stream(location).await?;
|
||||
let mut collected = Vec::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
collected.extend_from_slice(&chunk?);
|
||||
}
|
||||
Ok(Bytes::from(collected))
|
||||
}
|
||||
|
||||
/// Get a streaming handle for large objects.
|
||||
@@ -252,7 +255,10 @@ async fn create_storage_backend(
|
||||
) -> object_store::Result<(DynStorage, Option<PathBuf>)> {
|
||||
match cfg.storage {
|
||||
StorageKind::Local => {
|
||||
let base = resolve_base_dir(cfg);
|
||||
let base = resolve_base_dir(cfg).map_err(|err| object_store::Error::Generic {
|
||||
store: "LocalFileSystem",
|
||||
source: err.into(),
|
||||
})?;
|
||||
if !base.exists() {
|
||||
tokio::fs::create_dir_all(&base).await.map_err(|e| {
|
||||
object_store::Error::Generic {
|
||||
@@ -576,15 +582,22 @@ pub mod testing {
|
||||
|
||||
/// Resolve the absolute base directory used for local storage from config.
|
||||
///
|
||||
/// If `data_dir` is relative, it is resolved against the current working directory.
|
||||
#[must_use]
|
||||
pub fn resolve_base_dir(cfg: &AppConfig) -> PathBuf {
|
||||
/// If `data_dir` is relative, it is resolved against the process current working directory.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` when `data_dir` is relative and the current working directory cannot be read.
|
||||
pub fn resolve_base_dir(cfg: &AppConfig) -> AnyResult<PathBuf> {
|
||||
if cfg.data_dir.starts_with('/') {
|
||||
PathBuf::from(&cfg.data_dir)
|
||||
Ok(PathBuf::from(&cfg.data_dir))
|
||||
} else {
|
||||
std::env::current_dir()
|
||||
.unwrap_or_else(|_| PathBuf::from("."))
|
||||
.join(&cfg.data_dir)
|
||||
let cwd = std::env::current_dir().with_context(|| {
|
||||
format!(
|
||||
"failed to resolve relative data_dir '{}' against the current working directory",
|
||||
cfg.data_dir
|
||||
)
|
||||
})?;
|
||||
Ok(cwd.join(&cfg.data_dir))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::utils::embedding::EmbeddingBackend;
|
||||
use crate::utils::config::EmbeddingBackend;
|
||||
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, sync::Once, str::FromStr};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error returned when parsing an embedding backend name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
|
||||
pub struct ParseEmbeddingBackendError {
|
||||
/// The unrecognized input string.
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
/// Selects the embedding backend for vector generation.
|
||||
#[derive(Clone, Copy, Deserialize, Debug, Default, PartialEq)]
|
||||
#[derive(Clone, Copy, Deserialize, Serialize, Debug, Default, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EmbeddingBackend {
|
||||
/// Use OpenAI-compatible API for embeddings.
|
||||
@@ -15,6 +24,32 @@ pub enum EmbeddingBackend {
|
||||
Hashed,
|
||||
}
|
||||
|
||||
impl EmbeddingBackend {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::OpenAI => "openai",
|
||||
Self::FastEmbed => "fastembed",
|
||||
Self::Hashed => "hashed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EmbeddingBackend {
|
||||
type Err = ParseEmbeddingBackendError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"openai" => Ok(Self::OpenAI),
|
||||
"hashed" => Ok(Self::Hashed),
|
||||
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||
other => Err(ParseEmbeddingBackendError {
|
||||
input: other.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Deserialize, Debug, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum StorageKind {
|
||||
@@ -133,11 +168,17 @@ fn default_ingest_max_category_bytes() -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
static ORT_PATH_INIT: Once = Once::new();
|
||||
|
||||
/// Sets `ORT_DYLIB_PATH` once per process when a bundled ONNX runtime library is found.
|
||||
pub fn ensure_ort_path() {
|
||||
if env::var_os("ORT_DYLIB_PATH").is_some() {
|
||||
return;
|
||||
}
|
||||
if let Ok(mut exe) = env::current_exe() {
|
||||
ORT_PATH_INIT.call_once(|| {
|
||||
if env::var_os("ORT_DYLIB_PATH").is_some() {
|
||||
return;
|
||||
}
|
||||
let Ok(mut exe) = env::current_exe() else {
|
||||
return;
|
||||
};
|
||||
exe.pop();
|
||||
|
||||
if cfg!(target_os = "windows") {
|
||||
@@ -160,7 +201,7 @@ pub fn ensure_ort_path() {
|
||||
if p.exists() {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
|
||||
@@ -8,59 +8,15 @@ use std::{
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
|
||||
/// Error returned when parsing an embedding backend name.
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
|
||||
pub struct ParseEmbeddingBackendError {
|
||||
/// The unrecognized input string.
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
/// Supported embedding backends.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EmbeddingBackend {
|
||||
#[default]
|
||||
OpenAI,
|
||||
FastEmbed,
|
||||
Hashed,
|
||||
}
|
||||
|
||||
impl EmbeddingBackend {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::OpenAI => "openai",
|
||||
Self::FastEmbed => "fastembed",
|
||||
Self::Hashed => "hashed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for EmbeddingBackend {
|
||||
type Err = ParseEmbeddingBackendError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"openai" => Ok(Self::OpenAI),
|
||||
"hashed" => Ok(Self::Hashed),
|
||||
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||
other => Err(ParseEmbeddingBackendError {
|
||||
input: other.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub use crate::utils::config::{EmbeddingBackend, ParseEmbeddingBackendError};
|
||||
|
||||
/// Wrapper around the chosen embedding backend.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
@@ -281,30 +237,31 @@ impl EmbeddingProvider {
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an embedding provider based on application configuration.
|
||||
/// Creates an embedding provider from persisted settings and bootstrap config.
|
||||
///
|
||||
/// Dispatches to the appropriate constructor based on `config.embedding_backend`:
|
||||
/// - `OpenAI`: Requires a valid OpenAI client
|
||||
/// - `FastEmbed`: Uses local embedding model
|
||||
/// - `Hashed`: Uses deterministic hashed embeddings (for testing)
|
||||
pub async fn from_config(
|
||||
config: &crate::utils::config::AppConfig,
|
||||
/// Model name and dimensions come from [`SystemSettings`]. The active backend is taken
|
||||
/// from `config.embedding_backend` at startup; [`SystemSettings::sync_from_embedding_provider`]
|
||||
/// persists the resolved backend to the database.
|
||||
pub async fn from_system_settings(
|
||||
settings: &SystemSettings,
|
||||
config: &AppConfig,
|
||||
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||
) -> Result<Self> {
|
||||
use crate::utils::config::EmbeddingBackend;
|
||||
|
||||
let dimensions = settings.embedding_dimensions;
|
||||
match config.embedding_backend {
|
||||
EmbeddingBackend::OpenAI => {
|
||||
let client = openai_client
|
||||
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
|
||||
// Use defaults that match SystemSettings initial values
|
||||
Self::new_openai(client, "text-embedding-3-small".to_string(), 1536)
|
||||
Self::new_openai(client, settings.embedding_model.clone(), dimensions)
|
||||
}
|
||||
EmbeddingBackend::FastEmbed => {
|
||||
// Use nomic-embed-text-v1.5 as the default FastEmbed model
|
||||
Self::new_fastembed(Some("nomic-ai/nomic-embed-text-v1.5".to_string())).await
|
||||
Self::new_fastembed(Some(settings.embedding_model.clone())).await
|
||||
}
|
||||
EmbeddingBackend::Hashed => {
|
||||
let dimension = usize::try_from(dimensions)
|
||||
.map_err(|_| anyhow!("embedding_dimensions exceeds usize::MAX"))?;
|
||||
Self::new_hashed(dimension)
|
||||
}
|
||||
EmbeddingBackend::Hashed => Self::new_hashed(384),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -460,6 +417,14 @@ mod tests {
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn embedding_backend_defaults_to_fastembed() {
|
||||
assert_eq!(
|
||||
EmbeddingBackend::default(),
|
||||
EmbeddingBackend::FastEmbed
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embedding_backend_as_str_matches_serde_names() {
|
||||
assert_eq!(EmbeddingBackend::OpenAI.as_str(), "openai");
|
||||
|
||||
@@ -29,6 +29,14 @@ pub fn validate_ingest_input(
|
||||
category: &str,
|
||||
file_count: usize,
|
||||
) -> Result<(), IngestValidationError> {
|
||||
let text_field_bytes = content.map(str::len).unwrap_or(0) + ctx.len() + category.len();
|
||||
if text_field_bytes > config.ingest_max_body_bytes {
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"request text fields exceed maximum allowed body size of {} bytes",
|
||||
config.ingest_max_body_bytes
|
||||
)));
|
||||
}
|
||||
|
||||
if file_count > config.ingest_max_files {
|
||||
return Err(IngestValidationError::BadRequest(format!(
|
||||
"too many files: maximum allowed is {}",
|
||||
@@ -127,4 +135,18 @@ mod tests {
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_ingest_input_rejects_oversized_text_fields() {
|
||||
let config = AppConfig {
|
||||
ingest_max_body_bytes: 10,
|
||||
..Default::default()
|
||||
};
|
||||
let result = validate_ingest_input(&config, Some("123456"), "ctx", "cat", 0);
|
||||
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(IngestValidationError::PayloadTooLarge(_))
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user