chore: technical maintenance, reduced duplication

This commit is contained in:
Per Stark
2026-06-18 14:58:13 +02:00
parent fb51a8b55f
commit b3d42d2586
12 changed files with 546 additions and 561 deletions
+153 -93
View File
@@ -1,9 +1,11 @@
use super::types::StoredObject;
use super::types::{EmbeddingRecord, HasEmbedding, StoredObject};
use crate::error::AppError;
use axum_session::{SessionConfig, SessionError, SessionStore};
use axum_session_surreal::SessionSurrealPool;
use futures::Stream;
use include_dir::{include_dir, Dir};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::{ops::Deref, sync::Arc};
use surrealdb::{
engine::any::{connect, Any},
@@ -26,20 +28,6 @@ pub trait ProvidesDb {
}
impl SurrealDbClient {
/// Initialize a new database client.
///
/// # Arguments
///
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
/// * `username` — Root username for authentication.
/// * `password` — Root password for authentication.
/// * `namespace` — SurrealDB namespace to use.
/// * `database` — SurrealDB database to use.
///
/// # Errors
///
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
/// In-memory (`mem://`) connections skip authentication.
pub async fn new(
address: &str,
username: &str,
@@ -49,30 +37,15 @@ impl SurrealDbClient {
) -> Result<Self, Error> {
let db = connect(address).await?;
// Skip sign-in for in-memory engine (no auth support)
if !address.starts_with("mem://") {
db.signin(Root { username, password }).await?;
}
// Set namespace
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
/// Initialize a new database client using namespace-level authentication.
///
/// # Arguments
///
/// * `address` — Database connection string.
/// * `namespace` — SurrealDB namespace to use (also used for auth).
/// * `username` — Namespace username for authentication.
/// * `password` — Namespace password for authentication.
/// * `database` — SurrealDB database to use.
///
/// # Errors
///
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
pub async fn new_with_namespace_user(
address: &str,
namespace: &str,
@@ -91,11 +64,6 @@ impl SurrealDbClient {
Ok(SurrealDbClient { client: db })
}
/// Create an Axum session store backed by SurrealDB.
///
/// # Errors
///
/// Returns `SessionError` if the session store configuration or table creation fails.
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
@@ -109,15 +77,6 @@ impl SurrealDbClient {
.await
}
/// Applies all pending database migrations found in the embedded MIGRATIONS_DIR.
///
/// This function should be called during application startup, after connecting to
/// the database and selecting the appropriate namespace and database, but before
/// the application starts performing operations that rely on the schema.
///
/// # Errors
///
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
pub async fn apply_migrations(&self) -> Result<(), AppError> {
debug!("Applying migrations");
MigrationRunner::new(&self.client)
@@ -129,15 +88,6 @@ impl SurrealDbClient {
Ok(())
}
/// Store an object in SurrealDB.
///
/// # Arguments
///
/// * `item` — The item to store. Must implement `StoredObject`.
///
/// # Errors
///
/// Returns `Err` if the database create operation fails.
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
@@ -148,13 +98,6 @@ impl SurrealDbClient {
.await
}
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
///
/// Useful when a single record should be replaced by id (admin updates, embedding rows, etc.).
///
/// # Errors
///
/// Returns `Err` if the database upsert operation fails.
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
@@ -166,11 +109,6 @@ impl SurrealDbClient {
.await
}
/// Retrieve all objects from a table.
///
/// # Errors
///
/// Returns `Err` if the database select operation fails.
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
where
T: for<'de> StoredObject,
@@ -178,16 +116,6 @@ impl SurrealDbClient {
self.client.select(T::table_name()).await
}
/// Retrieve a single object by its ID.
///
/// # Arguments
///
/// * `id` — The ID of the item to retrieve.
///
/// # Errors
///
/// Returns `Err` if the database select operation fails.
/// Returns `Ok(None)` if no record with the given ID exists.
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
@@ -195,16 +123,6 @@ impl SurrealDbClient {
self.client.select((T::table_name(), id)).await
}
/// Delete a single object by its ID.
///
/// # Arguments
///
/// * `id` — The ID of the item to delete.
///
/// # Errors
///
/// Returns `Err` if the database delete operation fails.
/// Returns `Ok(None)` if no record with the given ID exists.
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
@@ -212,11 +130,6 @@ impl SurrealDbClient {
self.client.delete((T::table_name(), id)).await
}
/// Listen to a table for real-time updates via a live query stream.
///
/// # Errors
///
/// Returns `Err` if the database live query subscription fails.
pub async fn listen<T>(
&self,
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
@@ -225,6 +138,156 @@ impl SurrealDbClient {
{
self.client.select(T::table_name()).live().await
}
/// Atomically store an entity and its embedding vector in a single
/// SurrealDB transaction.
///
/// Creates (or overwrites) the entity row and upserts the linked
/// embedding record. The embedding dimension is validated against
/// `embedding_dimensions` before the query is issued.
pub async fn store_with_embedding<E>(
&self,
entity: E,
embedding: Vec<f32>,
embedding_dimensions: usize,
) -> Result<(), AppError>
where
E: HasEmbedding + Serialize + Send + Sync + 'static,
<E as HasEmbedding>::Embedding: Serialize + Send + Sync,
{
E::Embedding::validate_dimension(&embedding, embedding_dimensions)?;
let entity_id = entity.id().to_string();
let emb = <E as HasEmbedding>::Embedding::new(
&entity_id,
entity.source_id().to_string(),
embedding,
entity.user_id().to_string(),
E::table_name(),
);
let sql = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{et}', $id) CONTENT $entity;
UPSERT type::thing('{emt}', $id) CONTENT $emb;
COMMIT TRANSACTION;
",
et = E::table_name(),
emt = <E as HasEmbedding>::Embedding::table_name(),
);
self.client
.query(sql)
.bind(("id", entity_id))
.bind(("entity", entity))
.bind(("emb", emb))
.await?
.check()?;
Ok(())
}
/// Delete all entity and embedding rows matching a given `source_id`.
///
/// Runs inside a SurrealDB transaction so that entity and embedding
/// deletes are atomic.
pub async fn delete_by_source_id<E>(&self, source_id: &str) -> Result<(), AppError>
where
E: HasEmbedding,
E::Embedding: Send + Sync,
{
self.client
.query("BEGIN TRANSACTION;")
.query(format!(
"DELETE FROM {} WHERE source_id = $source_id;",
E::Embedding::table_name()
))
.query(format!(
"DELETE FROM {} WHERE source_id = $source_id;",
E::table_name()
))
.query("COMMIT TRANSACTION;")
.bind(("source_id", source_id.to_owned()))
.await?
.check()?;
Ok(())
}
/// Vector similarity search over entities using HNSW index.
///
/// Performs a cosine-similarity search against the embedding table,
/// fetches the corresponding entity rows server-side via `FETCH`,
/// and returns `(entity, score)` pairs ordered by descending
/// similarity. Orphaned embeddings (entity deleted but its
/// embedding row remains) are logged as a warning and dropped.
///
/// This is a single round-trip — SurrealDB resolves the link field
/// (`entity_id` or `chunk_id`) inside the query engine.
pub async fn vector_search<E, Emb>(
&self,
take: usize,
query_embedding: &[f32],
user_id: &str,
) -> Result<Vec<(E, f32)>, AppError>
where
E: StoredObject + DeserializeOwned + Clone + Send + Sync,
Emb: EmbeddingRecord + Send + Sync,
{
// Generic row that works with both `entity_id` and `chunk_id` link
// fields via `#[serde(alias)]`. SurrealDB's `FETCH` resolves the link
// server-side so we get the full entity in a single round-trip.
#[derive(serde::Deserialize)]
struct FetchRow<Ent> {
score: f32,
#[serde(alias = "entity_id", alias = "chunk_id")]
entity: Option<Ent>,
}
let link_field = Emb::link_field();
let sql = format!(
r#"
SELECT
{link_field},
vector::similarity::cosine(embedding, $embedding) AS score
FROM {emb_table}
WHERE user_id = $user_id
AND embedding <|{take},100|> $embedding
ORDER BY score DESC
LIMIT {take}
FETCH {link_field}
"#,
link_field = link_field,
emb_table = Emb::table_name(),
take = take,
);
let mut response = self
.client
.query(sql)
.bind(("embedding", query_embedding.to_vec()))
.bind(("user_id", user_id.to_string()))
.await?;
response = response.check()?;
let rows: Vec<FetchRow<E>> = response.take(0)?;
let mut results = Vec::with_capacity(rows.len());
for r in rows {
if let Some(entity) = r.entity {
results.push((entity, r.score));
} else {
tracing::warn!(
"Vector search hit orphaned {} row with missing {link_field}",
Emb::table_name()
);
}
}
Ok(results)
}
}
impl Deref for SurrealDbClient {
@@ -237,12 +300,9 @@ impl Deref for SurrealDbClient {
#[cfg(any(test, feature = "test-utils"))]
impl SurrealDbClient {
/// Create an in-memory SurrealDB client for testing.
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
let db = connect("mem://").await?;
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
}
+11 -11
View File
@@ -9,10 +9,7 @@ use tracing::{debug, error, info, warn};
use crate::{
error::AppError,
storage::{
db::SurrealDbClient,
types::system_settings::SystemSettings,
},
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
@@ -231,9 +228,7 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
///
/// Returns `AppError::InternalError` if any rebuild operation fails.
pub async fn rebuild_runtime(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_runtime_inner(db)
.await
.map_err(AppError::internal)
rebuild_runtime_inner(db).await.map_err(AppError::internal)
}
/// Returns whether a scheduled index rebuild is due based on the persisted last-run time.
@@ -525,8 +520,7 @@ async fn rebuild_existing_index_in_place(
if !index_exists(db, table, index_name).await? {
debug!(
index = index_name,
table,
"Skipping in-place rebuild because index is missing"
table, "Skipping in-place rebuild because index is missing"
);
return Ok(());
}
@@ -1074,7 +1068,11 @@ mod tests {
assert!(!scheduled_index_rebuild_due(None, 86_400, now));
assert!(!scheduled_index_rebuild_due(Some(last), 0, now));
assert!(!scheduled_index_rebuild_due(Some(now - chrono::Duration::hours(1)), 86_400, now));
assert!(!scheduled_index_rebuild_due(
Some(now - chrono::Duration::hours(1)),
86_400,
now
));
assert!(scheduled_index_rebuild_due(Some(last), 86_400, now));
}
@@ -1087,7 +1085,9 @@ mod tests {
.context("in-memory db")?;
db.apply_migrations().await.context("migrations")?;
ensure_runtime(&db, 8).await.context("ensure runtime indexes")?;
ensure_runtime(&db, 8)
.await
.context("ensure runtime indexes")?;
rebuild_runtime(&db)
.await
+38 -94
View File
@@ -4,10 +4,13 @@ use std::fmt::Write;
use crate::{
error::AppError,
storage::db::SurrealDbClient,
storage::indexes::hnsw_index_overwrite_sql,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
storage::types::system_settings::SystemSettings,
storage::{
db::SurrealDbClient,
indexes::hnsw_index_overwrite_sql,
types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
types::system_settings::SystemSettings,
types::{EmbeddingRecord, HasEmbedding},
},
stored_object,
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
};
@@ -70,6 +73,18 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
user_id: String
});
impl HasEmbedding for KnowledgeEntity {
type Embedding = KnowledgeEntityEmbedding;
fn source_id(&self) -> &str {
&self.source_id
}
fn user_id(&self) -> &str {
&self.user_id
}
}
impl KnowledgeEntity {
#[must_use]
pub fn new(
@@ -227,22 +242,9 @@ impl KnowledgeEntity {
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
db: &SurrealDbClient,
) -> Result<(), AppError> {
// Delete embeddings first, while we can still look them up via the entity's source_id
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
db_client
.client
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
.bind(("table", Self::table_name()))
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::from)?
.check()
.map_err(AppError::from)?;
Ok(())
db.delete_by_source_id::<Self>(source_id).await
}
/// Atomically store one knowledge entity and its embedding (single-record path).
@@ -254,38 +256,8 @@ impl KnowledgeEntity {
embedding_dimensions: usize,
db: &SurrealDbClient,
) -> Result<(), AppError> {
KnowledgeEntityEmbedding::validate_dimension(&embedding, embedding_dimensions)?;
let entity_id = entity.id.clone();
let emb = KnowledgeEntityEmbedding::new(
&entity_id,
entity.source_id.clone(),
embedding,
entity.user_id.clone(),
);
let query = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb;
COMMIT TRANSACTION;
",
entity_table = Self::table_name(),
emb_table = KnowledgeEntityEmbedding::table_name(),
);
db.client
.query(query)
.bind(("entity_id", entity_id))
.bind(("entity", entity))
.bind(("emb", emb))
db.store_with_embedding(entity, embedding, embedding_dimensions)
.await
.map_err(AppError::from)?
.check()
.map_err(AppError::from)?;
Ok(())
}
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
@@ -295,48 +267,14 @@ impl KnowledgeEntity {
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
#[derive(Deserialize)]
struct Row {
entity_id: Option<KnowledgeEntity>,
score: f32,
}
let sql = format!(
r#"
SELECT
entity_id,
vector::similarity::cosine(embedding, $embedding) AS score
FROM {emb_table}
WHERE user_id = $user_id
AND embedding <|{take},100|> $embedding
ORDER BY score DESC
LIMIT {take}
FETCH entity_id;
"#,
emb_table = KnowledgeEntityEmbedding::table_name(),
take = take
);
let mut response = db
.query(&sql)
.bind(("embedding", query_embedding.to_vec()))
.bind(("user_id", user_id.to_string()))
db.vector_search::<Self, KnowledgeEntityEmbedding>(take, query_embedding, user_id)
.await
.map_err(AppError::from)?;
response = response.check().map_err(AppError::from)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
Ok(rows
.into_iter()
.filter_map(|r| {
r.entity_id.map(|entity| KnowledgeEntitySearchResult {
entity,
score: r.score,
})
.map(|results| {
results
.into_iter()
.map(|(entity, score)| KnowledgeEntitySearchResult { entity, score })
.collect()
})
.collect())
}
pub async fn patch(
@@ -362,7 +300,13 @@ impl KnowledgeEntity {
settings.embedding_dimensions as usize,
)?;
let emb = KnowledgeEntityEmbedding::new(id, entity.source_id, embedding, entity.user_id);
let emb = KnowledgeEntityEmbedding::new(
id,
entity.source_id,
embedding,
entity.user_id,
Self::table_name(),
);
let now = Utc::now();
@@ -916,7 +860,7 @@ mod tests {
assert_eq!(stored_embeddings.len(), 1);
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
let fetched_emb = KnowledgeEntityEmbedding::get_by_record_id(&db, &rid)
.await
.with_context(|| "fetch embedding".to_string())?;
assert!(fetched_emb.is_some());
@@ -999,11 +943,11 @@ mod tests {
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e1)
.await
.with_context(|| "get embedding e1".to_string())?
.is_some());
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e2)
.await
.with_context(|| "get embedding e2".to_string())?
.is_some());
@@ -4,7 +4,7 @@ use surrealdb::RecordId;
use crate::{
error::AppError,
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
storage::{db::SurrealDbClient, types::EmbeddingRecord},
stored_object,
};
@@ -17,72 +17,48 @@ stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
user_id: String
});
impl KnowledgeEntityEmbedding {
/// Recreate the HNSW index with a new embedding dimension.
pub async fn redefine_hnsw_index(
db: &SurrealDbClient,
dimension: usize,
) -> Result<(), AppError> {
let query = hnsw_index_redefine_transaction_sql(
"idx_embedding_knowledge_entity_embedding",
Self::table_name(),
dimension,
);
let res = db.client.query(query).await.map_err(AppError::from)?;
res.check().map_err(AppError::from)?;
Ok(())
impl EmbeddingRecord for KnowledgeEntityEmbedding {
fn link_field() -> &'static str {
"entity_id"
}
/// Validates that an embedding vector matches the configured HNSW dimension.
#[allow(clippy::result_large_err)]
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
if embedding.len() != expected {
return Err(AppError::Validation(format!(
"embedding dimension mismatch: got {}, expected {expected}",
embedding.len()
)));
}
Ok(())
fn index_name() -> &'static str {
"idx_embedding_knowledge_entity_embedding"
}
/// Create a new knowledge entity embedding.
///
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
#[must_use]
pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
fn source_id(&self) -> &str {
&self.source_id
}
fn user_id(&self) -> &str {
&self.user_id
}
fn embedding(&self) -> &[f32] {
&self.embedding
}
fn new(
entity_id: &str,
source_id: String,
embedding: Vec<f32>,
user_id: String,
entity_table: &str,
) -> Self {
let now = Utc::now();
Self {
id: entity_id.to_owned(),
created_at: now,
updated_at: now,
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
entity_id: RecordId::from_table_key(entity_table, entity_id),
embedding,
source_id,
user_id,
}
}
}
/// Get embedding by entity ID
pub async fn get_by_entity_id(
entity_id: &RecordId,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let query = format!(
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
Self::table_name()
);
let mut result = db
.client
.query(query)
.bind(("entity_id", entity_id.clone()))
.await
.map_err(AppError::from)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
Ok(embeddings.into_iter().next())
}
impl KnowledgeEntityEmbedding {
/// Get embeddings for multiple entities in batch
pub async fn get_by_entity_ids(
entity_ids: &[RecordId],
@@ -109,44 +85,6 @@ impl KnowledgeEntityEmbedding {
.map(|e| (e.entity_id.key().to_string(), e.embedding))
.collect())
}
/// Delete embedding by entity ID
pub async fn delete_by_entity_id(
entity_id: &RecordId,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE FROM {} WHERE entity_id = $entity_id",
Self::table_name()
);
db.client
.query(query)
.bind(("entity_id", entity_id.clone()))
.await
.map_err(AppError::from)?
.check()
.map_err(AppError::from)?;
Ok(())
}
/// Delete all embeddings with the given denormalized `source_id`.
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE FROM {} WHERE source_id = $source_id",
Self::table_name()
);
db.client
.query(query)
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::from)?
.check()
.map_err(AppError::from)?;
Ok(())
}
}
#[cfg(test)]
@@ -184,6 +122,7 @@ mod tests {
"source-1".to_owned(),
vec![0.1, 0.2],
"user-1".to_owned(),
KnowledgeEntity::table_name(),
);
assert_eq!(emb.id, "entity-abc");
}
@@ -211,7 +150,7 @@ mod tests {
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
let fetched = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
.await
.with_context(|| "Failed to get embedding by entity_id".to_string())?
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
@@ -240,16 +179,16 @@ mod tests {
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
let existing = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
.await
.with_context(|| "Failed to get embedding before delete".to_string())?;
assert!(existing.is_some());
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
KnowledgeEntityEmbedding::delete_by_record_id(&db, &entity_rid)
.await
.with_context(|| "Failed to delete by entity_id".to_string())?;
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
let after = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
.await
.with_context(|| "Failed to get embedding after delete".to_string())?;
assert!(after.is_none());
@@ -277,7 +216,7 @@ mod tests {
assert!(stored_entity.is_some());
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
let stored_embedding = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
.await
.with_context(|| "Failed to fetch embedding".to_string())?;
let stored_embedding =
@@ -319,9 +258,14 @@ mod tests {
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], 3, &db)
.await
.with_context(|| "Failed to store entity with embedding".to_string())?;
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], 3, &db)
.await
.with_context(|| "Failed to store entity with embedding".to_string())?;
KnowledgeEntity::store_with_embedding(
entity_other.clone(),
vec![3.0_f32, 3.1, 3.2],
3,
&db,
)
.await
.with_context(|| "Failed to store entity with embedding".to_string())?;
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
@@ -332,18 +276,18 @@ mod tests {
.with_context(|| "Failed to delete by source_id".to_string())?;
assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity1_rid)
.await
.with_context(|| "get entity1 embedding after delete".to_string())?
.is_none()
);
assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity2_rid)
.await
.with_context(|| "get entity2 embedding after delete".to_string())?
.is_none()
);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &other_rid)
.await
.with_context(|| "get other embedding after delete".to_string())?
.is_some());
@@ -450,6 +394,7 @@ mod tests {
source_id.to_owned(),
vec![0.0, 1.0, 0.0],
user_id.to_owned(),
KnowledgeEntity::table_name(),
);
db.upsert_item(replacement)
.await
+130
View File
@@ -1,4 +1,5 @@
#![allow(clippy::unsafe_derive_deserialize)]
#![allow(async_fn_in_trait)]
use serde::{Deserialize, Serialize};
pub mod analytics;
pub mod conversation;
@@ -22,6 +23,135 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
fn id(&self) -> &str;
}
/// An entity that has an associated embedding record for vector search.
pub trait HasEmbedding: StoredObject {
/// The embedding record type paired with this entity.
type Embedding: EmbeddingRecord;
fn source_id(&self) -> &str;
fn user_id(&self) -> &str;
}
/// An embedding record linked to a `HasEmbedding` entity.
pub trait EmbeddingRecord: StoredObject {
/// The field name in the embedding table that links back to the entity
/// (e.g. `"entity_id"` or `"chunk_id"`). Used in FETCH and WHERE clauses.
fn link_field() -> &'static str;
/// The HNSW index name (e.g. `"idx_embedding_knowledge_entity_embedding"`).
fn index_name() -> &'static str;
fn source_id(&self) -> &str;
fn user_id(&self) -> &str;
fn embedding(&self) -> &[f32];
/// Construct a new embedding record.
///
/// * `id` shared record id (same as the entity id).
/// * `source_id` denormalised source id for bulk deletes.
/// * `embedding` the embedding vector.
/// * `user_id` denormalised user id for query scoping.
/// * `entity_table` the entity's table name (used to build the link `RecordId`).
fn new(
id: &str,
source_id: String,
embedding: Vec<f32>,
user_id: String,
entity_table: &str,
) -> Self;
/// Validate that an embedding vector matches the expected dimension.
fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), crate::error::AppError>
where
Self: Sized,
{
if embedding.len() != expected {
return Err(crate::error::AppError::Validation(format!(
"embedding dimension mismatch: got {}, expected {expected}",
embedding.len()
)));
}
Ok(())
}
/// Recreate the HNSW vector index with a new dimension.
///
/// This drops and recreates the index inside a transaction.
async fn redefine_hnsw_index(
db: &crate::storage::db::SurrealDbClient,
dimension: usize,
) -> Result<(), crate::error::AppError>
where
Self: Sized,
{
let query = crate::storage::indexes::hnsw_index_redefine_transaction_sql(
Self::index_name(),
Self::table_name(),
dimension,
);
db.client.query(query).await?.check()?;
Ok(())
}
/// Fetch a single embedding record by its link `RecordId`.
async fn get_by_record_id(
db: &crate::storage::db::SurrealDbClient,
rid: &surrealdb::RecordId,
) -> Result<Option<Self>, crate::error::AppError>
where
Self: Sized + serde::de::DeserializeOwned,
{
let query = format!(
"SELECT * FROM {} WHERE {} = $rid LIMIT 1",
Self::table_name(),
Self::link_field(),
);
let mut result = db.client.query(query).bind(("rid", rid.clone())).await?;
Ok(result.take(0)?)
}
/// Delete an embedding record by its link `RecordId`.
async fn delete_by_record_id(
db: &crate::storage::db::SurrealDbClient,
rid: &surrealdb::RecordId,
) -> Result<(), crate::error::AppError>
where
Self: Sized,
{
let query = format!(
"DELETE FROM {} WHERE {} = $rid",
Self::table_name(),
Self::link_field(),
);
db.client
.query(query)
.bind(("rid", rid.clone()))
.await?
.check()?;
Ok(())
}
/// Delete all embedding records with a given `source_id`.
async fn delete_by_source_id(
source_id: &str,
db: &crate::storage::db::SurrealDbClient,
) -> Result<(), crate::error::AppError>
where
Self: Sized,
{
let query = format!(
"DELETE FROM {} WHERE source_id = $source_id",
Self::table_name(),
);
db.client
.query(query)
.bind(("source_id", source_id.to_owned()))
.await?
.check()?;
Ok(())
}
}
#[macro_export]
macro_rules! stored_object {
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
+2 -4
View File
@@ -910,13 +910,11 @@ mod tests {
db.apply_migrations().await.context("migrations")?;
assert!(
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a")
.await?,
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a").await?,
"first lease claim should succeed"
);
assert!(
!SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b")
.await?,
!SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?,
"second lease claim should fail while lease is held"
);
+27 -97
View File
@@ -3,11 +3,13 @@ use std::collections::HashMap;
use std::fmt::Write;
use crate::storage::indexes::hnsw_index_overwrite_sql;
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use crate::storage::types::{
text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord, HasEmbedding,
};
use crate::utils::embedding::RE_EMBED_BATCH_SIZE;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use tracing::{error, info, warn};
use tracing::{error, info};
use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
@@ -24,6 +26,18 @@ pub struct TextChunkSearchResult {
pub score: f32,
}
impl HasEmbedding for TextChunk {
type Embedding = TextChunkEmbedding;
fn source_id(&self) -> &str {
&self.source_id
}
fn user_id(&self) -> &str {
&self.user_id
}
}
impl TextChunk {
#[must_use]
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
@@ -40,25 +54,9 @@ impl TextChunk {
pub async fn delete_by_source_id(
source_id: &str,
db_client: &SurrealDbClient,
db: &SurrealDbClient,
) -> Result<(), AppError> {
db_client
.client
.query("BEGIN TRANSACTION;")
.query(format!(
"DELETE FROM {} WHERE source_id = $source_id;",
TextChunkEmbedding::table_name()
))
.query("DELETE FROM type::table($table) WHERE source_id = $source_id;")
.query("COMMIT TRANSACTION;")
.bind(("source_id", source_id.to_owned()))
.bind(("table", Self::table_name()))
.await
.map_err(AppError::from)?
.check()
.map_err(AppError::from)?;
Ok(())
db.delete_by_source_id::<Self>(source_id).await
}
/// Atomically store one text chunk and its embedding (single-record path).
@@ -70,38 +68,8 @@ impl TextChunk {
embedding_dimensions: usize,
db: &SurrealDbClient,
) -> Result<(), AppError> {
TextChunkEmbedding::validate_dimension(&embedding, embedding_dimensions)?;
let chunk_id = chunk.id.clone();
let emb = TextChunkEmbedding::new(
&chunk_id,
chunk.source_id.clone(),
embedding,
chunk.user_id.clone(),
);
let query = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
UPSERT type::thing('{emb_table}', $chunk_id) CONTENT $emb;
COMMIT TRANSACTION;
",
chunk_table = Self::table_name(),
emb_table = TextChunkEmbedding::table_name(),
);
db.client
.query(query)
.bind(("chunk_id", chunk_id))
.bind(("chunk", chunk))
.bind(("emb", emb))
db.store_with_embedding(chunk, embedding, embedding_dimensions)
.await
.map_err(AppError::from)?
.check()
.map_err(AppError::from)?;
Ok(())
}
/// Vector search over text chunks using the embedding table, fetching full chunk rows and scores.
@@ -111,52 +79,14 @@ impl TextChunk {
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<TextChunkSearchResult>, AppError> {
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct Row {
chunk_id: Option<TextChunk>,
score: f32,
}
let sql = format!(
r#"
SELECT
chunk_id,
vector::similarity::cosine(embedding, $embedding) AS score
FROM {emb_table}
WHERE user_id = $user_id
AND embedding <|{take},100|> $embedding
ORDER BY score DESC
LIMIT {take}
FETCH chunk_id;
"#,
emb_table = TextChunkEmbedding::table_name(),
take = take
);
let mut response = db
.query(&sql)
.bind(("embedding", query_embedding.to_vec()))
.bind(("user_id", user_id.to_string()))
db.vector_search::<Self, TextChunkEmbedding>(take, query_embedding, user_id)
.await
.map_err(AppError::from)?;
response = response.check().map_err(AppError::from)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
Ok(rows
.into_iter()
.filter_map(|r| {
r.chunk_id.map(|chunk| TextChunkSearchResult {
chunk,
score: r.score,
}).or_else(|| {
warn!("vector search hit orphaned text_chunk_embedding row with missing chunk");
None
})
.map(|results| {
results
.into_iter()
.map(|(chunk, score)| TextChunkSearchResult { chunk, score })
.collect()
})
.collect())
}
/// Full-text search over text chunks using the BM25 FTS index.
@@ -645,7 +575,7 @@ mod tests {
assert_eq!(stored_chunk.user_id, user_id);
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
.await
.with_context(|| "get embedding".to_string())?
.with_context(|| "expected embedding".to_string())?;
@@ -695,7 +625,7 @@ mod tests {
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
.await
.with_context(|| "get embedding".to_string())?
.with_context(|| "embedding should exist".to_string())?;
+75 -143
View File
@@ -1,11 +1,9 @@
use surrealdb::RecordId;
use crate::storage::types::text_chunk::TextChunk;
use crate::{
error::AppError,
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
stored_object,
};
use crate::{storage::types::EmbeddingRecord, stored_object};
#[cfg(test)]
use crate::error::AppError;
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
/// Record link to the owning text_chunk
@@ -18,123 +16,46 @@ stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
user_id: String
});
impl TextChunkEmbedding {
/// Recreate the HNSW index with a new embedding dimension.
///
/// This is useful when the embedding length changes; Surreal requires the
/// index definition to be recreated with the updated dimension.
pub async fn redefine_hnsw_index(
db: &SurrealDbClient,
dimension: usize,
) -> Result<(), AppError> {
let query = hnsw_index_redefine_transaction_sql(
"idx_embedding_text_chunk_embedding",
Self::table_name(),
dimension,
);
let res = db.client.query(query).await.map_err(AppError::from)?;
res.check().map_err(AppError::from)?;
Ok(())
impl EmbeddingRecord for TextChunkEmbedding {
fn link_field() -> &'static str {
"chunk_id"
}
/// Validates that an embedding vector matches the configured HNSW dimension.
#[allow(clippy::result_large_err)]
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
if embedding.len() != expected {
return Err(AppError::Validation(format!(
"embedding dimension mismatch: got {}, expected {expected}",
embedding.len()
)));
}
Ok(())
fn index_name() -> &'static str {
"idx_embedding_text_chunk_embedding"
}
/// Create a new text chunk embedding.
///
/// The embedding record id equals `chunk_id` so each chunk has at most one embedding row.
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid".
#[must_use]
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
fn source_id(&self) -> &str {
&self.source_id
}
fn user_id(&self) -> &str {
&self.user_id
}
fn embedding(&self) -> &[f32] {
&self.embedding
}
fn new(
chunk_id: &str,
source_id: String,
embedding: Vec<f32>,
user_id: String,
entity_table: &str,
) -> Self {
let now = Utc::now();
Self {
id: chunk_id.to_owned(),
created_at: now,
updated_at: now,
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
chunk_id: RecordId::from_table_key(entity_table, chunk_id),
source_id,
embedding,
user_id,
}
}
/// Get a single embedding by its chunk RecordId
pub async fn get_by_chunk_id(
chunk_id: &RecordId,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let query = format!(
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
Self::table_name()
);
let mut result = db
.client
.query(query)
.bind(("chunk_id", chunk_id.clone()))
.await
.map_err(AppError::from)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
Ok(embeddings.into_iter().next())
}
/// Delete embeddings for a given chunk RecordId
pub async fn delete_by_chunk_id(
chunk_id: &RecordId,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE FROM {} WHERE chunk_id = $chunk_id",
Self::table_name()
);
db.client
.query(query)
.bind(("chunk_id", chunk_id.clone()))
.await
.map_err(AppError::from)?
.check()
.map_err(AppError::from)?;
Ok(())
}
/// Delete all embeddings that belong to chunks with a given `source_id`
///
/// This uses the denormalized `source_id` on the embedding table.
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE FROM {} WHERE source_id = $source_id",
Self::table_name()
);
db.client
.query(query)
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::from)?
.check()
.map_err(AppError::from)?;
Ok(())
}
}
#[cfg(test)]
@@ -144,8 +65,31 @@ mod tests {
use super::*;
use crate::storage::db::SurrealDbClient;
use crate::storage::types::text_chunk::TextChunk;
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
use surrealdb::Value as SurrealValue;
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.with_context(|| "info query failed".to_string())?;
let info: surrealdb::Value = info_res
.take(0)
.with_context(|| "failed to take info result".to_string())?;
let info_json: serde_json::Value = serde_json::to_value(info)
.with_context(|| "failed to convert info to json".to_string())?;
let idx_sql = info_json
.get("Object")
.and_then(|v| v.get("indexes"))
.and_then(|v| v.get("Object"))
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
.and_then(|v| v.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
Ok(idx_sql)
}
async fn create_text_chunk_with_id(
db: &SurrealDbClient,
@@ -169,29 +113,6 @@ mod tests {
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
}
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.with_context(|| "info query failed".to_string())?;
let info: SurrealValue = info_res
.take(0)
.with_context(|| "failed to take info result".to_string())?;
let info_json: serde_json::Value = serde_json::to_value(info)
.with_context(|| "failed to convert info to json".to_string())?;
let idx_sql = info_json
.get("Object")
.and_then(|v| v.get("indexes"))
.and_then(|v| v.get("Object"))
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
.and_then(|v| v.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
Ok(idx_sql)
}
#[test]
fn new_uses_chunk_id_as_record_id() {
let emb = TextChunkEmbedding::new(
@@ -199,6 +120,7 @@ mod tests {
"source-1".to_owned(),
vec![0.1, 0.2],
"user-1".to_owned(),
TextChunk::table_name(),
);
assert_eq!(emb.id, "chunk-abc");
}
@@ -226,13 +148,14 @@ mod tests {
source_id.to_string(),
embedding_vec.clone(),
user_id.to_string(),
TextChunk::table_name(),
);
db.upsert_item(emb)
.await
.with_context(|| "Failed to store embedding".to_string())?;
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
let fetched = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
.await
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
.with_context(|| "Expected an embedding to be found".to_string())?;
@@ -259,22 +182,23 @@ mod tests {
source_id.to_string(),
vec![0.4_f32, 0.5, 0.6],
user_id.to_string(),
TextChunk::table_name(),
);
db.upsert_item(emb)
.await
.with_context(|| "Failed to store embedding".to_string())?;
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
let existing = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
.await
.with_context(|| "Failed to get embedding before delete".to_string())?;
assert!(existing.is_some(), "Embedding should exist before delete");
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
TextChunkEmbedding::delete_by_record_id(&db, &chunk_rid)
.await
.with_context(|| "Failed to delete by chunk_id".to_string())?;
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
let after = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
.await
.with_context(|| "Failed to get embedding after delete".to_string())?;
assert!(after.is_none(), "Embedding should have been deleted");
@@ -299,21 +223,27 @@ mod tests {
("chunk-s2", source_id, vec![0.2]),
("chunk-other", other_source, vec![0.3]),
] {
let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
let emb = TextChunkEmbedding::new(
key,
src.to_string(),
vec,
user_id.to_string(),
TextChunk::table_name(),
);
db.upsert_item(emb)
.await
.with_context(|| format!("store embedding for {key}"))?;
}
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
.await
.with_context(|| "get chunk1".to_string())?
.is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
.await
.with_context(|| "get chunk2".to_string())?
.is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
.await
.with_context(|| "get chunk_other".to_string())?
.is_some());
@@ -322,15 +252,15 @@ mod tests {
.await
.with_context(|| "Failed to delete by source_id".to_string())?;
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
.await
.with_context(|| "check chunk1".to_string())?
.is_none());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
.await
.with_context(|| "check chunk2".to_string())?
.is_none());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
.await
.with_context(|| "check chunk_other".to_string())?
.is_some());
@@ -352,6 +282,7 @@ mod tests {
source_id.to_owned(),
vec![1.0_f32, 0.0, 0.0],
user_id.to_owned(),
TextChunk::table_name(),
);
db.upsert_item(initial)
.await
@@ -362,6 +293,7 @@ mod tests {
source_id.to_owned(),
vec![0.0, 1.0, 0.0],
user_id.to_owned(),
TextChunk::table_name(),
);
db.upsert_item(replacement)
.await
+1 -1
View File
@@ -9,7 +9,7 @@ use crate::storage::{
indexes::{ensure_runtime, rebuild},
types::{
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
text_chunk_embedding::TextChunkEmbedding,
text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord,
},
};