mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-28 04:46:35 +02:00
chore: technical maintenance, reduced duplication
This commit is contained in:
+153
-93
@@ -1,9 +1,11 @@
|
||||
use super::types::StoredObject;
|
||||
use super::types::{EmbeddingRecord, HasEmbedding, StoredObject};
|
||||
use crate::error::AppError;
|
||||
use axum_session::{SessionConfig, SessionError, SessionStore};
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::Stream;
|
||||
use include_dir::{include_dir, Dir};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
use surrealdb::{
|
||||
engine::any::{connect, Any},
|
||||
@@ -26,20 +28,6 @@ pub trait ProvidesDb {
|
||||
}
|
||||
|
||||
impl SurrealDbClient {
|
||||
/// Initialize a new database client.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
|
||||
/// * `username` — Root username for authentication.
|
||||
/// * `password` — Root password for authentication.
|
||||
/// * `namespace` — SurrealDB namespace to use.
|
||||
/// * `database` — SurrealDB database to use.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
|
||||
/// In-memory (`mem://`) connections skip authentication.
|
||||
pub async fn new(
|
||||
address: &str,
|
||||
username: &str,
|
||||
@@ -49,30 +37,15 @@ impl SurrealDbClient {
|
||||
) -> Result<Self, Error> {
|
||||
let db = connect(address).await?;
|
||||
|
||||
// Skip sign-in for in-memory engine (no auth support)
|
||||
if !address.starts_with("mem://") {
|
||||
db.signin(Root { username, password }).await?;
|
||||
}
|
||||
|
||||
// Set namespace
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
/// Initialize a new database client using namespace-level authentication.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` — Database connection string.
|
||||
/// * `namespace` — SurrealDB namespace to use (also used for auth).
|
||||
/// * `username` — Namespace username for authentication.
|
||||
/// * `password` — Namespace password for authentication.
|
||||
/// * `database` — SurrealDB database to use.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
|
||||
pub async fn new_with_namespace_user(
|
||||
address: &str,
|
||||
namespace: &str,
|
||||
@@ -91,11 +64,6 @@ impl SurrealDbClient {
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
/// Create an Axum session store backed by SurrealDB.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `SessionError` if the session store configuration or table creation fails.
|
||||
pub async fn create_session_store(
|
||||
&self,
|
||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||
@@ -109,15 +77,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Applies all pending database migrations found in the embedded MIGRATIONS_DIR.
|
||||
///
|
||||
/// This function should be called during application startup, after connecting to
|
||||
/// the database and selecting the appropriate namespace and database, but before
|
||||
/// the application starts performing operations that rely on the schema.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
|
||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||
debug!("Applying migrations");
|
||||
MigrationRunner::new(&self.client)
|
||||
@@ -129,15 +88,6 @@ impl SurrealDbClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store an object in SurrealDB.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` — The item to store. Must implement `StoredObject`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database create operation fails.
|
||||
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
@@ -148,13 +98,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
|
||||
///
|
||||
/// Useful when a single record should be replaced by id (admin updates, embedding rows, etc.).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database upsert operation fails.
|
||||
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
@@ -166,11 +109,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retrieve all objects from a table.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database select operation fails.
|
||||
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -178,16 +116,6 @@ impl SurrealDbClient {
|
||||
self.client.select(T::table_name()).await
|
||||
}
|
||||
|
||||
/// Retrieve a single object by its ID.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` — The ID of the item to retrieve.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database select operation fails.
|
||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -195,16 +123,6 @@ impl SurrealDbClient {
|
||||
self.client.select((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Delete a single object by its ID.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` — The ID of the item to delete.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database delete operation fails.
|
||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -212,11 +130,6 @@ impl SurrealDbClient {
|
||||
self.client.delete((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Listen to a table for real-time updates via a live query stream.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database live query subscription fails.
|
||||
pub async fn listen<T>(
|
||||
&self,
|
||||
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
||||
@@ -225,6 +138,156 @@ impl SurrealDbClient {
|
||||
{
|
||||
self.client.select(T::table_name()).live().await
|
||||
}
|
||||
|
||||
/// Atomically store an entity and its embedding vector in a single
|
||||
/// SurrealDB transaction.
|
||||
///
|
||||
/// Creates (or overwrites) the entity row and upserts the linked
|
||||
/// embedding record. The embedding dimension is validated against
|
||||
/// `embedding_dimensions` before the query is issued.
|
||||
pub async fn store_with_embedding<E>(
|
||||
&self,
|
||||
entity: E,
|
||||
embedding: Vec<f32>,
|
||||
embedding_dimensions: usize,
|
||||
) -> Result<(), AppError>
|
||||
where
|
||||
E: HasEmbedding + Serialize + Send + Sync + 'static,
|
||||
<E as HasEmbedding>::Embedding: Serialize + Send + Sync,
|
||||
{
|
||||
E::Embedding::validate_dimension(&embedding, embedding_dimensions)?;
|
||||
|
||||
let entity_id = entity.id().to_string();
|
||||
let emb = <E as HasEmbedding>::Embedding::new(
|
||||
&entity_id,
|
||||
entity.source_id().to_string(),
|
||||
embedding,
|
||||
entity.user_id().to_string(),
|
||||
E::table_name(),
|
||||
);
|
||||
|
||||
let sql = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{et}', $id) CONTENT $entity;
|
||||
UPSERT type::thing('{emt}', $id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
et = E::table_name(),
|
||||
emt = <E as HasEmbedding>::Embedding::table_name(),
|
||||
);
|
||||
|
||||
self.client
|
||||
.query(sql)
|
||||
.bind(("id", entity_id))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb", emb))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all entity and embedding rows matching a given `source_id`.
|
||||
///
|
||||
/// Runs inside a SurrealDB transaction so that entity and embedding
|
||||
/// deletes are atomic.
|
||||
pub async fn delete_by_source_id<E>(&self, source_id: &str) -> Result<(), AppError>
|
||||
where
|
||||
E: HasEmbedding,
|
||||
E::Embedding: Send + Sync,
|
||||
{
|
||||
self.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
E::Embedding::table_name()
|
||||
))
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
E::table_name()
|
||||
))
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector similarity search over entities using HNSW index.
|
||||
///
|
||||
/// Performs a cosine-similarity search against the embedding table,
|
||||
/// fetches the corresponding entity rows server-side via `FETCH`,
|
||||
/// and returns `(entity, score)` pairs ordered by descending
|
||||
/// similarity. Orphaned embeddings (entity deleted but its
|
||||
/// embedding row remains) are logged as a warning and dropped.
|
||||
///
|
||||
/// This is a single round-trip — SurrealDB resolves the link field
|
||||
/// (`entity_id` or `chunk_id`) inside the query engine.
|
||||
pub async fn vector_search<E, Emb>(
|
||||
&self,
|
||||
take: usize,
|
||||
query_embedding: &[f32],
|
||||
user_id: &str,
|
||||
) -> Result<Vec<(E, f32)>, AppError>
|
||||
where
|
||||
E: StoredObject + DeserializeOwned + Clone + Send + Sync,
|
||||
Emb: EmbeddingRecord + Send + Sync,
|
||||
{
|
||||
// Generic row that works with both `entity_id` and `chunk_id` link
|
||||
// fields via `#[serde(alias)]`. SurrealDB's `FETCH` resolves the link
|
||||
// server-side so we get the full entity in a single round-trip.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct FetchRow<Ent> {
|
||||
score: f32,
|
||||
#[serde(alias = "entity_id", alias = "chunk_id")]
|
||||
entity: Option<Ent>,
|
||||
}
|
||||
|
||||
let link_field = Emb::link_field();
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
{link_field},
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH {link_field}
|
||||
"#,
|
||||
link_field = link_field,
|
||||
emb_table = Emb::table_name(),
|
||||
take = take,
|
||||
);
|
||||
|
||||
let mut response = self
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?;
|
||||
|
||||
response = response.check()?;
|
||||
|
||||
let rows: Vec<FetchRow<E>> = response.take(0)?;
|
||||
|
||||
let mut results = Vec::with_capacity(rows.len());
|
||||
for r in rows {
|
||||
if let Some(entity) = r.entity {
|
||||
results.push((entity, r.score));
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Vector search hit orphaned {} row with missing {link_field}",
|
||||
Emb::table_name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SurrealDbClient {
|
||||
@@ -237,12 +300,9 @@ impl Deref for SurrealDbClient {
|
||||
|
||||
#[cfg(any(test, feature = "test-utils"))]
|
||||
impl SurrealDbClient {
|
||||
/// Create an in-memory SurrealDB client for testing.
|
||||
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
|
||||
let db = connect("mem://").await?;
|
||||
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user