mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-24 19:06:30 +02:00
chore: technical maintenance, reduced duplication
This commit is contained in:
+153
-93
@@ -1,9 +1,11 @@
|
||||
use super::types::StoredObject;
|
||||
use super::types::{EmbeddingRecord, HasEmbedding, StoredObject};
|
||||
use crate::error::AppError;
|
||||
use axum_session::{SessionConfig, SessionError, SessionStore};
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::Stream;
|
||||
use include_dir::{include_dir, Dir};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
use surrealdb::{
|
||||
engine::any::{connect, Any},
|
||||
@@ -26,20 +28,6 @@ pub trait ProvidesDb {
|
||||
}
|
||||
|
||||
impl SurrealDbClient {
|
||||
/// Initialize a new database client.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
|
||||
/// * `username` — Root username for authentication.
|
||||
/// * `password` — Root password for authentication.
|
||||
/// * `namespace` — SurrealDB namespace to use.
|
||||
/// * `database` — SurrealDB database to use.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
|
||||
/// In-memory (`mem://`) connections skip authentication.
|
||||
pub async fn new(
|
||||
address: &str,
|
||||
username: &str,
|
||||
@@ -49,30 +37,15 @@ impl SurrealDbClient {
|
||||
) -> Result<Self, Error> {
|
||||
let db = connect(address).await?;
|
||||
|
||||
// Skip sign-in for in-memory engine (no auth support)
|
||||
if !address.starts_with("mem://") {
|
||||
db.signin(Root { username, password }).await?;
|
||||
}
|
||||
|
||||
// Set namespace
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
/// Initialize a new database client using namespace-level authentication.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` — Database connection string.
|
||||
/// * `namespace` — SurrealDB namespace to use (also used for auth).
|
||||
/// * `username` — Namespace username for authentication.
|
||||
/// * `password` — Namespace password for authentication.
|
||||
/// * `database` — SurrealDB database to use.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
|
||||
pub async fn new_with_namespace_user(
|
||||
address: &str,
|
||||
namespace: &str,
|
||||
@@ -91,11 +64,6 @@ impl SurrealDbClient {
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
/// Create an Axum session store backed by SurrealDB.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `SessionError` if the session store configuration or table creation fails.
|
||||
pub async fn create_session_store(
|
||||
&self,
|
||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||
@@ -109,15 +77,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Applies all pending database migrations found in the embedded MIGRATIONS_DIR.
|
||||
///
|
||||
/// This function should be called during application startup, after connecting to
|
||||
/// the database and selecting the appropriate namespace and database, but before
|
||||
/// the application starts performing operations that rely on the schema.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
|
||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||
debug!("Applying migrations");
|
||||
MigrationRunner::new(&self.client)
|
||||
@@ -129,15 +88,6 @@ impl SurrealDbClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store an object in SurrealDB.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` — The item to store. Must implement `StoredObject`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database create operation fails.
|
||||
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
@@ -148,13 +98,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
|
||||
///
|
||||
/// Useful when a single record should be replaced by id (admin updates, embedding rows, etc.).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database upsert operation fails.
|
||||
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
@@ -166,11 +109,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retrieve all objects from a table.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database select operation fails.
|
||||
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -178,16 +116,6 @@ impl SurrealDbClient {
|
||||
self.client.select(T::table_name()).await
|
||||
}
|
||||
|
||||
/// Retrieve a single object by its ID.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` — The ID of the item to retrieve.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database select operation fails.
|
||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -195,16 +123,6 @@ impl SurrealDbClient {
|
||||
self.client.select((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Delete a single object by its ID.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` — The ID of the item to delete.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database delete operation fails.
|
||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -212,11 +130,6 @@ impl SurrealDbClient {
|
||||
self.client.delete((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Listen to a table for real-time updates via a live query stream.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database live query subscription fails.
|
||||
pub async fn listen<T>(
|
||||
&self,
|
||||
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
||||
@@ -225,6 +138,156 @@ impl SurrealDbClient {
|
||||
{
|
||||
self.client.select(T::table_name()).live().await
|
||||
}
|
||||
|
||||
/// Atomically store an entity and its embedding vector in a single
|
||||
/// SurrealDB transaction.
|
||||
///
|
||||
/// Creates (or overwrites) the entity row and upserts the linked
|
||||
/// embedding record. The embedding dimension is validated against
|
||||
/// `embedding_dimensions` before the query is issued.
|
||||
pub async fn store_with_embedding<E>(
|
||||
&self,
|
||||
entity: E,
|
||||
embedding: Vec<f32>,
|
||||
embedding_dimensions: usize,
|
||||
) -> Result<(), AppError>
|
||||
where
|
||||
E: HasEmbedding + Serialize + Send + Sync + 'static,
|
||||
<E as HasEmbedding>::Embedding: Serialize + Send + Sync,
|
||||
{
|
||||
E::Embedding::validate_dimension(&embedding, embedding_dimensions)?;
|
||||
|
||||
let entity_id = entity.id().to_string();
|
||||
let emb = <E as HasEmbedding>::Embedding::new(
|
||||
&entity_id,
|
||||
entity.source_id().to_string(),
|
||||
embedding,
|
||||
entity.user_id().to_string(),
|
||||
E::table_name(),
|
||||
);
|
||||
|
||||
let sql = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{et}', $id) CONTENT $entity;
|
||||
UPSERT type::thing('{emt}', $id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
et = E::table_name(),
|
||||
emt = <E as HasEmbedding>::Embedding::table_name(),
|
||||
);
|
||||
|
||||
self.client
|
||||
.query(sql)
|
||||
.bind(("id", entity_id))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb", emb))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all entity and embedding rows matching a given `source_id`.
|
||||
///
|
||||
/// Runs inside a SurrealDB transaction so that entity and embedding
|
||||
/// deletes are atomic.
|
||||
pub async fn delete_by_source_id<E>(&self, source_id: &str) -> Result<(), AppError>
|
||||
where
|
||||
E: HasEmbedding,
|
||||
E::Embedding: Send + Sync,
|
||||
{
|
||||
self.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
E::Embedding::table_name()
|
||||
))
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
E::table_name()
|
||||
))
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector similarity search over entities using HNSW index.
|
||||
///
|
||||
/// Performs a cosine-similarity search against the embedding table,
|
||||
/// fetches the corresponding entity rows server-side via `FETCH`,
|
||||
/// and returns `(entity, score)` pairs ordered by descending
|
||||
/// similarity. Orphaned embeddings (entity deleted but its
|
||||
/// embedding row remains) are logged as a warning and dropped.
|
||||
///
|
||||
/// This is a single round-trip — SurrealDB resolves the link field
|
||||
/// (`entity_id` or `chunk_id`) inside the query engine.
|
||||
pub async fn vector_search<E, Emb>(
|
||||
&self,
|
||||
take: usize,
|
||||
query_embedding: &[f32],
|
||||
user_id: &str,
|
||||
) -> Result<Vec<(E, f32)>, AppError>
|
||||
where
|
||||
E: StoredObject + DeserializeOwned + Clone + Send + Sync,
|
||||
Emb: EmbeddingRecord + Send + Sync,
|
||||
{
|
||||
// Generic row that works with both `entity_id` and `chunk_id` link
|
||||
// fields via `#[serde(alias)]`. SurrealDB's `FETCH` resolves the link
|
||||
// server-side so we get the full entity in a single round-trip.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct FetchRow<Ent> {
|
||||
score: f32,
|
||||
#[serde(alias = "entity_id", alias = "chunk_id")]
|
||||
entity: Option<Ent>,
|
||||
}
|
||||
|
||||
let link_field = Emb::link_field();
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
{link_field},
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH {link_field}
|
||||
"#,
|
||||
link_field = link_field,
|
||||
emb_table = Emb::table_name(),
|
||||
take = take,
|
||||
);
|
||||
|
||||
let mut response = self
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?;
|
||||
|
||||
response = response.check()?;
|
||||
|
||||
let rows: Vec<FetchRow<E>> = response.take(0)?;
|
||||
|
||||
let mut results = Vec::with_capacity(rows.len());
|
||||
for r in rows {
|
||||
if let Some(entity) = r.entity {
|
||||
results.push((entity, r.score));
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Vector search hit orphaned {} row with missing {link_field}",
|
||||
Emb::table_name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SurrealDbClient {
|
||||
@@ -237,12 +300,9 @@ impl Deref for SurrealDbClient {
|
||||
|
||||
#[cfg(any(test, feature = "test-utils"))]
|
||||
impl SurrealDbClient {
|
||||
/// Create an in-memory SurrealDB client for testing.
|
||||
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
|
||||
let db = connect("mem://").await?;
|
||||
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,10 +9,7 @@ use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::system_settings::SystemSettings,
|
||||
},
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
|
||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||
@@ -231,9 +228,7 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
///
|
||||
/// Returns `AppError::InternalError` if any rebuild operation fails.
|
||||
pub async fn rebuild_runtime(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_runtime_inner(db)
|
||||
.await
|
||||
.map_err(AppError::internal)
|
||||
rebuild_runtime_inner(db).await.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Returns whether a scheduled index rebuild is due based on the persisted last-run time.
|
||||
@@ -525,8 +520,7 @@ async fn rebuild_existing_index_in_place(
|
||||
if !index_exists(db, table, index_name).await? {
|
||||
debug!(
|
||||
index = index_name,
|
||||
table,
|
||||
"Skipping in-place rebuild because index is missing"
|
||||
table, "Skipping in-place rebuild because index is missing"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
@@ -1074,7 +1068,11 @@ mod tests {
|
||||
|
||||
assert!(!scheduled_index_rebuild_due(None, 86_400, now));
|
||||
assert!(!scheduled_index_rebuild_due(Some(last), 0, now));
|
||||
assert!(!scheduled_index_rebuild_due(Some(now - chrono::Duration::hours(1)), 86_400, now));
|
||||
assert!(!scheduled_index_rebuild_due(
|
||||
Some(now - chrono::Duration::hours(1)),
|
||||
86_400,
|
||||
now
|
||||
));
|
||||
assert!(scheduled_index_rebuild_due(Some(last), 86_400, now));
|
||||
}
|
||||
|
||||
@@ -1087,7 +1085,9 @@ mod tests {
|
||||
.context("in-memory db")?;
|
||||
|
||||
db.apply_migrations().await.context("migrations")?;
|
||||
ensure_runtime(&db, 8).await.context("ensure runtime indexes")?;
|
||||
ensure_runtime(&db, 8)
|
||||
.await
|
||||
.context("ensure runtime indexes")?;
|
||||
|
||||
rebuild_runtime(&db)
|
||||
.await
|
||||
|
||||
@@ -4,10 +4,13 @@ use std::fmt::Write;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::db::SurrealDbClient,
|
||||
storage::indexes::hnsw_index_overwrite_sql,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
storage::types::system_settings::SystemSettings,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::hnsw_index_overwrite_sql,
|
||||
types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
types::system_settings::SystemSettings,
|
||||
types::{EmbeddingRecord, HasEmbedding},
|
||||
},
|
||||
stored_object,
|
||||
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
|
||||
};
|
||||
@@ -70,6 +73,18 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl HasEmbedding for KnowledgeEntity {
|
||||
type Embedding = KnowledgeEntityEmbedding;
|
||||
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
impl KnowledgeEntity {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
@@ -227,22 +242,9 @@ impl KnowledgeEntity {
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
// Delete embeddings first, while we can still look them up via the entity's source_id
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
db.delete_by_source_id::<Self>(source_id).await
|
||||
}
|
||||
|
||||
/// Atomically store one knowledge entity and its embedding (single-record path).
|
||||
@@ -254,38 +256,8 @@ impl KnowledgeEntity {
|
||||
embedding_dimensions: usize,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
KnowledgeEntityEmbedding::validate_dimension(&embedding, embedding_dimensions)?;
|
||||
|
||||
let entity_id = entity.id.clone();
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
&entity_id,
|
||||
entity.source_id.clone(),
|
||||
embedding,
|
||||
entity.user_id.clone(),
|
||||
);
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
|
||||
UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
entity_table = Self::table_name(),
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb", emb))
|
||||
db.store_with_embedding(entity, embedding, embedding_dimensions)
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
||||
@@ -295,48 +267,14 @@ impl KnowledgeEntity {
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: Option<KnowledgeEntity>,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
entity_id,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH entity_id;
|
||||
"#,
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
db.vector_search::<Self, KnowledgeEntityEmbedding>(take, query_embedding, user_id)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
response = response.check().map_err(AppError::from)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
r.entity_id.map(|entity| KnowledgeEntitySearchResult {
|
||||
entity,
|
||||
score: r.score,
|
||||
})
|
||||
.map(|results| {
|
||||
results
|
||||
.into_iter()
|
||||
.map(|(entity, score)| KnowledgeEntitySearchResult { entity, score })
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
@@ -362,7 +300,13 @@ impl KnowledgeEntity {
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
|
||||
let emb = KnowledgeEntityEmbedding::new(id, entity.source_id, embedding, entity.user_id);
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
id,
|
||||
entity.source_id,
|
||||
embedding,
|
||||
entity.user_id,
|
||||
Self::table_name(),
|
||||
);
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
@@ -916,7 +860,7 @@ mod tests {
|
||||
assert_eq!(stored_embeddings.len(), 1);
|
||||
|
||||
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "fetch embedding".to_string())?;
|
||||
assert!(fetched_emb.is_some());
|
||||
@@ -999,11 +943,11 @@ mod tests {
|
||||
|
||||
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
||||
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
|
||||
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e1)
|
||||
.await
|
||||
.with_context(|| "get embedding e1".to_string())?
|
||||
.is_some());
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
|
||||
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e2)
|
||||
.await
|
||||
.with_context(|| "get embedding e2".to_string())?
|
||||
.is_some());
|
||||
|
||||
@@ -4,7 +4,7 @@ use surrealdb::RecordId;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||
storage::{db::SurrealDbClient, types::EmbeddingRecord},
|
||||
stored_object,
|
||||
};
|
||||
|
||||
@@ -17,72 +17,48 @@ stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl KnowledgeEntityEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = hnsw_index_redefine_transaction_sql(
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||
res.check().map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
impl EmbeddingRecord for KnowledgeEntityEmbedding {
|
||||
fn link_field() -> &'static str {
|
||||
"entity_id"
|
||||
}
|
||||
|
||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||
if embedding.len() != expected {
|
||||
return Err(AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
fn index_name() -> &'static str {
|
||||
"idx_embedding_knowledge_entity_embedding"
|
||||
}
|
||||
|
||||
/// Create a new knowledge entity embedding.
|
||||
///
|
||||
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
|
||||
#[must_use]
|
||||
pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
|
||||
fn embedding(&self) -> &[f32] {
|
||||
&self.embedding
|
||||
}
|
||||
|
||||
fn new(
|
||||
entity_id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: entity_id.to_owned(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
||||
entity_id: RecordId::from_table_key(entity_table, entity_id),
|
||||
embedding,
|
||||
source_id,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get embedding by entity ID
|
||||
pub async fn get_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
impl KnowledgeEntityEmbedding {
|
||||
/// Get embeddings for multiple entities in batch
|
||||
pub async fn get_by_entity_ids(
|
||||
entity_ids: &[RecordId],
|
||||
@@ -109,44 +85,6 @@ impl KnowledgeEntityEmbedding {
|
||||
.map(|e| (e.entity_id.key().to_string(), e.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Delete embedding by entity ID
|
||||
pub async fn delete_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE entity_id = $entity_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings with the given denormalized `source_id`.
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -184,6 +122,7 @@ mod tests {
|
||||
"source-1".to_owned(),
|
||||
vec![0.1, 0.2],
|
||||
"user-1".to_owned(),
|
||||
KnowledgeEntity::table_name(),
|
||||
);
|
||||
assert_eq!(emb.id, "entity-abc");
|
||||
}
|
||||
@@ -211,7 +150,7 @@ mod tests {
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
@@ -240,16 +179,16 @@ mod tests {
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let existing = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some());
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||
KnowledgeEntityEmbedding::delete_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
||||
|
||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let after = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none());
|
||||
@@ -277,7 +216,7 @@ mod tests {
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||
let stored_embedding =
|
||||
@@ -319,9 +258,14 @@ mod tests {
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(
|
||||
entity_other.clone(),
|
||||
vec![3.0_f32, 3.1, 3.2],
|
||||
3,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||
@@ -332,18 +276,18 @@ mod tests {
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity1_rid)
|
||||
.await
|
||||
.with_context(|| "get entity1 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity2_rid)
|
||||
.await
|
||||
.with_context(|| "get entity2 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &other_rid)
|
||||
.await
|
||||
.with_context(|| "get other embedding after delete".to_string())?
|
||||
.is_some());
|
||||
@@ -450,6 +394,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![0.0, 1.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
KnowledgeEntity::table_name(),
|
||||
);
|
||||
db.upsert_item(replacement)
|
||||
.await
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#![allow(clippy::unsafe_derive_deserialize)]
|
||||
#![allow(async_fn_in_trait)]
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub mod analytics;
|
||||
pub mod conversation;
|
||||
@@ -22,6 +23,135 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
||||
fn id(&self) -> &str;
|
||||
}
|
||||
|
||||
/// An entity that has an associated embedding record for vector search.
|
||||
pub trait HasEmbedding: StoredObject {
|
||||
/// The embedding record type paired with this entity.
|
||||
type Embedding: EmbeddingRecord;
|
||||
|
||||
fn source_id(&self) -> &str;
|
||||
fn user_id(&self) -> &str;
|
||||
}
|
||||
|
||||
/// An embedding record linked to a `HasEmbedding` entity.
|
||||
pub trait EmbeddingRecord: StoredObject {
|
||||
/// The field name in the embedding table that links back to the entity
|
||||
/// (e.g. `"entity_id"` or `"chunk_id"`). Used in FETCH and WHERE clauses.
|
||||
fn link_field() -> &'static str;
|
||||
|
||||
/// The HNSW index name (e.g. `"idx_embedding_knowledge_entity_embedding"`).
|
||||
fn index_name() -> &'static str;
|
||||
|
||||
fn source_id(&self) -> &str;
|
||||
fn user_id(&self) -> &str;
|
||||
fn embedding(&self) -> &[f32];
|
||||
|
||||
/// Construct a new embedding record.
|
||||
///
|
||||
/// * `id` – shared record id (same as the entity id).
|
||||
/// * `source_id` – denormalised source id for bulk deletes.
|
||||
/// * `embedding` – the embedding vector.
|
||||
/// * `user_id` – denormalised user id for query scoping.
|
||||
/// * `entity_table` – the entity's table name (used to build the link `RecordId`).
|
||||
fn new(
|
||||
id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self;
|
||||
|
||||
/// Validate that an embedding vector matches the expected dimension.
|
||||
fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
if embedding.len() != expected {
|
||||
return Err(crate::error::AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recreate the HNSW vector index with a new dimension.
|
||||
///
|
||||
/// This drops and recreates the index inside a transaction.
|
||||
async fn redefine_hnsw_index(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = crate::storage::indexes::hnsw_index_redefine_transaction_sql(
|
||||
Self::index_name(),
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
db.client.query(query).await?.check()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fetch a single embedding record by its link `RecordId`.
|
||||
async fn get_by_record_id(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
rid: &surrealdb::RecordId,
|
||||
) -> Result<Option<Self>, crate::error::AppError>
|
||||
where
|
||||
Self: Sized + serde::de::DeserializeOwned,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE {} = $rid LIMIT 1",
|
||||
Self::table_name(),
|
||||
Self::link_field(),
|
||||
);
|
||||
let mut result = db.client.query(query).bind(("rid", rid.clone())).await?;
|
||||
Ok(result.take(0)?)
|
||||
}
|
||||
|
||||
/// Delete an embedding record by its link `RecordId`.
|
||||
async fn delete_by_record_id(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
rid: &surrealdb::RecordId,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE {} = $rid",
|
||||
Self::table_name(),
|
||||
Self::link_field(),
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("rid", rid.clone()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embedding records with a given `source_id`.
|
||||
async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name(),
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! stored_object {
|
||||
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||
|
||||
@@ -910,13 +910,11 @@ mod tests {
|
||||
db.apply_migrations().await.context("migrations")?;
|
||||
|
||||
assert!(
|
||||
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a")
|
||||
.await?,
|
||||
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a").await?,
|
||||
"first lease claim should succeed"
|
||||
);
|
||||
assert!(
|
||||
!SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b")
|
||||
.await?,
|
||||
!SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?,
|
||||
"second lease claim should fail while lease is held"
|
||||
);
|
||||
|
||||
|
||||
@@ -3,11 +3,13 @@ use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::storage::indexes::hnsw_index_overwrite_sql;
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::storage::types::{
|
||||
text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord, HasEmbedding,
|
||||
};
|
||||
use crate::utils::embedding::RE_EMBED_BATCH_SIZE;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TextChunk, "text_chunk", {
|
||||
@@ -24,6 +26,18 @@ pub struct TextChunkSearchResult {
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl HasEmbedding for TextChunk {
|
||||
type Embedding = TextChunkEmbedding;
|
||||
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
impl TextChunk {
|
||||
#[must_use]
|
||||
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
|
||||
@@ -40,25 +54,9 @@ impl TextChunk {
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
db_client
|
||||
.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
TextChunkEmbedding::table_name()
|
||||
))
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id;")
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.bind(("table", Self::table_name()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
db.delete_by_source_id::<Self>(source_id).await
|
||||
}
|
||||
|
||||
/// Atomically store one text chunk and its embedding (single-record path).
|
||||
@@ -70,38 +68,8 @@ impl TextChunk {
|
||||
embedding_dimensions: usize,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
TextChunkEmbedding::validate_dimension(&embedding, embedding_dimensions)?;
|
||||
|
||||
let chunk_id = chunk.id.clone();
|
||||
let emb = TextChunkEmbedding::new(
|
||||
&chunk_id,
|
||||
chunk.source_id.clone(),
|
||||
embedding,
|
||||
chunk.user_id.clone(),
|
||||
);
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
|
||||
UPSERT type::thing('{emb_table}', $chunk_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
chunk_table = Self::table_name(),
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id))
|
||||
.bind(("chunk", chunk))
|
||||
.bind(("emb", emb))
|
||||
db.store_with_embedding(chunk, embedding, embedding_dimensions)
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and scores.
|
||||
@@ -111,52 +79,14 @@ impl TextChunk {
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||
#[allow(clippy::missing_docs_in_private_items)]
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
chunk_id: Option<TextChunk>,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
chunk_id,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH chunk_id;
|
||||
"#,
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
db.vector_search::<Self, TextChunkEmbedding>(take, query_embedding, user_id)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
response = response.check().map_err(AppError::from)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
r.chunk_id.map(|chunk| TextChunkSearchResult {
|
||||
chunk,
|
||||
score: r.score,
|
||||
}).or_else(|| {
|
||||
warn!("vector search hit orphaned text_chunk_embedding row with missing chunk");
|
||||
None
|
||||
})
|
||||
.map(|results| {
|
||||
results
|
||||
.into_iter()
|
||||
.map(|(chunk, score)| TextChunkSearchResult { chunk, score })
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Full-text search over text chunks using the BM25 FTS index.
|
||||
@@ -645,7 +575,7 @@ mod tests {
|
||||
assert_eq!(stored_chunk.user_id, user_id);
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "expected embedding".to_string())?;
|
||||
@@ -695,7 +625,7 @@ mod tests {
|
||||
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "embedding should exist".to_string())?;
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||
stored_object,
|
||||
};
|
||||
use crate::{storage::types::EmbeddingRecord, stored_object};
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::error::AppError;
|
||||
|
||||
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||
/// Record link to the owning text_chunk
|
||||
@@ -18,123 +16,46 @@ stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextChunkEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
///
|
||||
/// This is useful when the embedding length changes; Surreal requires the
|
||||
/// index definition to be recreated with the updated dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = hnsw_index_redefine_transaction_sql(
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||
res.check().map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
impl EmbeddingRecord for TextChunkEmbedding {
|
||||
fn link_field() -> &'static str {
|
||||
"chunk_id"
|
||||
}
|
||||
|
||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||
if embedding.len() != expected {
|
||||
return Err(AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
fn index_name() -> &'static str {
|
||||
"idx_embedding_text_chunk_embedding"
|
||||
}
|
||||
|
||||
/// Create a new text chunk embedding.
|
||||
///
|
||||
/// The embedding record id equals `chunk_id` so each chunk has at most one embedding row.
|
||||
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid".
|
||||
#[must_use]
|
||||
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
|
||||
fn embedding(&self) -> &[f32] {
|
||||
&self.embedding
|
||||
}
|
||||
|
||||
fn new(
|
||||
chunk_id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
|
||||
Self {
|
||||
id: chunk_id.to_owned(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
||||
chunk_id: RecordId::from_table_key(entity_table, chunk_id),
|
||||
source_id,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a single embedding by its chunk RecordId
|
||||
pub async fn get_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
/// Delete embeddings for a given chunk RecordId
|
||||
pub async fn delete_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE chunk_id = $chunk_id",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings that belong to chunks with a given `source_id`
|
||||
///
|
||||
/// This uses the denormalized `source_id` on the embedding table.
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -144,8 +65,31 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
||||
use surrealdb::Value as SurrealValue;
|
||||
|
||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: surrealdb::Value = info_res
|
||||
.take(0)
|
||||
.with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||
.with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
Ok(idx_sql)
|
||||
}
|
||||
|
||||
async fn create_text_chunk_with_id(
|
||||
db: &SurrealDbClient,
|
||||
@@ -169,29 +113,6 @@ mod tests {
|
||||
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
||||
}
|
||||
|
||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: SurrealValue = info_res
|
||||
.take(0)
|
||||
.with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||
.with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
Ok(idx_sql)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_chunk_id_as_record_id() {
|
||||
let emb = TextChunkEmbedding::new(
|
||||
@@ -199,6 +120,7 @@ mod tests {
|
||||
"source-1".to_owned(),
|
||||
vec![0.1, 0.2],
|
||||
"user-1".to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
assert_eq!(emb.id, "chunk-abc");
|
||||
}
|
||||
@@ -226,13 +148,14 @@ mod tests {
|
||||
source_id.to_string(),
|
||||
embedding_vec.clone(),
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| "Failed to store embedding".to_string())?;
|
||||
|
||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let fetched = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
||||
.with_context(|| "Expected an embedding to be found".to_string())?;
|
||||
@@ -259,22 +182,23 @@ mod tests {
|
||||
source_id.to_string(),
|
||||
vec![0.4_f32, 0.5, 0.6],
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| "Failed to store embedding".to_string())?;
|
||||
|
||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let existing = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||
|
||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||
TextChunkEmbedding::delete_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
||||
|
||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let after = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none(), "Embedding should have been deleted");
|
||||
@@ -299,21 +223,27 @@ mod tests {
|
||||
("chunk-s2", source_id, vec![0.2]),
|
||||
("chunk-other", other_source, vec![0.3]),
|
||||
] {
|
||||
let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
|
||||
let emb = TextChunkEmbedding::new(
|
||||
key,
|
||||
src.to_string(),
|
||||
vec,
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| format!("store embedding for {key}"))?;
|
||||
}
|
||||
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk1".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk2".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk_other".to_string())?
|
||||
.is_some());
|
||||
@@ -322,15 +252,15 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk1".to_string())?
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk2".to_string())?
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk_other".to_string())?
|
||||
.is_some());
|
||||
@@ -352,6 +282,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![1.0_f32, 0.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(initial)
|
||||
.await
|
||||
@@ -362,6 +293,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![0.0, 1.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(replacement)
|
||||
.await
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::storage::{
|
||||
indexes::{ensure_runtime, rebuild},
|
||||
types::{
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
|
||||
text_chunk_embedding::TextChunkEmbedding,
|
||||
text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user