evals: v3, ebeddings at the side

additional indexes
This commit is contained in:
Per Stark
2025-11-26 15:00:55 +01:00
parent 226b2db43a
commit 030f0fc17d
63 changed files with 3859 additions and 1124 deletions

View File

@@ -1,5 +1,8 @@
use super::types::StoredObject;
use crate::error::AppError;
use crate::{
error::AppError,
storage::{indexes::ensure_runtime_indexes, types::system_settings::SystemSettings},
};
use axum_session::{SessionConfig, SessionError, SessionStore};
use axum_session_surreal::SessionSurrealPool;
use futures::Stream;
@@ -96,20 +99,22 @@ impl SurrealDbClient {
}
/// Operation to rebuild indexes
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
pub async fn rebuild_indexes(&self) -> Result<(), AppError> {
debug!("Rebuilding indexes");
let rebuild_sql = r#"
BEGIN TRANSACTION;
REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
COMMIT TRANSACTION;
REBUILD INDEX IF EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding;
REBUILD INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding;
"#;
self.client.query(rebuild_sql).await?;
self.client
.query(rebuild_sql)
.await
.map_err(|e| AppError::InternalError(e.to_string()))?;
Ok(())
}

View File

@@ -0,0 +1,589 @@
use std::time::Duration;
use anyhow::{Context, Result};
use serde::Deserialize;
use serde_json::Value;
use tracing::{info, warn};
use crate::{error::AppError, storage::db::SurrealDbClient};
const INDEX_POLL_INTERVAL: Duration = Duration::from_secs(2);
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
#[derive(Clone, Copy)]
struct HnswIndexSpec {
index_name: &'static str,
table: &'static str,
options: &'static str,
}
impl HnswIndexSpec {
fn definition_if_not_exists(&self, dimension: usize) -> String {
format!(
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} \
FIELDS embedding HNSW DIMENSION {dimension} {options};",
index = self.index_name,
table = self.table,
dimension = dimension,
options = self.options,
)
}
fn definition_overwrite(&self, dimension: usize) -> String {
format!(
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} \
FIELDS embedding HNSW DIMENSION {dimension} {options};",
index = self.index_name,
table = self.table,
dimension = dimension,
options = self.options,
)
}
}
#[derive(Clone, Copy)]
struct FtsIndexSpec {
index_name: &'static str,
table: &'static str,
field: &'static str,
analyzer: Option<&'static str>,
method: &'static str,
}
impl FtsIndexSpec {
fn definition(&self) -> String {
let analyzer_clause = self
.analyzer
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
.unwrap_or_default();
format!(
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
index = self.index_name,
table = self.table,
field = self.field,
)
}
}
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
pub async fn ensure_runtime_indexes(
db: &SurrealDbClient,
embedding_dimension: usize,
) -> Result<(), AppError> {
ensure_runtime_indexes_inner(db, embedding_dimension)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
}
async fn ensure_runtime_indexes_inner(
db: &SurrealDbClient,
embedding_dimension: usize,
) -> Result<()> {
create_fts_analyzer(db).await?;
for spec in fts_index_specs() {
create_index_with_polling(
db,
spec.definition(),
spec.index_name,
spec.table,
Some(spec.table),
)
.await?;
}
for spec in hnsw_index_specs() {
ensure_hnsw_index(db, &spec, embedding_dimension).await?;
}
Ok(())
}
async fn ensure_hnsw_index(
db: &SurrealDbClient,
spec: &HnswIndexSpec,
dimension: usize,
) -> Result<()> {
let definition = match hnsw_index_state(db, spec, dimension).await? {
HnswIndexState::Missing => spec.definition_if_not_exists(dimension),
HnswIndexState::Matches(_) => spec.definition_if_not_exists(dimension),
HnswIndexState::Different(existing) => {
info!(
index = spec.index_name,
table = spec.table,
existing_dimension = existing,
target_dimension = dimension,
"Overwriting HNSW index to match new embedding dimension"
);
spec.definition_overwrite(dimension)
}
};
create_index_with_polling(
db,
definition,
spec.index_name,
spec.table,
Some(spec.table),
)
.await
}
async fn hnsw_index_state(
db: &SurrealDbClient,
spec: &HnswIndexSpec,
expected_dimension: usize,
) -> Result<HnswIndexState> {
let info_query = format!("INFO FOR TABLE {table};", table = spec.table);
let mut response = db
.client
.query(info_query)
.await
.with_context(|| format!("fetching table info for {}", spec.table))?;
let info: surrealdb::Value = response
.take(0)
.context("failed to take table info response")?;
let info_json: Value =
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
let Some(indexes) = info_json
.get("Object")
.and_then(|o| o.get("indexes"))
.and_then(|i| i.get("Object"))
.and_then(|i| i.as_object())
else {
return Ok(HnswIndexState::Missing);
};
let Some(definition) = indexes
.get(spec.index_name)
.and_then(|details| details.get("Strand"))
.and_then(|v| v.as_str())
else {
return Ok(HnswIndexState::Missing);
};
let Some(current_dimension) = extract_dimension(definition) else {
return Ok(HnswIndexState::Missing);
};
if current_dimension == expected_dimension as u64 {
Ok(HnswIndexState::Matches(current_dimension))
} else {
Ok(HnswIndexState::Different(current_dimension))
}
}
enum HnswIndexState {
Missing,
Matches(u64),
Different(u64),
}
fn extract_dimension(definition: &str) -> Option<u64> {
definition
.split("DIMENSION")
.nth(1)
.and_then(|rest| rest.split_whitespace().next())
.and_then(|token| token.trim_end_matches(';').parse::<u64>().ok())
}
async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
let analyzer_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);",
analyzer = FTS_ANALYZER_NAME
);
let res = db
.client
.query(analyzer_query)
.await
.context("creating FTS analyzer")?;
res.check().context("failed to create FTS analyzer")?;
Ok(())
}
async fn create_index_with_polling(
db: &SurrealDbClient,
definition: String,
index_name: &str,
table: &str,
progress_table: Option<&str>,
) -> Result<()> {
let expected_total = match progress_table {
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
format!("counting rows in {table} for index {index_name} progress")
})?),
None => None,
};
let res = db
.client
.query(definition)
.await
.with_context(|| format!("creating index {index_name} on table {table}"))?;
res.check()
.with_context(|| format!("index definition failed for {index_name} on {table}"))?;
info!(
index = %index_name,
table = %table,
expected_rows = ?expected_total,
"Index definition submitted; waiting for build to finish"
);
poll_index_build_status(db, index_name, table, expected_total, INDEX_POLL_INTERVAL).await
}
async fn poll_index_build_status(
db: &SurrealDbClient,
index_name: &str,
table: &str,
total_rows: Option<u64>,
poll_every: Duration,
) -> Result<()> {
let started_at = std::time::Instant::now();
loop {
tokio::time::sleep(poll_every).await;
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
let mut info_res = db.client.query(info_query).await.with_context(|| {
format!("checking index build status for {index_name} on {table}")
})?;
let info: Option<Value> = info_res
.take(0)
.context("failed to deserialize INFO FOR INDEX result")?;
let Some(snapshot) = parse_index_build_info(info, total_rows) else {
warn!(
index = %index_name,
table = %table,
"INFO FOR INDEX returned no data; assuming index definition might be missing"
);
break;
};
match snapshot.progress_pct {
Some(pct) => info!(
index = %index_name,
table = %table,
status = snapshot.status,
initial = snapshot.initial,
pending = snapshot.pending,
updated = snapshot.updated,
processed = snapshot.processed,
total = snapshot.total_rows,
progress_pct = format_args!("{pct:.1}"),
"Index build status"
),
None => info!(
index = %index_name,
table = %table,
status = snapshot.status,
initial = snapshot.initial,
pending = snapshot.pending,
updated = snapshot.updated,
processed = snapshot.processed,
"Index build status"
),
}
if snapshot.is_ready() {
info!(
index = %index_name,
table = %table,
elapsed = ?started_at.elapsed(),
processed = snapshot.processed,
total = snapshot.total_rows,
"Index is ready"
);
break;
}
if snapshot.status.eq_ignore_ascii_case("error") {
warn!(
index = %index_name,
table = %table,
status = snapshot.status,
"Index build reported error status; stopping polling"
);
break;
}
}
Ok(())
}
#[derive(Debug, PartialEq)]
struct IndexBuildSnapshot {
status: String,
initial: u64,
pending: u64,
updated: u64,
processed: u64,
total_rows: Option<u64>,
progress_pct: Option<f64>,
}
impl IndexBuildSnapshot {
fn is_ready(&self) -> bool {
self.status.eq_ignore_ascii_case("ready")
}
}
fn parse_index_build_info(
info: Option<Value>,
total_rows: Option<u64>,
) -> Option<IndexBuildSnapshot> {
let info = info?;
let building = info.get("building");
let status = building
.and_then(|b| b.get("status"))
.and_then(|s| s.as_str())
// If there's no `building` block at all, treat as "ready" (index not building anymore)
.unwrap_or("ready")
.to_string();
let initial = building
.and_then(|b| b.get("initial"))
.and_then(|v| v.as_u64())
.unwrap_or(0);
let pending = building
.and_then(|b| b.get("pending"))
.and_then(|v| v.as_u64())
.unwrap_or(0);
let updated = building
.and_then(|b| b.get("updated"))
.and_then(|v| v.as_u64())
.unwrap_or(0);
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
let processed = initial.saturating_add(updated);
let progress_pct = total_rows.map(|total| {
if total == 0 {
0.0
} else {
((processed as f64 / total as f64).min(1.0)) * 100.0
}
});
Some(IndexBuildSnapshot {
status,
initial,
pending,
updated,
processed,
total_rows,
progress_pct,
})
}
#[derive(Debug, Deserialize)]
struct CountRow {
count: u64,
}
async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
let query = format!("SELECT count() AS count FROM {table} GROUP ALL;");
let mut response = db
.client
.query(query)
.await
.with_context(|| format!("counting rows in {table}"))?;
let rows: Vec<CountRow> = response
.take(0)
.context("failed to deserialize count() response")?;
Ok(rows.first().map(|r| r.count).unwrap_or(0))
}
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
[
HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
HnswIndexSpec {
index_name: "idx_embedding_knowledge_entity_embedding",
table: "knowledge_entity_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
]
}
const fn fts_index_specs() -> [FtsIndexSpec; 9] {
[
FtsIndexSpec {
index_name: "text_content_fts_idx",
table: "text_content",
field: "text",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_category_fts_idx",
table: "text_content",
field: "category",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_context_fts_idx",
table: "text_content",
field: "context",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_file_name_fts_idx",
table: "text_content",
field: "file_info.file_name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_fts_idx",
table: "text_content",
field: "url_info.url",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_title_fts_idx",
table: "text_content",
field: "url_info.title",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_name_idx",
table: "knowledge_entity",
field: "name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_description_idx",
table: "knowledge_entity",
field: "description",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_chunk_fts_chunk_idx",
table: "text_chunk",
field: "chunk",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
]
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use uuid::Uuid;
#[test]
fn parse_index_build_info_reports_progress() {
let info = json!({
"building": {
"initial": 56894,
"pending": 0,
"status": "indexing",
"updated": 0
}
});
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot");
assert_eq!(
snapshot,
IndexBuildSnapshot {
status: "indexing".to_string(),
initial: 56894,
pending: 0,
updated: 0,
processed: 56894,
total_rows: Some(61081),
progress_pct: Some((56894_f64 / 61081_f64) * 100.0),
}
);
assert!(!snapshot.is_ready());
}
#[test]
fn parse_index_build_info_defaults_to_ready_when_no_building_block() {
// Surreal returns `{}` when the index exists but isn't building.
let info = json!({});
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot");
assert!(snapshot.is_ready());
assert_eq!(snapshot.processed, 0);
assert_eq!(snapshot.progress_pct, Some(0.0));
}
#[test]
fn extract_dimension_parses_value() {
let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;";
assert_eq!(extract_dimension(definition), Some(1536));
}
#[tokio::test]
async fn ensure_runtime_indexes_is_idempotent() {
let namespace = "indexes_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("in-memory db");
db.apply_migrations()
.await
.expect("migrations should succeed");
// First run creates everything
ensure_runtime_indexes(&db, 1536)
.await
.expect("initial index creation");
// Second run should be a no-op and still succeed
ensure_runtime_indexes(&db, 1536)
.await
.expect("second index creation");
}
#[tokio::test]
async fn ensure_hnsw_index_overwrites_dimension() {
let namespace = "indexes_dim";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("in-memory db");
db.apply_migrations()
.await
.expect("migrations should succeed");
// Create initial index with default dimension
ensure_runtime_indexes(&db, 1536)
.await
.expect("initial index creation");
// Change dimension and ensure overwrite path is exercised
ensure_runtime_indexes(&db, 128)
.await
.expect("overwritten index creation");
}
}

View File

@@ -1,3 +1,4 @@
pub mod db;
pub mod indexes;
pub mod store;
pub mod types;

View File

@@ -1,7 +1,8 @@
use std::collections::HashMap;
use crate::{
error::AppError, storage::db::SurrealDbClient, stored_object,
error::AppError, storage::db::SurrealDbClient,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
@@ -78,10 +79,16 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String
});
/// Vector search result including hydrated entity.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct KnowledgeEntityVectorResult {
pub entity: KnowledgeEntity,
pub score: f32,
}
impl KnowledgeEntity {
pub fn new(
source_id: String,
@@ -89,7 +96,6 @@ impl KnowledgeEntity {
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String,
) -> Self {
let now = Utc::now();
@@ -102,7 +108,6 @@ impl KnowledgeEntity {
description,
entity_type,
metadata,
embedding,
user_id,
}
}
@@ -165,6 +170,89 @@ impl KnowledgeEntity {
Ok(())
}
/// Atomically store a knowledge entity and its embedding.
/// Writes the entity to `knowledge_entity` and the embedding to `knowledge_entity_embedding`.
pub async fn store_with_embedding(
entity: KnowledgeEntity,
embedding: Vec<f32>,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone());
let query = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
COMMIT TRANSACTION;
",
entity_table = Self::table_name(),
emb_table = KnowledgeEntityEmbedding::table_name(),
);
db.client
.query(query)
.bind(("entity_id", entity.id.clone()))
.bind(("entity", entity))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
pub async fn vector_search(
take: usize,
query_embedding: Vec<f32>,
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<KnowledgeEntityVectorResult>, AppError> {
#[derive(Deserialize)]
struct Row {
entity_id: KnowledgeEntity,
score: f32,
}
let sql = format!(
r#"
SELECT
entity_id,
vector::similarity::cosine(embedding, $embedding) AS score
FROM {emb_table}
WHERE user_id = $user_id
AND embedding <|{take},100|> $embedding
ORDER BY score DESC
LIMIT {take}
FETCH entity_id;
"#,
emb_table = KnowledgeEntityEmbedding::table_name(),
take = take
);
let mut response = db
.query(&sql)
.bind(("embedding", query_embedding))
.bind(("user_id", user_id.to_string()))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
response = response.check().map_err(AppError::Database)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
Ok(rows
.into_iter()
.map(|r| KnowledgeEntityVectorResult {
entity: r.entity_id,
score: r.score,
})
.collect())
}
pub async fn patch(
id: &str,
name: &str,
@@ -178,32 +266,55 @@ impl KnowledgeEntity {
name, description, entity_type
);
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
let user_id = Self::get_user_id_by_id(id, db_client).await?;
let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id);
let now = Utc::now();
db_client
.client
.query(
"UPDATE type::thing($table, $id)
SET name = $name,
description = $description,
updated_at = $updated_at,
entity_type = $entity_type,
embedding = $embedding
RETURN AFTER",
"BEGIN TRANSACTION;
UPDATE type::thing($table, $id)
SET name = $name,
description = $description,
updated_at = $updated_at,
entity_type = $entity_type;
UPSERT type::thing($emb_table, $emb_id) CONTENT $emb;
COMMIT TRANSACTION;",
)
.bind(("table", Self::table_name()))
.bind(("emb_table", KnowledgeEntityEmbedding::table_name()))
.bind(("id", id.to_string()))
.bind(("name", name.to_string()))
.bind(("updated_at", surrealdb::Datetime::from(now)))
.bind(("entity_type", entity_type.to_owned()))
.bind(("embedding", embedding))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.bind(("description", description.to_string()))
.await?;
Ok(())
}
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
let mut response = db_client
.client
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
.bind(("table", Self::table_name()))
.bind(("id", id.to_string()))
.await
.map_err(AppError::Database)?;
#[derive(Deserialize)]
struct Row {
user_id: String,
}
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
rows.get(0)
.map(|r| r.user_id.clone())
.ok_or_else(|| AppError::InternalError("user not found for entity".to_string()))
}
/// Re-creates embeddings for all knowledge entities in the database.
///
/// This is a costly operation that should be run in the background. It follows the same
@@ -228,22 +339,13 @@ impl KnowledgeEntity {
if total_entities == 0 {
info!("No knowledge entities to update. Just updating the idx");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
transaction_query
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
db.query(transaction_query).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
return Ok(());
}
info!("Found {} entities to process.", total_entities);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
info!("Generating new embeddings for all entities...");
for entity in all_entities.iter() {
let embedding_input = format!(
@@ -271,17 +373,16 @@ impl KnowledgeEntity {
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(entity.id.clone(), embedding);
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
// We must properly serialize the vector for the SurrealQL query string
// Add all update statements to the embedding table
for (id, (embedding, user_id)) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
@@ -291,18 +392,22 @@ impl KnowledgeEntity {
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('knowledge_entity', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \
user_id = '{user_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id
));
}
// Re-create the index after updating the data that it will index
transaction_query
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
@@ -317,7 +422,9 @@ impl KnowledgeEntity {
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
use serde_json::json;
use uuid::Uuid;
#[tokio::test]
async fn test_knowledge_entity_creation() {
@@ -327,7 +434,6 @@ mod tests {
let description = "Test Description".to_string();
let entity_type = KnowledgeEntityType::Document;
let metadata = Some(json!({"key": "value"}));
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
let entity = KnowledgeEntity::new(
@@ -336,7 +442,6 @@ mod tests {
description.clone(),
entity_type.clone(),
metadata.clone(),
embedding.clone(),
user_id.clone(),
);
@@ -346,7 +451,6 @@ mod tests {
assert_eq!(entity.description, description);
assert_eq!(entity.entity_type, entity_type);
assert_eq!(entity.metadata, metadata);
assert_eq!(entity.embedding, embedding);
assert_eq!(entity.user_id, user_id);
assert!(!entity.id.is_empty());
}
@@ -410,20 +514,25 @@ mod tests {
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
// Create two entities with the same source_id
let source_id = "source123".to_string();
let entity_type = KnowledgeEntityType::Document;
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("Failed to redefine index length");
let entity1 = KnowledgeEntity::new(
source_id.clone(),
"Entity 1".to_string(),
"Description 1".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -433,7 +542,6 @@ mod tests {
"Description 2".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -445,18 +553,18 @@ mod tests {
"Different Description".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
// Store the entities
db.store_item(entity1)
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
.await
.expect("Failed to store entity 1");
db.store_item(entity2)
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
.await
.expect("Failed to store entity 2");
db.store_item(different_entity.clone())
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
.await
.expect("Failed to store different entity");
@@ -505,6 +613,162 @@ mod tests {
assert_eq!(different_remaining[0].id, different_entity.id);
}
// Note: We can't easily test the patch method without mocking the OpenAI client
// and the generate_embedding function. This would require more complex setup.
#[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await
.expect("vector search");
assert!(results.is_empty());
}
#[tokio::test]
async fn test_vector_search_single_result() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let user_id = "user".to_string();
let source_id = "src".to_string();
let entity = KnowledgeEntity::new(
source_id.clone(),
"hello".to_string(),
"world".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store entity with embedding");
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
assert!(stored_entity.is_some());
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {}",
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("query embeddings")
.take(0)
.expect("take embeddings");
assert_eq!(stored_embeddings.len(), 1);
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
.await
.expect("fetch embedding");
assert!(fetched_emb.is_some());
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.expect("vector search");
assert_eq!(results.len(), 1);
let res = &results[0];
assert_eq!(res.entity.id, entity.id);
assert_eq!(res.entity.source_id, source_id);
assert_eq!(res.entity.name, "hello");
}
#[tokio::test]
async fn test_vector_search_orders_by_similarity() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let user_id = "user".to_string();
let e1 = KnowledgeEntity::new(
"s1".to_string(),
"entity one".to_string(),
"desc".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
let e2 = KnowledgeEntity::new(
"s2".to_string(),
"entity two".to_string(),
"desc".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
.await
.expect("store e1");
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
.await
.expect("store e2");
let stored_e1: Option<KnowledgeEntity> = db.get_item(&e1.id).await.unwrap();
let stored_e2: Option<KnowledgeEntity> = db.get_item(&e2.id).await.unwrap();
assert!(stored_e1.is_some() && stored_e2.is_some());
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {}",
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("query embeddings")
.take(0)
.expect("take embeddings");
assert_eq!(stored_embeddings.len(), 2);
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
.await
.unwrap()
.is_some());
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
.await
.unwrap()
.is_some());
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await
.expect("vector search");
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity.id, e2.id);
assert_eq!(results[1].entity.id, e1.id);
}
}

View File

@@ -0,0 +1,385 @@
use std::collections::HashMap;
use surrealdb::RecordId;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
entity_id: RecordId,
embedding: Vec<f32>,
/// Denormalized user id for query scoping
user_id: String
});
impl KnowledgeEntityEmbedding {
/// Recreate the HNSW index with a new embedding dimension.
pub async fn redefine_hnsw_index(
db: &SurrealDbClient,
dimension: usize,
) -> Result<(), AppError> {
let query = format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table};
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
COMMIT TRANSACTION;",
table = Self::table_name(),
);
let res = db.client.query(query).await.map_err(AppError::Database)?;
res.check().map_err(AppError::Database)?;
Ok(())
}
/// Create a new knowledge entity embedding
pub fn new(entity_id: &str, embedding: Vec<f32>, user_id: String) -> Self {
let now = Utc::now();
Self {
id: uuid::Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
embedding,
user_id,
}
}
/// Get embedding by entity ID
pub async fn get_by_entity_id(
entity_id: &RecordId,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let query = format!(
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
Self::table_name()
);
let mut result = db
.client
.query(query)
.bind(("entity_id", entity_id.clone()))
.await
.map_err(AppError::Database)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
Ok(embeddings.into_iter().next())
}
/// Get embeddings for multiple entities in batch
pub async fn get_by_entity_ids(
entity_ids: &[RecordId],
db: &SurrealDbClient,
) -> Result<HashMap<String, Vec<f32>>, AppError> {
if entity_ids.is_empty() {
return Ok(HashMap::new());
}
let ids_list: Vec<RecordId> = entity_ids.iter().cloned().collect();
let query = format!(
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
Self::table_name()
);
let mut result = db
.client
.query(query)
.bind(("entity_ids", ids_list))
.await
.map_err(AppError::Database)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
Ok(embeddings
.into_iter()
.map(|e| (e.entity_id.key().to_string(), e.embedding))
.collect())
}
/// Delete embedding by entity ID
pub async fn delete_by_entity_id(
entity_id: &RecordId,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE FROM {} WHERE entity_id = $entity_id",
Self::table_name()
);
db.client
.query(query)
.bind(("entity_id", entity_id.clone()))
.await
.map_err(AppError::Database)?;
Ok(())
}
/// Delete embeddings by source_id (via joining to knowledge_entity table)
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
let mut res = db
.client
.query(query)
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?;
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
for row in ids {
Self::delete_by_entity_id(&row.id, db).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::db::SurrealDbClient;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use chrono::Utc;
use surrealdb::Value as SurrealValue;
use uuid::Uuid;
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
}
fn build_knowledge_entity_with_id(
key: &str,
source_id: &str,
user_id: &str,
) -> KnowledgeEntity {
KnowledgeEntity {
id: key.to_owned(),
created_at: Utc::now(),
updated_at: Utc::now(),
source_id: source_id.to_owned(),
name: "Test entity".to_owned(),
description: "Desc".to_owned(),
entity_type: KnowledgeEntityType::Document,
metadata: None,
user_id: user_id.to_owned(),
}
}
#[tokio::test]
async fn test_create_and_get_by_entity_id() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
let user_id = "user_ke";
let entity_key = "entity-1";
let source_id = "source-ke";
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
.await
.expect("Failed to store entity with embedding");
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding by entity_id")
.expect("Expected embedding to exist");
assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.entity_id, entity_rid);
assert_eq!(fetched.embedding, embedding_vec);
}
#[tokio::test]
async fn test_delete_by_entity_id() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
let user_id = "user_ke";
let entity_key = "entity-delete";
let source_id = "source-del";
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
.await
.expect("Failed to store entity with embedding");
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding before delete");
assert!(existing.is_some());
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to delete by entity_id");
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding after delete");
assert!(after.is_none());
}
#[tokio::test]
async fn test_store_with_embedding_creates_entity_and_embedding() {
let db = setup_test_db().await;
let user_id = "user_store";
let source_id = "source_store";
let embedding = vec![0.2_f32, 0.3, 0.4];
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
.await
.expect("set test index dimension");
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
.await
.expect("Failed to store entity with embedding");
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
assert!(stored_entity.is_some());
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to fetch embedding");
assert!(stored_embedding.is_some());
let stored_embedding = stored_embedding.unwrap();
assert_eq!(stored_embedding.user_id, user_id);
assert_eq!(stored_embedding.entity_id, entity_rid);
}
#[tokio::test]
async fn test_delete_by_source_id() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
let user_id = "user_ke";
let source_id = "shared-ke";
let other_source = "other-ke";
let entity1 = build_knowledge_entity_with_id("entity-s1", source_id, user_id);
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
.await
.expect("Failed to store entity with embedding");
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
.await
.expect("Failed to store entity with embedding");
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
.await
.expect("Failed to store entity with embedding");
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
let other_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity_other.id);
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
.await
.expect("Failed to delete by source_id");
assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
.await
.unwrap()
.is_none()
);
assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
.await
.unwrap()
.is_none()
);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
.await
.unwrap()
.is_some());
}
#[tokio::test]
async fn test_redefine_hnsw_index_updates_dimension() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
.await
.expect("failed to redefine index");
let mut info_res = db
.client
.query("INFO FOR TABLE knowledge_entity_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_knowledge_entity_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
assert!(
idx_sql.contains("DIMENSION 16"),
"expected index definition to contain new dimension, got: {idx_sql}"
);
}
#[tokio::test]
async fn test_fetch_entity_via_record_id() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
let user_id = "user_ke";
let entity_key = "entity-fetch";
let source_id = "source-fetch";
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
.await
.expect("Failed to store entity with embedding");
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
#[derive(Deserialize)]
struct Row {
entity_id: KnowledgeEntity,
}
let mut res = db
.client
.query(
"SELECT entity_id FROM knowledge_entity_embedding WHERE entity_id = $id FETCH entity_id;",
)
.bind(("id", entity_rid.clone()))
.await
.expect("failed to fetch embedding with FETCH");
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows");
assert_eq!(rows.len(), 1);
let fetched_entity = &rows[0].entity_id;
assert_eq!(fetched_entity.id, entity_key);
assert_eq!(fetched_entity.name, "Test entity");
assert_eq!(fetched_entity.user_id, user_id);
}
}

