mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-26 03:11:34 +01:00
evals: v3, ebeddings at the side
additional indexes
This commit is contained in:
@@ -45,6 +45,7 @@ tokio-retry = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
state-machines = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
|
||||
[features]
|
||||
|
||||
@@ -14,6 +14,9 @@ CREATE system_settings:current CONTENT {
|
||||
query_model: "gpt-4o-mini",
|
||||
processing_model: "gpt-4o-mini",
|
||||
embedding_model: "text-embedding-3-small",
|
||||
voice_processing_model: "whisper-1",
|
||||
image_processing_model: "gpt-4o-mini",
|
||||
image_processing_prompt: "Analyze this image and respond based on its primary content:\n - If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.\n - If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.\n - For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a Text: heading.\n\n Respond directly with the analysis.",
|
||||
embedding_dimensions: 1536,
|
||||
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
|
||||
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity\"\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."
|
||||
|
||||
@@ -1,27 +1,2 @@
|
||||
DEFINE ANALYZER IF NOT EXISTS app_default_fts_analyzer
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_content_fts_text_idx ON TABLE text_content
|
||||
FIELDS text
|
||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_content_fts_category_idx ON TABLE text_content
|
||||
FIELDS category
|
||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_content_fts_context_idx ON TABLE text_content
|
||||
FIELDS context
|
||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_content_fts_file_name_idx ON TABLE text_content
|
||||
FIELDS file_info.file_name
|
||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_content_fts_url_idx ON TABLE text_content
|
||||
FIELDS url_info.url
|
||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_content_fts_url_title_idx ON TABLE text_content
|
||||
FIELDS url_info.title
|
||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
||||
-- Runtime-managed: text_content FTS indexes now created at startup via the shared Surreal helper.
|
||||
-- This migration is intentionally left as a no-op to avoid heavy index builds during migration.
|
||||
|
||||
@@ -1 +1 @@
|
||||
REMOVE TABLE job;
|
||||
-- No-op: legacy `job` table was superseded by `ingestion_task`; kept for migration order compatibility.
|
||||
|
||||
@@ -1,17 +1 @@
|
||||
-- Add FTS indexes for searching name and description on entities
|
||||
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity
|
||||
FIELDS name
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity
|
||||
FIELDS description
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk
|
||||
FIELDS chunk
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
-- Runtime-managed: FTS indexes now built at startup; migration retained as a no-op.
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
-- Move chunk/entity embeddings to dedicated tables for index efficiency.
|
||||
|
||||
-- Text chunk embeddings table
|
||||
DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
|
||||
|
||||
-- Knowledge entity embeddings table
|
||||
DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
||||
@@ -0,0 +1,3 @@
|
||||
-- Drop legacy embedding fields from base tables; embeddings now live in *_embedding tables.
|
||||
REMOVE FIELD IF EXISTS embedding ON TABLE text_chunk;
|
||||
REMOVE FIELD IF EXISTS embedding ON TABLE knowledge_entity;
|
||||
@@ -1 +0,0 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -98,7 +98,7 @@\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n # Defines the schema for the 'message' table.\n\n@@ -157,6 +157,8 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n\n","events":null}
|
||||
@@ -1 +0,0 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -51,23 +51,23 @@\n\n # Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n\n-DEFINE TABLE IF NOT EXISTS job SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON job TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON job TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n\n # Custom fields from the IngestionTask struct\n # IngestionPayload is complex, store as object\n-DEFINE FIELD IF NOT EXISTS content ON job TYPE object;\n+DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n # IngestionTaskStatus can hold data (InProgress), store as object\n-DEFINE FIELD IF NOT EXISTS status ON job TYPE object;\n-DEFINE FIELD IF NOT EXISTS user_id ON job TYPE string;\n+DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n+DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n # Indexes explicitly defined in build_indexes and useful for get_unfinished_tasks\n-DEFINE INDEX IF NOT EXISTS idx_job_status ON job FIELDS status;\n-DEFINE INDEX IF NOT EXISTS idx_job_user ON job FIELDS user_id;\n-DEFINE INDEX IF NOT EXISTS idx_job_created ON job FIELDS created_at;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_status ON ingestion_task FIELDS status;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_user ON ingestion_task FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_created ON ingestion_task FIELDS created_at;\n\n # Defines the schema for the 'knowledge_entity' table.\n\n","events":null}
|
||||
@@ -1 +0,0 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -57,10 +57,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n\n-# Custom fields from the IngestionTask struct\n-# IngestionPayload is complex, store as object\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n-# IngestionTaskStatus can hold data (InProgress), store as object\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n@@ -157,10 +154,12 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS image_processing_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}
|
||||
@@ -1 +0,0 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -160,6 +160,7 @@\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}
|
||||
@@ -1 +0,0 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -18,8 +18,8 @@\n DEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE datetime;\n\n # Custom fields from the Conversation struct\n DEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;\n@@ -34,8 +34,8 @@\n DEFINE TABLE IF NOT EXISTS file SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON file TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON file TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE datetime;\n\n # Custom fields from the FileInfo struct\n DEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;\n@@ -54,8 +54,8 @@\n DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime;\n\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n@@ -71,8 +71,8 @@\n DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE datetime;\n\n # Custom fields from the KnowledgeEntity struct\n DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;\n@@ -102,8 +102,8 @@\n DEFINE TABLE IF NOT EXISTS message SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON message TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON message TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE datetime;\n\n # Custom fields from the Message struct\n DEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;\n@@ -167,8 +167,8 @@\n DEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;\n\n # Custom fields from the TextChunk struct\n DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;\n@@ -191,8 +191,8 @@\n DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;\n\n # Custom fields from the TextContent struct\n DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;\n@@ -215,8 +215,8 @@\n DEFINE TABLE IF NOT EXISTS user SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON user TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON user TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE datetime;\n\n # Custom fields from the User struct\n DEFINE FIELD IF NOT EXISTS email ON user TYPE string;\n","events":null}
|
||||
@@ -1 +0,0 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -137,6 +137,30 @@\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n\n+# Defines the schema for the 'scratchpad' table.\n+\n+DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;\n+\n+# Standard fields from stored_object! macro\n+DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;\n+\n+# Custom fields from the Scratchpad struct\n+DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;\n+DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;\n+DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;\n+DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;\n+\n+# Indexes based on query patterns\n+DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;\n+DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;\n+DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;\n+\n DEFINE TABLE OVERWRITE script_migration SCHEMAFULL\n PERMISSIONS\n FOR select FULL\n","events":null}
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -85,31 +85,30 @@\n\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;\n\n-# Indexes based on build_indexes and query patterns\n-# The INDEX definition correctly specifies the vector properties\n-# HNSW index now defined on knowledge_entity_embedding table for better memory usage \n-# DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n+-- Indexes based on build_indexes and query patterns\n+-- HNSW index now defined on knowledge_entity_embedding table for better memory usage\n+-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n-# Defines the schema for the 'knowledge_entity_embedding' table.\n-# Separate table to optimize HNSW index creation memory usage\n+-- Defines the schema for the 'knowledge_entity_embedding' table.\n+-- Separate table to optimize HNSW index creation memory usage\n\n DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;\n\n-# Standard fields\n+-- Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n\n-# Custom fields\n+-- Custom fields\n DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;\n DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;\n\n-# Indexes\n-# DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n+-- Indexes\n+-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n\n@@ -220,8 +219,8 @@\n DEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;\n\n-# Defines the schema for the 'text_chunk_embedding' table.\n-# Separate table to optimize HNSW index creation memory usage\n+-- Defines the schema for the 'text_chunk_embedding' table.\n+-- Separate table to optimize HNSW index creation memory usage\n\n DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;\n\n@@ -235,8 +234,8 @@\n DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;\n DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;\n\n-# Indexes\n-# DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;\n+-- Indexes\n+-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;\n","events":null}
|
||||
File diff suppressed because one or more lines are too long
@@ -15,16 +15,12 @@ DEFINE FIELD IF NOT EXISTS entity_type ON knowledge_entity TYPE string;
|
||||
# metadata is Option<serde_json::Value>, store as object
|
||||
DEFINE FIELD IF NOT EXISTS metadata ON knowledge_entity TYPE option<object>;
|
||||
|
||||
# Define embedding as a standard array of floats for schema definition
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity TYPE array<float>;
|
||||
# The specific vector nature is handled by the index definition below
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
|
||||
|
||||
# Indexes based on build_indexes and query patterns
|
||||
# The INDEX definition correctly specifies the vector properties
|
||||
DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||
-- Indexes based on build_indexes and query patterns
|
||||
-- HNSW index now defined on knowledge_entity_embedding table for better memory usage
|
||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
||||
|
||||
18
common/schemas/knowledge_entity_embedding.surql
Normal file
18
common/schemas/knowledge_entity_embedding.surql
Normal file
@@ -0,0 +1,18 @@
|
||||
-- Defines the schema for the 'knowledge_entity_embedding' table.
|
||||
-- Separate table to optimize HNSW index creation memory usage
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
|
||||
|
||||
-- Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
||||
|
||||
-- Custom fields
|
||||
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;
|
||||
|
||||
-- Indexes
|
||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
||||
@@ -10,14 +10,8 @@ DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS chunk ON text_chunk TYPE string;
|
||||
|
||||
# Define embedding as a standard array of floats for schema definition
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk TYPE array<float>;
|
||||
# The specific vector nature is handled by the index definition below
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk TYPE string;
|
||||
|
||||
# Indexes based on build_indexes and query patterns (delete_by_source_id)
|
||||
# The INDEX definition correctly specifies the vector properties
|
||||
DEFINE INDEX IF NOT EXISTS idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;
|
||||
|
||||
20
common/schemas/text_chunk_embedding.surql
Normal file
20
common/schemas/text_chunk_embedding.surql
Normal file
@@ -0,0 +1,20 @@
|
||||
-- Defines the schema for the 'text_chunk_embedding' table.
|
||||
-- Separate table to optimize HNSW index creation memory usage
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;
|
||||
|
||||
# Custom fields
|
||||
DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
|
||||
|
||||
-- Indexes
|
||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
|
||||
@@ -1,5 +1,8 @@
|
||||
use super::types::StoredObject;
|
||||
use crate::error::AppError;
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{indexes::ensure_runtime_indexes, types::system_settings::SystemSettings},
|
||||
};
|
||||
use axum_session::{SessionConfig, SessionError, SessionStore};
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::Stream;
|
||||
@@ -96,20 +99,22 @@ impl SurrealDbClient {
|
||||
}
|
||||
|
||||
/// Operation to rebuild indexes
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), AppError> {
|
||||
debug!("Rebuilding indexes");
|
||||
let rebuild_sql = r#"
|
||||
BEGIN TRANSACTION;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
|
||||
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
|
||||
COMMIT TRANSACTION;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding;
|
||||
"#;
|
||||
|
||||
self.client.query(rebuild_sql).await?;
|
||||
self.client
|
||||
.query(rebuild_sql)
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
589
common/src/storage/indexes.rs
Normal file
589
common/src/storage/indexes.rs
Normal file
@@ -0,0 +1,589 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
|
||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_secs(2);
|
||||
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct HnswIndexSpec {
|
||||
index_name: &'static str,
|
||||
table: &'static str,
|
||||
options: &'static str,
|
||||
}
|
||||
|
||||
impl HnswIndexSpec {
|
||||
fn definition_if_not_exists(&self, dimension: usize) -> String {
|
||||
format!(
|
||||
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} \
|
||||
FIELDS embedding HNSW DIMENSION {dimension} {options};",
|
||||
index = self.index_name,
|
||||
table = self.table,
|
||||
dimension = dimension,
|
||||
options = self.options,
|
||||
)
|
||||
}
|
||||
|
||||
fn definition_overwrite(&self, dimension: usize) -> String {
|
||||
format!(
|
||||
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} \
|
||||
FIELDS embedding HNSW DIMENSION {dimension} {options};",
|
||||
index = self.index_name,
|
||||
table = self.table,
|
||||
dimension = dimension,
|
||||
options = self.options,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct FtsIndexSpec {
|
||||
index_name: &'static str,
|
||||
table: &'static str,
|
||||
field: &'static str,
|
||||
analyzer: Option<&'static str>,
|
||||
method: &'static str,
|
||||
}
|
||||
|
||||
impl FtsIndexSpec {
|
||||
fn definition(&self) -> String {
|
||||
let analyzer_clause = self
|
||||
.analyzer
|
||||
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
|
||||
.unwrap_or_default();
|
||||
|
||||
format!(
|
||||
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
|
||||
index = self.index_name,
|
||||
table = self.table,
|
||||
field = self.field,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
|
||||
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
|
||||
pub async fn ensure_runtime_indexes(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
ensure_runtime_indexes_inner(db, embedding_dimension)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
async fn ensure_runtime_indexes_inner(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
) -> Result<()> {
|
||||
create_fts_analyzer(db).await?;
|
||||
|
||||
for spec in fts_index_specs() {
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.definition(),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
for spec in hnsw_index_specs() {
|
||||
ensure_hnsw_index(db, &spec, embedding_dimension).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
spec: &HnswIndexSpec,
|
||||
dimension: usize,
|
||||
) -> Result<()> {
|
||||
let definition = match hnsw_index_state(db, spec, dimension).await? {
|
||||
HnswIndexState::Missing => spec.definition_if_not_exists(dimension),
|
||||
HnswIndexState::Matches(_) => spec.definition_if_not_exists(dimension),
|
||||
HnswIndexState::Different(existing) => {
|
||||
info!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
existing_dimension = existing,
|
||||
target_dimension = dimension,
|
||||
"Overwriting HNSW index to match new embedding dimension"
|
||||
);
|
||||
spec.definition_overwrite(dimension)
|
||||
}
|
||||
};
|
||||
|
||||
create_index_with_polling(
|
||||
db,
|
||||
definition,
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn hnsw_index_state(
|
||||
db: &SurrealDbClient,
|
||||
spec: &HnswIndexSpec,
|
||||
expected_dimension: usize,
|
||||
) -> Result<HnswIndexState> {
|
||||
let info_query = format!("INFO FOR TABLE {table};", table = spec.table);
|
||||
let mut response = db
|
||||
.client
|
||||
.query(info_query)
|
||||
.await
|
||||
.with_context(|| format!("fetching table info for {}", spec.table))?;
|
||||
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
.context("failed to take table info response")?;
|
||||
|
||||
let info_json: Value =
|
||||
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
|
||||
|
||||
let Some(indexes) = info_json
|
||||
.get("Object")
|
||||
.and_then(|o| o.get("indexes"))
|
||||
.and_then(|i| i.get("Object"))
|
||||
.and_then(|i| i.as_object())
|
||||
else {
|
||||
return Ok(HnswIndexState::Missing);
|
||||
};
|
||||
|
||||
let Some(definition) = indexes
|
||||
.get(spec.index_name)
|
||||
.and_then(|details| details.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
else {
|
||||
return Ok(HnswIndexState::Missing);
|
||||
};
|
||||
|
||||
let Some(current_dimension) = extract_dimension(definition) else {
|
||||
return Ok(HnswIndexState::Missing);
|
||||
};
|
||||
|
||||
if current_dimension == expected_dimension as u64 {
|
||||
Ok(HnswIndexState::Matches(current_dimension))
|
||||
} else {
|
||||
Ok(HnswIndexState::Different(current_dimension))
|
||||
}
|
||||
}
|
||||
|
||||
enum HnswIndexState {
|
||||
Missing,
|
||||
Matches(u64),
|
||||
Different(u64),
|
||||
}
|
||||
|
||||
fn extract_dimension(definition: &str) -> Option<u64> {
|
||||
definition
|
||||
.split("DIMENSION")
|
||||
.nth(1)
|
||||
.and_then(|rest| rest.split_whitespace().next())
|
||||
.and_then(|token| token.trim_end_matches(';').parse::<u64>().ok())
|
||||
}
|
||||
|
||||
async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
||||
let analyzer_query = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);",
|
||||
analyzer = FTS_ANALYZER_NAME
|
||||
);
|
||||
|
||||
let res = db
|
||||
.client
|
||||
.query(analyzer_query)
|
||||
.await
|
||||
.context("creating FTS analyzer")?;
|
||||
|
||||
res.check().context("failed to create FTS analyzer")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_index_with_polling(
|
||||
db: &SurrealDbClient,
|
||||
definition: String,
|
||||
index_name: &str,
|
||||
table: &str,
|
||||
progress_table: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let expected_total = match progress_table {
|
||||
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
|
||||
format!("counting rows in {table} for index {index_name} progress")
|
||||
})?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let res = db
|
||||
.client
|
||||
.query(definition)
|
||||
.await
|
||||
.with_context(|| format!("creating index {index_name} on table {table}"))?;
|
||||
res.check()
|
||||
.with_context(|| format!("index definition failed for {index_name} on {table}"))?;
|
||||
|
||||
info!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
expected_rows = ?expected_total,
|
||||
"Index definition submitted; waiting for build to finish"
|
||||
);
|
||||
|
||||
poll_index_build_status(db, index_name, table, expected_total, INDEX_POLL_INTERVAL).await
|
||||
}
|
||||
|
||||
async fn poll_index_build_status(
|
||||
db: &SurrealDbClient,
|
||||
index_name: &str,
|
||||
table: &str,
|
||||
total_rows: Option<u64>,
|
||||
poll_every: Duration,
|
||||
) -> Result<()> {
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(poll_every).await;
|
||||
|
||||
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
|
||||
let mut info_res = db.client.query(info_query).await.with_context(|| {
|
||||
format!("checking index build status for {index_name} on {table}")
|
||||
})?;
|
||||
|
||||
let info: Option<Value> = info_res
|
||||
.take(0)
|
||||
.context("failed to deserialize INFO FOR INDEX result")?;
|
||||
|
||||
let Some(snapshot) = parse_index_build_info(info, total_rows) else {
|
||||
warn!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
"INFO FOR INDEX returned no data; assuming index definition might be missing"
|
||||
);
|
||||
break;
|
||||
};
|
||||
|
||||
match snapshot.progress_pct {
|
||||
Some(pct) => info!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
initial = snapshot.initial,
|
||||
pending = snapshot.pending,
|
||||
updated = snapshot.updated,
|
||||
processed = snapshot.processed,
|
||||
total = snapshot.total_rows,
|
||||
progress_pct = format_args!("{pct:.1}"),
|
||||
"Index build status"
|
||||
),
|
||||
None => info!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
initial = snapshot.initial,
|
||||
pending = snapshot.pending,
|
||||
updated = snapshot.updated,
|
||||
processed = snapshot.processed,
|
||||
"Index build status"
|
||||
),
|
||||
}
|
||||
|
||||
if snapshot.is_ready() {
|
||||
info!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
elapsed = ?started_at.elapsed(),
|
||||
processed = snapshot.processed,
|
||||
total = snapshot.total_rows,
|
||||
"Index is ready"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
if snapshot.status.eq_ignore_ascii_case("error") {
|
||||
warn!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
"Index build reported error status; stopping polling"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct IndexBuildSnapshot {
|
||||
status: String,
|
||||
initial: u64,
|
||||
pending: u64,
|
||||
updated: u64,
|
||||
processed: u64,
|
||||
total_rows: Option<u64>,
|
||||
progress_pct: Option<f64>,
|
||||
}
|
||||
|
||||
impl IndexBuildSnapshot {
|
||||
fn is_ready(&self) -> bool {
|
||||
self.status.eq_ignore_ascii_case("ready")
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_index_build_info(
|
||||
info: Option<Value>,
|
||||
total_rows: Option<u64>,
|
||||
) -> Option<IndexBuildSnapshot> {
|
||||
let info = info?;
|
||||
let building = info.get("building");
|
||||
|
||||
let status = building
|
||||
.and_then(|b| b.get("status"))
|
||||
.and_then(|s| s.as_str())
|
||||
// If there's no `building` block at all, treat as "ready" (index not building anymore)
|
||||
.unwrap_or("ready")
|
||||
.to_string();
|
||||
|
||||
let initial = building
|
||||
.and_then(|b| b.get("initial"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let pending = building
|
||||
.and_then(|b| b.get("pending"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let updated = building
|
||||
.and_then(|b| b.get("updated"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
|
||||
let processed = initial.saturating_add(updated);
|
||||
|
||||
let progress_pct = total_rows.map(|total| {
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
((processed as f64 / total as f64).min(1.0)) * 100.0
|
||||
}
|
||||
});
|
||||
|
||||
Some(IndexBuildSnapshot {
|
||||
status,
|
||||
initial,
|
||||
pending,
|
||||
updated,
|
||||
processed,
|
||||
total_rows,
|
||||
progress_pct,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CountRow {
|
||||
count: u64,
|
||||
}
|
||||
|
||||
async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
|
||||
let query = format!("SELECT count() AS count FROM {table} GROUP ALL;");
|
||||
let mut response = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.with_context(|| format!("counting rows in {table}"))?;
|
||||
let rows: Vec<CountRow> = response
|
||||
.take(0)
|
||||
.context("failed to deserialize count() response")?;
|
||||
Ok(rows.first().map(|r| r.count).unwrap_or(0))
|
||||
}
|
||||
|
||||
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
|
||||
[
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_text_chunk_embedding",
|
||||
table: "text_chunk_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
},
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_knowledge_entity_embedding",
|
||||
table: "knowledge_entity_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
const fn fts_index_specs() -> [FtsIndexSpec; 9] {
|
||||
[
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_fts_idx",
|
||||
table: "text_content",
|
||||
field: "text",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_category_fts_idx",
|
||||
table: "text_content",
|
||||
field: "category",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_context_fts_idx",
|
||||
table: "text_content",
|
||||
field: "context",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_file_name_fts_idx",
|
||||
table: "text_content",
|
||||
field: "file_info.file_name",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_url_fts_idx",
|
||||
table: "text_content",
|
||||
field: "url_info.url",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_url_title_fts_idx",
|
||||
table: "text_content",
|
||||
field: "url_info.title",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "knowledge_entity_fts_name_idx",
|
||||
table: "knowledge_entity",
|
||||
field: "name",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "knowledge_entity_fts_description_idx",
|
||||
table: "knowledge_entity",
|
||||
field: "description",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_chunk_fts_chunk_idx",
|
||||
table: "text_chunk",
|
||||
field: "chunk",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_reports_progress() {
|
||||
let info = json!({
|
||||
"building": {
|
||||
"initial": 56894,
|
||||
"pending": 0,
|
||||
"status": "indexing",
|
||||
"updated": 0
|
||||
}
|
||||
});
|
||||
|
||||
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot");
|
||||
assert_eq!(
|
||||
snapshot,
|
||||
IndexBuildSnapshot {
|
||||
status: "indexing".to_string(),
|
||||
initial: 56894,
|
||||
pending: 0,
|
||||
updated: 0,
|
||||
processed: 56894,
|
||||
total_rows: Some(61081),
|
||||
progress_pct: Some((56894_f64 / 61081_f64) * 100.0),
|
||||
}
|
||||
);
|
||||
assert!(!snapshot.is_ready());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_defaults_to_ready_when_no_building_block() {
|
||||
// Surreal returns `{}` when the index exists but isn't building.
|
||||
let info = json!({});
|
||||
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot");
|
||||
assert!(snapshot.is_ready());
|
||||
assert_eq!(snapshot.processed, 0);
|
||||
assert_eq!(snapshot.progress_pct, Some(0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_dimension_parses_value() {
|
||||
let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;";
|
||||
assert_eq!(extract_dimension(definition), Some(1536));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_runtime_indexes_is_idempotent() {
|
||||
let namespace = "indexes_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should succeed");
|
||||
|
||||
// First run creates everything
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("initial index creation");
|
||||
|
||||
// Second run should be a no-op and still succeed
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("second index creation");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_hnsw_index_overwrites_dimension() {
|
||||
let namespace = "indexes_dim";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should succeed");
|
||||
|
||||
// Create initial index with default dimension
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("initial index creation");
|
||||
|
||||
// Change dimension and ensure overwrite path is exercised
|
||||
ensure_runtime_indexes(&db, 128)
|
||||
.await
|
||||
.expect("overwritten index creation");
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod db;
|
||||
pub mod indexes;
|
||||
pub mod store;
|
||||
pub mod types;
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
error::AppError, storage::db::SurrealDbClient, stored_object,
|
||||
error::AppError, storage::db::SurrealDbClient,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object,
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
@@ -78,10 +79,16 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
|
||||
description: String,
|
||||
entity_type: KnowledgeEntityType,
|
||||
metadata: Option<serde_json::Value>,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
/// Vector search result including hydrated entity.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
|
||||
pub struct KnowledgeEntityVectorResult {
|
||||
pub entity: KnowledgeEntity,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl KnowledgeEntity {
|
||||
pub fn new(
|
||||
source_id: String,
|
||||
@@ -89,7 +96,6 @@ impl KnowledgeEntity {
|
||||
description: String,
|
||||
entity_type: KnowledgeEntityType,
|
||||
metadata: Option<serde_json::Value>,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
@@ -102,7 +108,6 @@ impl KnowledgeEntity {
|
||||
description,
|
||||
entity_type,
|
||||
metadata,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
@@ -165,6 +170,89 @@ impl KnowledgeEntity {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Atomically store a knowledge entity and its embedding.
|
||||
/// Writes the entity to `knowledge_entity` and the embedding to `knowledge_entity_embedding`.
|
||||
pub async fn store_with_embedding(
|
||||
entity: KnowledgeEntity,
|
||||
embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone());
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
|
||||
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
entity_table = Self::table_name(),
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity.id.clone()))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
||||
pub async fn vector_search(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntityVectorResult>, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: KnowledgeEntity,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
entity_id,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH entity_id;
|
||||
"#,
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||
|
||||
response = response.check().map_err(AppError::Database)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| KnowledgeEntityVectorResult {
|
||||
entity: r.entity_id,
|
||||
score: r.score,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
id: &str,
|
||||
name: &str,
|
||||
@@ -178,32 +266,55 @@ impl KnowledgeEntity {
|
||||
name, description, entity_type
|
||||
);
|
||||
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
|
||||
let user_id = Self::get_user_id_by_id(id, db_client).await?;
|
||||
let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id);
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing($table, $id)
|
||||
SET name = $name,
|
||||
description = $description,
|
||||
updated_at = $updated_at,
|
||||
entity_type = $entity_type,
|
||||
embedding = $embedding
|
||||
RETURN AFTER",
|
||||
"BEGIN TRANSACTION;
|
||||
UPDATE type::thing($table, $id)
|
||||
SET name = $name,
|
||||
description = $description,
|
||||
updated_at = $updated_at,
|
||||
entity_type = $entity_type;
|
||||
UPSERT type::thing($emb_table, $emb_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("emb_table", KnowledgeEntityEmbedding::table_name()))
|
||||
.bind(("id", id.to_string()))
|
||||
.bind(("name", name.to_string()))
|
||||
.bind(("updated_at", surrealdb::Datetime::from(now)))
|
||||
.bind(("entity_type", entity_type.to_owned()))
|
||||
.bind(("embedding", embedding))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.bind(("description", description.to_string()))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
|
||||
let mut response = db_client
|
||||
.client
|
||||
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
user_id: String,
|
||||
}
|
||||
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
|
||||
rows.get(0)
|
||||
.map(|r| r.user_id.clone())
|
||||
.ok_or_else(|| AppError::InternalError("user not found for entity".to_string()))
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all knowledge entities in the database.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It follows the same
|
||||
@@ -228,22 +339,13 @@ impl KnowledgeEntity {
|
||||
if total_entities == 0 {
|
||||
info!("No knowledge entities to update. Just updating the idx");
|
||||
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
transaction_query
|
||||
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.query(transaction_query).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} entities to process.", total_entities);
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all entities...");
|
||||
for entity in all_entities.iter() {
|
||||
let embedding_input = format!(
|
||||
@@ -271,17 +373,16 @@ impl KnowledgeEntity {
|
||||
error!("{}", err_msg);
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(entity.id.clone(), embedding);
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Perform DB updates in a single transaction
|
||||
info!("Applying schema and data changes in a transaction...");
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
// Add all update statements
|
||||
for (id, embedding) in new_embeddings {
|
||||
// We must properly serialize the vector for the SurrealQL query string
|
||||
// Add all update statements to the embedding table
|
||||
for (id, (embedding, user_id)) in new_embeddings {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
@@ -291,18 +392,22 @@ impl KnowledgeEntity {
|
||||
.join(",")
|
||||
);
|
||||
transaction_query.push_str(&format!(
|
||||
"UPDATE type::thing('knowledge_entity', '{}') SET embedding = {}, updated_at = time::now();",
|
||||
id, embedding_str
|
||||
));
|
||||
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
|
||||
entity_id = type::thing('knowledge_entity', '{id}'), \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id
|
||||
));
|
||||
}
|
||||
|
||||
// Re-create the index after updating the data that it will index
|
||||
transaction_query
|
||||
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
@@ -317,7 +422,9 @@ impl KnowledgeEntity {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_creation() {
|
||||
@@ -327,7 +434,6 @@ mod tests {
|
||||
let description = "Test Description".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let metadata = Some(json!({"key": "value"}));
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
@@ -336,7 +442,6 @@ mod tests {
|
||||
description.clone(),
|
||||
entity_type.clone(),
|
||||
metadata.clone(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -346,7 +451,6 @@ mod tests {
|
||||
assert_eq!(entity.description, description);
|
||||
assert_eq!(entity.entity_type, entity_type);
|
||||
assert_eq!(entity.metadata, metadata);
|
||||
assert_eq!(entity.embedding, embedding);
|
||||
assert_eq!(entity.user_id, user_id);
|
||||
assert!(!entity.id.is_empty());
|
||||
}
|
||||
@@ -410,20 +514,25 @@ mod tests {
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
// Create two entities with the same source_id
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -433,7 +542,6 @@ mod tests {
|
||||
"Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -445,18 +553,18 @@ mod tests {
|
||||
"Different Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
// Store the entities
|
||||
db.store_item(entity1)
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
db.store_item(entity2)
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
db.store_item(different_entity.clone())
|
||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store different entity");
|
||||
|
||||
@@ -505,6 +613,162 @@ mod tests {
|
||||
assert_eq!(different_remaining[0].id, different_entity.id);
|
||||
}
|
||||
|
||||
// Note: We can't easily test the patch method without mocking the OpenAI client
|
||||
// and the generate_embedding function. This would require more complex setup.
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.expect("vector search");
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
let entity = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"hello".to_string(),
|
||||
"world".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store entity with embedding");
|
||||
|
||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {}",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("query embeddings")
|
||||
.take(0)
|
||||
.expect("take embeddings");
|
||||
assert_eq!(stored_embeddings.len(), 1);
|
||||
|
||||
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
|
||||
.await
|
||||
.expect("fetch embedding");
|
||||
assert!(fetched_emb.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.expect("vector search");
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let res = &results[0];
|
||||
assert_eq!(res.entity.id, entity.id);
|
||||
assert_eq!(res.entity.source_id, source_id);
|
||||
assert_eq!(res.entity.name, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let e1 = KnowledgeEntity::new(
|
||||
"s1".to_string(),
|
||||
"entity one".to_string(),
|
||||
"desc".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
let e2 = KnowledgeEntity::new(
|
||||
"s2".to_string(),
|
||||
"entity two".to_string(),
|
||||
"desc".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store e1");
|
||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store e2");
|
||||
|
||||
let stored_e1: Option<KnowledgeEntity> = db.get_item(&e1.id).await.unwrap();
|
||||
let stored_e2: Option<KnowledgeEntity> = db.get_item(&e2.id).await.unwrap();
|
||||
assert!(stored_e1.is_some() && stored_e2.is_some());
|
||||
|
||||
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {}",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("query embeddings")
|
||||
.take(0)
|
||||
.expect("take embeddings");
|
||||
assert_eq!(stored_embeddings.len(), 2);
|
||||
|
||||
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
||||
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.expect("vector search");
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].entity.id, e2.id);
|
||||
assert_eq!(results[1].entity.id, e1.id);
|
||||
}
|
||||
}
|
||||
|
||||
385
common/src/storage/types/knowledge_entity_embedding.rs
Normal file
385
common/src/storage/types/knowledge_entity_embedding.rs
Normal file
@@ -0,0 +1,385 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
||||
entity_id: RecordId,
|
||||
embedding: Vec<f32>,
|
||||
/// Denormalized user id for query scoping
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl KnowledgeEntityEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table};
|
||||
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
|
||||
COMMIT TRANSACTION;",
|
||||
table = Self::table_name(),
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::Database)?;
|
||||
res.check().map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new knowledge entity embedding
|
||||
pub fn new(entity_id: &str, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get embedding by entity ID
|
||||
pub async fn get_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
/// Get embeddings for multiple entities in batch
|
||||
pub async fn get_by_entity_ids(
|
||||
entity_ids: &[RecordId],
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<HashMap<String, Vec<f32>>, AppError> {
|
||||
if entity_ids.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let ids_list: Vec<RecordId> = entity_ids.iter().cloned().collect();
|
||||
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
|
||||
Self::table_name()
|
||||
);
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("entity_ids", ids_list))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(embeddings
|
||||
.into_iter()
|
||||
.map(|e| (e.entity_id.key().to_string(), e.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Delete embedding by entity ID
|
||||
pub async fn delete_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE entity_id = $entity_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete embeddings by source_id (via joining to knowledge_entity table)
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
|
||||
let mut res = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
|
||||
|
||||
for row in ids {
|
||||
Self::delete_by_entity_id(&row.id, db).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use chrono::Utc;
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
fn build_knowledge_entity_with_id(
|
||||
key: &str,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
) -> KnowledgeEntity {
|
||||
KnowledgeEntity {
|
||||
id: key.to_owned(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
source_id: source_id.to_owned(),
|
||||
name: "Test entity".to_owned(),
|
||||
description: "Desc".to_owned(),
|
||||
entity_type: KnowledgeEntityType::Document,
|
||||
metadata: None,
|
||||
user_id: user_id.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_entity_id() {
|
||||
let db = setup_test_db().await;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-1";
|
||||
let source_id = "source-ke";
|
||||
|
||||
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding by entity_id")
|
||||
.expect("Expected embedding to exist");
|
||||
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.entity_id, entity_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_entity_id() {
|
||||
let db = setup_test_db().await;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-delete";
|
||||
let source_id = "source-del";
|
||||
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding before delete");
|
||||
assert!(existing.is_some());
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to delete by entity_id");
|
||||
|
||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding after delete");
|
||||
assert!(after.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_entity_and_embedding() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user_store";
|
||||
let source_id = "source_store";
|
||||
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to fetch embedding");
|
||||
assert!(stored_embedding.is_some());
|
||||
let stored_embedding = stored_embedding.unwrap();
|
||||
assert_eq!(stored_embedding.user_id, user_id);
|
||||
assert_eq!(stored_embedding.entity_id, entity_rid);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
let db = setup_test_db().await;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
let user_id = "user_ke";
|
||||
let source_id = "shared-ke";
|
||||
let other_source = "other-ke";
|
||||
|
||||
let entity1 = build_knowledge_entity_with_id("entity-s1", source_id, user_id);
|
||||
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
|
||||
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||
let other_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity_other.id);
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete by source_id");
|
||||
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none()
|
||||
);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
|
||||
.await
|
||||
.expect("failed to redefine index");
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE knowledge_entity_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_knowledge_entity_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 16"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_entity_via_record_id() {
|
||||
let db = setup_test_db().await;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-fetch";
|
||||
let source_id = "source-fetch";
|
||||
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: KnowledgeEntity,
|
||||
}
|
||||
|
||||
let mut res = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT entity_id FROM knowledge_entity_embedding WHERE entity_id = $id FETCH entity_id;",
|
||||
)
|
||||
.bind(("id", entity_rid.clone()))
|
||||
.await
|
||||
.expect("failed to fetch embedding with FETCH");
|
||||
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows");
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
let fetched_entity = &rows[0].entity_id;
|
||||
assert_eq!(fetched_entity.id, entity_key);
|
||||
assert_eq!(fetched_entity.name, "Test entity");
|
||||
assert_eq!(fetched_entity.user_id, user_id);
|
||||
}
|
||||
}
|
||||
@@ -119,7 +119,6 @@ mod tests {
|
||||
let source_id = "source123".to_string();
|
||||
let description = format!("Description for {}", name);
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
@@ -128,7 +127,6 @@ mod tests {
|
||||
description,
|
||||
entity_type,
|
||||
None,
|
||||
embedding,
|
||||
user_id,
|
||||
);
|
||||
|
||||
|
||||
@@ -5,12 +5,14 @@ pub mod file_info;
|
||||
pub mod ingestion_payload;
|
||||
pub mod ingestion_task;
|
||||
pub mod knowledge_entity;
|
||||
pub mod knowledge_entity_embedding;
|
||||
pub mod knowledge_relationship;
|
||||
pub mod message;
|
||||
pub mod scratchpad;
|
||||
pub mod system_prompts;
|
||||
pub mod system_settings;
|
||||
pub mod text_chunk;
|
||||
pub mod text_chunk_embedding;
|
||||
pub mod text_content;
|
||||
pub mod user;
|
||||
|
||||
|
||||
@@ -71,25 +71,22 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to fetch table info");
|
||||
|
||||
let info: Option<serde_json::Value> = response
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
.expect("Failed to extract table info response");
|
||||
|
||||
let info = info.expect("Table info result missing");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("Failed to convert info to json");
|
||||
|
||||
let indexes = info
|
||||
.get("indexes")
|
||||
.or_else(|| {
|
||||
info.get("tables")
|
||||
.and_then(|tables| tables.get(table_name))
|
||||
.and_then(|table| table.get("indexes"))
|
||||
})
|
||||
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info:#?}"));
|
||||
let indexes = info_json["Object"]["indexes"]["Object"]
|
||||
.as_object()
|
||||
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}"));
|
||||
|
||||
let definition = indexes
|
||||
.get(index_name)
|
||||
.and_then(|definition| definition.as_str())
|
||||
.unwrap_or_else(|| panic!("Index definition not found in table info: {info:#?}"));
|
||||
.and_then(|definition| definition.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}"));
|
||||
|
||||
let dimension_part = definition
|
||||
.split("DIMENSION")
|
||||
@@ -261,48 +258,56 @@ mod tests {
|
||||
let initial_chunk = TextChunk::new(
|
||||
"source1".into(),
|
||||
"This chunk has the original dimension".into(),
|
||||
vec![0.1; 1536],
|
||||
"user1".into(),
|
||||
);
|
||||
|
||||
db.store_item(initial_chunk.clone())
|
||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
||||
.await
|
||||
.expect("Failed to store initial chunk");
|
||||
.expect("Failed to store initial chunk with embedding");
|
||||
|
||||
async fn simulate_reembedding(
|
||||
db: &SurrealDbClient,
|
||||
target_dimension: usize,
|
||||
initial_chunk: TextChunk,
|
||||
) {
|
||||
db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;")
|
||||
.await
|
||||
.unwrap();
|
||||
db.query(
|
||||
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let define_index_query = format!(
|
||||
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
|
||||
target_dimension
|
||||
);
|
||||
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
target_dimension
|
||||
);
|
||||
db.query(define_index_query)
|
||||
.await
|
||||
.expect("Re-defining index should succeed");
|
||||
|
||||
let new_embedding = vec![0.5; target_dimension];
|
||||
let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;";
|
||||
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
||||
|
||||
let update_result = db
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("id", initial_chunk.id.clone()))
|
||||
.bind(("user_id", initial_chunk.user_id.clone()))
|
||||
.bind(("embedding", new_embedding))
|
||||
.await;
|
||||
|
||||
assert!(update_result.is_ok());
|
||||
}
|
||||
|
||||
simulate_reembedding(&db, 768, initial_chunk).await;
|
||||
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
|
||||
let target_dimension = 1536usize;
|
||||
simulate_reembedding(&db, target_dimension, initial_chunk).await;
|
||||
|
||||
let migration_result = db.apply_migrations().await;
|
||||
|
||||
assert!(migration_result.is_ok(), "Migrations should not fail");
|
||||
assert!(
|
||||
migration_result.is_ok(),
|
||||
"Migrations should not fail: {:?}",
|
||||
migration_result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -320,8 +325,12 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to load current settings");
|
||||
|
||||
let initial_chunk_dimension =
|
||||
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
|
||||
let initial_chunk_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"text_chunk_embedding",
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
initial_chunk_dimension, current_settings.embedding_dimensions,
|
||||
@@ -352,10 +361,18 @@ mod tests {
|
||||
.await
|
||||
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
|
||||
|
||||
let text_chunk_dimension =
|
||||
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
|
||||
let knowledge_dimension =
|
||||
get_hnsw_index_dimension(&db, "knowledge_entity", "idx_embedding_entities").await;
|
||||
let text_chunk_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"text_chunk_embedding",
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
)
|
||||
.await;
|
||||
let knowledge_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"knowledge_entity_embedding",
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
text_chunk_dimension, new_dimension,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use tokio_retry::{
|
||||
@@ -13,12 +14,18 @@ use uuid::Uuid;
|
||||
stored_object!(TextChunk, "text_chunk", {
|
||||
source_id: String,
|
||||
chunk: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
/// Vector search result including hydrated chunk.
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||
pub struct TextChunkVectorResult {
|
||||
pub chunk: TextChunk,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl TextChunk {
|
||||
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
@@ -26,7 +33,6 @@ impl TextChunk {
|
||||
updated_at: now,
|
||||
source_id,
|
||||
chunk,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
@@ -45,6 +51,94 @@ impl TextChunk {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Atomically store a text chunk and its embedding.
|
||||
/// Writes the chunk to `text_chunk` and the embedding to `text_chunk_embedding`.
|
||||
pub async fn store_with_embedding(
|
||||
chunk: TextChunk,
|
||||
embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let emb = TextChunkEmbedding::new(
|
||||
&chunk.id,
|
||||
chunk.source_id.clone(),
|
||||
embedding,
|
||||
chunk.user_id.clone(),
|
||||
);
|
||||
|
||||
// Create both records in a single query
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
|
||||
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
chunk_table = Self::table_name(),
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk.id.clone()))
|
||||
.bind(("chunk", chunk))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
|
||||
pub async fn vector_search(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<TextChunkVectorResult>, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
chunk_id: TextChunk,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
chunk_id,
|
||||
embedding,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH chunk_id;
|
||||
"#,
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).unwrap_or_default();
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| TextChunkVectorResult {
|
||||
chunk: r.chunk_id,
|
||||
score: r.score,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It performs these steps:
|
||||
@@ -70,21 +164,14 @@ impl TextChunk {
|
||||
if total_chunks == 0 {
|
||||
info!("No text chunks to update. Just updating the idx");
|
||||
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions));
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.query(transaction_query).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} chunks to process.", total_chunks);
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all chunks...");
|
||||
for chunk in all_chunks.iter() {
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
@@ -108,16 +195,18 @@ impl TextChunk {
|
||||
error!("{}", err_msg);
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(chunk.id.clone(), embedding);
|
||||
new_embeddings.insert(
|
||||
chunk.id.clone(),
|
||||
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Perform DB updates in a single transaction
|
||||
info!("Applying schema and data changes in a transaction...");
|
||||
// Perform DB updates in a single transaction against the embedding table
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
// Add all update statements
|
||||
for (id, embedding) in new_embeddings {
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
@@ -126,22 +215,29 @@ impl TextChunk {
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
// Use the chunk id as the embedding record id to keep a 1:1 mapping
|
||||
transaction_query.push_str(&format!(
|
||||
"UPDATE type::thing('text_chunk', '{}') SET embedding = {}, updated_at = time::now();",
|
||||
id, embedding_str
|
||||
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
|
||||
chunk_id = type::thing('text_chunk', '{id}'), \
|
||||
source_id = '{source_id}', \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id,
|
||||
source_id = source_id
|
||||
));
|
||||
}
|
||||
|
||||
// Re-create the index inside the same transaction
|
||||
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
// Execute the entire atomic operation
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
info!("Re-embedding process for text chunks completed successfully.");
|
||||
@@ -152,171 +248,269 @@ impl TextChunk {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use surrealdb::RecordId;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_chunk_creation() {
|
||||
// Test basic object creation
|
||||
let source_id = "source123".to_string();
|
||||
let chunk = "This is a text chunk for testing embeddings".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_chunk = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk.clone(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
let text_chunk = TextChunk::new(source_id.clone(), chunk.clone(), user_id.clone());
|
||||
|
||||
// Check that the fields are set correctly
|
||||
assert_eq!(text_chunk.source_id, source_id);
|
||||
assert_eq!(text_chunk.chunk, chunk);
|
||||
assert_eq!(text_chunk.embedding, embedding);
|
||||
assert_eq!(text_chunk.user_id, user_id);
|
||||
assert!(!text_chunk.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
// Create test data
|
||||
let source_id = "source123".to_string();
|
||||
let chunk1 = "First chunk from the same source".to_string();
|
||||
let chunk2 = "Second chunk from the same source".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
// Create two chunks with the same source_id
|
||||
let text_chunk1 = TextChunk::new(
|
||||
let chunk1 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk1,
|
||||
embedding.clone(),
|
||||
"First chunk from the same source".to_string(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let text_chunk2 = TextChunk::new(
|
||||
let chunk2 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk2,
|
||||
embedding.clone(),
|
||||
"Second chunk from the same source".to_string(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create a chunk with a different source_id
|
||||
let different_source_id = "different_source".to_string();
|
||||
let different_chunk = TextChunk::new(
|
||||
different_source_id.clone(),
|
||||
"different_source".to_string(),
|
||||
"Different source chunk".to_string(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store the chunks
|
||||
db.store_item(text_chunk1)
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("Failed to store text chunk 1");
|
||||
db.store_item(text_chunk2)
|
||||
.expect("store chunk1");
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("Failed to store text chunk 2");
|
||||
db.store_item(different_chunk.clone())
|
||||
.await
|
||||
.expect("Failed to store different chunk");
|
||||
.expect("store chunk2");
|
||||
TextChunk::store_with_embedding(
|
||||
different_chunk.clone(),
|
||||
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("store different chunk");
|
||||
|
||||
// Delete by source_id
|
||||
TextChunk::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete chunks by source_id");
|
||||
|
||||
// Verify all chunks with the original source_id are deleted
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
source_id
|
||||
);
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(query)
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
0,
|
||||
"All chunks with the source_id should be deleted"
|
||||
);
|
||||
assert_eq!(remaining.len(), 0);
|
||||
|
||||
// Verify the different source_id chunk still exists
|
||||
let different_query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
different_source_id
|
||||
);
|
||||
let different_remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(different_query)
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
"different_source"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
different_remaining.len(),
|
||||
1,
|
||||
"Chunk with different source_id should still exist"
|
||||
);
|
||||
assert_eq!(different_remaining.len(), 1);
|
||||
assert_eq!(different_remaining[0].id, different_chunk.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_nonexistent_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
// Create a chunk with a real source_id
|
||||
let real_source_id = "real_source".to_string();
|
||||
let chunk = "Test chunk".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_chunk = TextChunk::new(real_source_id.clone(), chunk, embedding, user_id);
|
||||
|
||||
// Store the chunk
|
||||
db.store_item(text_chunk)
|
||||
.await
|
||||
.expect("Failed to store text chunk");
|
||||
|
||||
// Delete using nonexistent source_id
|
||||
let nonexistent_source_id = "nonexistent_source";
|
||||
TextChunk::delete_by_source_id(nonexistent_source_id, &db)
|
||||
.await
|
||||
.expect("Delete operation with nonexistent source_id should not fail");
|
||||
|
||||
// Verify the real chunk still exists
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
real_source_id
|
||||
let chunk = TextChunk::new(
|
||||
real_source_id.clone(),
|
||||
"Test chunk".to_string(),
|
||||
"user123".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk");
|
||||
|
||||
TextChunk::delete_by_source_id("nonexistent_source", &db)
|
||||
.await
|
||||
.expect("Delete should succeed");
|
||||
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(query)
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
real_source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
1,
|
||||
"Chunk with real source_id should still exist"
|
||||
assert_eq!(remaining.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_both_records() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
let source_id = "store-src".to_string();
|
||||
let user_id = "user_store".to_string();
|
||||
let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone());
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store with embedding");
|
||||
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
|
||||
assert!(stored_chunk.is_some());
|
||||
let stored_chunk = stored_chunk.unwrap();
|
||||
assert_eq!(stored_chunk.source_id, source_id);
|
||||
assert_eq!(stored_chunk.user_id, user_id);
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
.await
|
||||
.expect("get embedding");
|
||||
assert!(embedding.is_some());
|
||||
let embedding = embedding.unwrap();
|
||||
assert_eq!(embedding.chunk_id, rid);
|
||||
assert_eq!(embedding.user_id, user_id);
|
||||
assert_eq!(embedding.source_id, source_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
let results: Vec<TextChunkVectorResult> =
|
||||
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
let source_id = "src".to_string();
|
||||
let user_id = "user".to_string();
|
||||
let chunk = TextChunk::new(
|
||||
source_id.clone(),
|
||||
"hello world".to_string(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store");
|
||||
|
||||
let results: Vec<TextChunkVectorResult> =
|
||||
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let res = &results[0];
|
||||
assert_eq!(res.chunk.id, chunk.id);
|
||||
assert_eq!(res.chunk.source_id, source_id);
|
||||
assert_eq!(res.chunk.chunk, "hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
|
||||
let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone());
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
|
||||
let results: Vec<TextChunkVectorResult> =
|
||||
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].chunk.id, chunk2.id);
|
||||
assert_eq!(results[1].chunk.id, chunk1.id);
|
||||
assert!(results[0].score >= results[1].score);
|
||||
}
|
||||
}
|
||||
|
||||
435
common/src/storage/types/text_chunk_embedding.rs
Normal file
435
common/src/storage/types/text_chunk_embedding.rs
Normal file
@@ -0,0 +1,435 @@
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||
/// Record link to the owning text_chunk
|
||||
chunk_id: RecordId,
|
||||
/// Denormalized source id for bulk deletes
|
||||
source_id: String,
|
||||
/// Embedding vector
|
||||
embedding: Vec<f32>,
|
||||
/// Denormalized user id (for scoping + permissions)
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextChunkEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
///
|
||||
/// This is useful when the embedding length changes; Surreal requires the
|
||||
/// index definition to be recreated with the updated dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE {table};
|
||||
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
|
||||
COMMIT TRANSACTION;",
|
||||
table = Self::table_name(),
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::Database)?;
|
||||
res.check().map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new text chunk embedding
|
||||
///
|
||||
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID),
|
||||
/// not "text_chunk:uuid".
|
||||
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
|
||||
Self {
|
||||
// NOTE: `stored_object!` macro defines `id` as `String`
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
// Create a record<text_chunk> link: text_chunk:<chunk_id>
|
||||
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
||||
source_id,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a single embedding by its chunk RecordId
|
||||
pub async fn get_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
/// Delete embeddings for a given chunk RecordId
|
||||
pub async fn delete_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE chunk_id = $chunk_id",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings that belong to chunks with a given `source_id`
|
||||
///
|
||||
/// This uses a subquery to the `text_chunk` table:
|
||||
///
|
||||
/// DELETE FROM text_chunk_embedding
|
||||
/// WHERE chunk_id IN (SELECT id FROM text_chunk WHERE source_id = $source_id)
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let ids_query = format!(
|
||||
"SELECT id FROM {} WHERE source_id = $source_id",
|
||||
TextChunk::table_name()
|
||||
);
|
||||
let mut res = db
|
||||
.client
|
||||
.query(ids_query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
|
||||
|
||||
if ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let delete_query = format!(
|
||||
"DELETE FROM {} WHERE chunk_id IN $chunk_ids",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(delete_query)
|
||||
.bind((
|
||||
"chunk_ids",
|
||||
ids.into_iter().map(|row| row.id).collect::<Vec<_>>(),
|
||||
))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Helper to create an in-memory DB and apply migrations
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
/// Helper: create a text_chunk with a known key, return its RecordId
|
||||
async fn create_text_chunk_with_id(
|
||||
db: &SurrealDbClient,
|
||||
key: &str,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
) -> RecordId {
|
||||
let chunk = TextChunk {
|
||||
id: key.to_owned(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
source_id: source_id.to_owned(),
|
||||
chunk: "Some test chunk text".to_owned(),
|
||||
user_id: user_id.to_owned(),
|
||||
};
|
||||
|
||||
db.store_item(chunk)
|
||||
.await
|
||||
.expect("Failed to create text_chunk");
|
||||
|
||||
RecordId::from_table_key(TextChunk::table_name(), key)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_chunk_id() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let user_id = "user_a";
|
||||
let chunk_key = "chunk-123";
|
||||
let source_id = "source-1";
|
||||
|
||||
// 1) Create a text_chunk with a known key
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
||||
|
||||
// 2) Create and store an embedding for that chunk
|
||||
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
|
||||
let emb = TextChunkEmbedding::new(
|
||||
chunk_key,
|
||||
source_id.to_string(),
|
||||
embedding_vec.clone(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
|
||||
// 3) Fetch it via get_by_chunk_id
|
||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding by chunk_id");
|
||||
|
||||
assert!(fetched.is_some(), "Expected an embedding to be found");
|
||||
let fetched = fetched.unwrap();
|
||||
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.chunk_id, chunk_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_chunk_id() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let user_id = "user_b";
|
||||
let chunk_key = "chunk-delete";
|
||||
let source_id = "source-del";
|
||||
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
||||
|
||||
let emb = TextChunkEmbedding::new(
|
||||
chunk_key,
|
||||
source_id.to_string(),
|
||||
vec![0.4_f32, 0.5, 0.6],
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
|
||||
// Ensure it exists
|
||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding before delete");
|
||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||
|
||||
// Delete by chunk_id
|
||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to delete by chunk_id");
|
||||
|
||||
// Ensure it no longer exists
|
||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding after delete");
|
||||
assert!(after.is_none(), "Embedding should have been deleted");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let user_id = "user_c";
|
||||
let source_id = "shared-source";
|
||||
let other_source = "other-source";
|
||||
|
||||
// Two chunks with the same source_id
|
||||
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await;
|
||||
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await;
|
||||
|
||||
// One chunk with a different source_id
|
||||
let chunk_other_rid =
|
||||
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await;
|
||||
|
||||
// Create embeddings for all three
|
||||
let emb1 = TextChunkEmbedding::new(
|
||||
"chunk-s1",
|
||||
source_id.to_string(),
|
||||
vec![0.1],
|
||||
user_id.to_string(),
|
||||
);
|
||||
let emb2 = TextChunkEmbedding::new(
|
||||
"chunk-s2",
|
||||
source_id.to_string(),
|
||||
vec![0.2],
|
||||
user_id.to_string(),
|
||||
);
|
||||
let emb3 = TextChunkEmbedding::new(
|
||||
"chunk-other",
|
||||
other_source.to_string(),
|
||||
vec![0.3],
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
// Update length on index
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
for emb in [emb1, emb2, emb3] {
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
}
|
||||
|
||||
// Sanity check: they all exist
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
|
||||
// Delete embeddings by source_id (shared-source)
|
||||
TextChunkEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete by source_id");
|
||||
|
||||
// Chunks from shared-source should have no embeddings
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
// The other chunk should still have its embedding
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Change the index dimension from default (1536) to a smaller test value.
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
|
||||
.await
|
||||
.expect("failed to redefine index");
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 8"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_is_idempotent() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||
.await
|
||||
.expect("first redefine failed");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||
.await
|
||||
.expect("second redefine failed");
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 4"),
|
||||
"expected index definition to retain dimension 4, got: {idx_sql}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -146,12 +146,12 @@ impl TextContent {
|
||||
search::highlight('<b>', '</b>', 4) AS highlighted_url,
|
||||
search::highlight('<b>', '</b>', 5) AS highlighted_url_title,
|
||||
(
|
||||
search::score(0) +
|
||||
search::score(1) +
|
||||
search::score(2) +
|
||||
search::score(3) +
|
||||
search::score(4) +
|
||||
search::score(5)
|
||||
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END +
|
||||
IF search::score(1) != NONE THEN search::score(1) ELSE 0 END +
|
||||
IF search::score(2) != NONE THEN search::score(2) ELSE 0 END +
|
||||
IF search::score(3) != NONE THEN search::score(3) ELSE 0 END +
|
||||
IF search::score(4) != NONE THEN search::score(4) ELSE 0 END +
|
||||
IF search::score(5) != NONE THEN search::score(5) ELSE 0 END
|
||||
) AS score
|
||||
FROM text_content
|
||||
WHERE
|
||||
|
||||
@@ -1,10 +1,279 @@
|
||||
use async_openai::types::CreateEmbeddingRequestArgs;
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum EmbeddingBackend {
|
||||
OpenAI,
|
||||
FastEmbed,
|
||||
Hashed,
|
||||
}
|
||||
|
||||
impl Default for EmbeddingBackend {
|
||||
fn default() -> Self {
|
||||
Self::FastEmbed
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for EmbeddingBackend {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"openai" => Ok(Self::OpenAI),
|
||||
"hashed" => Ok(Self::Hashed),
|
||||
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||
other => Err(anyhow!(
|
||||
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingProvider {
|
||||
inner: EmbeddingInner,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum EmbeddingInner {
|
||||
OpenAI {
|
||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
model: String,
|
||||
dimensions: u32,
|
||||
},
|
||||
Hashed {
|
||||
dimension: usize,
|
||||
},
|
||||
FastEmbed {
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
model_name: EmbeddingModel,
|
||||
dimension: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
pub fn backend_label(&self) -> &'static str {
|
||||
match self.inner {
|
||||
EmbeddingInner::Hashed { .. } => "hashed",
|
||||
EmbeddingInner::FastEmbed { .. } => "fastembed",
|
||||
EmbeddingInner::OpenAI { .. } => "openai",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimension(&self) -> usize {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => *dimension,
|
||||
EmbeddingInner::FastEmbed { dimension, .. } => *dimension,
|
||||
EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn model_code(&self) -> Option<String> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
|
||||
EmbeddingInner::OpenAI { model, .. } => Some(model.clone()),
|
||||
EmbeddingInner::Hashed { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let mut guard = model.lock().await;
|
||||
let embeddings = guard
|
||||
.embed(vec![text.to_owned()], None)
|
||||
.context("generating fastembed vector")?;
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
model,
|
||||
dimensions,
|
||||
} => {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model.clone())
|
||||
.input([text])
|
||||
.dimensions(*dimensions)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||
.into_iter()
|
||||
.map(|text| hashed_embedding(&text, *dimension))
|
||||
.collect()),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut guard = model.lock().await;
|
||||
guard
|
||||
.embed(texts, None)
|
||||
.context("generating fastembed batch embeddings")
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
model,
|
||||
dimensions,
|
||||
} => {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model.clone())
|
||||
.input(texts)
|
||||
.dimensions(*dimensions)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
let embeddings: Vec<Vec<f32>> = response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|item| item.embedding)
|
||||
.collect();
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_openai(
|
||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
model: String,
|
||||
dimensions: u32,
|
||||
) -> Result<Self> {
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::OpenAI {
|
||||
client,
|
||||
model,
|
||||
dimensions,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
|
||||
let model_name = if let Some(code) = model_override {
|
||||
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
|
||||
} else {
|
||||
EmbeddingModel::default()
|
||||
};
|
||||
|
||||
let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
|
||||
let model =
|
||||
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task)
|
||||
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
|
||||
Ok((model, info.dim))
|
||||
})
|
||||
.await
|
||||
.context("joining FastEmbed initialisation task")??;
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
model: Arc::new(Mutex::new(model)),
|
||||
model_name,
|
||||
dimension,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_hashed(dimension: usize) -> Result<Self> {
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::Hashed {
|
||||
dimension: dimension.max(1),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for hashed embeddings
|
||||
fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
|
||||
let dim = dimension.max(1);
|
||||
let mut vector = vec![0.0f32; dim];
|
||||
if text.is_empty() {
|
||||
return vector;
|
||||
}
|
||||
|
||||
let mut token_count = 0f32;
|
||||
for token in tokens(text) {
|
||||
token_count += 1.0;
|
||||
let idx = bucket(&token, dim);
|
||||
vector[idx] += 1.0;
|
||||
}
|
||||
|
||||
if token_count == 0.0 {
|
||||
return vector;
|
||||
}
|
||||
|
||||
let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for value in &mut vector {
|
||||
*value /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
vector
|
||||
}
|
||||
|
||||
fn tokens(text: &str) -> impl Iterator<Item = String> + '_ {
|
||||
text.split(|c: char| !c.is_ascii_alphanumeric())
|
||||
.filter(|token| !token.is_empty())
|
||||
.map(|token| token.to_ascii_lowercase())
|
||||
}
|
||||
|
||||
fn bucket(token: &str, dimension: usize) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
token.hash(&mut hasher);
|
||||
(hasher.finish() as usize) % dimension
|
||||
}
|
||||
|
||||
// Backward compatibility function
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider.embed(input).await.map_err(AppError::from)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
|
||||
///
|
||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
||||
|
||||
Reference in New Issue
Block a user