retrieval simplfied

This commit is contained in:
Per Stark
2025-12-09 20:35:42 +01:00
parent a8d10f265c
commit a090a8c76e
55 changed files with 469 additions and 1208 deletions

View File

@@ -5,6 +5,7 @@ use tokio::task::JoinError;
use crate::storage::types::file_info::FileError;
// Core internal errors
#[allow(clippy::module_name_repetitions)]
#[derive(Error, Debug)]
pub enum AppError {
#[error("Database error: {0}")]

View File

@@ -1,3 +1,5 @@
#![allow(clippy::doc_markdown)]
//! Shared utilities and storage helpers for the workspace crates.
pub mod error;
pub mod storage;
pub mod utils;

View File

@@ -13,12 +13,14 @@ use surrealdb::{
use surrealdb_migrations::MigrationRunner;
use tracing::debug;
/// Embedded SurrealDB migration directory packaged with the crate.
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Any>,
}
#[allow(clippy::module_name_repetitions)]
pub trait ProvidesDb {
fn db(&self) -> &Arc<SurrealDbClient>;
}

View File

@@ -1,3 +1,13 @@
#![allow(
clippy::missing_docs_in_private_items,
clippy::module_name_repetitions,
clippy::items_after_statements,
clippy::arithmetic_side_effects,
clippy::cast_precision_loss,
clippy::redundant_closure_for_method_calls,
clippy::single_match_else,
clippy::uninlined_format_args
)]
use std::time::Duration;
use anyhow::{Context, Result};
@@ -234,12 +244,25 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
analyzer = FTS_ANALYZER_NAME
);
db.client
let res = db
.client
.query(fallback_query)
.await
.context("creating fallback FTS analyzer")?
.check()
.context("failed to create fallback FTS analyzer")?;
.context("creating fallback FTS analyzer")?;
if let Err(err) = res.check() {
warn!(
error = %err,
"Fallback analyzer creation failed; FTS will run without snowball/ascii analyzer ({})",
FTS_ANALYZER_NAME
);
return Err(err).context("failed to create fallback FTS analyzer");
}
warn!(
"Snowball analyzer unavailable; using fallback analyzer ({}) with lowercase+ascii only",
FTS_ANALYZER_NAME
);
Ok(())
}
@@ -466,7 +489,7 @@ async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
let rows: Vec<CountRow> = response
.take(0)
.context("failed to deserialize count() response")?;
Ok(rows.first().map(|r| r.count).unwrap_or(0))
Ok(rows.first().map_or(0, |r| r.count))
}
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {

View File

@@ -183,7 +183,7 @@ impl StorageManager {
while current.starts_with(base) && current.as_path() != base.as_path() {
match tokio::fs::remove_dir(&current).await {
Ok(_) => {}
Ok(()) => {}
Err(err) => match err.kind() {
ErrorKind::NotFound => {}
ErrorKind::DirectoryNotEmpty => break,

View File

@@ -71,6 +71,7 @@ impl Analytics {
// We need to use a direct query for COUNT aggregation
#[derive(Debug, Deserialize)]
struct CountResult {
/// Total user count.
count: i64,
}
@@ -81,7 +82,7 @@ impl Analytics {
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
Ok(result.map_or(0, |r| r.count))
}
}

View File

@@ -3,12 +3,10 @@ use bytes;
use mime_guess::from_path;
use object_store::Error as ObjectStoreError;
use sha2::{Digest, Sha256};
use std::{
io::{BufReader, Read},
path::Path,
};
use std::{io::{BufReader, Read}, path::Path};
use tempfile::NamedTempFile;
use thiserror::Error;
use tokio::task;
use tracing::info;
use uuid::Uuid;
@@ -71,21 +69,29 @@ impl FileInfo {
///
/// # Returns
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
#[allow(clippy::indexing_slicing)]
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
let mut reader = BufReader::new(file.as_file());
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192]; // 8KB buffer
let mut file_clone = file.as_file().try_clone()?;
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
let digest = task::spawn_blocking(move || -> Result<_, std::io::Error> {
let mut reader = BufReader::new(&mut file_clone);
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192]; // 8KB buffer
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
hasher.update(&buffer[..n]);
}
hasher.update(&buffer[..n]);
}
let digest = hasher.finalize();
Ok(format!("{:x}", digest))
Ok::<_, std::io::Error>(hasher.finalize())
})
.await
.map_err(std::io::Error::other)??;
Ok(format!("{digest:x}"))
}
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
@@ -103,7 +109,7 @@ impl FileInfo {
}
})
.collect();
format!("{}{}", sanitized_name, ext)
format!("{sanitized_name}{ext}")
} else {
// No extension
file_name
@@ -292,7 +298,7 @@ impl FileInfo {
storage: &StorageManager,
) -> Result<String, FileError> {
// Logical object location relative to the store root
let location = format!("{}/{}/{}", user_id, uuid, file_name);
let location = format!("{user_id}/{uuid}/{file_name}");
info!("Persisting to object location: {}", location);
let bytes = tokio::fs::read(file.path()).await?;

View File

@@ -1,3 +1,9 @@
#![allow(
clippy::result_large_err,
clippy::needless_pass_by_value,
clippy::implicit_clone,
clippy::semicolon_if_nothing_returned
)]
use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use tracing::info;
@@ -38,6 +44,7 @@ impl IngestionPayload {
/// # Returns
/// * `Result<Vec<IngestionPayload>, AppError>` - On success, returns a vector of ingress objects
/// (one per file/content type). On failure, returns an `AppError`.
#[allow(clippy::similar_names)]
pub fn create_ingestion_payload(
content: Option<String>,
context: String,

View File

@@ -1,3 +1,12 @@
#![allow(
clippy::cast_possible_wrap,
clippy::items_after_statements,
clippy::arithmetic_side_effects,
clippy::cast_sign_loss,
clippy::missing_docs_in_private_items,
clippy::trivially_copy_pass_by_ref,
clippy::expect_used
)]
use std::time::Duration;
use chrono::Duration as ChronoDuration;