View File

@@ -119,7 +119,6 @@ mod tests {
let source_id = "source123".to_string();
let description = format!("Description for {}", name);
let entity_type = KnowledgeEntityType::Document;
let embedding = vec![0.1, 0.2, 0.3];
let user_id = "user123".to_string();
let entity = KnowledgeEntity::new(
@@ -128,7 +127,6 @@ mod tests {
description,
entity_type,
None,
embedding,
user_id,
);

View File

@@ -5,12 +5,14 @@ pub mod file_info;
pub mod ingestion_payload;
pub mod ingestion_task;
pub mod knowledge_entity;
pub mod knowledge_entity_embedding;
pub mod knowledge_relationship;
pub mod message;
pub mod scratchpad;
pub mod system_prompts;
pub mod system_settings;
pub mod text_chunk;
pub mod text_chunk_embedding;
pub mod text_content;
pub mod user;

View File

@@ -71,25 +71,22 @@ mod tests {
.await
.expect("Failed to fetch table info");
let info: Option<serde_json::Value> = response
let info: surrealdb::Value = response
.take(0)
.expect("Failed to extract table info response");
let info = info.expect("Table info result missing");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("Failed to convert info to json");
let indexes = info
.get("indexes")
.or_else(|| {
info.get("tables")
.and_then(|tables| tables.get(table_name))
.and_then(|table| table.get("indexes"))
})
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info:#?}"));
let indexes = info_json["Object"]["indexes"]["Object"]
.as_object()
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}"));
let definition = indexes
.get(index_name)
.and_then(|definition| definition.as_str())
.unwrap_or_else(|| panic!("Index definition not found in table info: {info:#?}"));
.and_then(|definition| definition.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}"));
let dimension_part = definition
.split("DIMENSION")
@@ -261,48 +258,56 @@ mod tests {
let initial_chunk = TextChunk::new(
"source1".into(),
"This chunk has the original dimension".into(),
vec![0.1; 1536],
"user1".into(),
);
db.store_item(initial_chunk.clone())
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
.await
.expect("Failed to store initial chunk");
.expect("Failed to store initial chunk with embedding");
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) {
db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;")
.await
.unwrap();
db.query(
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
)
.await
.unwrap();
let define_index_query = format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
db.query(define_index_query)
.await
.expect("Re-defining index should succeed");
let new_embedding = vec![0.5; target_dimension];
let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;";
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
let update_result = db
.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("user_id", initial_chunk.user_id.clone()))
.bind(("embedding", new_embedding))
.await;
assert!(update_result.is_ok());
}
simulate_reembedding(&db, 768, initial_chunk).await;
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
let target_dimension = 1536usize;
simulate_reembedding(&db, target_dimension, initial_chunk).await;
let migration_result = db.apply_migrations().await;
assert!(migration_result.is_ok(), "Migrations should not fail");
assert!(
migration_result.is_ok(),
"Migrations should not fail: {:?}",
migration_result.err()
);
}
#[tokio::test]
@@ -320,8 +325,12 @@ mod tests {
.await
.expect("Failed to load current settings");
let initial_chunk_dimension =
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
let initial_chunk_dimension = get_hnsw_index_dimension(
&db,
"text_chunk_embedding",
"idx_embedding_text_chunk_embedding",
)
.await;
assert_eq!(
initial_chunk_dimension, current_settings.embedding_dimensions,
@@ -352,10 +361,18 @@ mod tests {
.await
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
let text_chunk_dimension =
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
let knowledge_dimension =
get_hnsw_index_dimension(&db, "knowledge_entity", "idx_embedding_entities").await;
let text_chunk_dimension = get_hnsw_index_dimension(
&db,
"text_chunk_embedding",
"idx_embedding_text_chunk_embedding",
)
.await;
let knowledge_dimension = get_hnsw_index_dimension(
&db,
"knowledge_entity_embedding",
"idx_embedding_knowledge_entity_embedding",
)
.await;
assert_eq!(
text_chunk_dimension, new_dimension,

View File

@@ -1,5 +1,6 @@
use std::collections::HashMap;
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
@@ -13,12 +14,18 @@ use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
source_id: String,
chunk: String,
embedding: Vec<f32>,
user_id: String
});
/// Vector search result including hydrated chunk.
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct TextChunkVectorResult {
pub chunk: TextChunk,
pub score: f32,
}
impl TextChunk {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
@@ -26,7 +33,6 @@ impl TextChunk {
updated_at: now,
source_id,
chunk,
embedding,
user_id,
}
}
@@ -45,6 +51,94 @@ impl TextChunk {
Ok(())
}
/// Atomically store a text chunk and its embedding.
/// Writes the chunk to `text_chunk` and the embedding to `text_chunk_embedding`.
pub async fn store_with_embedding(
chunk: TextChunk,
embedding: Vec<f32>,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let emb = TextChunkEmbedding::new(
&chunk.id,
chunk.source_id.clone(),
embedding,
chunk.user_id.clone(),
);
// Create both records in a single query
let query = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
COMMIT TRANSACTION;
",
chunk_table = Self::table_name(),
emb_table = TextChunkEmbedding::table_name(),
);
db.client
.query(query)
.bind(("chunk_id", chunk.id.clone()))
.bind(("chunk", chunk))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
pub async fn vector_search(
take: usize,
query_embedding: Vec<f32>,
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<TextChunkVectorResult>, AppError> {
#[derive(Deserialize)]
struct Row {
chunk_id: TextChunk,
score: f32,
}
let sql = format!(
r#"
SELECT
chunk_id,
embedding,
vector::similarity::cosine(embedding, $embedding) AS score
FROM {emb_table}
WHERE user_id = $user_id
AND embedding <|{take},100|> $embedding
ORDER BY score DESC
LIMIT {take}
FETCH chunk_id;
"#,
emb_table = TextChunkEmbedding::table_name(),
take = take
);
let mut response = db
.query(&sql)
.bind(("embedding", query_embedding))
.bind(("user_id", user_id.to_string()))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).unwrap_or_default();
Ok(rows
.into_iter()
.map(|r| TextChunkVectorResult {
chunk: r.chunk_id,
score: r.score,
})
.collect())
}
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
///
/// This is a costly operation that should be run in the background. It performs these steps:
@@ -70,21 +164,14 @@ impl TextChunk {
if total_chunks == 0 {
info!("No text chunks to update. Just updating the idx");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
new_dimensions));
transaction_query.push_str("COMMIT TRANSACTION;");
db.query(transaction_query).await?;
TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
return Ok(());
}
info!("Found {} chunks to process.", total_chunks);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all chunks...");
for chunk in all_chunks.iter() {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
@@ -108,16 +195,18 @@ impl TextChunk {
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(chunk.id.clone(), embedding);
new_embeddings.insert(
chunk.id.clone(),
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
// Perform DB updates in a single transaction against the embedding table
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
@@ -126,22 +215,29 @@ impl TextChunk {
.collect::<Vec<_>>()
.join(",")
);
// Use the chunk id as the embedding record id to keep a 1:1 mapping
transaction_query.push_str(&format!(
"UPDATE type::thing('text_chunk', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
chunk_id = type::thing('text_chunk', '{id}'), \
source_id = '{source_id}', \
embedding = {embedding}, \
user_id = '{user_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id,
source_id = source_id
));
}
// Re-create the index inside the same transaction
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for text chunks completed successfully.");
@@ -152,171 +248,269 @@ impl TextChunk {
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use surrealdb::RecordId;
use uuid::Uuid;
#[tokio::test]
async fn test_text_chunk_creation() {
// Test basic object creation
let source_id = "source123".to_string();
let chunk = "This is a text chunk for testing embeddings".to_string();
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
let text_chunk = TextChunk::new(
source_id.clone(),
chunk.clone(),
embedding.clone(),
user_id.clone(),
);
let text_chunk = TextChunk::new(source_id.clone(), chunk.clone(), user_id.clone());
// Check that the fields are set correctly
assert_eq!(text_chunk.source_id, source_id);
assert_eq!(text_chunk.chunk, chunk);
assert_eq!(text_chunk.embedding, embedding);
assert_eq!(text_chunk.user_id, user_id);
assert!(!text_chunk.id.is_empty());
}
#[tokio::test]
async fn test_delete_by_source_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
// Create test data
let source_id = "source123".to_string();
let chunk1 = "First chunk from the same source".to_string();
let chunk2 = "Second chunk from the same source".to_string();
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
// Create two chunks with the same source_id
let text_chunk1 = TextChunk::new(
let chunk1 = TextChunk::new(
source_id.clone(),
chunk1,
embedding.clone(),
"First chunk from the same source".to_string(),
user_id.clone(),
);
let text_chunk2 = TextChunk::new(
let chunk2 = TextChunk::new(
source_id.clone(),
chunk2,
embedding.clone(),
"Second chunk from the same source".to_string(),
user_id.clone(),
);
// Create a chunk with a different source_id
let different_source_id = "different_source".to_string();
let different_chunk = TextChunk::new(
different_source_id.clone(),
"different_source".to_string(),
"Different source chunk".to_string(),
embedding.clone(),
user_id.clone(),
);
// Store the chunks
db.store_item(text_chunk1)
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("Failed to store text chunk 1");
db.store_item(text_chunk2)
.expect("store chunk1");
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("Failed to store text chunk 2");
db.store_item(different_chunk.clone())
.await
.expect("Failed to store different chunk");
.expect("store chunk2");
TextChunk::store_with_embedding(
different_chunk.clone(),
vec![0.1, 0.2, 0.3, 0.4, 0.5],
&db,
)
.await
.expect("store different chunk");
// Delete by source_id
TextChunk::delete_by_source_id(&source_id, &db)
.await
.expect("Failed to delete chunks by source_id");
// Verify all chunks with the original source_id are deleted
let query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
source_id
);
let remaining: Vec<TextChunk> = db
.client
.query(query)
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
source_id
))
.await
.expect("Query failed")
.take(0)
.expect("Failed to get query results");
assert_eq!(
remaining.len(),
0,
"All chunks with the source_id should be deleted"
);
assert_eq!(remaining.len(), 0);
// Verify the different source_id chunk still exists
let different_query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
different_source_id
);
let different_remaining: Vec<TextChunk> = db
.client
.query(different_query)
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
"different_source"
))
.await
.expect("Query failed")
.take(0)
.expect("Failed to get query results");
assert_eq!(
different_remaining.len(),
1,
"Chunk with different source_id should still exist"
);
assert_eq!(different_remaining.len(), 1);
assert_eq!(different_remaining[0].id, different_chunk.id);
}
#[tokio::test]
async fn test_delete_by_nonexistent_source_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
// Create a chunk with a real source_id
let real_source_id = "real_source".to_string();
let chunk = "Test chunk".to_string();
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
let text_chunk = TextChunk::new(real_source_id.clone(), chunk, embedding, user_id);
// Store the chunk
db.store_item(text_chunk)
.await
.expect("Failed to store text chunk");
// Delete using nonexistent source_id
let nonexistent_source_id = "nonexistent_source";
TextChunk::delete_by_source_id(nonexistent_source_id, &db)
.await
.expect("Delete operation with nonexistent source_id should not fail");
// Verify the real chunk still exists
let query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
real_source_id
let chunk = TextChunk::new(
real_source_id.clone(),
"Test chunk".to_string(),
"user123".to_string(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("store chunk");
TextChunk::delete_by_source_id("nonexistent_source", &db)
.await
.expect("Delete should succeed");
let remaining: Vec<TextChunk> = db
.client
.query(query)
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
real_source_id
))
.await
.expect("Query failed")
.take(0)
.expect("Failed to get query results");
assert_eq!(
remaining.len(),
1,
"Chunk with real source_id should still exist"
assert_eq!(remaining.len(), 1);
}
#[tokio::test]
async fn test_store_with_embedding_creates_both_records() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
let source_id = "store-src".to_string();
let user_id = "user_store".to_string();
let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone());
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store with embedding");
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
assert!(stored_chunk.is_some());
let stored_chunk = stored_chunk.unwrap();
assert_eq!(stored_chunk.source_id, source_id);
assert_eq!(stored_chunk.user_id, user_id);
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await
.expect("get embedding");
assert!(embedding.is_some());
let embedding = embedding.unwrap();
assert_eq!(embedding.chunk_id, rid);
assert_eq!(embedding.user_id, user_id);
assert_eq!(embedding.source_id, source_id);
}
#[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
let results: Vec<TextChunkVectorResult> =
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await
.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_vector_search_single_result() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
let source_id = "src".to_string();
let user_id = "user".to_string();
let chunk = TextChunk::new(
source_id.clone(),
"hello world".to_string(),
user_id.clone(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store");
let results: Vec<TextChunkVectorResult> =
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.unwrap();
assert_eq!(results.len(), 1);
let res = &results[0];
assert_eq!(res.chunk.id, chunk.id);
assert_eq!(res.chunk.source_id, source_id);
assert_eq!(res.chunk.chunk, "hello world");
}
#[tokio::test]
async fn test_vector_search_orders_by_similarity() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
let user_id = "user".to_string();
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone());
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
.await
.expect("store chunk1");
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
.await
.expect("store chunk2");
let results: Vec<TextChunkVectorResult> =
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].chunk.id, chunk2.id);
assert_eq!(results[1].chunk.id, chunk1.id);
assert!(results[0].score >= results[1].score);
}
}

