diff --git a/CHANGELOG.md b/CHANGELOG.md index 6278ba5..c982aaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Changelog + ## Unreleased + - Evaluations: simplified crate layout — linear pipeline, sharded-only converted store, in-memory ingestion, `db/` and `cli/` modules; namespace reuse state in corpus manifest (removed `cache/snapshots/`); no legacy JSON/history compatibility (re-run `--warm` after upgrade) - Performance: ingestion skips per-task index rebuild; worker runs scheduled `REBUILD INDEX` (default every 24h via `index_rebuild_interval_secs`, `0` disables) - Performance: ingestion persists all artifacts in a single SurrealDB transaction per task (atomic replace by task id) @@ -7,10 +9,10 @@ - Fix: ingestion reclaims tasks after a successful persist without re-running the pipeline when `mark_succeeded` failed - Fix: content deletion clears graph relationships via shared `TextContent::clear_ingested_children` - Fix: regression re suggestion of relationships -- Internal: eval corpus DB seed uses `persist_artifacts` instead of a separate batched insert path -- Internal: removed unused `entity_embedding_concurrency` ingest tuning knob +- Internal: extracted duplicate entity+embedding patterns into `HasEmbedding` and `EmbeddingRecord` traits with generic `store_with_embedding`, `delete_by_source_id`, and `vector_search` on `SurrealDbClient`. ## 1.0.3 (2026-06-12) + - Search: filter results by type — knowledge entities, ingested content, or both - Admin: choose the local FastEmbed model from the admin UI; changes save immediately and apply after restart (re-embeds when the vector dimension changes) - Performance: pooled FastEmbed workers and batched embedding generation for faster ingestion and search @@ -20,6 +22,7 @@ - Fix: API key revocation now correctly clears the stored key ## 1.0.2 (2026-02-15) + - Fix: edge case where navigation back to a chat page could trigger a new response generation - Fix: chat references now validate and render more reliably - Fix: improved admin access checks for restricted routes @@ -28,73 +31,88 @@ - Security: hardened query handling and ingestion logging to reduce injection and data exposure risk ## 1.0.1 (2026-02-11) + - Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments. - Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling. - Fixed edge cases, including content deletion behavior and compatibility for older user records. ## 1.0.0 (2026-01-02) -- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms. + +- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms. - Added a benchmarks create for evaluating the retrieval process - Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms. - Embeddings stored on own table. - Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details. ## Version 0.2.7 (2025-12-04) + - Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features. - Fix: timezone aware info in scratchpad ## Version 0.2.6 (2025-10-29) + - Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results. - Fix: default name for relationships harmonized across application ## Version 0.2.5 (2025-10-24) + - Added manual knowledge entity creation flows using a modal, with the option for suggested relationships - Scratchpad feature, with the feature to convert scratchpads to content. - Added knowledge entity search results to the global search - Backend fixes for improved performance when ingesting and retrieval ## Version 0.2.4 (2025-10-15) + - Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal. - Ingestion task archive ## Version 0.2.3 (2025-10-12) + - Fix changing vector dimensions on a fresh database (#3) ## Version 0.2.2 (2025-10-07) + - Support for ingestion of PDF files - Improved ingestion speed - Fix deletion of items work as expected - Fix enabling GPT-5 use via OpenAI API ## Version 0.2.1 (2025-09-24) + - Fixed API JSON responses so iOS Shortcuts integrations keep working. ## Version 0.2.0 (2025-09-23) + - Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph. - Added pagination for entities and content plus new observability metrics on the dashboard. - Enabled audio ingestion and merged the new storage backend. - Improved performance, request filtering, and journalctl/systemd compatibility. ## Version 0.1.4 (2025-07-01) + - Added image ingestion with configurable system settings and updated Docker Compose docs. - Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses. ## Version 0.1.3 (2025-06-08) + - Added support for AI providers beyond OpenAI. - Made the HTTP port configurable for deployments. - Smoothed graph mapper failures, long content tiles, and refreshed project documentation. ## Version 0.1.2 (2025-05-26) + - Introduced full-text search across indexed knowledge. - Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling. - Fixed search result links and SurrealDB vector formatting glitches. ## Version 0.1.1 (2025-05-13) + - Added streaming feedback to ingestion tasks for clearer progress updates. - Made the data storage path configurable. - Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes. ## Version 0.1.0 (2025-05-06) + - Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage. - Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts. - Introduced an admin console with analytics, registration and timezone controls, and job monitoring. diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index c48ed32..5245ebd 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -1,9 +1,11 @@ -use super::types::StoredObject; +use super::types::{EmbeddingRecord, HasEmbedding, StoredObject}; use crate::error::AppError; use axum_session::{SessionConfig, SessionError, SessionStore}; use axum_session_surreal::SessionSurrealPool; use futures::Stream; use include_dir::{include_dir, Dir}; +use serde::de::DeserializeOwned; +use serde::Serialize; use std::{ops::Deref, sync::Arc}; use surrealdb::{ engine::any::{connect, Any}, @@ -26,20 +28,6 @@ pub trait ProvidesDb { } impl SurrealDbClient { - /// Initialize a new database client. - /// - /// # Arguments - /// - /// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`). - /// * `username` — Root username for authentication. - /// * `password` — Root password for authentication. - /// * `namespace` — SurrealDB namespace to use. - /// * `database` — SurrealDB database to use. - /// - /// # Errors - /// - /// Returns `Err` if the connection, authentication, or namespace/database selection fails. - /// In-memory (`mem://`) connections skip authentication. pub async fn new( address: &str, username: &str, @@ -49,30 +37,15 @@ impl SurrealDbClient { ) -> Result { let db = connect(address).await?; - // Skip sign-in for in-memory engine (no auth support) if !address.starts_with("mem://") { db.signin(Root { username, password }).await?; } - // Set namespace db.use_ns(namespace).use_db(database).await?; Ok(SurrealDbClient { client: db }) } - /// Initialize a new database client using namespace-level authentication. - /// - /// # Arguments - /// - /// * `address` — Database connection string. - /// * `namespace` — SurrealDB namespace to use (also used for auth). - /// * `username` — Namespace username for authentication. - /// * `password` — Namespace password for authentication. - /// * `database` — SurrealDB database to use. - /// - /// # Errors - /// - /// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails. pub async fn new_with_namespace_user( address: &str, namespace: &str, @@ -91,11 +64,6 @@ impl SurrealDbClient { Ok(SurrealDbClient { client: db }) } - /// Create an Axum session store backed by SurrealDB. - /// - /// # Errors - /// - /// Returns `SessionError` if the session store configuration or table creation fails. pub async fn create_session_store( &self, ) -> Result>, SessionError> { @@ -109,15 +77,6 @@ impl SurrealDbClient { .await } - /// Applies all pending database migrations found in the embedded MIGRATIONS_DIR. - /// - /// This function should be called during application startup, after connecting to - /// the database and selecting the appropriate namespace and database, but before - /// the application starts performing operations that rely on the schema. - /// - /// # Errors - /// - /// Returns `AppError::InternalError` if the migration runner fails to apply any migration. pub async fn apply_migrations(&self) -> Result<(), AppError> { debug!("Applying migrations"); MigrationRunner::new(&self.client) @@ -129,15 +88,6 @@ impl SurrealDbClient { Ok(()) } - /// Store an object in SurrealDB. - /// - /// # Arguments - /// - /// * `item` — The item to store. Must implement `StoredObject`. - /// - /// # Errors - /// - /// Returns `Err` if the database create operation fails. pub async fn store_item(&self, item: T) -> Result, Error> where T: StoredObject + Send + Sync + 'static, @@ -148,13 +98,6 @@ impl SurrealDbClient { .await } - /// Upsert an object in SurrealDB, replacing any existing record with the same ID. - /// - /// Useful when a single record should be replaced by id (admin updates, embedding rows, etc.). - /// - /// # Errors - /// - /// Returns `Err` if the database upsert operation fails. pub async fn upsert_item(&self, item: T) -> Result, Error> where T: StoredObject + Send + Sync + 'static, @@ -166,11 +109,6 @@ impl SurrealDbClient { .await } - /// Retrieve all objects from a table. - /// - /// # Errors - /// - /// Returns `Err` if the database select operation fails. pub async fn get_all_stored_items(&self) -> Result, Error> where T: for<'de> StoredObject, @@ -178,16 +116,6 @@ impl SurrealDbClient { self.client.select(T::table_name()).await } - /// Retrieve a single object by its ID. - /// - /// # Arguments - /// - /// * `id` — The ID of the item to retrieve. - /// - /// # Errors - /// - /// Returns `Err` if the database select operation fails. - /// Returns `Ok(None)` if no record with the given ID exists. pub async fn get_item(&self, id: &str) -> Result, Error> where T: for<'de> StoredObject, @@ -195,16 +123,6 @@ impl SurrealDbClient { self.client.select((T::table_name(), id)).await } - /// Delete a single object by its ID. - /// - /// # Arguments - /// - /// * `id` — The ID of the item to delete. - /// - /// # Errors - /// - /// Returns `Err` if the database delete operation fails. - /// Returns `Ok(None)` if no record with the given ID exists. pub async fn delete_item(&self, id: &str) -> Result, Error> where T: for<'de> StoredObject, @@ -212,11 +130,6 @@ impl SurrealDbClient { self.client.delete((T::table_name(), id)).await } - /// Listen to a table for real-time updates via a live query stream. - /// - /// # Errors - /// - /// Returns `Err` if the database live query subscription fails. pub async fn listen( &self, ) -> Result, Error>>, Error> @@ -225,6 +138,156 @@ impl SurrealDbClient { { self.client.select(T::table_name()).live().await } + + /// Atomically store an entity and its embedding vector in a single + /// SurrealDB transaction. + /// + /// Creates (or overwrites) the entity row and upserts the linked + /// embedding record. The embedding dimension is validated against + /// `embedding_dimensions` before the query is issued. + pub async fn store_with_embedding( + &self, + entity: E, + embedding: Vec, + embedding_dimensions: usize, + ) -> Result<(), AppError> + where + E: HasEmbedding + Serialize + Send + Sync + 'static, + ::Embedding: Serialize + Send + Sync, + { + E::Embedding::validate_dimension(&embedding, embedding_dimensions)?; + + let entity_id = entity.id().to_string(); + let emb = ::Embedding::new( + &entity_id, + entity.source_id().to_string(), + embedding, + entity.user_id().to_string(), + E::table_name(), + ); + + let sql = format!( + " + BEGIN TRANSACTION; + CREATE type::thing('{et}', $id) CONTENT $entity; + UPSERT type::thing('{emt}', $id) CONTENT $emb; + COMMIT TRANSACTION; + ", + et = E::table_name(), + emt = ::Embedding::table_name(), + ); + + self.client + .query(sql) + .bind(("id", entity_id)) + .bind(("entity", entity)) + .bind(("emb", emb)) + .await? + .check()?; + + Ok(()) + } + + /// Delete all entity and embedding rows matching a given `source_id`. + /// + /// Runs inside a SurrealDB transaction so that entity and embedding + /// deletes are atomic. + pub async fn delete_by_source_id(&self, source_id: &str) -> Result<(), AppError> + where + E: HasEmbedding, + E::Embedding: Send + Sync, + { + self.client + .query("BEGIN TRANSACTION;") + .query(format!( + "DELETE FROM {} WHERE source_id = $source_id;", + E::Embedding::table_name() + )) + .query(format!( + "DELETE FROM {} WHERE source_id = $source_id;", + E::table_name() + )) + .query("COMMIT TRANSACTION;") + .bind(("source_id", source_id.to_owned())) + .await? + .check()?; + + Ok(()) + } + + /// Vector similarity search over entities using HNSW index. + /// + /// Performs a cosine-similarity search against the embedding table, + /// fetches the corresponding entity rows server-side via `FETCH`, + /// and returns `(entity, score)` pairs ordered by descending + /// similarity. Orphaned embeddings (entity deleted but its + /// embedding row remains) are logged as a warning and dropped. + /// + /// This is a single round-trip — SurrealDB resolves the link field + /// (`entity_id` or `chunk_id`) inside the query engine. + pub async fn vector_search( + &self, + take: usize, + query_embedding: &[f32], + user_id: &str, + ) -> Result, AppError> + where + E: StoredObject + DeserializeOwned + Clone + Send + Sync, + Emb: EmbeddingRecord + Send + Sync, + { + // Generic row that works with both `entity_id` and `chunk_id` link + // fields via `#[serde(alias)]`. SurrealDB's `FETCH` resolves the link + // server-side so we get the full entity in a single round-trip. + #[derive(serde::Deserialize)] + struct FetchRow { + score: f32, + #[serde(alias = "entity_id", alias = "chunk_id")] + entity: Option, + } + + let link_field = Emb::link_field(); + let sql = format!( + r#" + SELECT + {link_field}, + vector::similarity::cosine(embedding, $embedding) AS score + FROM {emb_table} + WHERE user_id = $user_id + AND embedding <|{take},100|> $embedding + ORDER BY score DESC + LIMIT {take} + FETCH {link_field} + "#, + link_field = link_field, + emb_table = Emb::table_name(), + take = take, + ); + + let mut response = self + .client + .query(sql) + .bind(("embedding", query_embedding.to_vec())) + .bind(("user_id", user_id.to_string())) + .await?; + + response = response.check()?; + + let rows: Vec> = response.take(0)?; + + let mut results = Vec::with_capacity(rows.len()); + for r in rows { + if let Some(entity) = r.entity { + results.push((entity, r.score)); + } else { + tracing::warn!( + "Vector search hit orphaned {} row with missing {link_field}", + Emb::table_name() + ); + } + } + + Ok(results) + } } impl Deref for SurrealDbClient { @@ -237,12 +300,9 @@ impl Deref for SurrealDbClient { #[cfg(any(test, feature = "test-utils"))] impl SurrealDbClient { - /// Create an in-memory SurrealDB client for testing. pub async fn memory(namespace: &str, database: &str) -> Result { let db = connect("mem://").await?; - db.use_ns(namespace).use_db(database).await?; - Ok(SurrealDbClient { client: db }) } } diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs index 1aac6bb..bbcc383 100644 --- a/common/src/storage/indexes.rs +++ b/common/src/storage/indexes.rs @@ -9,10 +9,7 @@ use tracing::{debug, error, info, warn}; use crate::{ error::AppError, - storage::{ - db::SurrealDbClient, - types::system_settings::SystemSettings, - }, + storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, }; const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50); @@ -231,9 +228,7 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> { /// /// Returns `AppError::InternalError` if any rebuild operation fails. pub async fn rebuild_runtime(db: &SurrealDbClient) -> Result<(), AppError> { - rebuild_runtime_inner(db) - .await - .map_err(AppError::internal) + rebuild_runtime_inner(db).await.map_err(AppError::internal) } /// Returns whether a scheduled index rebuild is due based on the persisted last-run time. @@ -525,8 +520,7 @@ async fn rebuild_existing_index_in_place( if !index_exists(db, table, index_name).await? { debug!( index = index_name, - table, - "Skipping in-place rebuild because index is missing" + table, "Skipping in-place rebuild because index is missing" ); return Ok(()); } @@ -1074,7 +1068,11 @@ mod tests { assert!(!scheduled_index_rebuild_due(None, 86_400, now)); assert!(!scheduled_index_rebuild_due(Some(last), 0, now)); - assert!(!scheduled_index_rebuild_due(Some(now - chrono::Duration::hours(1)), 86_400, now)); + assert!(!scheduled_index_rebuild_due( + Some(now - chrono::Duration::hours(1)), + 86_400, + now + )); assert!(scheduled_index_rebuild_due(Some(last), 86_400, now)); } @@ -1087,7 +1085,9 @@ mod tests { .context("in-memory db")?; db.apply_migrations().await.context("migrations")?; - ensure_runtime(&db, 8).await.context("ensure runtime indexes")?; + ensure_runtime(&db, 8) + .await + .context("ensure runtime indexes")?; rebuild_runtime(&db) .await diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 76ef2e4..d72b212 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -4,10 +4,13 @@ use std::fmt::Write; use crate::{ error::AppError, - storage::db::SurrealDbClient, - storage::indexes::hnsw_index_overwrite_sql, - storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, - storage::types::system_settings::SystemSettings, + storage::{ + db::SurrealDbClient, + indexes::hnsw_index_overwrite_sql, + types::knowledge_entity_embedding::KnowledgeEntityEmbedding, + types::system_settings::SystemSettings, + types::{EmbeddingRecord, HasEmbedding}, + }, stored_object, utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE}, }; @@ -70,6 +73,18 @@ stored_object!(KnowledgeEntity, "knowledge_entity", { user_id: String }); +impl HasEmbedding for KnowledgeEntity { + type Embedding = KnowledgeEntityEmbedding; + + fn source_id(&self) -> &str { + &self.source_id + } + + fn user_id(&self) -> &str { + &self.user_id + } +} + impl KnowledgeEntity { #[must_use] pub fn new( @@ -227,22 +242,9 @@ impl KnowledgeEntity { pub async fn delete_by_source_id( source_id: &str, - db_client: &SurrealDbClient, + db: &SurrealDbClient, ) -> Result<(), AppError> { - // Delete embeddings first, while we can still look them up via the entity's source_id - KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?; - - db_client - .client - .query("DELETE FROM type::table($table) WHERE source_id = $source_id") - .bind(("table", Self::table_name())) - .bind(("source_id", source_id.to_owned())) - .await - .map_err(AppError::from)? - .check() - .map_err(AppError::from)?; - - Ok(()) + db.delete_by_source_id::(source_id).await } /// Atomically store one knowledge entity and its embedding (single-record path). @@ -254,38 +256,8 @@ impl KnowledgeEntity { embedding_dimensions: usize, db: &SurrealDbClient, ) -> Result<(), AppError> { - KnowledgeEntityEmbedding::validate_dimension(&embedding, embedding_dimensions)?; - - let entity_id = entity.id.clone(); - let emb = KnowledgeEntityEmbedding::new( - &entity_id, - entity.source_id.clone(), - embedding, - entity.user_id.clone(), - ); - - let query = format!( - " - BEGIN TRANSACTION; - CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity; - UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb; - COMMIT TRANSACTION; - ", - entity_table = Self::table_name(), - emb_table = KnowledgeEntityEmbedding::table_name(), - ); - - db.client - .query(query) - .bind(("entity_id", entity_id)) - .bind(("entity", entity)) - .bind(("emb", emb)) + db.store_with_embedding(entity, embedding, embedding_dimensions) .await - .map_err(AppError::from)? - .check() - .map_err(AppError::from)?; - - Ok(()) } /// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores. @@ -295,48 +267,14 @@ impl KnowledgeEntity { db: &SurrealDbClient, user_id: &str, ) -> Result, AppError> { - #[derive(Deserialize)] - struct Row { - entity_id: Option, - score: f32, - } - - let sql = format!( - r#" - SELECT - entity_id, - vector::similarity::cosine(embedding, $embedding) AS score - FROM {emb_table} - WHERE user_id = $user_id - AND embedding <|{take},100|> $embedding - ORDER BY score DESC - LIMIT {take} - FETCH entity_id; - "#, - emb_table = KnowledgeEntityEmbedding::table_name(), - take = take - ); - - let mut response = db - .query(&sql) - .bind(("embedding", query_embedding.to_vec())) - .bind(("user_id", user_id.to_string())) + db.vector_search::(take, query_embedding, user_id) .await - .map_err(AppError::from)?; - - response = response.check().map_err(AppError::from)?; - - let rows: Vec = response.take::>(0).map_err(AppError::from)?; - - Ok(rows - .into_iter() - .filter_map(|r| { - r.entity_id.map(|entity| KnowledgeEntitySearchResult { - entity, - score: r.score, - }) + .map(|results| { + results + .into_iter() + .map(|(entity, score)| KnowledgeEntitySearchResult { entity, score }) + .collect() }) - .collect()) } pub async fn patch( @@ -362,7 +300,13 @@ impl KnowledgeEntity { settings.embedding_dimensions as usize, )?; - let emb = KnowledgeEntityEmbedding::new(id, entity.source_id, embedding, entity.user_id); + let emb = KnowledgeEntityEmbedding::new( + id, + entity.source_id, + embedding, + entity.user_id, + Self::table_name(), + ); let now = Utc::now(); @@ -916,7 +860,7 @@ mod tests { assert_eq!(stored_embeddings.len(), 1); let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); - let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db) + let fetched_emb = KnowledgeEntityEmbedding::get_by_record_id(&db, &rid) .await .with_context(|| "fetch embedding".to_string())?; assert!(fetched_emb.is_some()); @@ -999,11 +943,11 @@ mod tests { let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id); let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id); - assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db) + assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e1) .await .with_context(|| "get embedding e1".to_string())? .is_some()); - assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db) + assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e2) .await .with_context(|| "get embedding e2".to_string())? .is_some()); diff --git a/common/src/storage/types/knowledge_entity_embedding.rs b/common/src/storage/types/knowledge_entity_embedding.rs index d43d27e..18dccba 100644 --- a/common/src/storage/types/knowledge_entity_embedding.rs +++ b/common/src/storage/types/knowledge_entity_embedding.rs @@ -4,7 +4,7 @@ use surrealdb::RecordId; use crate::{ error::AppError, - storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql}, + storage::{db::SurrealDbClient, types::EmbeddingRecord}, stored_object, }; @@ -17,72 +17,48 @@ stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", { user_id: String }); -impl KnowledgeEntityEmbedding { - /// Recreate the HNSW index with a new embedding dimension. - pub async fn redefine_hnsw_index( - db: &SurrealDbClient, - dimension: usize, - ) -> Result<(), AppError> { - let query = hnsw_index_redefine_transaction_sql( - "idx_embedding_knowledge_entity_embedding", - Self::table_name(), - dimension, - ); - - let res = db.client.query(query).await.map_err(AppError::from)?; - res.check().map_err(AppError::from)?; - - Ok(()) +impl EmbeddingRecord for KnowledgeEntityEmbedding { + fn link_field() -> &'static str { + "entity_id" } - /// Validates that an embedding vector matches the configured HNSW dimension. - #[allow(clippy::result_large_err)] - pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> { - if embedding.len() != expected { - return Err(AppError::Validation(format!( - "embedding dimension mismatch: got {}, expected {expected}", - embedding.len() - ))); - } - Ok(()) + fn index_name() -> &'static str { + "idx_embedding_knowledge_entity_embedding" } - /// Create a new knowledge entity embedding. - /// - /// The embedding record id equals `entity_id` so each entity has at most one embedding row. - #[must_use] - pub fn new(entity_id: &str, source_id: String, embedding: Vec, user_id: String) -> Self { + fn source_id(&self) -> &str { + &self.source_id + } + + fn user_id(&self) -> &str { + &self.user_id + } + + fn embedding(&self) -> &[f32] { + &self.embedding + } + + fn new( + entity_id: &str, + source_id: String, + embedding: Vec, + user_id: String, + entity_table: &str, + ) -> Self { let now = Utc::now(); Self { id: entity_id.to_owned(), created_at: now, updated_at: now, - entity_id: RecordId::from_table_key("knowledge_entity", entity_id), + entity_id: RecordId::from_table_key(entity_table, entity_id), embedding, source_id, user_id, } } +} - /// Get embedding by entity ID - pub async fn get_by_entity_id( - entity_id: &RecordId, - db: &SurrealDbClient, - ) -> Result, AppError> { - let query = format!( - "SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1", - Self::table_name() - ); - let mut result = db - .client - .query(query) - .bind(("entity_id", entity_id.clone())) - .await - .map_err(AppError::from)?; - let embeddings: Vec = result.take(0).map_err(AppError::from)?; - Ok(embeddings.into_iter().next()) - } - +impl KnowledgeEntityEmbedding { /// Get embeddings for multiple entities in batch pub async fn get_by_entity_ids( entity_ids: &[RecordId], @@ -109,44 +85,6 @@ impl KnowledgeEntityEmbedding { .map(|e| (e.entity_id.key().to_string(), e.embedding)) .collect()) } - - /// Delete embedding by entity ID - pub async fn delete_by_entity_id( - entity_id: &RecordId, - db: &SurrealDbClient, - ) -> Result<(), AppError> { - let query = format!( - "DELETE FROM {} WHERE entity_id = $entity_id", - Self::table_name() - ); - db.client - .query(query) - .bind(("entity_id", entity_id.clone())) - .await - .map_err(AppError::from)? - .check() - .map_err(AppError::from)?; - Ok(()) - } - - /// Delete all embeddings with the given denormalized `source_id`. - pub async fn delete_by_source_id( - source_id: &str, - db: &SurrealDbClient, - ) -> Result<(), AppError> { - let query = format!( - "DELETE FROM {} WHERE source_id = $source_id", - Self::table_name() - ); - db.client - .query(query) - .bind(("source_id", source_id.to_owned())) - .await - .map_err(AppError::from)? - .check() - .map_err(AppError::from)?; - Ok(()) - } } #[cfg(test)] @@ -184,6 +122,7 @@ mod tests { "source-1".to_owned(), vec![0.1, 0.2], "user-1".to_owned(), + KnowledgeEntity::table_name(), ); assert_eq!(emb.id, "entity-abc"); } @@ -211,7 +150,7 @@ mod tests { let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); - let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) + let fetched = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid) .await .with_context(|| "Failed to get embedding by entity_id".to_string())? .ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?; @@ -240,16 +179,16 @@ mod tests { let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); - let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) + let existing = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid) .await .with_context(|| "Failed to get embedding before delete".to_string())?; assert!(existing.is_some()); - KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db) + KnowledgeEntityEmbedding::delete_by_record_id(&db, &entity_rid) .await .with_context(|| "Failed to delete by entity_id".to_string())?; - let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) + let after = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid) .await .with_context(|| "Failed to get embedding after delete".to_string())?; assert!(after.is_none()); @@ -277,7 +216,7 @@ mod tests { assert!(stored_entity.is_some()); let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); - let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) + let stored_embedding = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid) .await .with_context(|| "Failed to fetch embedding".to_string())?; let stored_embedding = @@ -319,9 +258,14 @@ mod tests { KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], 3, &db) .await .with_context(|| "Failed to store entity with embedding".to_string())?; - KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], 3, &db) - .await - .with_context(|| "Failed to store entity with embedding".to_string())?; + KnowledgeEntity::store_with_embedding( + entity_other.clone(), + vec![3.0_f32, 3.1, 3.2], + 3, + &db, + ) + .await + .with_context(|| "Failed to store entity with embedding".to_string())?; let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id); let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id); @@ -332,18 +276,18 @@ mod tests { .with_context(|| "Failed to delete by source_id".to_string())?; assert!( - KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db) + KnowledgeEntityEmbedding::get_by_record_id(&db, &entity1_rid) .await .with_context(|| "get entity1 embedding after delete".to_string())? .is_none() ); assert!( - KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db) + KnowledgeEntityEmbedding::get_by_record_id(&db, &entity2_rid) .await .with_context(|| "get entity2 embedding after delete".to_string())? .is_none() ); - assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db) + assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &other_rid) .await .with_context(|| "get other embedding after delete".to_string())? .is_some()); @@ -450,6 +394,7 @@ mod tests { source_id.to_owned(), vec![0.0, 1.0, 0.0], user_id.to_owned(), + KnowledgeEntity::table_name(), ); db.upsert_item(replacement) .await diff --git a/common/src/storage/types/mod.rs b/common/src/storage/types/mod.rs index f9a4d62..6364817 100644 --- a/common/src/storage/types/mod.rs +++ b/common/src/storage/types/mod.rs @@ -1,4 +1,5 @@ #![allow(clippy::unsafe_derive_deserialize)] +#![allow(async_fn_in_trait)] use serde::{Deserialize, Serialize}; pub mod analytics; pub mod conversation; @@ -22,6 +23,135 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> { fn id(&self) -> &str; } +/// An entity that has an associated embedding record for vector search. +pub trait HasEmbedding: StoredObject { + /// The embedding record type paired with this entity. + type Embedding: EmbeddingRecord; + + fn source_id(&self) -> &str; + fn user_id(&self) -> &str; +} + +/// An embedding record linked to a `HasEmbedding` entity. +pub trait EmbeddingRecord: StoredObject { + /// The field name in the embedding table that links back to the entity + /// (e.g. `"entity_id"` or `"chunk_id"`). Used in FETCH and WHERE clauses. + fn link_field() -> &'static str; + + /// The HNSW index name (e.g. `"idx_embedding_knowledge_entity_embedding"`). + fn index_name() -> &'static str; + + fn source_id(&self) -> &str; + fn user_id(&self) -> &str; + fn embedding(&self) -> &[f32]; + + /// Construct a new embedding record. + /// + /// * `id` – shared record id (same as the entity id). + /// * `source_id` – denormalised source id for bulk deletes. + /// * `embedding` – the embedding vector. + /// * `user_id` – denormalised user id for query scoping. + /// * `entity_table` – the entity's table name (used to build the link `RecordId`). + fn new( + id: &str, + source_id: String, + embedding: Vec, + user_id: String, + entity_table: &str, + ) -> Self; + + /// Validate that an embedding vector matches the expected dimension. + fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), crate::error::AppError> + where + Self: Sized, + { + if embedding.len() != expected { + return Err(crate::error::AppError::Validation(format!( + "embedding dimension mismatch: got {}, expected {expected}", + embedding.len() + ))); + } + Ok(()) + } + + /// Recreate the HNSW vector index with a new dimension. + /// + /// This drops and recreates the index inside a transaction. + async fn redefine_hnsw_index( + db: &crate::storage::db::SurrealDbClient, + dimension: usize, + ) -> Result<(), crate::error::AppError> + where + Self: Sized, + { + let query = crate::storage::indexes::hnsw_index_redefine_transaction_sql( + Self::index_name(), + Self::table_name(), + dimension, + ); + db.client.query(query).await?.check()?; + Ok(()) + } + + /// Fetch a single embedding record by its link `RecordId`. + async fn get_by_record_id( + db: &crate::storage::db::SurrealDbClient, + rid: &surrealdb::RecordId, + ) -> Result, crate::error::AppError> + where + Self: Sized + serde::de::DeserializeOwned, + { + let query = format!( + "SELECT * FROM {} WHERE {} = $rid LIMIT 1", + Self::table_name(), + Self::link_field(), + ); + let mut result = db.client.query(query).bind(("rid", rid.clone())).await?; + Ok(result.take(0)?) + } + + /// Delete an embedding record by its link `RecordId`. + async fn delete_by_record_id( + db: &crate::storage::db::SurrealDbClient, + rid: &surrealdb::RecordId, + ) -> Result<(), crate::error::AppError> + where + Self: Sized, + { + let query = format!( + "DELETE FROM {} WHERE {} = $rid", + Self::table_name(), + Self::link_field(), + ); + db.client + .query(query) + .bind(("rid", rid.clone())) + .await? + .check()?; + Ok(()) + } + + /// Delete all embedding records with a given `source_id`. + async fn delete_by_source_id( + source_id: &str, + db: &crate::storage::db::SurrealDbClient, + ) -> Result<(), crate::error::AppError> + where + Self: Sized, + { + let query = format!( + "DELETE FROM {} WHERE source_id = $source_id", + Self::table_name(), + ); + db.client + .query(query) + .bind(("source_id", source_id.to_owned())) + .await? + .check()?; + Ok(()) + } +} + #[macro_export] macro_rules! stored_object { ($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => { diff --git a/common/src/storage/types/system_settings.rs b/common/src/storage/types/system_settings.rs index a9687ff..215da0e 100644 --- a/common/src/storage/types/system_settings.rs +++ b/common/src/storage/types/system_settings.rs @@ -910,13 +910,11 @@ mod tests { db.apply_migrations().await.context("migrations")?; assert!( - SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a") - .await?, + SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a").await?, "first lease claim should succeed" ); assert!( - !SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b") - .await?, + !SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?, "second lease claim should fail while lease is held" ); diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 7d1bbdd..b632489 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -3,11 +3,13 @@ use std::collections::HashMap; use std::fmt::Write; use crate::storage::indexes::hnsw_index_overwrite_sql; -use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; +use crate::storage::types::{ + text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord, HasEmbedding, +}; use crate::utils::embedding::RE_EMBED_BATCH_SIZE; use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; -use tracing::{error, info, warn}; +use tracing::{error, info}; use uuid::Uuid; stored_object!(TextChunk, "text_chunk", { @@ -24,6 +26,18 @@ pub struct TextChunkSearchResult { pub score: f32, } +impl HasEmbedding for TextChunk { + type Embedding = TextChunkEmbedding; + + fn source_id(&self) -> &str { + &self.source_id + } + + fn user_id(&self) -> &str { + &self.user_id + } +} + impl TextChunk { #[must_use] pub fn new(source_id: String, chunk: String, user_id: String) -> Self { @@ -40,25 +54,9 @@ impl TextChunk { pub async fn delete_by_source_id( source_id: &str, - db_client: &SurrealDbClient, + db: &SurrealDbClient, ) -> Result<(), AppError> { - db_client - .client - .query("BEGIN TRANSACTION;") - .query(format!( - "DELETE FROM {} WHERE source_id = $source_id;", - TextChunkEmbedding::table_name() - )) - .query("DELETE FROM type::table($table) WHERE source_id = $source_id;") - .query("COMMIT TRANSACTION;") - .bind(("source_id", source_id.to_owned())) - .bind(("table", Self::table_name())) - .await - .map_err(AppError::from)? - .check() - .map_err(AppError::from)?; - - Ok(()) + db.delete_by_source_id::(source_id).await } /// Atomically store one text chunk and its embedding (single-record path). @@ -70,38 +68,8 @@ impl TextChunk { embedding_dimensions: usize, db: &SurrealDbClient, ) -> Result<(), AppError> { - TextChunkEmbedding::validate_dimension(&embedding, embedding_dimensions)?; - - let chunk_id = chunk.id.clone(); - let emb = TextChunkEmbedding::new( - &chunk_id, - chunk.source_id.clone(), - embedding, - chunk.user_id.clone(), - ); - - let query = format!( - " - BEGIN TRANSACTION; - CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk; - UPSERT type::thing('{emb_table}', $chunk_id) CONTENT $emb; - COMMIT TRANSACTION; - ", - chunk_table = Self::table_name(), - emb_table = TextChunkEmbedding::table_name(), - ); - - db.client - .query(query) - .bind(("chunk_id", chunk_id)) - .bind(("chunk", chunk)) - .bind(("emb", emb)) + db.store_with_embedding(chunk, embedding, embedding_dimensions) .await - .map_err(AppError::from)? - .check() - .map_err(AppError::from)?; - - Ok(()) } /// Vector search over text chunks using the embedding table, fetching full chunk rows and scores. @@ -111,52 +79,14 @@ impl TextChunk { db: &SurrealDbClient, user_id: &str, ) -> Result, AppError> { - #[allow(clippy::missing_docs_in_private_items)] - #[derive(Deserialize)] - struct Row { - chunk_id: Option, - score: f32, - } - - let sql = format!( - r#" - SELECT - chunk_id, - vector::similarity::cosine(embedding, $embedding) AS score - FROM {emb_table} - WHERE user_id = $user_id - AND embedding <|{take},100|> $embedding - ORDER BY score DESC - LIMIT {take} - FETCH chunk_id; - "#, - emb_table = TextChunkEmbedding::table_name(), - take = take - ); - - let mut response = db - .query(&sql) - .bind(("embedding", query_embedding.to_vec())) - .bind(("user_id", user_id.to_string())) + db.vector_search::(take, query_embedding, user_id) .await - .map_err(AppError::from)?; - - response = response.check().map_err(AppError::from)?; - - let rows: Vec = response.take::>(0).map_err(AppError::from)?; - - Ok(rows - .into_iter() - .filter_map(|r| { - r.chunk_id.map(|chunk| TextChunkSearchResult { - chunk, - score: r.score, - }).or_else(|| { - warn!("vector search hit orphaned text_chunk_embedding row with missing chunk"); - None - }) + .map(|results| { + results + .into_iter() + .map(|(chunk, score)| TextChunkSearchResult { chunk, score }) + .collect() }) - .collect()) } /// Full-text search over text chunks using the BM25 FTS index. @@ -645,7 +575,7 @@ mod tests { assert_eq!(stored_chunk.user_id, user_id); let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id); - let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db) + let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid) .await .with_context(|| "get embedding".to_string())? .with_context(|| "expected embedding".to_string())?; @@ -695,7 +625,7 @@ mod tests { assert!(stored_chunk.id == chunk.id, "chunk should be stored"); let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id); - let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db) + let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid) .await .with_context(|| "get embedding".to_string())? .with_context(|| "embedding should exist".to_string())?; diff --git a/common/src/storage/types/text_chunk_embedding.rs b/common/src/storage/types/text_chunk_embedding.rs index 9f85841..d61bc0a 100644 --- a/common/src/storage/types/text_chunk_embedding.rs +++ b/common/src/storage/types/text_chunk_embedding.rs @@ -1,11 +1,9 @@ use surrealdb::RecordId; -use crate::storage::types::text_chunk::TextChunk; -use crate::{ - error::AppError, - storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql}, - stored_object, -}; +use crate::{storage::types::EmbeddingRecord, stored_object}; + +#[cfg(test)] +use crate::error::AppError; stored_object!(TextChunkEmbedding, "text_chunk_embedding", { /// Record link to the owning text_chunk @@ -18,123 +16,46 @@ stored_object!(TextChunkEmbedding, "text_chunk_embedding", { user_id: String }); -impl TextChunkEmbedding { - /// Recreate the HNSW index with a new embedding dimension. - /// - /// This is useful when the embedding length changes; Surreal requires the - /// index definition to be recreated with the updated dimension. - pub async fn redefine_hnsw_index( - db: &SurrealDbClient, - dimension: usize, - ) -> Result<(), AppError> { - let query = hnsw_index_redefine_transaction_sql( - "idx_embedding_text_chunk_embedding", - Self::table_name(), - dimension, - ); - - let res = db.client.query(query).await.map_err(AppError::from)?; - res.check().map_err(AppError::from)?; - - Ok(()) +impl EmbeddingRecord for TextChunkEmbedding { + fn link_field() -> &'static str { + "chunk_id" } - /// Validates that an embedding vector matches the configured HNSW dimension. - #[allow(clippy::result_large_err)] - pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> { - if embedding.len() != expected { - return Err(AppError::Validation(format!( - "embedding dimension mismatch: got {}, expected {expected}", - embedding.len() - ))); - } - Ok(()) + fn index_name() -> &'static str { + "idx_embedding_text_chunk_embedding" } - /// Create a new text chunk embedding. - /// - /// The embedding record id equals `chunk_id` so each chunk has at most one embedding row. - /// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid". - #[must_use] - pub fn new(chunk_id: &str, source_id: String, embedding: Vec, user_id: String) -> Self { + fn source_id(&self) -> &str { + &self.source_id + } + + fn user_id(&self) -> &str { + &self.user_id + } + + fn embedding(&self) -> &[f32] { + &self.embedding + } + + fn new( + chunk_id: &str, + source_id: String, + embedding: Vec, + user_id: String, + entity_table: &str, + ) -> Self { let now = Utc::now(); Self { id: chunk_id.to_owned(), created_at: now, updated_at: now, - chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id), + chunk_id: RecordId::from_table_key(entity_table, chunk_id), source_id, embedding, user_id, } } - - /// Get a single embedding by its chunk RecordId - pub async fn get_by_chunk_id( - chunk_id: &RecordId, - db: &SurrealDbClient, - ) -> Result, AppError> { - let query = format!( - "SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1", - Self::table_name() - ); - - let mut result = db - .client - .query(query) - .bind(("chunk_id", chunk_id.clone())) - .await - .map_err(AppError::from)?; - - let embeddings: Vec = result.take(0).map_err(AppError::from)?; - - Ok(embeddings.into_iter().next()) - } - - /// Delete embeddings for a given chunk RecordId - pub async fn delete_by_chunk_id( - chunk_id: &RecordId, - db: &SurrealDbClient, - ) -> Result<(), AppError> { - let query = format!( - "DELETE FROM {} WHERE chunk_id = $chunk_id", - Self::table_name() - ); - - db.client - .query(query) - .bind(("chunk_id", chunk_id.clone())) - .await - .map_err(AppError::from)? - .check() - .map_err(AppError::from)?; - - Ok(()) - } - - /// Delete all embeddings that belong to chunks with a given `source_id` - /// - /// This uses the denormalized `source_id` on the embedding table. - pub async fn delete_by_source_id( - source_id: &str, - db: &SurrealDbClient, - ) -> Result<(), AppError> { - let query = format!( - "DELETE FROM {} WHERE source_id = $source_id", - Self::table_name() - ); - - db.client - .query(query) - .bind(("source_id", source_id.to_owned())) - .await - .map_err(AppError::from)? - .check() - .map_err(AppError::from)?; - - Ok(()) - } } #[cfg(test)] @@ -144,8 +65,31 @@ mod tests { use super::*; use crate::storage::db::SurrealDbClient; + use crate::storage::types::text_chunk::TextChunk; use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db}; - use surrealdb::Value as SurrealValue; + + async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result { + let mut info_res = db + .client + .query("INFO FOR TABLE text_chunk_embedding;") + .await + .with_context(|| "info query failed".to_string())?; + let info: surrealdb::Value = info_res + .take(0) + .with_context(|| "failed to take info result".to_string())?; + let info_json: serde_json::Value = serde_json::to_value(info) + .with_context(|| "failed to convert info to json".to_string())?; + let idx_sql = info_json + .get("Object") + .and_then(|v| v.get("indexes")) + .and_then(|v| v.get("Object")) + .and_then(|v| v.get("idx_embedding_text_chunk_embedding")) + .and_then(|v| v.get("Strand")) + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + Ok(idx_sql) + } async fn create_text_chunk_with_id( db: &SurrealDbClient, @@ -169,29 +113,6 @@ mod tests { Ok(RecordId::from_table_key(TextChunk::table_name(), key)) } - async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result { - let mut info_res = db - .client - .query("INFO FOR TABLE text_chunk_embedding;") - .await - .with_context(|| "info query failed".to_string())?; - let info: SurrealValue = info_res - .take(0) - .with_context(|| "failed to take info result".to_string())?; - let info_json: serde_json::Value = serde_json::to_value(info) - .with_context(|| "failed to convert info to json".to_string())?; - let idx_sql = info_json - .get("Object") - .and_then(|v| v.get("indexes")) - .and_then(|v| v.get("Object")) - .and_then(|v| v.get("idx_embedding_text_chunk_embedding")) - .and_then(|v| v.get("Strand")) - .and_then(|v| v.as_str()) - .unwrap_or_default() - .to_string(); - Ok(idx_sql) - } - #[test] fn new_uses_chunk_id_as_record_id() { let emb = TextChunkEmbedding::new( @@ -199,6 +120,7 @@ mod tests { "source-1".to_owned(), vec![0.1, 0.2], "user-1".to_owned(), + TextChunk::table_name(), ); assert_eq!(emb.id, "chunk-abc"); } @@ -226,13 +148,14 @@ mod tests { source_id.to_string(), embedding_vec.clone(), user_id.to_string(), + TextChunk::table_name(), ); db.upsert_item(emb) .await .with_context(|| "Failed to store embedding".to_string())?; - let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) + let fetched = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid) .await .with_context(|| "Failed to get embedding by chunk_id".to_string())? .with_context(|| "Expected an embedding to be found".to_string())?; @@ -259,22 +182,23 @@ mod tests { source_id.to_string(), vec![0.4_f32, 0.5, 0.6], user_id.to_string(), + TextChunk::table_name(), ); db.upsert_item(emb) .await .with_context(|| "Failed to store embedding".to_string())?; - let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) + let existing = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid) .await .with_context(|| "Failed to get embedding before delete".to_string())?; assert!(existing.is_some(), "Embedding should exist before delete"); - TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db) + TextChunkEmbedding::delete_by_record_id(&db, &chunk_rid) .await .with_context(|| "Failed to delete by chunk_id".to_string())?; - let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) + let after = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid) .await .with_context(|| "Failed to get embedding after delete".to_string())?; assert!(after.is_none(), "Embedding should have been deleted"); @@ -299,21 +223,27 @@ mod tests { ("chunk-s2", source_id, vec![0.2]), ("chunk-other", other_source, vec![0.3]), ] { - let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string()); + let emb = TextChunkEmbedding::new( + key, + src.to_string(), + vec, + user_id.to_string(), + TextChunk::table_name(), + ); db.upsert_item(emb) .await .with_context(|| format!("store embedding for {key}"))?; } - assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) + assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid) .await .with_context(|| "get chunk1".to_string())? .is_some()); - assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db) + assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid) .await .with_context(|| "get chunk2".to_string())? .is_some()); - assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) + assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid) .await .with_context(|| "get chunk_other".to_string())? .is_some()); @@ -322,15 +252,15 @@ mod tests { .await .with_context(|| "Failed to delete by source_id".to_string())?; - assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) + assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid) .await .with_context(|| "check chunk1".to_string())? .is_none()); - assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db) + assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid) .await .with_context(|| "check chunk2".to_string())? .is_none()); - assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) + assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid) .await .with_context(|| "check chunk_other".to_string())? .is_some()); @@ -352,6 +282,7 @@ mod tests { source_id.to_owned(), vec![1.0_f32, 0.0, 0.0], user_id.to_owned(), + TextChunk::table_name(), ); db.upsert_item(initial) .await @@ -362,6 +293,7 @@ mod tests { source_id.to_owned(), vec![0.0, 1.0, 0.0], user_id.to_owned(), + TextChunk::table_name(), ); db.upsert_item(replacement) .await diff --git a/common/src/test_utils.rs b/common/src/test_utils.rs index 26e5127..caf7800 100644 --- a/common/src/test_utils.rs +++ b/common/src/test_utils.rs @@ -9,7 +9,7 @@ use crate::storage::{ indexes::{ensure_runtime, rebuild}, types::{ knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings, - text_chunk_embedding::TextChunkEmbedding, + text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord, }, }; diff --git a/evaluations/src/context_stats.rs b/evaluations/src/context_stats.rs index ffcd425..e620252 100644 --- a/evaluations/src/context_stats.rs +++ b/evaluations/src/context_stats.rs @@ -1,3 +1,5 @@ +#![allow(clippy::arithmetic_side_effects)] + use serde::{Deserialize, Serialize}; use common::storage::types::StoredObject; @@ -6,6 +8,7 @@ use crate::types::EvaluationCandidate; const TOKENIZER_LABEL: &str = "estimated (~chars/4; ingestion uses bert-base-cased)"; +#[allow(clippy::struct_field_names)] #[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)] pub struct RetrievedContextStats { pub chunk_count: usize, @@ -48,6 +51,7 @@ pub fn stats_for_candidates(candidates: &[EvaluationCandidate]) -> RetrievedCont stats } +#[allow(clippy::cast_precision_loss)] pub fn aggregate_context_stats(per_query: &[RetrievedContextStats]) -> RetrievalContextStats { let queries = per_query.len(); if queries == 0 { @@ -69,19 +73,28 @@ pub fn aggregate_context_stats(per_query: &[RetrievedContextStats]) -> Retrieval let total_chunks: usize = per_query.iter().map(|stats| stats.chunk_count).sum(); let total_chars: usize = per_query.iter().map(|stats| stats.char_count).sum(); let total_tokens: usize = per_query.iter().map(|stats| stats.token_count).sum(); - let mut tokens_per_query: Vec = per_query.iter().map(|stats| stats.token_count).collect(); + let mut tokens_per_query: Vec = + per_query.iter().map(|stats| stats.token_count).collect(); tokens_per_query.sort_unstable(); let max_tokens_per_query = *tokens_per_query.last().unwrap_or(&0); + let total_chunks_f = total_chunks as f64; + let total_chars_f = total_chars as f64; + let total_tokens_f = total_tokens as f64; + let queries_f = queries as f64; + let avg_chunks_per_query = total_chunks_f / queries_f; + let avg_chars_per_query = total_chars_f / queries_f; + let avg_tokens_per_query = total_tokens_f / queries_f; + RetrievalContextStats { tokenizer: TOKENIZER_LABEL.to_string(), queries, total_chunks, total_chars, total_tokens, - avg_chunks_per_query: total_chunks as f64 / queries as f64, - avg_chars_per_query: total_chars as f64 / queries as f64, - avg_tokens_per_query: total_tokens as f64 / queries as f64, + avg_chunks_per_query, + avg_chars_per_query, + avg_tokens_per_query, p50_tokens_per_query: percentile_usize(&tokens_per_query, 0.50), p95_tokens_per_query: percentile_usize(&tokens_per_query, 0.95), max_tokens_per_query, @@ -96,7 +109,13 @@ fn estimate_ingestion_tokens(text: &str) -> usize { chars.div_ceil(4) } -#[allow(clippy::cast_precision_loss, clippy::indexing_slicing, clippy::arithmetic_side_effects)] +#[allow( + clippy::cast_precision_loss, + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + clippy::indexing_slicing, + clippy::arithmetic_side_effects +)] fn percentile_usize(sorted: &[usize], fraction: f64) -> usize { if sorted.is_empty() { return 0; diff --git a/ingestion-pipeline/src/pipeline/persistence.rs b/ingestion-pipeline/src/pipeline/persistence.rs index 7eb9ab5..89cbd00 100644 --- a/ingestion-pipeline/src/pipeline/persistence.rs +++ b/ingestion-pipeline/src/pipeline/persistence.rs @@ -13,10 +13,9 @@ use common::{ db::SurrealDbClient, types::{ knowledge_entity::KnowledgeEntity, - knowledge_entity_embedding::KnowledgeEntityEmbedding, - text_chunk::TextChunk, - text_chunk_embedding::TextChunkEmbedding, - text_content::TextContent, + knowledge_entity_embedding::KnowledgeEntityEmbedding, text_chunk::TextChunk, + text_chunk_embedding::TextChunkEmbedding, text_content::TextContent, EmbeddingRecord, + StoredObject, }, }, }; @@ -116,8 +115,7 @@ struct PersistPayload { entity_embeddings: Arc<[KnowledgeEntityEmbedding]>, chunks: Arc<[TextChunk]>, chunk_embeddings: Arc<[TextChunkEmbedding]>, - relationships: - Arc<[common::storage::types::knowledge_relationship::KnowledgeRelationship]>, + relationships: Arc<[common::storage::types::knowledge_relationship::KnowledgeRelationship]>, } async fn execute_persist_transaction( @@ -208,6 +206,7 @@ fn prepare_entity_rows( entity.source_id.clone(), item.embedding, entity.user_id.clone(), + KnowledgeEntity::table_name(), )); entities.push(entity); } @@ -230,6 +229,7 @@ fn prepare_chunk_rows( chunk.source_id.clone(), item.embedding, chunk.user_id.clone(), + TextChunk::table_name(), )); chunks.push(chunk); } @@ -244,7 +244,8 @@ fn is_retryable_conflict(error: &AppError) -> bool { } #[cfg(test)] -static TEST_PERSIST_FAILURES: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); +static TEST_PERSIST_FAILURES: std::sync::atomic::AtomicUsize = + std::sync::atomic::AtomicUsize::new(0); #[cfg(test)] fn set_test_persist_failures(count: usize) { @@ -348,8 +349,16 @@ mod tests { let source_b = uuid::Uuid::new_v4().to_string(); let user_id = "persist-isolation"; - persist(&db, large_artifacts(&source_a, user_id, 5, 3, 4, TEST_EMBEDDING_DIM)).await?; - persist(&db, large_artifacts(&source_b, user_id, 2, 1, 1, TEST_EMBEDDING_DIM)).await?; + persist( + &db, + large_artifacts(&source_a, user_id, 5, 3, 4, TEST_EMBEDDING_DIM), + ) + .await?; + persist( + &db, + large_artifacts(&source_b, user_id, 2, 1, 1, TEST_EMBEDDING_DIM), + ) + .await?; persist( &db, large_artifacts(&source_a, user_id, 7, 4, 6, TEST_EMBEDDING_DIM),