View File

@@ -1,3 +1,14 @@
#![allow(
clippy::missing_docs_in_private_items,
clippy::module_name_repetitions,
clippy::match_same_arms,
clippy::format_push_string,
clippy::uninlined_format_args,
clippy::explicit_iter_loop,
clippy::items_after_statements,
clippy::get_first,
clippy::redundant_closure_for_method_calls
)]
use std::collections::HashMap;
use crate::{

View File

@@ -72,7 +72,7 @@ impl KnowledgeEntityEmbedding {
return Ok(HashMap::new());
}
let ids_list: Vec<RecordId> = entity_ids.iter().cloned().collect();
let ids_list: Vec<RecordId> = entity_ids.to_vec();
let query = format!(
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
@@ -110,6 +110,7 @@ impl KnowledgeEntityEmbedding {
}
/// Delete embeddings by source_id (via joining to knowledge_entity table)
#[allow(clippy::items_after_statements)]
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
@@ -121,6 +122,7 @@ impl KnowledgeEntityEmbedding {
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?;
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct IdRow {
id: RecordId,

View File

@@ -65,8 +65,7 @@ impl KnowledgeRelationship {
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'",
source_id
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'"
);
db_client.query(query).await?;
@@ -81,15 +80,14 @@ impl KnowledgeRelationship {
) -> Result<(), AppError> {
let mut authorized_result = db_client
.query(format!(
"SELECT * FROM relates_to WHERE id = relates_to:`{}` AND metadata.user_id = '{}'",
id, user_id
"SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'"
))
.await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
if authorized.is_empty() {
let mut exists_result = db_client
.query(format!("SELECT * FROM relates_to:`{}`", id))
.query(format!("SELECT * FROM relates_to:`{id}`"))
.await?;
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
@@ -98,11 +96,11 @@ impl KnowledgeRelationship {
"Not authorized to delete relationship".into(),
))
} else {
Err(AppError::NotFound(format!("Relationship {} not found", id)))
Err(AppError::NotFound(format!("Relationship {id} not found")))
}
} else {
db_client
.query(format!("DELETE relates_to:`{}`", id))
.query(format!("DELETE relates_to:`{id}`"))
.await?;
Ok(())
}

View File

@@ -1,3 +1,4 @@
#![allow(clippy::module_name_repetitions)]
use uuid::Uuid;
use crate::stored_object;
@@ -56,7 +57,7 @@ impl fmt::Display for Message {
pub fn format_history(history: &[Message]) -> String {
history
.iter()
.map(|msg| format!("{}", msg))
.map(|msg| format!("{msg}"))
.collect::<Vec<String>>()
.join("\n")
}

View File

@@ -1,3 +1,4 @@
#![allow(clippy::unsafe_derive_deserialize)]
use serde::{Deserialize, Serialize};
pub mod analytics;
pub mod conversation;
@@ -23,7 +24,7 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
#[macro_export]
macro_rules! stored_object {
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
use serde::{Deserialize, Deserializer, Serialize};
use surrealdb::sql::Thing;
use $crate::storage::types::StoredObject;
@@ -87,6 +88,7 @@ macro_rules! stored_object {
}
#[allow(dead_code)]
#[allow(clippy::ref_option)]
fn serialize_option_datetime<S>(
date: &Option<DateTime<Utc>>,
serializer: S,
@@ -102,6 +104,7 @@ macro_rules! stored_object {
}
#[allow(dead_code)]
#[allow(clippy::ref_option)]
fn deserialize_option_datetime<'de, D>(
deserializer: D,
) -> Result<Option<DateTime<Utc>>, D::Error>
@@ -113,6 +116,7 @@ macro_rules! stored_object {
}
$(#[$struct_attr])*
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct $name {
#[serde(deserialize_with = "deserialize_flexible_id")]
@@ -121,7 +125,7 @@ macro_rules! stored_object {
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub updated_at: DateTime<Utc>,
$( $(#[$attr])* pub $field: $ty),*
$( $(#[$field_attr])* pub $field: $ty),*
}
impl StoredObject for $name {

View File

@@ -1,4 +1,6 @@
#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)]
use std::collections::HashMap;
use std::fmt::Write;
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
@@ -18,6 +20,7 @@ stored_object!(TextChunk, "text_chunk", {
});
/// Search result including hydrated chunk.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct TextChunkSearchResult {
pub chunk: TextChunk,
@@ -98,6 +101,7 @@ impl TextChunk {
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<TextChunkSearchResult>, AppError> {
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct Row {
chunk_id: TextChunk,
@@ -160,6 +164,8 @@ impl TextChunk {
score: f32,
}
let limit = i64::try_from(take).unwrap_or(i64::MAX);
let sql = format!(
r#"
SELECT
@@ -183,7 +189,7 @@ impl TextChunk {
.query(&sql)
.bind(("terms", terms.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.bind(("limit", limit))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
@@ -245,7 +251,7 @@ impl TextChunk {
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all chunks...");
for chunk in all_chunks.iter() {
for chunk in &all_chunks {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
@@ -283,12 +289,13 @@ impl TextChunk {
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
// Use the chunk id as the embedding record id to keep a 1:1 mapping
transaction_query.push_str(&format!(
write!(
&mut transaction_query,
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
chunk_id = type::thing('text_chunk', '{id}'), \
source_id = '{source_id}', \
@@ -300,13 +307,16 @@ impl TextChunk {
embedding = embedding_str,
user_id = user_id,
source_id = source_id
));
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
}
transaction_query.push_str(&format!(
write!(
&mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
transaction_query.push_str("COMMIT TRANSACTION;");

View File

@@ -110,6 +110,11 @@ impl TextChunkEmbedding {
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids_query = format!(
"SELECT id FROM {} WHERE source_id = $source_id",
TextChunk::table_name()
@@ -120,10 +125,6 @@ impl TextChunkEmbedding {
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?;
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
if ids.is_empty() {

View File

@@ -5,6 +5,7 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use super::file_info::FileInfo;
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Deserialize, Serialize)]
pub struct TextContentSearchResult {
#[serde(deserialize_with = "deserialize_flexible_id")]

View File

@@ -1,4 +1,5 @@
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use anyhow::anyhow;
use async_trait::async_trait;
use axum_session_auth::Authentication;
use chrono_tz::Tz;
@@ -17,12 +18,16 @@ use super::{
use chrono::Duration;
use futures::try_join;
/// Result row for returning user category.
#[derive(Deserialize)]
pub struct CategoryResponse {
/// Category name tied to the user.
category: String,
}
stored_object!(User, "user", {
stored_object!(
#[allow(clippy::unsafe_derive_deserialize)]
User, "user", {
email: String,
password: String,
anonymous: bool,
@@ -35,11 +40,11 @@ stored_object!(User, "user", {
#[async_trait]
impl Authentication<User, String, Surreal<Any>> for User {
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
let db = db.unwrap();
let db = db.ok_or_else(|| anyhow!("Database handle missing"))?;
Ok(db
.select((Self::table_name(), userid.as_str()))
.await?
.unwrap())
.ok_or_else(|| anyhow!("User {userid} not found"))?)
}
fn is_authenticated(&self) -> bool {
@@ -55,14 +60,14 @@ impl Authentication<User, String, Surreal<Any>> for User {
}
}
/// Ensures a timezone string parses, defaulting to UTC when invalid.
fn validate_timezone(input: &str) -> String {
match input.parse::<Tz>() {
Ok(_) => input.to_owned(),
Err(_) => {
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
}
if input.parse::<Tz>().is_ok() {
return input.to_owned();
}
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
}
#[derive(Serialize, Deserialize, Debug, Clone)]
@@ -77,12 +82,15 @@ pub struct DashboardStats {
pub new_text_chunks_week: i64,
}
/// Helper for aggregating `SurrealDB` count responses.
#[derive(Deserialize)]
struct CountResult {
/// Row count returned by the query.
count: i64,
}
impl User {
/// Counts all objects of a given type belonging to the user.
async fn count_total<T: crate::storage::types::StoredObject>(
db: &SurrealDbClient,
user_id: &str,
@@ -94,9 +102,10 @@ impl User {
.bind(("user_id", user_id.to_string()))
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
Ok(result.map_or(0, |r| r.count))
}
/// Counts objects of a given type created after a specific timestamp.
async fn count_since<T: crate::storage::types::StoredObject>(
db: &SurrealDbClient,
user_id: &str,
@@ -112,14 +121,16 @@ impl User {
.bind(("since", surrealdb::Datetime::from(since)))
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
Ok(result.map_or(0, |r| r.count))
}
pub async fn get_dashboard_stats(
user_id: &str,
db: &SurrealDbClient,
) -> Result<DashboardStats, AppError> {
let since = chrono::Utc::now() - Duration::days(7);
let since = chrono::Utc::now()
.checked_sub_signed(Duration::days(7))
.unwrap_or_else(chrono::Utc::now);
let (
total_documents,
@@ -261,7 +272,7 @@ impl User {
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
// Generate a secure random API key
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace('-', ""));
// Update the user record with the new API key
let user: Option<Self> = db
@@ -341,6 +352,7 @@ impl User {
) -> Result<Vec<String>, AppError> {
#[derive(Deserialize)]
struct EntityTypeResponse {
/// Raw entity type value from the database.
entity_type: String,
}
@@ -358,7 +370,7 @@ impl User {
.into_iter()
.map(|item| {
let normalized = KnowledgeEntityType::from(item.entity_type);
format!("{:?}", normalized)
format!("{normalized:?}")
})
.collect();

View File

@@ -9,6 +9,7 @@ pub enum StorageKind {
Memory,
}
/// Default storage backend when none is configured.
fn default_storage_kind() -> StorageKind {
StorageKind::Local
}
@@ -23,10 +24,13 @@ pub enum PdfIngestMode {
LlmFirst,
}
/// Default PDF ingestion mode when unset.
fn default_pdf_ingest_mode() -> PdfIngestMode {
PdfIngestMode::LlmFirst
}
/// Application configuration loaded from files and environment variables.
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Deserialize, Debug)]
pub struct AppConfig {
pub openai_api_key: String,
@@ -58,14 +62,17 @@ pub struct AppConfig {
pub retrieval_strategy: Option<String>,
}
/// Default data directory for persisted assets.
fn default_data_dir() -> String {
"./data".to_string()
}
/// Default base URL used for OpenAI-compatible APIs.
fn default_base_url() -> String {
"https://api.openai.com/v1".to_string()
}
/// Whether reranking is enabled by default.
fn default_reranking_enabled() -> bool {
false
}
@@ -124,6 +131,8 @@ impl Default for AppConfig {
}
}
/// Loads the application configuration from the environment and optional config file.
#[allow(clippy::module_name_repetitions)]
pub fn get_config() -> Result<AppConfig, ConfigError> {
ensure_ort_path();

View File

@@ -16,19 +16,16 @@ use crate::{
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Supported embedding backends.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EmbeddingBackend {
OpenAI,
#[default]
FastEmbed,
Hashed,
}
impl Default for EmbeddingBackend {
fn default() -> Self {
Self::FastEmbed
}
}
impl std::str::FromStr for EmbeddingBackend {
type Err = anyhow::Error;
@@ -44,24 +41,38 @@ impl std::str::FromStr for EmbeddingBackend {
}
}
/// Wrapper around the chosen embedding backend.
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)]
pub struct EmbeddingProvider {
/// Concrete backend implementation.
inner: EmbeddingInner,
}
/// Concrete embedding implementations.
#[derive(Clone)]
enum EmbeddingInner {
/// Uses an `OpenAI`-compatible API.
OpenAI {
/// Client used to issue embedding requests.
client: Arc<Client<async_openai::config::OpenAIConfig>>,
/// Model identifier for the API.
model: String,
/// Expected output dimensions.
dimensions: u32,
},
/// Generates deterministic hashed embeddings without external calls.
Hashed {
/// Output vector length.
dimension: usize,
},
/// Uses `FastEmbed` running locally.
FastEmbed {
/// Shared `FastEmbed` model.
model: Arc<Mutex<TextEmbedding>>,
/// Model metadata used for info logging.
model_name: EmbeddingModel,
/// Output vector length.
dimension: usize,
},
}
@@ -77,8 +88,9 @@ impl EmbeddingProvider {
pub fn dimension(&self) -> usize {
match &self.inner {
EmbeddingInner::Hashed { dimension } => *dimension,
EmbeddingInner::FastEmbed { dimension, .. } => *dimension,
EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => {
*dimension
}
EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize,
}
}
@@ -172,12 +184,12 @@ impl EmbeddingProvider {
}
}
pub async fn new_openai(
pub fn new_openai(
client: Arc<Client<async_openai::config::OpenAIConfig>>,
model: String,
dimensions: u32,
) -> Result<Self> {
Ok(EmbeddingProvider {
Ok(Self {
inner: EmbeddingInner::OpenAI {
client,
model,
@@ -226,6 +238,7 @@ impl EmbeddingProvider {
}
// Helper functions for hashed embeddings
/// Generates a hashed embedding vector without external dependencies.
fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
let dim = dimension.max(1);
let mut vector = vec![0.0f32; dim];
@@ -233,15 +246,11 @@ fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
return vector;
}
let mut token_count = 0f32;
for token in tokens(text) {
token_count += 1.0;
let idx = bucket(&token, dim);
vector[idx] += 1.0;
}
if token_count == 0.0 {
return vector;
if let Some(slot) = vector.get_mut(idx) {
*slot += 1.0;
}
}
let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
@@ -254,16 +263,22 @@ fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
vector
}
/// Tokenizes the text into alphanumeric lowercase tokens.
fn tokens(text: &str) -> impl Iterator<Item = String> + '_ {
text.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|token| !token.is_empty())
.map(|token| token.to_ascii_lowercase())
.map(str::to_ascii_lowercase)
}
/// Buckets a token into the hashed embedding vector.
#[allow(clippy::arithmetic_side_effects)]
fn bucket(token: &str, dimension: usize) -> usize {
let safe_dimension = dimension.max(1);
let mut hasher = DefaultHasher::new();
token.hash(&mut hasher);
(hasher.finish() as usize) % dimension
usize::try_from(hasher.finish())
.unwrap_or_default()
% safe_dimension
}
// Backward compatibility function
@@ -274,15 +289,15 @@ pub async fn generate_embedding_with_provider(
provider.embed(input).await.map_err(AppError::from)
}
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity
/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity
/// comparisons, vector search, and other natural language processing tasks.
///
/// # Arguments
///
/// * `client`: The OpenAI client instance used to make API requests.
/// * `client`: The `OpenAI` client instance used to make API requests.
/// * `input`: The text string to generate embeddings for.
///
/// # Returns
@@ -294,9 +309,10 @@ pub async fn generate_embedding_with_provider(
/// # Errors
///
/// This function can return a `AppError` in the following cases:
/// * If the OpenAI API request fails
/// * If the `OpenAI` API request fails
/// * If the request building fails
/// * If no embedding data is received in the response
#[allow(clippy::module_name_repetitions)]
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,

View File

@@ -4,6 +4,7 @@ pub use minijinja_contrib;
pub use minijinja_embed;
use std::sync::Arc;
#[allow(clippy::module_name_repetitions)]
pub trait ProvidesTemplateEngine {
fn template_engine(&self) -> &Arc<TemplateEngine>;
}