View File

@@ -0,0 +1,435 @@
use surrealdb::RecordId;
use crate::storage::types::text_chunk::TextChunk;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
/// Record link to the owning text_chunk
chunk_id: RecordId,
/// Denormalized source id for bulk deletes
source_id: String,
/// Embedding vector
embedding: Vec<f32>,
/// Denormalized user id (for scoping + permissions)
user_id: String
});
impl TextChunkEmbedding {
/// Recreate the HNSW index with a new embedding dimension.
///
/// This is useful when the embedding length changes; Surreal requires the
/// index definition to be recreated with the updated dimension.
pub async fn redefine_hnsw_index(
db: &SurrealDbClient,
dimension: usize,
) -> Result<(), AppError> {
let query = format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE {table};
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
COMMIT TRANSACTION;",
table = Self::table_name(),
);
let res = db.client.query(query).await.map_err(AppError::Database)?;
res.check().map_err(AppError::Database)?;
Ok(())
}
/// Create a new text chunk embedding
///
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID),
/// not "text_chunk:uuid".
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
let now = Utc::now();
Self {
// NOTE: `stored_object!` macro defines `id` as `String`
id: uuid::Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
// Create a record<text_chunk> link: text_chunk:<chunk_id>
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
source_id,
embedding,
user_id,
}
}
/// Get a single embedding by its chunk RecordId
pub async fn get_by_chunk_id(
chunk_id: &RecordId,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let query = format!(
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
Self::table_name()
);
let mut result = db
.client
.query(query)
.bind(("chunk_id", chunk_id.clone()))
.await
.map_err(AppError::Database)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
Ok(embeddings.into_iter().next())
}
/// Delete embeddings for a given chunk RecordId
pub async fn delete_by_chunk_id(
chunk_id: &RecordId,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE FROM {} WHERE chunk_id = $chunk_id",
Self::table_name()
);
db.client
.query(query)
.bind(("chunk_id", chunk_id.clone()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
/// Delete all embeddings that belong to chunks with a given `source_id`
///
/// This uses a subquery to the `text_chunk` table:
///
/// DELETE FROM text_chunk_embedding
/// WHERE chunk_id IN (SELECT id FROM text_chunk WHERE source_id = $source_id)
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let ids_query = format!(
"SELECT id FROM {} WHERE source_id = $source_id",
TextChunk::table_name()
);
let mut res = db
.client
.query(ids_query)
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?;
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
if ids.is_empty() {
return Ok(());
}
let delete_query = format!(
"DELETE FROM {} WHERE chunk_id IN $chunk_ids",
Self::table_name()
);
db.client
.query(delete_query)
.bind((
"chunk_ids",
ids.into_iter().map(|row| row.id).collect::<Vec<_>>(),
))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::db::SurrealDbClient;
use surrealdb::Value as SurrealValue;
use uuid::Uuid;
/// Helper to create an in-memory DB and apply migrations
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
}
/// Helper: create a text_chunk with a known key, return its RecordId
async fn create_text_chunk_with_id(
db: &SurrealDbClient,
key: &str,
source_id: &str,
user_id: &str,
) -> RecordId {
let chunk = TextChunk {
id: key.to_owned(),
created_at: Utc::now(),
updated_at: Utc::now(),
source_id: source_id.to_owned(),
chunk: "Some test chunk text".to_owned(),
user_id: user_id.to_owned(),
};
db.store_item(chunk)
.await
.expect("Failed to create text_chunk");
RecordId::from_table_key(TextChunk::table_name(), key)
}
#[tokio::test]
async fn test_create_and_get_by_chunk_id() {
let db = setup_test_db().await;
let user_id = "user_a";
let chunk_key = "chunk-123";
let source_id = "source-1";
// 1) Create a text_chunk with a known key
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
// 2) Create and store an embedding for that chunk
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
let emb = TextChunkEmbedding::new(
chunk_key,
source_id.to_string(),
embedding_vec.clone(),
user_id.to_string(),
);
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
.await
.expect("Failed to redefine index length");
let _: Option<TextChunkEmbedding> = db
.client
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
// 3) Fetch it via get_by_chunk_id
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding by chunk_id");
assert!(fetched.is_some(), "Expected an embedding to be found");
let fetched = fetched.unwrap();
assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.chunk_id, chunk_rid);
assert_eq!(fetched.embedding, embedding_vec);
}
#[tokio::test]
async fn test_delete_by_chunk_id() {
let db = setup_test_db().await;
let user_id = "user_b";
let chunk_key = "chunk-delete";
let source_id = "source-del";
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
let emb = TextChunkEmbedding::new(
chunk_key,
source_id.to_string(),
vec![0.4_f32, 0.5, 0.6],
user_id.to_string(),
);
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
.await
.expect("Failed to redefine index length");
let _: Option<TextChunkEmbedding> = db
.client
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
// Ensure it exists
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding before delete");
assert!(existing.is_some(), "Embedding should exist before delete");
// Delete by chunk_id
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to delete by chunk_id");
// Ensure it no longer exists
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding after delete");
assert!(after.is_none(), "Embedding should have been deleted");
}
#[tokio::test]
async fn test_delete_by_source_id() {
let db = setup_test_db().await;
let user_id = "user_c";
let source_id = "shared-source";
let other_source = "other-source";
// Two chunks with the same source_id
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await;
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await;
// One chunk with a different source_id
let chunk_other_rid =
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await;
// Create embeddings for all three
let emb1 = TextChunkEmbedding::new(
"chunk-s1",
source_id.to_string(),
vec![0.1],
user_id.to_string(),
);
let emb2 = TextChunkEmbedding::new(
"chunk-s2",
source_id.to_string(),
vec![0.2],
user_id.to_string(),
);
let emb3 = TextChunkEmbedding::new(
"chunk-other",
other_source.to_string(),
vec![0.3],
user_id.to_string(),
);
// Update length on index
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
.await
.expect("Failed to redefine index length");
for emb in [emb1, emb2, emb3] {
let _: Option<TextChunkEmbedding> = db
.client
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
}
// Sanity check: they all exist
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
.await
.unwrap()
.is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
.await
.unwrap()
.is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
.await
.unwrap()
.is_some());
// Delete embeddings by source_id (shared-source)
TextChunkEmbedding::delete_by_source_id(source_id, &db)
.await
.expect("Failed to delete by source_id");
// Chunks from shared-source should have no embeddings
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
.await
.unwrap()
.is_none());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
.await
.unwrap()
.is_none());
// The other chunk should still have its embedding
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
.await
.unwrap()
.is_some());
}
#[tokio::test]
async fn test_redefine_hnsw_index_updates_dimension() {
let db = setup_test_db().await;
// Change the index dimension from default (1536) to a smaller test value.
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
.await
.expect("failed to redefine index");
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_text_chunk_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
assert!(
idx_sql.contains("DIMENSION 8"),
"expected index definition to contain new dimension, got: {idx_sql}"
);
}
#[tokio::test]
async fn test_redefine_hnsw_index_is_idempotent() {
let db = setup_test_db().await;
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
.await
.expect("first redefine failed");
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
.await
.expect("second redefine failed");
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_text_chunk_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
assert!(
idx_sql.contains("DIMENSION 4"),
"expected index definition to retain dimension 4, got: {idx_sql}"
);
}
}

