chore: technical maintenance, reduced duplication

This commit is contained in:
Per Stark
2026-06-18 14:58:13 +02:00
parent fb51a8b55f
commit b3d42d2586
12 changed files with 546 additions and 561 deletions
+153 -93
View File
@@ -1,9 +1,11 @@
use super::types::StoredObject;
use super::types::{EmbeddingRecord, HasEmbedding, StoredObject};
use crate::error::AppError;
use axum_session::{SessionConfig, SessionError, SessionStore};
use axum_session_surreal::SessionSurrealPool;
use futures::Stream;
use include_dir::{include_dir, Dir};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::{ops::Deref, sync::Arc};
use surrealdb::{
engine::any::{connect, Any},
@@ -26,20 +28,6 @@ pub trait ProvidesDb {
}
impl SurrealDbClient {
/// Initialize a new database client.
///
/// # Arguments
///
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
/// * `username` — Root username for authentication.
/// * `password` — Root password for authentication.
/// * `namespace` — SurrealDB namespace to use.
/// * `database` — SurrealDB database to use.
///
/// # Errors
///
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
/// In-memory (`mem://`) connections skip authentication.
pub async fn new(
address: &str,
username: &str,
@@ -49,30 +37,15 @@ impl SurrealDbClient {
) -> Result<Self, Error> {
let db = connect(address).await?;
// Skip sign-in for in-memory engine (no auth support)
if !address.starts_with("mem://") {
db.signin(Root { username, password }).await?;
}
// Set namespace
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
/// Initialize a new database client using namespace-level authentication.
///
/// # Arguments
///
/// * `address` — Database connection string.
/// * `namespace` — SurrealDB namespace to use (also used for auth).
/// * `username` — Namespace username for authentication.
/// * `password` — Namespace password for authentication.
/// * `database` — SurrealDB database to use.
///
/// # Errors
///
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
pub async fn new_with_namespace_user(
address: &str,
namespace: &str,
@@ -91,11 +64,6 @@ impl SurrealDbClient {
Ok(SurrealDbClient { client: db })
}
/// Create an Axum session store backed by SurrealDB.
///
/// # Errors
///
/// Returns `SessionError` if the session store configuration or table creation fails.
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
@@ -109,15 +77,6 @@ impl SurrealDbClient {
.await
}
/// Applies all pending database migrations found in the embedded MIGRATIONS_DIR.
///
/// This function should be called during application startup, after connecting to
/// the database and selecting the appropriate namespace and database, but before
/// the application starts performing operations that rely on the schema.
///
/// # Errors
///
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
pub async fn apply_migrations(&self) -> Result<(), AppError> {
debug!("Applying migrations");
MigrationRunner::new(&self.client)
@@ -129,15 +88,6 @@ impl SurrealDbClient {
Ok(())
}
/// Store an object in SurrealDB.
///
/// # Arguments
///
/// * `item` — The item to store. Must implement `StoredObject`.
///
/// # Errors
///
/// Returns `Err` if the database create operation fails.
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
@@ -148,13 +98,6 @@ impl SurrealDbClient {
.await
}
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
///
/// Useful when a single record should be replaced by id (admin updates, embedding rows, etc.).
///
/// # Errors
///
/// Returns `Err` if the database upsert operation fails.
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
@@ -166,11 +109,6 @@ impl SurrealDbClient {
.await
}
/// Retrieve all objects from a table.
///
/// # Errors
///
/// Returns `Err` if the database select operation fails.
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
where
T: for<'de> StoredObject,
@@ -178,16 +116,6 @@ impl SurrealDbClient {
self.client.select(T::table_name()).await
}
/// Retrieve a single object by its ID.
///
/// # Arguments
///
/// * `id` — The ID of the item to retrieve.
///
/// # Errors
///
/// Returns `Err` if the database select operation fails.
/// Returns `Ok(None)` if no record with the given ID exists.
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
@@ -195,16 +123,6 @@ impl SurrealDbClient {
self.client.select((T::table_name(), id)).await
}
/// Delete a single object by its ID.
///
/// # Arguments
///
/// * `id` — The ID of the item to delete.
///
/// # Errors
///
/// Returns `Err` if the database delete operation fails.
/// Returns `Ok(None)` if no record with the given ID exists.
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
where
T: for<'de> StoredObject,
@@ -212,11 +130,6 @@ impl SurrealDbClient {
self.client.delete((T::table_name(), id)).await
}
/// Listen to a table for real-time updates via a live query stream.
///
/// # Errors
///
/// Returns `Err` if the database live query subscription fails.
pub async fn listen<T>(
&self,
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
@@ -225,6 +138,156 @@ impl SurrealDbClient {
{
self.client.select(T::table_name()).live().await
}
/// Atomically store an entity and its embedding vector in a single
/// SurrealDB transaction.
///
/// Creates (or overwrites) the entity row and upserts the linked
/// embedding record. The embedding dimension is validated against
/// `embedding_dimensions` before the query is issued.
pub async fn store_with_embedding<E>(
&self,
entity: E,
embedding: Vec<f32>,
embedding_dimensions: usize,
) -> Result<(), AppError>
where
E: HasEmbedding + Serialize + Send + Sync + 'static,
<E as HasEmbedding>::Embedding: Serialize + Send + Sync,
{
E::Embedding::validate_dimension(&embedding, embedding_dimensions)?;
let entity_id = entity.id().to_string();
let emb = <E as HasEmbedding>::Embedding::new(
&entity_id,
entity.source_id().to_string(),
embedding,
entity.user_id().to_string(),
E::table_name(),
);
let sql = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{et}', $id) CONTENT $entity;
UPSERT type::thing('{emt}', $id) CONTENT $emb;
COMMIT TRANSACTION;
",
et = E::table_name(),
emt = <E as HasEmbedding>::Embedding::table_name(),
);
self.client
.query(sql)
.bind(("id", entity_id))
.bind(("entity", entity))
.bind(("emb", emb))
.await?
.check()?;
Ok(())
}
/// Delete all entity and embedding rows matching a given `source_id`.
///
/// Runs inside a SurrealDB transaction so that entity and embedding
/// deletes are atomic.
pub async fn delete_by_source_id<E>(&self, source_id: &str) -> Result<(), AppError>
where
E: HasEmbedding,
E::Embedding: Send + Sync,
{
self.client
.query("BEGIN TRANSACTION;")
.query(format!(
"DELETE FROM {} WHERE source_id = $source_id;",
E::Embedding::table_name()
))
.query(format!(
"DELETE FROM {} WHERE source_id = $source_id;",
E::table_name()
))
.query("COMMIT TRANSACTION;")
.bind(("source_id", source_id.to_owned()))
.await?
.check()?;
Ok(())
}
/// Vector similarity search over entities using HNSW index.
///
/// Performs a cosine-similarity search against the embedding table,
/// fetches the corresponding entity rows server-side via `FETCH`,
/// and returns `(entity, score)` pairs ordered by descending
/// similarity. Orphaned embeddings (entity deleted but its
/// embedding row remains) are logged as a warning and dropped.
///
/// This is a single round-trip — SurrealDB resolves the link field
/// (`entity_id` or `chunk_id`) inside the query engine.
pub async fn vector_search<E, Emb>(
&self,
take: usize,
query_embedding: &[f32],
user_id: &str,
) -> Result<Vec<(E, f32)>, AppError>
where
E: StoredObject + DeserializeOwned + Clone + Send + Sync,
Emb: EmbeddingRecord + Send + Sync,
{
// Generic row that works with both `entity_id` and `chunk_id` link
// fields via `#[serde(alias)]`. SurrealDB's `FETCH` resolves the link
// server-side so we get the full entity in a single round-trip.
#[derive(serde::Deserialize)]
struct FetchRow<Ent> {
score: f32,
#[serde(alias = "entity_id", alias = "chunk_id")]
entity: Option<Ent>,
}
let link_field = Emb::link_field();
let sql = format!(
r#"
SELECT
{link_field},
vector::similarity::cosine(embedding, $embedding) AS score
FROM {emb_table}
WHERE user_id = $user_id
AND embedding <|{take},100|> $embedding
ORDER BY score DESC
LIMIT {take}
FETCH {link_field}
"#,
link_field = link_field,
emb_table = Emb::table_name(),
take = take,
);
let mut response = self
.client
.query(sql)
.bind(("embedding", query_embedding.to_vec()))
.bind(("user_id", user_id.to_string()))
.await?;
response = response.check()?;
let rows: Vec<FetchRow<E>> = response.take(0)?;
let mut results = Vec::with_capacity(rows.len());
for r in rows {
if let Some(entity) = r.entity {
results.push((entity, r.score));
} else {
tracing::warn!(
"Vector search hit orphaned {} row with missing {link_field}",
Emb::table_name()
);
}
}
Ok(results)
}
}
impl Deref for SurrealDbClient {
@@ -237,12 +300,9 @@ impl Deref for SurrealDbClient {
#[cfg(any(test, feature = "test-utils"))]
impl SurrealDbClient {
/// Create an in-memory SurrealDB client for testing.
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
let db = connect("mem://").await?;
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
}