View File

@@ -146,12 +146,12 @@ impl TextContent {
search::highlight('<b>', '</b>', 4) AS highlighted_url,
search::highlight('<b>', '</b>', 5) AS highlighted_url_title,
(
search::score(0) +
search::score(1) +
search::score(2) +
search::score(3) +
search::score(4) +
search::score(5)
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END +
IF search::score(1) != NONE THEN search::score(1) ELSE 0 END +
IF search::score(2) != NONE THEN search::score(2) ELSE 0 END +
IF search::score(3) != NONE THEN search::score(3) ELSE 0 END +
IF search::score(4) != NONE THEN search::score(4) ELSE 0 END +
IF search::score(5) != NONE THEN search::score(5) ELSE 0 END
) AS score
FROM text_content
WHERE

View File

@@ -1,10 +1,279 @@
use async_openai::types::CreateEmbeddingRequestArgs;
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
str::FromStr,
sync::Arc,
};
use anyhow::{anyhow, Context, Result};
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tokio::sync::Mutex;
use tracing::debug;
use crate::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbeddingBackend {
OpenAI,
FastEmbed,
Hashed,
}
impl Default for EmbeddingBackend {
fn default() -> Self {
Self::FastEmbed
}
}
impl std::str::FromStr for EmbeddingBackend {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"openai" => Ok(Self::OpenAI),
"hashed" => Ok(Self::Hashed),
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
other => Err(anyhow!(
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
)),
}
}
}
#[derive(Clone)]
pub struct EmbeddingProvider {
inner: EmbeddingInner,
}
#[derive(Clone)]
enum EmbeddingInner {
OpenAI {
client: Arc<Client<async_openai::config::OpenAIConfig>>,
model: String,
dimensions: u32,
},
Hashed {
dimension: usize,
},
FastEmbed {
model: Arc<Mutex<TextEmbedding>>,
model_name: EmbeddingModel,
dimension: usize,
},
}
impl EmbeddingProvider {
pub fn backend_label(&self) -> &'static str {
match self.inner {
EmbeddingInner::Hashed { .. } => "hashed",
EmbeddingInner::FastEmbed { .. } => "fastembed",
EmbeddingInner::OpenAI { .. } => "openai",
}
}
pub fn dimension(&self) -> usize {
match &self.inner {
EmbeddingInner::Hashed { dimension } => *dimension,
EmbeddingInner::FastEmbed { dimension, .. } => *dimension,
EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize,
}
}
pub fn model_code(&self) -> Option<String> {
match &self.inner {
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
EmbeddingInner::OpenAI { model, .. } => Some(model.clone()),
EmbeddingInner::Hashed { .. } => None,
}
}
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => {
let mut guard = model.lock().await;
let embeddings = guard
.embed(vec![text.to_owned()], None)
.context("generating fastembed vector")?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
}
EmbeddingInner::OpenAI {
client,
model,
dimensions,
} => {
let request = CreateEmbeddingRequestArgs::default()
.model(model.clone())
.input([text])
.dimensions(*dimensions)
.build()?;
let response = client.embeddings().create(request).await?;
let embedding = response
.data
.first()
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
.embedding
.clone();
Ok(embedding)
}
}
}
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(texts
.into_iter()
.map(|text| hashed_embedding(&text, *dimension))
.collect()),
EmbeddingInner::FastEmbed { model, .. } => {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut guard = model.lock().await;
guard
.embed(texts, None)
.context("generating fastembed batch embeddings")
}
EmbeddingInner::OpenAI {
client,
model,
dimensions,
} => {
if texts.is_empty() {
return Ok(Vec::new());
}
let request = CreateEmbeddingRequestArgs::default()
.model(model.clone())
.input(texts)
.dimensions(*dimensions)
.build()?;
let response = client.embeddings().create(request).await?;
let embeddings: Vec<Vec<f32>> = response
.data
.into_iter()
.map(|item| item.embedding)
.collect();
Ok(embeddings)
}
}
}
pub async fn new_openai(
client: Arc<Client<async_openai::config::OpenAIConfig>>,
model: String,
dimensions: u32,
) -> Result<Self> {
Ok(EmbeddingProvider {
inner: EmbeddingInner::OpenAI {
client,
model,
dimensions,
},
})
}
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
let model_name = if let Some(code) = model_override {
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
} else {
EmbeddingModel::default()
};
let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
let model_name_for_task = model_name.clone();
let model_name_code = model_name.to_string();
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
let model =
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
let info = EmbeddingModel::get_model_info(&model_name_for_task)
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
Ok((model, info.dim))
})
.await
.context("joining FastEmbed initialisation task")??;
Ok(EmbeddingProvider {
inner: EmbeddingInner::FastEmbed {
model: Arc::new(Mutex::new(model)),
model_name,
dimension,
},
})
}
pub fn new_hashed(dimension: usize) -> Result<Self> {
Ok(EmbeddingProvider {
inner: EmbeddingInner::Hashed {
dimension: dimension.max(1),
},
})
}
}
// Helper functions for hashed embeddings
fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
let dim = dimension.max(1);
let mut vector = vec![0.0f32; dim];
if text.is_empty() {
return vector;
}
let mut token_count = 0f32;
for token in tokens(text) {
token_count += 1.0;
let idx = bucket(&token, dim);
vector[idx] += 1.0;
}
if token_count == 0.0 {
return vector;
}
let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
for value in &mut vector {
*value /= norm;
}
}
vector
}
fn tokens(text: &str) -> impl Iterator<Item = String> + '_ {
text.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|token| !token.is_empty())
.map(|token| token.to_ascii_lowercase())
}
fn bucket(token: &str, dimension: usize) -> usize {
let mut hasher = DefaultHasher::new();
token.hash(&mut hasher);
(hasher.finish() as usize) % dimension
}
// Backward compatibility function
pub async fn generate_embedding_with_provider(
provider: &EmbeddingProvider,
input: &str,
) -> Result<Vec<f32>, AppError> {
provider.embed(input).await.map_err(AppError::from)
}
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)