diff --git a/Cargo.lock b/Cargo.lock index 870d08e..65389a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1417,6 +1417,7 @@ dependencies = [ "chrono-tz", "config", "dom_smoothie", + "fastembed", "futures", "include_dir", "mime", diff --git a/api-router/src/routes/ingress.rs b/api-router/src/routes/ingress.rs index 7d1548e..9b72777 100644 --- a/api-router/src/routes/ingress.rs +++ b/api-router/src/routes/ingress.rs @@ -30,9 +30,10 @@ pub async fn ingest_data( TypedMultipart(input): TypedMultipart, ) -> Result { info!("Received input: {:?}", input); + let user_id = user.id; let file_infos = try_join_all(input.files.into_iter().map(|file| { - FileInfo::new_with_storage(file, &state.db, &user.id, &state.storage) + FileInfo::new_with_storage(file, &state.db, &user_id, &state.storage) .map_err(AppError::from) })) .await?; @@ -42,12 +43,12 @@ pub async fn ingest_data( input.context, input.category, file_infos, - user.id.as_str(), + &user_id, )?; let futures: Vec<_> = payloads .into_iter() - .map(|object| IngestionTask::create_and_add_to_db(object, user.id.clone(), &state.db)) + .map(|object| IngestionTask::create_and_add_to_db(object, user_id.clone(), &state.db)) .collect(); try_join_all(futures).await?; diff --git a/common/Cargo.toml b/common/Cargo.toml index fa3d1ab..bc5a489 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -45,6 +45,7 @@ tokio-retry = { workspace = true } object_store = { workspace = true } bytes = { workspace = true } state-machines = { workspace = true } +fastembed = { workspace = true } [features] diff --git a/common/migrations/20250503_215025_initial_setup.surql b/common/migrations/20250503_215025_initial_setup.surql index faaae98..1e9c897 100644 --- a/common/migrations/20250503_215025_initial_setup.surql +++ b/common/migrations/20250503_215025_initial_setup.surql @@ -14,6 +14,9 @@ CREATE system_settings:current CONTENT { query_model: "gpt-4o-mini", processing_model: "gpt-4o-mini", embedding_model: "text-embedding-3-small", + voice_processing_model: "whisper-1", + image_processing_model: "gpt-4o-mini", + image_processing_prompt: "Analyze this image and respond based on its primary content:\n - If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.\n - If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.\n - For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a Text: heading.\n\n Respond directly with the analysis.", embedding_dimensions: 1536, query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"", ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity\"\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity." diff --git a/common/migrations/20250514_142342_add_full_text_search_text_content.surql b/common/migrations/20250514_142342_add_full_text_search_text_content.surql index 499a1f1..40d272a 100644 --- a/common/migrations/20250514_142342_add_full_text_search_text_content.surql +++ b/common/migrations/20250514_142342_add_full_text_search_text_content.surql @@ -1,27 +1,2 @@ -DEFINE ANALYZER IF NOT EXISTS app_default_fts_analyzer - TOKENIZERS class - FILTERS lowercase, ascii; - -DEFINE INDEX IF NOT EXISTS text_content_fts_text_idx ON TABLE text_content - FIELDS text - SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS; - -DEFINE INDEX IF NOT EXISTS text_content_fts_category_idx ON TABLE text_content - FIELDS category - SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS; - -DEFINE INDEX IF NOT EXISTS text_content_fts_context_idx ON TABLE text_content - FIELDS context - SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS; - -DEFINE INDEX IF NOT EXISTS text_content_fts_file_name_idx ON TABLE text_content - FIELDS file_info.file_name - SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS; - -DEFINE INDEX IF NOT EXISTS text_content_fts_url_idx ON TABLE text_content - FIELDS url_info.url - SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS; - -DEFINE INDEX IF NOT EXISTS text_content_fts_url_title_idx ON TABLE text_content - FIELDS url_info.title - SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS; +-- Runtime-managed: text_content FTS indexes now created at startup via the shared Surreal helper. +-- This migration is intentionally left as a no-op to avoid heavy index builds during migration. diff --git a/common/migrations/20250627_231035_remove_job_table.surql b/common/migrations/20250627_231035_remove_job_table.surql index a6151a1..e899c27 100644 --- a/common/migrations/20250627_231035_remove_job_table.surql +++ b/common/migrations/20250627_231035_remove_job_table.surql @@ -1 +1 @@ -REMOVE TABLE job; +-- No-op: legacy `job` table was superseded by `ingestion_task`; kept for migration order compatibility. diff --git a/common/migrations/20251003_000001_add_fts_for_entities_and_chunks.surql b/common/migrations/20251003_000001_add_fts_for_entities_and_chunks.surql index cad213c..98ed3c0 100644 --- a/common/migrations/20251003_000001_add_fts_for_entities_and_chunks.surql +++ b/common/migrations/20251003_000001_add_fts_for_entities_and_chunks.surql @@ -1,17 +1 @@ --- Add FTS indexes for searching name and description on entities - -DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer - TOKENIZERS class - FILTERS lowercase, ascii, snowball(english); - -DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity - FIELDS name - SEARCH ANALYZER app_en_fts_analyzer BM25; - -DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity - FIELDS description - SEARCH ANALYZER app_en_fts_analyzer BM25; - -DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk - FIELDS chunk - SEARCH ANALYZER app_en_fts_analyzer BM25; +-- Runtime-managed: FTS indexes now built at startup; migration retained as a no-op. diff --git a/common/migrations/20251121_113121_separate_embeddings_to_own_table.surql b/common/migrations/20251121_113121_separate_embeddings_to_own_table.surql new file mode 100644 index 0000000..9fefc42 --- /dev/null +++ b/common/migrations/20251121_113121_separate_embeddings_to_own_table.surql @@ -0,0 +1,23 @@ +-- Move chunk/entity embeddings to dedicated tables for index efficiency. + +-- Text chunk embeddings table +DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL; +DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime; +DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime; +DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string; +DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string; +DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record; +DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array; +DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id; +DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id; +DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id; + +-- Knowledge entity embeddings table +DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL; +DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime; +DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime; +DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string; +DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record; +DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array; +DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id; +DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id; diff --git a/common/migrations/20251122_151002_remove_legacy_embedding_fields.surql b/common/migrations/20251122_151002_remove_legacy_embedding_fields.surql new file mode 100644 index 0000000..305a552 --- /dev/null +++ b/common/migrations/20251122_151002_remove_legacy_embedding_fields.surql @@ -0,0 +1,3 @@ +-- Drop legacy embedding fields from base tables; embeddings now live in *_embedding tables. +REMOVE FIELD IF EXISTS embedding ON TABLE text_chunk; +REMOVE FIELD IF EXISTS embedding ON TABLE knowledge_entity; diff --git a/common/migrations/definitions/20250514_142342_add_full_text_search_text_content.json b/common/migrations/definitions/20250514_142342_add_full_text_search_text_content.json deleted file mode 100644 index 58d6847..0000000 --- a/common/migrations/definitions/20250514_142342_add_full_text_search_text_content.json +++ /dev/null @@ -1 +0,0 @@ -{"schemas":"--- original\n+++ modified\n@@ -98,7 +98,7 @@\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n # Defines the schema for the 'message' table.\n\n@@ -157,6 +157,8 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n\n","events":null} \ No newline at end of file diff --git a/common/migrations/definitions/20250606_234535_add_embedding_model_and_dimensions_to_system_settings.json b/common/migrations/definitions/20250606_234535_add_embedding_model_and_dimensions_to_system_settings.json deleted file mode 100644 index cf85299..0000000 --- a/common/migrations/definitions/20250606_234535_add_embedding_model_and_dimensions_to_system_settings.json +++ /dev/null @@ -1 +0,0 @@ -{"schemas":"--- original\n+++ modified\n@@ -51,23 +51,23 @@\n\n # Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n\n-DEFINE TABLE IF NOT EXISTS job SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON job TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON job TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n\n # Custom fields from the IngestionTask struct\n # IngestionPayload is complex, store as object\n-DEFINE FIELD IF NOT EXISTS content ON job TYPE object;\n+DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n # IngestionTaskStatus can hold data (InProgress), store as object\n-DEFINE FIELD IF NOT EXISTS status ON job TYPE object;\n-DEFINE FIELD IF NOT EXISTS user_id ON job TYPE string;\n+DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n+DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n # Indexes explicitly defined in build_indexes and useful for get_unfinished_tasks\n-DEFINE INDEX IF NOT EXISTS idx_job_status ON job FIELDS status;\n-DEFINE INDEX IF NOT EXISTS idx_job_user ON job FIELDS user_id;\n-DEFINE INDEX IF NOT EXISTS idx_job_created ON job FIELDS created_at;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_status ON ingestion_task FIELDS status;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_user ON ingestion_task FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_created ON ingestion_task FIELDS created_at;\n\n # Defines the schema for the 'knowledge_entity' table.\n\n","events":null} \ No newline at end of file diff --git a/common/migrations/definitions/20250627_120000_add_image_processing_settings.json b/common/migrations/definitions/20250627_120000_add_image_processing_settings.json deleted file mode 100644 index c1b23cf..0000000 --- a/common/migrations/definitions/20250627_120000_add_image_processing_settings.json +++ /dev/null @@ -1 +0,0 @@ -{"schemas":"--- original\n+++ modified\n@@ -57,10 +57,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n\n-# Custom fields from the IngestionTask struct\n-# IngestionPayload is complex, store as object\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n-# IngestionTaskStatus can hold data (InProgress), store as object\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n@@ -157,10 +154,12 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS image_processing_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null} \ No newline at end of file diff --git a/common/migrations/definitions/20250701_000000_add_voice_processing_model_to_system_settings.json b/common/migrations/definitions/20250701_000000_add_voice_processing_model_to_system_settings.json deleted file mode 100644 index b948e59..0000000 --- a/common/migrations/definitions/20250701_000000_add_voice_processing_model_to_system_settings.json +++ /dev/null @@ -1 +0,0 @@ -{"schemas":"--- original\n+++ modified\n@@ -160,6 +160,7 @@\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null} \ No newline at end of file diff --git a/common/migrations/definitions/20250921_120004_fix_datetime_fields.json b/common/migrations/definitions/20250921_120004_fix_datetime_fields.json deleted file mode 100644 index d377e38..0000000 --- a/common/migrations/definitions/20250921_120004_fix_datetime_fields.json +++ /dev/null @@ -1 +0,0 @@ -{"schemas":"--- original\n+++ modified\n@@ -18,8 +18,8 @@\n DEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE datetime;\n\n # Custom fields from the Conversation struct\n DEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;\n@@ -34,8 +34,8 @@\n DEFINE TABLE IF NOT EXISTS file SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON file TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON file TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE datetime;\n\n # Custom fields from the FileInfo struct\n DEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;\n@@ -54,8 +54,8 @@\n DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime;\n\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n@@ -71,8 +71,8 @@\n DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE datetime;\n\n # Custom fields from the KnowledgeEntity struct\n DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;\n@@ -102,8 +102,8 @@\n DEFINE TABLE IF NOT EXISTS message SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON message TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON message TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE datetime;\n\n # Custom fields from the Message struct\n DEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;\n@@ -167,8 +167,8 @@\n DEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;\n\n # Custom fields from the TextChunk struct\n DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;\n@@ -191,8 +191,8 @@\n DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;\n\n # Custom fields from the TextContent struct\n DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;\n@@ -215,8 +215,8 @@\n DEFINE TABLE IF NOT EXISTS user SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON user TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON user TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE datetime;\n\n # Custom fields from the User struct\n DEFINE FIELD IF NOT EXISTS email ON user TYPE string;\n","events":null} \ No newline at end of file diff --git a/common/migrations/definitions/20251022_120302_add_scratchpad_table.json b/common/migrations/definitions/20251022_120302_add_scratchpad_table.json deleted file mode 100644 index ae08afa..0000000 --- a/common/migrations/definitions/20251022_120302_add_scratchpad_table.json +++ /dev/null @@ -1 +0,0 @@ -{"schemas":"--- original\n+++ modified\n@@ -137,6 +137,30 @@\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n\n+# Defines the schema for the 'scratchpad' table.\n+\n+DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;\n+\n+# Standard fields from stored_object! macro\n+DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;\n+\n+# Custom fields from the Scratchpad struct\n+DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;\n+DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;\n+DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option;\n+DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option;\n+\n+# Indexes based on query patterns\n+DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;\n+DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;\n+DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;\n+\n DEFINE TABLE OVERWRITE script_migration SCHEMAFULL\n PERMISSIONS\n FOR select FULL\n","events":null} \ No newline at end of file diff --git a/common/migrations/definitions/20251122_151002_remove_legacy_embedding_fields.json b/common/migrations/definitions/20251122_151002_remove_legacy_embedding_fields.json new file mode 100644 index 0000000..07655a8 --- /dev/null +++ b/common/migrations/definitions/20251122_151002_remove_legacy_embedding_fields.json @@ -0,0 +1 @@ +{"schemas":"--- original\n+++ modified\n@@ -85,31 +85,30 @@\n\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;\n\n-# Indexes based on build_indexes and query patterns\n-# The INDEX definition correctly specifies the vector properties\n-# HNSW index now defined on knowledge_entity_embedding table for better memory usage \n-# DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n+-- Indexes based on build_indexes and query patterns\n+-- HNSW index now defined on knowledge_entity_embedding table for better memory usage\n+-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n-# Defines the schema for the 'knowledge_entity_embedding' table.\n-# Separate table to optimize HNSW index creation memory usage\n+-- Defines the schema for the 'knowledge_entity_embedding' table.\n+-- Separate table to optimize HNSW index creation memory usage\n\n DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;\n\n-# Standard fields\n+-- Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n\n-# Custom fields\n+-- Custom fields\n DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record;\n DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array;\n\n-# Indexes\n-# DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n+-- Indexes\n+-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n\n@@ -220,8 +219,8 @@\n DEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;\n\n-# Defines the schema for the 'text_chunk_embedding' table.\n-# Separate table to optimize HNSW index creation memory usage\n+-- Defines the schema for the 'text_chunk_embedding' table.\n+-- Separate table to optimize HNSW index creation memory usage\n\n DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;\n\n@@ -235,8 +234,8 @@\n DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record;\n DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array;\n\n-# Indexes\n-# DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;\n+-- Indexes\n+-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;\n","events":null} \ No newline at end of file diff --git a/common/migrations/definitions/_initial.json b/common/migrations/definitions/_initial.json index 2516565..7c40dd1 100644 --- a/common/migrations/definitions/_initial.json +++ b/common/migrations/definitions/_initial.json @@ -1 +1 @@ -{"schemas":"# Defines the schema for the 'analytics' table.\n\nDEFINE TABLE IF NOT EXISTS analytics SCHEMALESS;\n\n# Custom fields from the Analytics struct\nDEFINE FIELD IF NOT EXISTS page_loads ON analytics TYPE number;\nDEFINE FIELD IF NOT EXISTS visitors ON analytics TYPE number;\n\n# Defines authentication scope and access rules.\n# This mirrors the logic previously in SurrealDbClient::setup_auth\n\nDEFINE ACCESS IF NOT EXISTS account ON DATABASE TYPE RECORD\n SIGNUP ( CREATE user SET email = $email, password = crypto::argon2::generate($password), anonymous = false, user_id = $user_id) # Ensure user_id is provided if needed\n SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(password, $password) );\n\n# Defines the schema for the 'conversation' table.\n\nDEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE string;\nDEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE string;\n\n# Custom fields from the Conversation struct\nDEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;\nDEFINE FIELD IF NOT EXISTS title ON conversation TYPE string;\n\n# Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)\nDEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY\n\n# Defines the schema for the 'file' table (used by FileInfo).\n\nDEFINE TABLE IF NOT EXISTS file SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS updated_at ON file TYPE string;\n\n# Custom fields from the FileInfo struct\nDEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS path ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS file_name ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;\n\n# Indexes based on usage (get_by_sha, potentially user lookups)\n# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates\nDEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;\nDEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;\n\n# Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n\nDEFINE TABLE IF NOT EXISTS job SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON job TYPE string;\nDEFINE FIELD IF NOT EXISTS updated_at ON job TYPE string;\n\n# Custom fields from the IngestionTask struct\n# IngestionPayload is complex, store as object\nDEFINE FIELD IF NOT EXISTS content ON job TYPE object;\n# IngestionTaskStatus can hold data (InProgress), store as object\nDEFINE FIELD IF NOT EXISTS status ON job TYPE object;\nDEFINE FIELD IF NOT EXISTS user_id ON job TYPE string;\n\n# Indexes explicitly defined in build_indexes and useful for get_unfinished_tasks\nDEFINE INDEX IF NOT EXISTS idx_job_status ON job FIELDS status;\nDEFINE INDEX IF NOT EXISTS idx_job_user ON job FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS idx_job_created ON job FIELDS created_at;\n\n# Defines the schema for the 'knowledge_entity' table.\n\nDEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE string;\nDEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE string;\n\n# Custom fields from the KnowledgeEntity struct\nDEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;\nDEFINE FIELD IF NOT EXISTS name ON knowledge_entity TYPE string;\nDEFINE FIELD IF NOT EXISTS description ON knowledge_entity TYPE string;\n# KnowledgeEntityType is an enum, store as string\nDEFINE FIELD IF NOT EXISTS entity_type ON knowledge_entity TYPE string;\n# metadata is Option, store as object\nDEFINE FIELD IF NOT EXISTS metadata ON knowledge_entity TYPE option;\n\n# Define embedding as a standard array of floats for schema definition\nDEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity TYPE array;\n# The specific vector nature is handled by the index definition below\n\nDEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;\n\n# Indexes based on build_indexes and query patterns\n# The INDEX definition correctly specifies the vector properties\nDEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities\n\n# Defines the schema for the 'message' table.\n\nDEFINE TABLE IF NOT EXISTS message SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON message TYPE string;\nDEFINE FIELD IF NOT EXISTS updated_at ON message TYPE string;\n\n# Custom fields from the Message struct\nDEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;\n# MessageRole is an enum, store as string\nDEFINE FIELD IF NOT EXISTS role ON message TYPE string;\nDEFINE FIELD IF NOT EXISTS content ON message TYPE string;\n# references is Option>, store as array\nDEFINE FIELD IF NOT EXISTS references ON message TYPE option>;\n\n# Indexes based on query patterns (get_complete_conversation)\nDEFINE INDEX IF NOT EXISTS message_conversation_id_idx ON message FIELDS conversation_id;\nDEFINE INDEX IF NOT EXISTS message_updated_at_idx ON message FIELDS updated_at; # For ORDER BY\n\n# Defines the 'relates_to' edge table for KnowledgeRelationships.\n# Edges connect nodes, in this case knowledge_entity records.\n\n# Define the edge table itself, enforcing connections between knowledge_entity records\n# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary\nDEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n\n# Define the metadata field within the edge\n# RelationshipMetadata is a struct, store as object\nDEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;\n\n# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)\n# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n\n# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)\nDEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\nDEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n\nDEFINE TABLE OVERWRITE script_migration SCHEMAFULL\n PERMISSIONS\n FOR select FULL\n FOR create, update, delete NONE;\n\nDEFINE FIELD OVERWRITE script_name ON script_migration TYPE string;\nDEFINE FIELD OVERWRITE executed_at ON script_migration TYPE datetime VALUE time::now() READONLY;\n\n# Defines the schema for the 'system_settings' table.\n\nDEFINE TABLE IF NOT EXISTS system_settings SCHEMALESS;\n\n# Custom fields from the SystemSettings struct\nDEFINE FIELD IF NOT EXISTS registrations_enabled ON system_settings TYPE bool;\nDEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\nDEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n\n# Defines the schema for the 'text_chunk' table.\n\nDEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE string;\nDEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE string;\n\n# Custom fields from the TextChunk struct\nDEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;\nDEFINE FIELD IF NOT EXISTS chunk ON text_chunk TYPE string;\n\n# Define embedding as a standard array of floats for schema definition\nDEFINE FIELD IF NOT EXISTS embedding ON text_chunk TYPE array;\n# The specific vector nature is handled by the index definition below\n\nDEFINE FIELD IF NOT EXISTS user_id ON text_chunk TYPE string;\n\n# Indexes based on build_indexes and query patterns (delete_by_source_id)\n# The INDEX definition correctly specifies the vector properties\nDEFINE INDEX IF NOT EXISTS idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536;\nDEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;\nDEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;\n\n# Defines the schema for the 'text_content' table.\n\nDEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE string;\nDEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE string;\n\n# Custom fields from the TextContent struct\nDEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;\n# FileInfo is a struct, store as object\nDEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option;\n# UrlInfo is a struct, store as object\nDEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option;\nDEFINE FIELD IF NOT EXISTS context ON text_content TYPE option;\nDEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;\nDEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;\n\n# Indexes based on query patterns\nDEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;\nDEFINE INDEX IF NOT EXISTS text_content_category_idx ON text_content FIELDS category;\n\n# Defines the schema for the 'user' table.\n# NOTE: Authentication scope and access rules are defined in auth.surql\n\nDEFINE TABLE IF NOT EXISTS user SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON user TYPE string;\nDEFINE FIELD IF NOT EXISTS updated_at ON user TYPE string;\n\n# Custom fields from the User struct\nDEFINE FIELD IF NOT EXISTS email ON user TYPE string;\nDEFINE FIELD IF NOT EXISTS password ON user TYPE string; # Stores the hashed password\nDEFINE FIELD IF NOT EXISTS anonymous ON user TYPE bool;\nDEFINE FIELD IF NOT EXISTS api_key ON user TYPE option;\nDEFINE FIELD IF NOT EXISTS admin ON user TYPE bool;\nDEFINE FIELD IF NOT EXISTS timezone ON user TYPE string;\n\n# Indexes based on query patterns (find_by_email, find_by_api_key, unique constraint from setup_auth)\nDEFINE INDEX IF NOT EXISTS user_email_idx ON user FIELDS email UNIQUE;\nDEFINE INDEX IF NOT EXISTS user_api_key_idx ON user FIELDS api_key;\n","events":""} \ No newline at end of file +{"schemas":"# Defines the schema for the 'analytics' table.\n\nDEFINE TABLE IF NOT EXISTS analytics SCHEMALESS;\n\n# Custom fields from the Analytics struct\nDEFINE FIELD IF NOT EXISTS page_loads ON analytics TYPE number;\nDEFINE FIELD IF NOT EXISTS visitors ON analytics TYPE number;\n\n# Defines authentication scope and access rules.\n# This mirrors the logic previously in SurrealDbClient::setup_auth\n\nDEFINE ACCESS IF NOT EXISTS account ON DATABASE TYPE RECORD\n SIGNUP ( CREATE user SET email = $email, password = crypto::argon2::generate($password), anonymous = false, user_id = $user_id) # Ensure user_id is provided if needed\n SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(password, $password) );\n\n# Defines the schema for the 'conversation' table.\n\nDEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE datetime;\n\n# Custom fields from the Conversation struct\nDEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;\nDEFINE FIELD IF NOT EXISTS title ON conversation TYPE string;\n\n# Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)\nDEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY\n\n# Defines the schema for the 'file' table (used by FileInfo).\n\nDEFINE TABLE IF NOT EXISTS file SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON file TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON file TYPE datetime;\n\n# Custom fields from the FileInfo struct\nDEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS path ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS file_name ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;\nDEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;\n\n# Indexes based on usage (get_by_sha, potentially user lookups)\n# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates\nDEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;\nDEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;\n\n# Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n\nDEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime;\n\nDEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\nDEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\nDEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n# Indexes explicitly defined in build_indexes and useful for get_unfinished_tasks\nDEFINE INDEX IF NOT EXISTS idx_ingestion_task_status ON ingestion_task FIELDS status;\nDEFINE INDEX IF NOT EXISTS idx_ingestion_task_user ON ingestion_task FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS idx_ingestion_task_created ON ingestion_task FIELDS created_at;\n\n# Defines the schema for the 'knowledge_entity' table.\n\nDEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE datetime;\n\n# Custom fields from the KnowledgeEntity struct\nDEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;\nDEFINE FIELD IF NOT EXISTS name ON knowledge_entity TYPE string;\nDEFINE FIELD IF NOT EXISTS description ON knowledge_entity TYPE string;\n# KnowledgeEntityType is an enum, store as string\nDEFINE FIELD IF NOT EXISTS entity_type ON knowledge_entity TYPE string;\n# metadata is Option, store as object\nDEFINE FIELD IF NOT EXISTS metadata ON knowledge_entity TYPE option;\n\nDEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;\n\n# Indexes based on build_indexes and query patterns\n# The INDEX definition correctly specifies the vector properties\n# HNSW index now defined on knowledge_entity_embedding table for better memory usage \n# DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n# Defines the schema for the 'knowledge_entity_embedding' table.\n# Separate table to optimize HNSW index creation memory usage\n\nDEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\nDEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n\n# Custom fields\nDEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record;\nDEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array;\n\n# Indexes\n# DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\nDEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n\n# Defines the schema for the 'message' table.\n\nDEFINE TABLE IF NOT EXISTS message SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON message TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON message TYPE datetime;\n\n# Custom fields from the Message struct\nDEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;\n# MessageRole is an enum, store as string\nDEFINE FIELD IF NOT EXISTS role ON message TYPE string;\nDEFINE FIELD IF NOT EXISTS content ON message TYPE string;\n# references is Option>, store as array\nDEFINE FIELD IF NOT EXISTS references ON message TYPE option>;\n\n# Indexes based on query patterns (get_complete_conversation)\nDEFINE INDEX IF NOT EXISTS message_conversation_id_idx ON message FIELDS conversation_id;\nDEFINE INDEX IF NOT EXISTS message_updated_at_idx ON message FIELDS updated_at; # For ORDER BY\n\n# Defines the 'relates_to' edge table for KnowledgeRelationships.\n# Edges connect nodes, in this case knowledge_entity records.\n\n# Define the edge table itself, enforcing connections between knowledge_entity records\n# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary\nDEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n\n# Define the metadata field within the edge\n# RelationshipMetadata is a struct, store as object\nDEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;\n\n# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)\n# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n\n# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)\nDEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\nDEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n\n# Defines the schema for the 'scratchpad' table.\n\nDEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;\n\n# Standard fields from stored_object! macro\nDEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;\n\n# Custom fields from the Scratchpad struct\nDEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;\nDEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;\nDEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;\nDEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;\nDEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;\nDEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;\nDEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option;\nDEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option;\n\n# Indexes based on query patterns\nDEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;\nDEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;\nDEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;\n\nDEFINE TABLE OVERWRITE script_migration SCHEMAFULL\n PERMISSIONS\n FOR select FULL\n FOR create, update, delete NONE;\n\nDEFINE FIELD OVERWRITE script_name ON script_migration TYPE string;\nDEFINE FIELD OVERWRITE executed_at ON script_migration TYPE datetime VALUE time::now() READONLY;\n\n# Defines the schema for the 'system_settings' table.\n\nDEFINE TABLE IF NOT EXISTS system_settings SCHEMALESS;\n\n# Custom fields from the SystemSettings struct\nDEFINE FIELD IF NOT EXISTS registrations_enabled ON system_settings TYPE bool;\nDEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\nDEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS image_processing_model ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\nDEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\nDEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n\n# Defines the schema for the 'text_chunk' table.\n\nDEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;\n\n# Custom fields from the TextChunk struct\nDEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;\nDEFINE FIELD IF NOT EXISTS chunk ON text_chunk TYPE string;\n\nDEFINE FIELD IF NOT EXISTS user_id ON text_chunk TYPE string;\n\n# Indexes based on build_indexes and query patterns (delete_by_source_id)\nDEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;\nDEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;\n\n# Defines the schema for the 'text_chunk_embedding' table.\n# Separate table to optimize HNSW index creation memory usage\n\nDEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;\nDEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;\nDEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;\n\n# Custom fields\nDEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record;\nDEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array;\n\n# Indexes\n# DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;\nDEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;\nDEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;\n\n# Defines the schema for the 'text_content' table.\n\nDEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;\n\n# Custom fields from the TextContent struct\nDEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;\n# FileInfo is a struct, store as object\nDEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option;\n# UrlInfo is a struct, store as object\nDEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option;\nDEFINE FIELD IF NOT EXISTS context ON text_content TYPE option;\nDEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;\nDEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;\n\n# Indexes based on query patterns\nDEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;\nDEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;\nDEFINE INDEX IF NOT EXISTS text_content_category_idx ON text_content FIELDS category;\n\n# Defines the schema for the 'user' table.\n# NOTE: Authentication scope and access rules are defined in auth.surql\n\nDEFINE TABLE IF NOT EXISTS user SCHEMALESS;\n\n# Standard fields\nDEFINE FIELD IF NOT EXISTS created_at ON user TYPE datetime;\nDEFINE FIELD IF NOT EXISTS updated_at ON user TYPE datetime;\n\n# Custom fields from the User struct\nDEFINE FIELD IF NOT EXISTS email ON user TYPE string;\nDEFINE FIELD IF NOT EXISTS password ON user TYPE string; # Stores the hashed password\nDEFINE FIELD IF NOT EXISTS anonymous ON user TYPE bool;\nDEFINE FIELD IF NOT EXISTS api_key ON user TYPE option;\nDEFINE FIELD IF NOT EXISTS admin ON user TYPE bool;\nDEFINE FIELD IF NOT EXISTS timezone ON user TYPE string;\n\n# Indexes based on query patterns (find_by_email, find_by_api_key, unique constraint from setup_auth)\nDEFINE INDEX IF NOT EXISTS user_email_idx ON user FIELDS email UNIQUE;\nDEFINE INDEX IF NOT EXISTS user_api_key_idx ON user FIELDS api_key;\n","events":""} \ No newline at end of file diff --git a/common/schemas/knowledge_entity.surql b/common/schemas/knowledge_entity.surql index 1fd95ab..6c6be77 100644 --- a/common/schemas/knowledge_entity.surql +++ b/common/schemas/knowledge_entity.surql @@ -15,16 +15,12 @@ DEFINE FIELD IF NOT EXISTS entity_type ON knowledge_entity TYPE string; # metadata is Option, store as object DEFINE FIELD IF NOT EXISTS metadata ON knowledge_entity TYPE option; -# Define embedding as a standard array of floats for schema definition -DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity TYPE array; -# The specific vector nature is handled by the index definition below - DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string; -# Indexes based on build_indexes and query patterns -# The INDEX definition correctly specifies the vector properties -DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536; -DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id; +-- Indexes based on build_indexes and query patterns +-- HNSW index now defined on knowledge_entity_embedding table for better memory usage +-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536; DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id; +DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id; DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type; DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; diff --git a/common/schemas/knowledge_entity_embedding.surql b/common/schemas/knowledge_entity_embedding.surql new file mode 100644 index 0000000..7f852b4 --- /dev/null +++ b/common/schemas/knowledge_entity_embedding.surql @@ -0,0 +1,18 @@ +-- Defines the schema for the 'knowledge_entity_embedding' table. +-- Separate table to optimize HNSW index creation memory usage + +DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL; + +-- Standard fields +DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime; +DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime; +DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string; + +-- Custom fields +DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record; +DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array; + +-- Indexes +-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536; +DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id; +DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id; diff --git a/common/schemas/text_chunk.surql b/common/schemas/text_chunk.surql index 9d1fe16..6105d06 100644 --- a/common/schemas/text_chunk.surql +++ b/common/schemas/text_chunk.surql @@ -10,14 +10,8 @@ DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime; DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string; DEFINE FIELD IF NOT EXISTS chunk ON text_chunk TYPE string; -# Define embedding as a standard array of floats for schema definition -DEFINE FIELD IF NOT EXISTS embedding ON text_chunk TYPE array; -# The specific vector nature is handled by the index definition below - DEFINE FIELD IF NOT EXISTS user_id ON text_chunk TYPE string; # Indexes based on build_indexes and query patterns (delete_by_source_id) -# The INDEX definition correctly specifies the vector properties -DEFINE INDEX IF NOT EXISTS idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536; DEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id; DEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id; diff --git a/common/schemas/text_chunk_embedding.surql b/common/schemas/text_chunk_embedding.surql new file mode 100644 index 0000000..5a43d55 --- /dev/null +++ b/common/schemas/text_chunk_embedding.surql @@ -0,0 +1,20 @@ +-- Defines the schema for the 'text_chunk_embedding' table. +-- Separate table to optimize HNSW index creation memory usage + +DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL; + +# Standard fields +DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime; +DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime; +DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string; +DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string; + +# Custom fields +DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record; +DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array; + +-- Indexes +-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536; +DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id; +DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id; +DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id; diff --git a/common/src/storage/db.rs b/common/src/storage/db.rs index 6475ae2..fb0d7dd 100644 --- a/common/src/storage/db.rs +++ b/common/src/storage/db.rs @@ -1,5 +1,8 @@ use super::types::StoredObject; -use crate::error::AppError; +use crate::{ + error::AppError, + storage::{indexes::ensure_runtime_indexes, types::system_settings::SystemSettings}, +}; use axum_session::{SessionConfig, SessionError, SessionStore}; use axum_session_surreal::SessionSurrealPool; use futures::Stream; @@ -96,20 +99,22 @@ impl SurrealDbClient { } /// Operation to rebuild indexes - pub async fn rebuild_indexes(&self) -> Result<(), Error> { + pub async fn rebuild_indexes(&self) -> Result<(), AppError> { debug!("Rebuilding indexes"); let rebuild_sql = r#" - BEGIN TRANSACTION; - REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk; - REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity; REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content; REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity; REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity; REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk; - COMMIT TRANSACTION; + REBUILD INDEX IF EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding; + REBUILD INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding; "#; - self.client.query(rebuild_sql).await?; + self.client + .query(rebuild_sql) + .await + .map_err(|e| AppError::InternalError(e.to_string()))?; + Ok(()) } diff --git a/common/src/storage/indexes.rs b/common/src/storage/indexes.rs new file mode 100644 index 0000000..ea1fbfe --- /dev/null +++ b/common/src/storage/indexes.rs @@ -0,0 +1,589 @@ +use std::time::Duration; + +use anyhow::{Context, Result}; +use serde::Deserialize; +use serde_json::Value; +use tracing::{info, warn}; + +use crate::{error::AppError, storage::db::SurrealDbClient}; + +const INDEX_POLL_INTERVAL: Duration = Duration::from_secs(2); +const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer"; + +#[derive(Clone, Copy)] +struct HnswIndexSpec { + index_name: &'static str, + table: &'static str, + options: &'static str, +} + +impl HnswIndexSpec { + fn definition_if_not_exists(&self, dimension: usize) -> String { + format!( + "DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} \ + FIELDS embedding HNSW DIMENSION {dimension} {options};", + index = self.index_name, + table = self.table, + dimension = dimension, + options = self.options, + ) + } + + fn definition_overwrite(&self, dimension: usize) -> String { + format!( + "DEFINE INDEX OVERWRITE {index} ON TABLE {table} \ + FIELDS embedding HNSW DIMENSION {dimension} {options};", + index = self.index_name, + table = self.table, + dimension = dimension, + options = self.options, + ) + } +} + +#[derive(Clone, Copy)] +struct FtsIndexSpec { + index_name: &'static str, + table: &'static str, + field: &'static str, + analyzer: Option<&'static str>, + method: &'static str, +} + +impl FtsIndexSpec { + fn definition(&self) -> String { + let analyzer_clause = self + .analyzer + .map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method)) + .unwrap_or_default(); + + format!( + "DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;", + index = self.index_name, + table = self.table, + field = self.field, + ) + } +} + +/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling. +/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes. +pub async fn ensure_runtime_indexes( + db: &SurrealDbClient, + embedding_dimension: usize, +) -> Result<(), AppError> { + ensure_runtime_indexes_inner(db, embedding_dimension) + .await + .map_err(|err| AppError::InternalError(err.to_string())) +} + +async fn ensure_runtime_indexes_inner( + db: &SurrealDbClient, + embedding_dimension: usize, +) -> Result<()> { + create_fts_analyzer(db).await?; + + for spec in fts_index_specs() { + create_index_with_polling( + db, + spec.definition(), + spec.index_name, + spec.table, + Some(spec.table), + ) + .await?; + } + + for spec in hnsw_index_specs() { + ensure_hnsw_index(db, &spec, embedding_dimension).await?; + } + + Ok(()) +} + +async fn ensure_hnsw_index( + db: &SurrealDbClient, + spec: &HnswIndexSpec, + dimension: usize, +) -> Result<()> { + let definition = match hnsw_index_state(db, spec, dimension).await? { + HnswIndexState::Missing => spec.definition_if_not_exists(dimension), + HnswIndexState::Matches(_) => spec.definition_if_not_exists(dimension), + HnswIndexState::Different(existing) => { + info!( + index = spec.index_name, + table = spec.table, + existing_dimension = existing, + target_dimension = dimension, + "Overwriting HNSW index to match new embedding dimension" + ); + spec.definition_overwrite(dimension) + } + }; + + create_index_with_polling( + db, + definition, + spec.index_name, + spec.table, + Some(spec.table), + ) + .await +} + +async fn hnsw_index_state( + db: &SurrealDbClient, + spec: &HnswIndexSpec, + expected_dimension: usize, +) -> Result { + let info_query = format!("INFO FOR TABLE {table};", table = spec.table); + let mut response = db + .client + .query(info_query) + .await + .with_context(|| format!("fetching table info for {}", spec.table))?; + + let info: surrealdb::Value = response + .take(0) + .context("failed to take table info response")?; + + let info_json: Value = + serde_json::to_value(info).context("serializing table info to JSON for parsing")?; + + let Some(indexes) = info_json + .get("Object") + .and_then(|o| o.get("indexes")) + .and_then(|i| i.get("Object")) + .and_then(|i| i.as_object()) + else { + return Ok(HnswIndexState::Missing); + }; + + let Some(definition) = indexes + .get(spec.index_name) + .and_then(|details| details.get("Strand")) + .and_then(|v| v.as_str()) + else { + return Ok(HnswIndexState::Missing); + }; + + let Some(current_dimension) = extract_dimension(definition) else { + return Ok(HnswIndexState::Missing); + }; + + if current_dimension == expected_dimension as u64 { + Ok(HnswIndexState::Matches(current_dimension)) + } else { + Ok(HnswIndexState::Different(current_dimension)) + } +} + +enum HnswIndexState { + Missing, + Matches(u64), + Different(u64), +} + +fn extract_dimension(definition: &str) -> Option { + definition + .split("DIMENSION") + .nth(1) + .and_then(|rest| rest.split_whitespace().next()) + .and_then(|token| token.trim_end_matches(';').parse::().ok()) +} + +async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> { + let analyzer_query = format!( + "DEFINE ANALYZER IF NOT EXISTS {analyzer} + TOKENIZERS class + FILTERS lowercase, ascii, snowball(english);", + analyzer = FTS_ANALYZER_NAME + ); + + let res = db + .client + .query(analyzer_query) + .await + .context("creating FTS analyzer")?; + + res.check().context("failed to create FTS analyzer")?; + Ok(()) +} + +async fn create_index_with_polling( + db: &SurrealDbClient, + definition: String, + index_name: &str, + table: &str, + progress_table: Option<&str>, +) -> Result<()> { + let expected_total = match progress_table { + Some(table) => Some(count_table_rows(db, table).await.with_context(|| { + format!("counting rows in {table} for index {index_name} progress") + })?), + None => None, + }; + + let res = db + .client + .query(definition) + .await + .with_context(|| format!("creating index {index_name} on table {table}"))?; + res.check() + .with_context(|| format!("index definition failed for {index_name} on {table}"))?; + + info!( + index = %index_name, + table = %table, + expected_rows = ?expected_total, + "Index definition submitted; waiting for build to finish" + ); + + poll_index_build_status(db, index_name, table, expected_total, INDEX_POLL_INTERVAL).await +} + +async fn poll_index_build_status( + db: &SurrealDbClient, + index_name: &str, + table: &str, + total_rows: Option, + poll_every: Duration, +) -> Result<()> { + let started_at = std::time::Instant::now(); + + loop { + tokio::time::sleep(poll_every).await; + + let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};"); + let mut info_res = db.client.query(info_query).await.with_context(|| { + format!("checking index build status for {index_name} on {table}") + })?; + + let info: Option = info_res + .take(0) + .context("failed to deserialize INFO FOR INDEX result")?; + + let Some(snapshot) = parse_index_build_info(info, total_rows) else { + warn!( + index = %index_name, + table = %table, + "INFO FOR INDEX returned no data; assuming index definition might be missing" + ); + break; + }; + + match snapshot.progress_pct { + Some(pct) => info!( + index = %index_name, + table = %table, + status = snapshot.status, + initial = snapshot.initial, + pending = snapshot.pending, + updated = snapshot.updated, + processed = snapshot.processed, + total = snapshot.total_rows, + progress_pct = format_args!("{pct:.1}"), + "Index build status" + ), + None => info!( + index = %index_name, + table = %table, + status = snapshot.status, + initial = snapshot.initial, + pending = snapshot.pending, + updated = snapshot.updated, + processed = snapshot.processed, + "Index build status" + ), + } + + if snapshot.is_ready() { + info!( + index = %index_name, + table = %table, + elapsed = ?started_at.elapsed(), + processed = snapshot.processed, + total = snapshot.total_rows, + "Index is ready" + ); + break; + } + + if snapshot.status.eq_ignore_ascii_case("error") { + warn!( + index = %index_name, + table = %table, + status = snapshot.status, + "Index build reported error status; stopping polling" + ); + break; + } + } + + Ok(()) +} + +#[derive(Debug, PartialEq)] +struct IndexBuildSnapshot { + status: String, + initial: u64, + pending: u64, + updated: u64, + processed: u64, + total_rows: Option, + progress_pct: Option, +} + +impl IndexBuildSnapshot { + fn is_ready(&self) -> bool { + self.status.eq_ignore_ascii_case("ready") + } +} + +fn parse_index_build_info( + info: Option, + total_rows: Option, +) -> Option { + let info = info?; + let building = info.get("building"); + + let status = building + .and_then(|b| b.get("status")) + .and_then(|s| s.as_str()) + // If there's no `building` block at all, treat as "ready" (index not building anymore) + .unwrap_or("ready") + .to_string(); + + let initial = building + .and_then(|b| b.get("initial")) + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + let pending = building + .and_then(|b| b.get("pending")) + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + let updated = building + .and_then(|b| b.get("updated")) + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + // `initial` is the number of rows seen when the build started; `updated` accounts for later writes. + let processed = initial.saturating_add(updated); + + let progress_pct = total_rows.map(|total| { + if total == 0 { + 0.0 + } else { + ((processed as f64 / total as f64).min(1.0)) * 100.0 + } + }); + + Some(IndexBuildSnapshot { + status, + initial, + pending, + updated, + processed, + total_rows, + progress_pct, + }) +} + +#[derive(Debug, Deserialize)] +struct CountRow { + count: u64, +} + +async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result { + let query = format!("SELECT count() AS count FROM {table} GROUP ALL;"); + let mut response = db + .client + .query(query) + .await + .with_context(|| format!("counting rows in {table}"))?; + let rows: Vec = response + .take(0) + .context("failed to deserialize count() response")?; + Ok(rows.first().map(|r| r.count).unwrap_or(0)) +} + +const fn hnsw_index_specs() -> [HnswIndexSpec; 2] { + [ + HnswIndexSpec { + index_name: "idx_embedding_text_chunk_embedding", + table: "text_chunk_embedding", + options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", + }, + HnswIndexSpec { + index_name: "idx_embedding_knowledge_entity_embedding", + table: "knowledge_entity_embedding", + options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY", + }, + ] +} + +const fn fts_index_specs() -> [FtsIndexSpec; 9] { + [ + FtsIndexSpec { + index_name: "text_content_fts_idx", + table: "text_content", + field: "text", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_category_fts_idx", + table: "text_content", + field: "category", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_context_fts_idx", + table: "text_content", + field: "context", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_file_name_fts_idx", + table: "text_content", + field: "file_info.file_name", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_url_fts_idx", + table: "text_content", + field: "url_info.url", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_content_url_title_fts_idx", + table: "text_content", + field: "url_info.title", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "knowledge_entity_fts_name_idx", + table: "knowledge_entity", + field: "name", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "knowledge_entity_fts_description_idx", + table: "knowledge_entity", + field: "description", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + FtsIndexSpec { + index_name: "text_chunk_fts_chunk_idx", + table: "text_chunk", + field: "chunk", + analyzer: Some(FTS_ANALYZER_NAME), + method: "BM25", + }, + ] +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use uuid::Uuid; + + #[test] + fn parse_index_build_info_reports_progress() { + let info = json!({ + "building": { + "initial": 56894, + "pending": 0, + "status": "indexing", + "updated": 0 + } + }); + + let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot"); + assert_eq!( + snapshot, + IndexBuildSnapshot { + status: "indexing".to_string(), + initial: 56894, + pending: 0, + updated: 0, + processed: 56894, + total_rows: Some(61081), + progress_pct: Some((56894_f64 / 61081_f64) * 100.0), + } + ); + assert!(!snapshot.is_ready()); + } + + #[test] + fn parse_index_build_info_defaults_to_ready_when_no_building_block() { + // Surreal returns `{}` when the index exists but isn't building. + let info = json!({}); + let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot"); + assert!(snapshot.is_ready()); + assert_eq!(snapshot.processed, 0); + assert_eq!(snapshot.progress_pct, Some(0.0)); + } + + #[test] + fn extract_dimension_parses_value() { + let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;"; + assert_eq!(extract_dimension(definition), Some(1536)); + } + + #[tokio::test] + async fn ensure_runtime_indexes_is_idempotent() { + let namespace = "indexes_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("in-memory db"); + + db.apply_migrations() + .await + .expect("migrations should succeed"); + + // First run creates everything + ensure_runtime_indexes(&db, 1536) + .await + .expect("initial index creation"); + + // Second run should be a no-op and still succeed + ensure_runtime_indexes(&db, 1536) + .await + .expect("second index creation"); + } + + #[tokio::test] + async fn ensure_hnsw_index_overwrites_dimension() { + let namespace = "indexes_dim"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("in-memory db"); + + db.apply_migrations() + .await + .expect("migrations should succeed"); + + // Create initial index with default dimension + ensure_runtime_indexes(&db, 1536) + .await + .expect("initial index creation"); + + // Change dimension and ensure overwrite path is exercised + ensure_runtime_indexes(&db, 128) + .await + .expect("overwritten index creation"); + } +} diff --git a/common/src/storage/mod.rs b/common/src/storage/mod.rs index 724987b..01d2415 100644 --- a/common/src/storage/mod.rs +++ b/common/src/storage/mod.rs @@ -1,3 +1,4 @@ pub mod db; +pub mod indexes; pub mod store; pub mod types; diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index db153ad..2a13901 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; use crate::{ - error::AppError, storage::db::SurrealDbClient, stored_object, + error::AppError, storage::db::SurrealDbClient, + storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object, utils::embedding::generate_embedding, }; use async_openai::{config::OpenAIConfig, Client}; @@ -78,10 +79,16 @@ stored_object!(KnowledgeEntity, "knowledge_entity", { description: String, entity_type: KnowledgeEntityType, metadata: Option, - embedding: Vec, user_id: String }); +/// Vector search result including hydrated entity. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct KnowledgeEntityVectorResult { + pub entity: KnowledgeEntity, + pub score: f32, +} + impl KnowledgeEntity { pub fn new( source_id: String, @@ -89,7 +96,6 @@ impl KnowledgeEntity { description: String, entity_type: KnowledgeEntityType, metadata: Option, - embedding: Vec, user_id: String, ) -> Self { let now = Utc::now(); @@ -102,7 +108,6 @@ impl KnowledgeEntity { description, entity_type, metadata, - embedding, user_id, } } @@ -165,6 +170,89 @@ impl KnowledgeEntity { Ok(()) } + /// Atomically store a knowledge entity and its embedding. + /// Writes the entity to `knowledge_entity` and the embedding to `knowledge_entity_embedding`. + pub async fn store_with_embedding( + entity: KnowledgeEntity, + embedding: Vec, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone()); + + let query = format!( + " + BEGIN TRANSACTION; + CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity; + CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb; + COMMIT TRANSACTION; + ", + entity_table = Self::table_name(), + emb_table = KnowledgeEntityEmbedding::table_name(), + ); + + db.client + .query(query) + .bind(("entity_id", entity.id.clone())) + .bind(("entity", entity)) + .bind(("emb_id", emb.id.clone())) + .bind(("emb", emb)) + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; + + Ok(()) + } + + /// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores. + pub async fn vector_search( + take: usize, + query_embedding: Vec, + db: &SurrealDbClient, + user_id: &str, + ) -> Result, AppError> { + #[derive(Deserialize)] + struct Row { + entity_id: KnowledgeEntity, + score: f32, + } + + let sql = format!( + r#" + SELECT + entity_id, + vector::similarity::cosine(embedding, $embedding) AS score + FROM {emb_table} + WHERE user_id = $user_id + AND embedding <|{take},100|> $embedding + ORDER BY score DESC + LIMIT {take} + FETCH entity_id; + "#, + emb_table = KnowledgeEntityEmbedding::table_name(), + take = take + ); + + let mut response = db + .query(&sql) + .bind(("embedding", query_embedding)) + .bind(("user_id", user_id.to_string())) + .await + .map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?; + + response = response.check().map_err(AppError::Database)?; + + let rows: Vec = response.take::>(0).map_err(AppError::Database)?; + + Ok(rows + .into_iter() + .map(|r| KnowledgeEntityVectorResult { + entity: r.entity_id, + score: r.score, + }) + .collect()) + } + pub async fn patch( id: &str, name: &str, @@ -178,32 +266,55 @@ impl KnowledgeEntity { name, description, entity_type ); let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?; + let user_id = Self::get_user_id_by_id(id, db_client).await?; + let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id); let now = Utc::now(); db_client .client .query( - "UPDATE type::thing($table, $id) - SET name = $name, - description = $description, - updated_at = $updated_at, - entity_type = $entity_type, - embedding = $embedding - RETURN AFTER", + "BEGIN TRANSACTION; + UPDATE type::thing($table, $id) + SET name = $name, + description = $description, + updated_at = $updated_at, + entity_type = $entity_type; + UPSERT type::thing($emb_table, $emb_id) CONTENT $emb; + COMMIT TRANSACTION;", ) .bind(("table", Self::table_name())) + .bind(("emb_table", KnowledgeEntityEmbedding::table_name())) .bind(("id", id.to_string())) .bind(("name", name.to_string())) .bind(("updated_at", surrealdb::Datetime::from(now))) .bind(("entity_type", entity_type.to_owned())) - .bind(("embedding", embedding)) + .bind(("emb_id", emb.id.clone())) + .bind(("emb", emb)) .bind(("description", description.to_string())) .await?; Ok(()) } + async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result { + let mut response = db_client + .client + .query("SELECT user_id FROM type::thing($table, $id) LIMIT 1") + .bind(("table", Self::table_name())) + .bind(("id", id.to_string())) + .await + .map_err(AppError::Database)?; + #[derive(Deserialize)] + struct Row { + user_id: String, + } + let rows: Vec = response.take(0).map_err(AppError::Database)?; + rows.get(0) + .map(|r| r.user_id.clone()) + .ok_or_else(|| AppError::InternalError("user not found for entity".to_string())) + } + /// Re-creates embeddings for all knowledge entities in the database. /// /// This is a costly operation that should be run in the background. It follows the same @@ -228,22 +339,13 @@ impl KnowledgeEntity { if total_entities == 0 { info!("No knowledge entities to update. Just updating the idx"); - let mut transaction_query = String::from("BEGIN TRANSACTION;"); - transaction_query - .push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;"); - transaction_query.push_str(&format!( - "DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};", - new_dimensions - )); - transaction_query.push_str("COMMIT TRANSACTION;"); - - db.query(transaction_query).await?; + KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?; return Ok(()); } info!("Found {} entities to process.", total_entities); // Generate all new embeddings in memory - let mut new_embeddings: HashMap> = HashMap::new(); + let mut new_embeddings: HashMap, String)> = HashMap::new(); info!("Generating new embeddings for all entities..."); for entity in all_entities.iter() { let embedding_input = format!( @@ -271,17 +373,16 @@ impl KnowledgeEntity { error!("{}", err_msg); return Err(AppError::InternalError(err_msg)); } - new_embeddings.insert(entity.id.clone(), embedding); + new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); } info!("Successfully generated all new embeddings."); // Perform DB updates in a single transaction - info!("Applying schema and data changes in a transaction..."); + info!("Applying embedding updates in a transaction..."); let mut transaction_query = String::from("BEGIN TRANSACTION;"); - // Add all update statements - for (id, embedding) in new_embeddings { - // We must properly serialize the vector for the SurrealQL query string + // Add all update statements to the embedding table + for (id, (embedding, user_id)) in new_embeddings { let embedding_str = format!( "[{}]", embedding @@ -291,18 +392,22 @@ impl KnowledgeEntity { .join(",") ); transaction_query.push_str(&format!( - "UPDATE type::thing('knowledge_entity', '{}') SET embedding = {}, updated_at = time::now();", - id, embedding_str - )); + "UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \ + entity_id = type::thing('knowledge_entity', '{id}'), \ + embedding = {embedding}, \ + user_id = '{user_id}', \ + created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ + updated_at = time::now();", + id = id, + embedding = embedding_str, + user_id = user_id + )); } - // Re-create the index after updating the data that it will index - transaction_query - .push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;"); transaction_query.push_str(&format!( - "DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};", - new_dimensions - )); + "DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {};", + new_dimensions + )); transaction_query.push_str("COMMIT TRANSACTION;"); @@ -317,7 +422,9 @@ impl KnowledgeEntity { #[cfg(test)] mod tests { use super::*; + use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding; use serde_json::json; + use uuid::Uuid; #[tokio::test] async fn test_knowledge_entity_creation() { @@ -327,7 +434,6 @@ mod tests { let description = "Test Description".to_string(); let entity_type = KnowledgeEntityType::Document; let metadata = Some(json!({"key": "value"})); - let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; let user_id = "user123".to_string(); let entity = KnowledgeEntity::new( @@ -336,7 +442,6 @@ mod tests { description.clone(), entity_type.clone(), metadata.clone(), - embedding.clone(), user_id.clone(), ); @@ -346,7 +451,6 @@ mod tests { assert_eq!(entity.description, description); assert_eq!(entity.entity_type, entity_type); assert_eq!(entity.metadata, metadata); - assert_eq!(entity.embedding, embedding); assert_eq!(entity.user_id, user_id); assert!(!entity.id.is_empty()); } @@ -410,20 +514,25 @@ mod tests { let db = SurrealDbClient::memory(namespace, database) .await .expect("Failed to start in-memory surrealdb"); + db.apply_migrations() + .await + .expect("Failed to apply migrations"); // Create two entities with the same source_id let source_id = "source123".to_string(); let entity_type = KnowledgeEntityType::Document; - let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; let user_id = "user123".to_string(); + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5) + .await + .expect("Failed to redefine index length"); + let entity1 = KnowledgeEntity::new( source_id.clone(), "Entity 1".to_string(), "Description 1".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -433,7 +542,6 @@ mod tests { "Description 2".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -445,18 +553,18 @@ mod tests { "Different Description".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); + let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5]; // Store the entities - db.store_item(entity1) + KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db) .await .expect("Failed to store entity 1"); - db.store_item(entity2) + KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db) .await .expect("Failed to store entity 2"); - db.store_item(different_entity.clone()) + KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db) .await .expect("Failed to store different entity"); @@ -505,6 +613,162 @@ mod tests { assert_eq!(different_remaining[0].id, different_entity.id); } - // Note: We can't easily test the patch method without mocking the OpenAI client - // and the generate_embedding function. This would require more complex setup. + #[tokio::test] + async fn test_vector_search_returns_empty_when_no_embeddings() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("Failed to redefine index length"); + + let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") + .await + .expect("vector search"); + assert!(results.is_empty()); + } + + #[tokio::test] + async fn test_vector_search_single_result() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("Failed to redefine index length"); + + let user_id = "user".to_string(); + let source_id = "src".to_string(); + let entity = KnowledgeEntity::new( + source_id.clone(), + "hello".to_string(), + "world".to_string(), + KnowledgeEntityType::Document, + None, + user_id.clone(), + ); + + KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db) + .await + .expect("store entity with embedding"); + + let stored_entity: Option = db.get_item(&entity.id).await.unwrap(); + assert!(stored_entity.is_some()); + + let stored_embeddings: Vec = db + .client + .query(format!( + "SELECT * FROM {}", + KnowledgeEntityEmbedding::table_name() + )) + .await + .expect("query embeddings") + .take(0) + .expect("take embeddings"); + assert_eq!(stored_embeddings.len(), 1); + + let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); + let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db) + .await + .expect("fetch embedding"); + assert!(fetched_emb.is_some()); + + let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) + .await + .expect("vector search"); + + assert_eq!(results.len(), 1); + let res = &results[0]; + assert_eq!(res.entity.id, entity.id); + assert_eq!(res.entity.source_id, source_id); + assert_eq!(res.entity.name, "hello"); + } + + #[tokio::test] + async fn test_vector_search_orders_by_similarity() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("Failed to redefine index length"); + + let user_id = "user".to_string(); + let e1 = KnowledgeEntity::new( + "s1".to_string(), + "entity one".to_string(), + "desc".to_string(), + KnowledgeEntityType::Document, + None, + user_id.clone(), + ); + let e2 = KnowledgeEntity::new( + "s2".to_string(), + "entity two".to_string(), + "desc".to_string(), + KnowledgeEntityType::Document, + None, + user_id.clone(), + ); + + KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db) + .await + .expect("store e1"); + KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db) + .await + .expect("store e2"); + + let stored_e1: Option = db.get_item(&e1.id).await.unwrap(); + let stored_e2: Option = db.get_item(&e2.id).await.unwrap(); + assert!(stored_e1.is_some() && stored_e2.is_some()); + + let stored_embeddings: Vec = db + .client + .query(format!( + "SELECT * FROM {}", + KnowledgeEntityEmbedding::table_name() + )) + .await + .expect("query embeddings") + .take(0) + .expect("take embeddings"); + assert_eq!(stored_embeddings.len(), 2); + + let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id); + let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id); + assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db) + .await + .unwrap() + .is_some()); + assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db) + .await + .unwrap() + .is_some()); + + let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) + .await + .expect("vector search"); + + assert_eq!(results.len(), 2); + assert_eq!(results[0].entity.id, e2.id); + assert_eq!(results[1].entity.id, e1.id); + } } diff --git a/common/src/storage/types/knowledge_entity_embedding.rs b/common/src/storage/types/knowledge_entity_embedding.rs new file mode 100644 index 0000000..ad4ccfa --- /dev/null +++ b/common/src/storage/types/knowledge_entity_embedding.rs @@ -0,0 +1,385 @@ +use std::collections::HashMap; + +use surrealdb::RecordId; + +use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; + +stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", { + entity_id: RecordId, + embedding: Vec, + /// Denormalized user id for query scoping + user_id: String +}); + +impl KnowledgeEntityEmbedding { + /// Recreate the HNSW index with a new embedding dimension. + pub async fn redefine_hnsw_index( + db: &SurrealDbClient, + dimension: usize, + ) -> Result<(), AppError> { + let query = format!( + "BEGIN TRANSACTION; + REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table}; + DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension}; + COMMIT TRANSACTION;", + table = Self::table_name(), + ); + + let res = db.client.query(query).await.map_err(AppError::Database)?; + res.check().map_err(AppError::Database)?; + + Ok(()) + } + + /// Create a new knowledge entity embedding + pub fn new(entity_id: &str, embedding: Vec, user_id: String) -> Self { + let now = Utc::now(); + Self { + id: uuid::Uuid::new_v4().to_string(), + created_at: now, + updated_at: now, + entity_id: RecordId::from_table_key("knowledge_entity", entity_id), + embedding, + user_id, + } + } + + /// Get embedding by entity ID + pub async fn get_by_entity_id( + entity_id: &RecordId, + db: &SurrealDbClient, + ) -> Result, AppError> { + let query = format!( + "SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1", + Self::table_name() + ); + let mut result = db + .client + .query(query) + .bind(("entity_id", entity_id.clone())) + .await + .map_err(AppError::Database)?; + let embeddings: Vec = result.take(0).map_err(AppError::Database)?; + Ok(embeddings.into_iter().next()) + } + + /// Get embeddings for multiple entities in batch + pub async fn get_by_entity_ids( + entity_ids: &[RecordId], + db: &SurrealDbClient, + ) -> Result>, AppError> { + if entity_ids.is_empty() { + return Ok(HashMap::new()); + } + + let ids_list: Vec = entity_ids.iter().cloned().collect(); + + let query = format!( + "SELECT * FROM {} WHERE entity_id INSIDE $entity_ids", + Self::table_name() + ); + let mut result = db + .client + .query(query) + .bind(("entity_ids", ids_list)) + .await + .map_err(AppError::Database)?; + let embeddings: Vec = result.take(0).map_err(AppError::Database)?; + + Ok(embeddings + .into_iter() + .map(|e| (e.entity_id.key().to_string(), e.embedding)) + .collect()) + } + + /// Delete embedding by entity ID + pub async fn delete_by_entity_id( + entity_id: &RecordId, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + let query = format!( + "DELETE FROM {} WHERE entity_id = $entity_id", + Self::table_name() + ); + db.client + .query(query) + .bind(("entity_id", entity_id.clone())) + .await + .map_err(AppError::Database)?; + Ok(()) + } + + /// Delete embeddings by source_id (via joining to knowledge_entity table) + pub async fn delete_by_source_id( + source_id: &str, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id"; + let mut res = db + .client + .query(query) + .bind(("source_id", source_id.to_owned())) + .await + .map_err(AppError::Database)?; + #[derive(Deserialize)] + struct IdRow { + id: RecordId, + } + let ids: Vec = res.take(0).map_err(AppError::Database)?; + + for row in ids { + Self::delete_by_entity_id(&row.id, db).await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::db::SurrealDbClient; + use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; + use chrono::Utc; + use surrealdb::Value as SurrealValue; + use uuid::Uuid; + + async fn setup_test_db() -> SurrealDbClient { + let namespace = "test_ns"; + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, &database) + .await + .expect("Failed to start in-memory surrealdb"); + + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + + db + } + + fn build_knowledge_entity_with_id( + key: &str, + source_id: &str, + user_id: &str, + ) -> KnowledgeEntity { + KnowledgeEntity { + id: key.to_owned(), + created_at: Utc::now(), + updated_at: Utc::now(), + source_id: source_id.to_owned(), + name: "Test entity".to_owned(), + description: "Desc".to_owned(), + entity_type: KnowledgeEntityType::Document, + metadata: None, + user_id: user_id.to_owned(), + } + } + + #[tokio::test] + async fn test_create_and_get_by_entity_id() { + let db = setup_test_db().await; + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("set test index dimension"); + let user_id = "user_ke"; + let entity_key = "entity-1"; + let source_id = "source-ke"; + + let embedding_vec = vec![0.11_f32, 0.22, 0.33]; + let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id); + + KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db) + .await + .expect("Failed to store entity with embedding"); + + let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); + + let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) + .await + .expect("Failed to get embedding by entity_id") + .expect("Expected embedding to exist"); + + assert_eq!(fetched.user_id, user_id); + assert_eq!(fetched.entity_id, entity_rid); + assert_eq!(fetched.embedding, embedding_vec); + } + + #[tokio::test] + async fn test_delete_by_entity_id() { + let db = setup_test_db().await; + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("set test index dimension"); + let user_id = "user_ke"; + let entity_key = "entity-delete"; + let source_id = "source-del"; + + let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id); + + KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db) + .await + .expect("Failed to store entity with embedding"); + + let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); + + let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) + .await + .expect("Failed to get embedding before delete"); + assert!(existing.is_some()); + + KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db) + .await + .expect("Failed to delete by entity_id"); + + let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) + .await + .expect("Failed to get embedding after delete"); + assert!(after.is_none()); + } + + #[tokio::test] + async fn test_store_with_embedding_creates_entity_and_embedding() { + let db = setup_test_db().await; + let user_id = "user_store"; + let source_id = "source_store"; + let embedding = vec![0.2_f32, 0.3, 0.4]; + + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len()) + .await + .expect("set test index dimension"); + + let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id); + + KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db) + .await + .expect("Failed to store entity with embedding"); + + let stored_entity: Option = db.get_item(&entity.id).await.unwrap(); + assert!(stored_entity.is_some()); + + let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); + let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db) + .await + .expect("Failed to fetch embedding"); + assert!(stored_embedding.is_some()); + let stored_embedding = stored_embedding.unwrap(); + assert_eq!(stored_embedding.user_id, user_id); + assert_eq!(stored_embedding.entity_id, entity_rid); + } + + #[tokio::test] + async fn test_delete_by_source_id() { + let db = setup_test_db().await; + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("set test index dimension"); + let user_id = "user_ke"; + let source_id = "shared-ke"; + let other_source = "other-ke"; + + let entity1 = build_knowledge_entity_with_id("entity-s1", source_id, user_id); + let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id); + let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id); + + KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db) + .await + .expect("Failed to store entity with embedding"); + KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db) + .await + .expect("Failed to store entity with embedding"); + KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db) + .await + .expect("Failed to store entity with embedding"); + + let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id); + let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id); + let other_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity_other.id); + + KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db) + .await + .expect("Failed to delete by source_id"); + + assert!( + KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db) + .await + .unwrap() + .is_none() + ); + assert!( + KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db) + .await + .unwrap() + .is_none() + ); + assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db) + .await + .unwrap() + .is_some()); + } + + #[tokio::test] + async fn test_redefine_hnsw_index_updates_dimension() { + let db = setup_test_db().await; + + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16) + .await + .expect("failed to redefine index"); + + let mut info_res = db + .client + .query("INFO FOR TABLE knowledge_entity_embedding;") + .await + .expect("info query failed"); + let info: SurrealValue = info_res.take(0).expect("failed to take info result"); + let info_json: serde_json::Value = + serde_json::to_value(info).expect("failed to convert info to json"); + let idx_sql = info_json["Object"]["indexes"]["Object"] + ["idx_embedding_knowledge_entity_embedding"]["Strand"] + .as_str() + .unwrap_or_default(); + + assert!( + idx_sql.contains("DIMENSION 16"), + "expected index definition to contain new dimension, got: {idx_sql}" + ); + } + + #[tokio::test] + async fn test_fetch_entity_via_record_id() { + let db = setup_test_db().await; + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("set test index dimension"); + let user_id = "user_ke"; + let entity_key = "entity-fetch"; + let source_id = "source-fetch"; + + let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id); + KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db) + .await + .expect("Failed to store entity with embedding"); + + let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id); + + #[derive(Deserialize)] + struct Row { + entity_id: KnowledgeEntity, + } + + let mut res = db + .client + .query( + "SELECT entity_id FROM knowledge_entity_embedding WHERE entity_id = $id FETCH entity_id;", + ) + .bind(("id", entity_rid.clone())) + .await + .expect("failed to fetch embedding with FETCH"); + let rows: Vec = res.take(0).expect("failed to deserialize fetch rows"); + + assert_eq!(rows.len(), 1); + let fetched_entity = &rows[0].entity_id; + assert_eq!(fetched_entity.id, entity_key); + assert_eq!(fetched_entity.name, "Test entity"); + assert_eq!(fetched_entity.user_id, user_id); + } +} diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index 673bc71..7df01b3 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -119,7 +119,6 @@ mod tests { let source_id = "source123".to_string(); let description = format!("Description for {}", name); let entity_type = KnowledgeEntityType::Document; - let embedding = vec![0.1, 0.2, 0.3]; let user_id = "user123".to_string(); let entity = KnowledgeEntity::new( @@ -128,7 +127,6 @@ mod tests { description, entity_type, None, - embedding, user_id, ); diff --git a/common/src/storage/types/mod.rs b/common/src/storage/types/mod.rs index fbe8d3d..4f053ee 100644 --- a/common/src/storage/types/mod.rs +++ b/common/src/storage/types/mod.rs @@ -5,12 +5,14 @@ pub mod file_info; pub mod ingestion_payload; pub mod ingestion_task; pub mod knowledge_entity; +pub mod knowledge_entity_embedding; pub mod knowledge_relationship; pub mod message; pub mod scratchpad; pub mod system_prompts; pub mod system_settings; pub mod text_chunk; +pub mod text_chunk_embedding; pub mod text_content; pub mod user; diff --git a/common/src/storage/types/system_settings.rs b/common/src/storage/types/system_settings.rs index 77112ca..392cfdb 100644 --- a/common/src/storage/types/system_settings.rs +++ b/common/src/storage/types/system_settings.rs @@ -71,25 +71,22 @@ mod tests { .await .expect("Failed to fetch table info"); - let info: Option = response + let info: surrealdb::Value = response .take(0) .expect("Failed to extract table info response"); - let info = info.expect("Table info result missing"); + let info_json: serde_json::Value = + serde_json::to_value(info).expect("Failed to convert info to json"); - let indexes = info - .get("indexes") - .or_else(|| { - info.get("tables") - .and_then(|tables| tables.get(table_name)) - .and_then(|table| table.get("indexes")) - }) - .unwrap_or_else(|| panic!("Indexes collection missing in table info: {info:#?}")); + let indexes = info_json["Object"]["indexes"]["Object"] + .as_object() + .unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}")); let definition = indexes .get(index_name) - .and_then(|definition| definition.as_str()) - .unwrap_or_else(|| panic!("Index definition not found in table info: {info:#?}")); + .and_then(|definition| definition.get("Strand")) + .and_then(|v| v.as_str()) + .unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}")); let dimension_part = definition .split("DIMENSION") @@ -261,48 +258,56 @@ mod tests { let initial_chunk = TextChunk::new( "source1".into(), "This chunk has the original dimension".into(), - vec![0.1; 1536], "user1".into(), ); - db.store_item(initial_chunk.clone()) + TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db) .await - .expect("Failed to store initial chunk"); + .expect("Failed to store initial chunk with embedding"); async fn simulate_reembedding( db: &SurrealDbClient, target_dimension: usize, initial_chunk: TextChunk, ) { - db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;") - .await - .unwrap(); + db.query( + "REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;", + ) + .await + .unwrap(); let define_index_query = format!( - "DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};", - target_dimension - ); + "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", + target_dimension + ); db.query(define_index_query) .await .expect("Re-defining index should succeed"); let new_embedding = vec![0.5; target_dimension]; - let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;"; + let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;"; let update_result = db .client .query(sql) .bind(("id", initial_chunk.id.clone())) + .bind(("user_id", initial_chunk.user_id.clone())) .bind(("embedding", new_embedding)) .await; assert!(update_result.is_ok()); } - simulate_reembedding(&db, 768, initial_chunk).await; + // Re-embed with the existing configured dimension to ensure migrations remain idempotent. + let target_dimension = 1536usize; + simulate_reembedding(&db, target_dimension, initial_chunk).await; let migration_result = db.apply_migrations().await; - assert!(migration_result.is_ok(), "Migrations should not fail"); + assert!( + migration_result.is_ok(), + "Migrations should not fail: {:?}", + migration_result.err() + ); } #[tokio::test] @@ -320,8 +325,12 @@ mod tests { .await .expect("Failed to load current settings"); - let initial_chunk_dimension = - get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await; + let initial_chunk_dimension = get_hnsw_index_dimension( + &db, + "text_chunk_embedding", + "idx_embedding_text_chunk_embedding", + ) + .await; assert_eq!( initial_chunk_dimension, current_settings.embedding_dimensions, @@ -352,10 +361,18 @@ mod tests { .await .expect("KnowledgeEntity re-embedding should succeed on fresh DB"); - let text_chunk_dimension = - get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await; - let knowledge_dimension = - get_hnsw_index_dimension(&db, "knowledge_entity", "idx_embedding_entities").await; + let text_chunk_dimension = get_hnsw_index_dimension( + &db, + "text_chunk_embedding", + "idx_embedding_text_chunk_embedding", + ) + .await; + let knowledge_dimension = get_hnsw_index_dimension( + &db, + "knowledge_entity_embedding", + "idx_embedding_knowledge_entity_embedding", + ) + .await; assert_eq!( text_chunk_dimension, new_dimension, diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 11a574a..6ab7df1 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; use async_openai::{config::OpenAIConfig, Client}; use tokio_retry::{ @@ -13,12 +14,18 @@ use uuid::Uuid; stored_object!(TextChunk, "text_chunk", { source_id: String, chunk: String, - embedding: Vec, user_id: String }); +/// Vector search result including hydrated chunk. +#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)] +pub struct TextChunkVectorResult { + pub chunk: TextChunk, + pub score: f32, +} + impl TextChunk { - pub fn new(source_id: String, chunk: String, embedding: Vec, user_id: String) -> Self { + pub fn new(source_id: String, chunk: String, user_id: String) -> Self { let now = Utc::now(); Self { id: Uuid::new_v4().to_string(), @@ -26,7 +33,6 @@ impl TextChunk { updated_at: now, source_id, chunk, - embedding, user_id, } } @@ -45,6 +51,94 @@ impl TextChunk { Ok(()) } + /// Atomically store a text chunk and its embedding. + /// Writes the chunk to `text_chunk` and the embedding to `text_chunk_embedding`. + pub async fn store_with_embedding( + chunk: TextChunk, + embedding: Vec, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + let emb = TextChunkEmbedding::new( + &chunk.id, + chunk.source_id.clone(), + embedding, + chunk.user_id.clone(), + ); + + // Create both records in a single query + let query = format!( + " + BEGIN TRANSACTION; + CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk; + CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb; + COMMIT TRANSACTION; + ", + chunk_table = Self::table_name(), + emb_table = TextChunkEmbedding::table_name(), + ); + + db.client + .query(query) + .bind(("chunk_id", chunk.id.clone())) + .bind(("chunk", chunk)) + .bind(("emb_id", emb.id.clone())) + .bind(("emb", emb)) + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; + + Ok(()) + } + + /// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings. + pub async fn vector_search( + take: usize, + query_embedding: Vec, + db: &SurrealDbClient, + user_id: &str, + ) -> Result, AppError> { + #[derive(Deserialize)] + struct Row { + chunk_id: TextChunk, + score: f32, + } + + let sql = format!( + r#" + SELECT + chunk_id, + embedding, + vector::similarity::cosine(embedding, $embedding) AS score + FROM {emb_table} + WHERE user_id = $user_id + AND embedding <|{take},100|> $embedding + ORDER BY score DESC + LIMIT {take} + FETCH chunk_id; + "#, + emb_table = TextChunkEmbedding::table_name(), + take = take + ); + + let mut response = db + .query(&sql) + .bind(("embedding", query_embedding)) + .bind(("user_id", user_id.to_string())) + .await + .map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?; + + let rows: Vec = response.take::>(0).unwrap_or_default(); + + Ok(rows + .into_iter() + .map(|r| TextChunkVectorResult { + chunk: r.chunk_id, + score: r.score, + }) + .collect()) + } + /// Re-creates embeddings for all text chunks using a safe, atomic transaction. /// /// This is a costly operation that should be run in the background. It performs these steps: @@ -70,21 +164,14 @@ impl TextChunk { if total_chunks == 0 { info!("No text chunks to update. Just updating the idx"); - let mut transaction_query = String::from("BEGIN TRANSACTION;"); - transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;"); - transaction_query.push_str(&format!( - "DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};", - new_dimensions)); - transaction_query.push_str("COMMIT TRANSACTION;"); - - db.query(transaction_query).await?; + TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?; return Ok(()); } info!("Found {} chunks to process.", total_chunks); // Generate all new embeddings in memory - let mut new_embeddings: HashMap> = HashMap::new(); + let mut new_embeddings: HashMap, String, String)> = HashMap::new(); info!("Generating new embeddings for all chunks..."); for chunk in all_chunks.iter() { let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3); @@ -108,16 +195,18 @@ impl TextChunk { error!("{}", err_msg); return Err(AppError::InternalError(err_msg)); } - new_embeddings.insert(chunk.id.clone(), embedding); + new_embeddings.insert( + chunk.id.clone(), + (embedding, chunk.user_id.clone(), chunk.source_id.clone()), + ); } info!("Successfully generated all new embeddings."); - // Perform DB updates in a single transaction - info!("Applying schema and data changes in a transaction..."); + // Perform DB updates in a single transaction against the embedding table + info!("Applying embedding updates in a transaction..."); let mut transaction_query = String::from("BEGIN TRANSACTION;"); - // Add all update statements - for (id, embedding) in new_embeddings { + for (id, (embedding, user_id, source_id)) in new_embeddings { let embedding_str = format!( "[{}]", embedding @@ -126,22 +215,29 @@ impl TextChunk { .collect::>() .join(",") ); + // Use the chunk id as the embedding record id to keep a 1:1 mapping transaction_query.push_str(&format!( - "UPDATE type::thing('text_chunk', '{}') SET embedding = {}, updated_at = time::now();", - id, embedding_str + "UPSERT type::thing('text_chunk_embedding', '{id}') SET \ + chunk_id = type::thing('text_chunk', '{id}'), \ + source_id = '{source_id}', \ + embedding = {embedding}, \ + user_id = '{user_id}', \ + created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ + updated_at = time::now();", + id = id, + embedding = embedding_str, + user_id = user_id, + source_id = source_id )); } - // Re-create the index inside the same transaction - transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;"); transaction_query.push_str(&format!( - "DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};", + "DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};", new_dimensions )); transaction_query.push_str("COMMIT TRANSACTION;"); - // Execute the entire atomic operation db.query(transaction_query).await?; info!("Re-embedding process for text chunks completed successfully."); @@ -152,171 +248,269 @@ impl TextChunk { #[cfg(test)] mod tests { use super::*; + use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; + use surrealdb::RecordId; + use uuid::Uuid; #[tokio::test] async fn test_text_chunk_creation() { - // Test basic object creation let source_id = "source123".to_string(); let chunk = "This is a text chunk for testing embeddings".to_string(); - let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; let user_id = "user123".to_string(); - let text_chunk = TextChunk::new( - source_id.clone(), - chunk.clone(), - embedding.clone(), - user_id.clone(), - ); + let text_chunk = TextChunk::new(source_id.clone(), chunk.clone(), user_id.clone()); - // Check that the fields are set correctly assert_eq!(text_chunk.source_id, source_id); assert_eq!(text_chunk.chunk, chunk); - assert_eq!(text_chunk.embedding, embedding); assert_eq!(text_chunk.user_id, user_id); assert!(!text_chunk.id.is_empty()); } #[tokio::test] async fn test_delete_by_source_id() { - // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); - // Create test data let source_id = "source123".to_string(); - let chunk1 = "First chunk from the same source".to_string(); - let chunk2 = "Second chunk from the same source".to_string(); - let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; let user_id = "user123".to_string(); + TextChunkEmbedding::redefine_hnsw_index(&db, 5) + .await + .expect("redefine index"); - // Create two chunks with the same source_id - let text_chunk1 = TextChunk::new( + let chunk1 = TextChunk::new( source_id.clone(), - chunk1, - embedding.clone(), + "First chunk from the same source".to_string(), user_id.clone(), ); - - let text_chunk2 = TextChunk::new( + let chunk2 = TextChunk::new( source_id.clone(), - chunk2, - embedding.clone(), + "Second chunk from the same source".to_string(), user_id.clone(), ); - - // Create a chunk with a different source_id - let different_source_id = "different_source".to_string(); let different_chunk = TextChunk::new( - different_source_id.clone(), + "different_source".to_string(), "Different source chunk".to_string(), - embedding.clone(), user_id.clone(), ); - // Store the chunks - db.store_item(text_chunk1) + TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) .await - .expect("Failed to store text chunk 1"); - db.store_item(text_chunk2) + .expect("store chunk1"); + TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) .await - .expect("Failed to store text chunk 2"); - db.store_item(different_chunk.clone()) - .await - .expect("Failed to store different chunk"); + .expect("store chunk2"); + TextChunk::store_with_embedding( + different_chunk.clone(), + vec![0.1, 0.2, 0.3, 0.4, 0.5], + &db, + ) + .await + .expect("store different chunk"); - // Delete by source_id TextChunk::delete_by_source_id(&source_id, &db) .await .expect("Failed to delete chunks by source_id"); - // Verify all chunks with the original source_id are deleted - let query = format!( - "SELECT * FROM {} WHERE source_id = '{}'", - TextChunk::table_name(), - source_id - ); let remaining: Vec = db .client - .query(query) + .query(format!( + "SELECT * FROM {} WHERE source_id = '{}'", + TextChunk::table_name(), + source_id + )) .await .expect("Query failed") .take(0) .expect("Failed to get query results"); - assert_eq!( - remaining.len(), - 0, - "All chunks with the source_id should be deleted" - ); + assert_eq!(remaining.len(), 0); - // Verify the different source_id chunk still exists - let different_query = format!( - "SELECT * FROM {} WHERE source_id = '{}'", - TextChunk::table_name(), - different_source_id - ); let different_remaining: Vec = db .client - .query(different_query) + .query(format!( + "SELECT * FROM {} WHERE source_id = '{}'", + TextChunk::table_name(), + "different_source" + )) .await .expect("Query failed") .take(0) .expect("Failed to get query results"); - assert_eq!( - different_remaining.len(), - 1, - "Chunk with different source_id should still exist" - ); + assert_eq!(different_remaining.len(), 1); assert_eq!(different_remaining[0].id, different_chunk.id); } #[tokio::test] async fn test_delete_by_nonexistent_source_id() { - // Setup in-memory database for testing let namespace = "test_ns"; let database = &Uuid::new_v4().to_string(); let db = SurrealDbClient::memory(namespace, database) .await .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + TextChunkEmbedding::redefine_hnsw_index(&db, 5) + .await + .expect("redefine index"); - // Create a chunk with a real source_id let real_source_id = "real_source".to_string(); - let chunk = "Test chunk".to_string(); - let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5]; - let user_id = "user123".to_string(); - - let text_chunk = TextChunk::new(real_source_id.clone(), chunk, embedding, user_id); - - // Store the chunk - db.store_item(text_chunk) - .await - .expect("Failed to store text chunk"); - - // Delete using nonexistent source_id - let nonexistent_source_id = "nonexistent_source"; - TextChunk::delete_by_source_id(nonexistent_source_id, &db) - .await - .expect("Delete operation with nonexistent source_id should not fail"); - - // Verify the real chunk still exists - let query = format!( - "SELECT * FROM {} WHERE source_id = '{}'", - TextChunk::table_name(), - real_source_id + let chunk = TextChunk::new( + real_source_id.clone(), + "Test chunk".to_string(), + "user123".to_string(), ); + + TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) + .await + .expect("store chunk"); + + TextChunk::delete_by_source_id("nonexistent_source", &db) + .await + .expect("Delete should succeed"); + let remaining: Vec = db .client - .query(query) + .query(format!( + "SELECT * FROM {} WHERE source_id = '{}'", + TextChunk::table_name(), + real_source_id + )) .await .expect("Query failed") .take(0) .expect("Failed to get query results"); - assert_eq!( - remaining.len(), - 1, - "Chunk with real source_id should still exist" + assert_eq!(remaining.len(), 1); + } + + #[tokio::test] + async fn test_store_with_embedding_creates_both_records() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + + let source_id = "store-src".to_string(); + let user_id = "user_store".to_string(); + let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone()); + + TextChunkEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("redefine index"); + + TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) + .await + .expect("store with embedding"); + + let stored_chunk: Option = db.get_item(&chunk.id).await.unwrap(); + assert!(stored_chunk.is_some()); + let stored_chunk = stored_chunk.unwrap(); + assert_eq!(stored_chunk.source_id, source_id); + assert_eq!(stored_chunk.user_id, user_id); + + let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id); + let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db) + .await + .expect("get embedding"); + assert!(embedding.is_some()); + let embedding = embedding.unwrap(); + assert_eq!(embedding.chunk_id, rid); + assert_eq!(embedding.user_id, user_id); + assert_eq!(embedding.source_id, source_id); + } + + #[tokio::test] + async fn test_vector_search_returns_empty_when_no_embeddings() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + + TextChunkEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("redefine index"); + + let results: Vec = + TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") + .await + .unwrap(); + assert!(results.is_empty()); + } + + #[tokio::test] + async fn test_vector_search_single_result() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + + TextChunkEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("redefine index"); + + let source_id = "src".to_string(); + let user_id = "user".to_string(); + let chunk = TextChunk::new( + source_id.clone(), + "hello world".to_string(), + user_id.clone(), ); + + TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db) + .await + .expect("store"); + + let results: Vec = + TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let res = &results[0]; + assert_eq!(res.chunk.id, chunk.id); + assert_eq!(res.chunk.source_id, source_id); + assert_eq!(res.chunk.chunk, "hello world"); + } + + #[tokio::test] + async fn test_vector_search_orders_by_similarity() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + + TextChunkEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("redefine index"); + + let user_id = "user".to_string(); + let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone()); + let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone()); + + TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db) + .await + .expect("store chunk1"); + TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db) + .await + .expect("store chunk2"); + + let results: Vec = + TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) + .await + .unwrap(); + + assert_eq!(results.len(), 2); + assert_eq!(results[0].chunk.id, chunk2.id); + assert_eq!(results[1].chunk.id, chunk1.id); + assert!(results[0].score >= results[1].score); } } diff --git a/common/src/storage/types/text_chunk_embedding.rs b/common/src/storage/types/text_chunk_embedding.rs new file mode 100644 index 0000000..771b9ca --- /dev/null +++ b/common/src/storage/types/text_chunk_embedding.rs @@ -0,0 +1,435 @@ +use surrealdb::RecordId; + +use crate::storage::types::text_chunk::TextChunk; +use crate::{error::AppError, storage::db::SurrealDbClient, stored_object}; + +stored_object!(TextChunkEmbedding, "text_chunk_embedding", { + /// Record link to the owning text_chunk + chunk_id: RecordId, + /// Denormalized source id for bulk deletes + source_id: String, + /// Embedding vector + embedding: Vec, + /// Denormalized user id (for scoping + permissions) + user_id: String +}); + +impl TextChunkEmbedding { + /// Recreate the HNSW index with a new embedding dimension. + /// + /// This is useful when the embedding length changes; Surreal requires the + /// index definition to be recreated with the updated dimension. + pub async fn redefine_hnsw_index( + db: &SurrealDbClient, + dimension: usize, + ) -> Result<(), AppError> { + let query = format!( + "BEGIN TRANSACTION; + REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE {table}; + DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension}; + COMMIT TRANSACTION;", + table = Self::table_name(), + ); + + let res = db.client.query(query).await.map_err(AppError::Database)?; + res.check().map_err(AppError::Database)?; + + Ok(()) + } + + /// Create a new text chunk embedding + /// + /// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), + /// not "text_chunk:uuid". + pub fn new(chunk_id: &str, source_id: String, embedding: Vec, user_id: String) -> Self { + let now = Utc::now(); + + Self { + // NOTE: `stored_object!` macro defines `id` as `String` + id: uuid::Uuid::new_v4().to_string(), + created_at: now, + updated_at: now, + // Create a record link: text_chunk: + chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id), + source_id, + embedding, + user_id, + } + } + + /// Get a single embedding by its chunk RecordId + pub async fn get_by_chunk_id( + chunk_id: &RecordId, + db: &SurrealDbClient, + ) -> Result, AppError> { + let query = format!( + "SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1", + Self::table_name() + ); + + let mut result = db + .client + .query(query) + .bind(("chunk_id", chunk_id.clone())) + .await + .map_err(AppError::Database)?; + + let embeddings: Vec = result.take(0).map_err(AppError::Database)?; + + Ok(embeddings.into_iter().next()) + } + + /// Delete embeddings for a given chunk RecordId + pub async fn delete_by_chunk_id( + chunk_id: &RecordId, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + let query = format!( + "DELETE FROM {} WHERE chunk_id = $chunk_id", + Self::table_name() + ); + + db.client + .query(query) + .bind(("chunk_id", chunk_id.clone())) + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; + + Ok(()) + } + + /// Delete all embeddings that belong to chunks with a given `source_id` + /// + /// This uses a subquery to the `text_chunk` table: + /// + /// DELETE FROM text_chunk_embedding + /// WHERE chunk_id IN (SELECT id FROM text_chunk WHERE source_id = $source_id) + pub async fn delete_by_source_id( + source_id: &str, + db: &SurrealDbClient, + ) -> Result<(), AppError> { + let ids_query = format!( + "SELECT id FROM {} WHERE source_id = $source_id", + TextChunk::table_name() + ); + let mut res = db + .client + .query(ids_query) + .bind(("source_id", source_id.to_owned())) + .await + .map_err(AppError::Database)?; + #[derive(Deserialize)] + struct IdRow { + id: RecordId, + } + let ids: Vec = res.take(0).map_err(AppError::Database)?; + + if ids.is_empty() { + return Ok(()); + } + let delete_query = format!( + "DELETE FROM {} WHERE chunk_id IN $chunk_ids", + Self::table_name() + ); + db.client + .query(delete_query) + .bind(( + "chunk_ids", + ids.into_iter().map(|row| row.id).collect::>(), + )) + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::db::SurrealDbClient; + use surrealdb::Value as SurrealValue; + use uuid::Uuid; + + /// Helper to create an in-memory DB and apply migrations + async fn setup_test_db() -> SurrealDbClient { + let namespace = "test_ns"; + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, &database) + .await + .expect("Failed to start in-memory surrealdb"); + + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + + db + } + + /// Helper: create a text_chunk with a known key, return its RecordId + async fn create_text_chunk_with_id( + db: &SurrealDbClient, + key: &str, + source_id: &str, + user_id: &str, + ) -> RecordId { + let chunk = TextChunk { + id: key.to_owned(), + created_at: Utc::now(), + updated_at: Utc::now(), + source_id: source_id.to_owned(), + chunk: "Some test chunk text".to_owned(), + user_id: user_id.to_owned(), + }; + + db.store_item(chunk) + .await + .expect("Failed to create text_chunk"); + + RecordId::from_table_key(TextChunk::table_name(), key) + } + + #[tokio::test] + async fn test_create_and_get_by_chunk_id() { + let db = setup_test_db().await; + + let user_id = "user_a"; + let chunk_key = "chunk-123"; + let source_id = "source-1"; + + // 1) Create a text_chunk with a known key + let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await; + + // 2) Create and store an embedding for that chunk + let embedding_vec = vec![0.1_f32, 0.2, 0.3]; + let emb = TextChunkEmbedding::new( + chunk_key, + source_id.to_string(), + embedding_vec.clone(), + user_id.to_string(), + ); + + TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len()) + .await + .expect("Failed to redefine index length"); + + let _: Option = db + .client + .create(TextChunkEmbedding::table_name()) + .content(emb) + .await + .expect("Failed to store embedding") + .take() + .expect("Failed to deserialize stored embedding"); + + // 3) Fetch it via get_by_chunk_id + let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) + .await + .expect("Failed to get embedding by chunk_id"); + + assert!(fetched.is_some(), "Expected an embedding to be found"); + let fetched = fetched.unwrap(); + + assert_eq!(fetched.user_id, user_id); + assert_eq!(fetched.chunk_id, chunk_rid); + assert_eq!(fetched.embedding, embedding_vec); + } + + #[tokio::test] + async fn test_delete_by_chunk_id() { + let db = setup_test_db().await; + + let user_id = "user_b"; + let chunk_key = "chunk-delete"; + let source_id = "source-del"; + + let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await; + + let emb = TextChunkEmbedding::new( + chunk_key, + source_id.to_string(), + vec![0.4_f32, 0.5, 0.6], + user_id.to_string(), + ); + + TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len()) + .await + .expect("Failed to redefine index length"); + + let _: Option = db + .client + .create(TextChunkEmbedding::table_name()) + .content(emb) + .await + .expect("Failed to store embedding") + .take() + .expect("Failed to deserialize stored embedding"); + + // Ensure it exists + let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) + .await + .expect("Failed to get embedding before delete"); + assert!(existing.is_some(), "Embedding should exist before delete"); + + // Delete by chunk_id + TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db) + .await + .expect("Failed to delete by chunk_id"); + + // Ensure it no longer exists + let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db) + .await + .expect("Failed to get embedding after delete"); + assert!(after.is_none(), "Embedding should have been deleted"); + } + + #[tokio::test] + async fn test_delete_by_source_id() { + let db = setup_test_db().await; + + let user_id = "user_c"; + let source_id = "shared-source"; + let other_source = "other-source"; + + // Two chunks with the same source_id + let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await; + let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await; + + // One chunk with a different source_id + let chunk_other_rid = + create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await; + + // Create embeddings for all three + let emb1 = TextChunkEmbedding::new( + "chunk-s1", + source_id.to_string(), + vec![0.1], + user_id.to_string(), + ); + let emb2 = TextChunkEmbedding::new( + "chunk-s2", + source_id.to_string(), + vec![0.2], + user_id.to_string(), + ); + let emb3 = TextChunkEmbedding::new( + "chunk-other", + other_source.to_string(), + vec![0.3], + user_id.to_string(), + ); + + // Update length on index + TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len()) + .await + .expect("Failed to redefine index length"); + + for emb in [emb1, emb2, emb3] { + let _: Option = db + .client + .create(TextChunkEmbedding::table_name()) + .content(emb) + .await + .expect("Failed to store embedding") + .take() + .expect("Failed to deserialize stored embedding"); + } + + // Sanity check: they all exist + assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) + .await + .unwrap() + .is_some()); + assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db) + .await + .unwrap() + .is_some()); + assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) + .await + .unwrap() + .is_some()); + + // Delete embeddings by source_id (shared-source) + TextChunkEmbedding::delete_by_source_id(source_id, &db) + .await + .expect("Failed to delete by source_id"); + + // Chunks from shared-source should have no embeddings + assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db) + .await + .unwrap() + .is_none()); + assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db) + .await + .unwrap() + .is_none()); + + // The other chunk should still have its embedding + assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db) + .await + .unwrap() + .is_some()); + } + + #[tokio::test] + async fn test_redefine_hnsw_index_updates_dimension() { + let db = setup_test_db().await; + + // Change the index dimension from default (1536) to a smaller test value. + TextChunkEmbedding::redefine_hnsw_index(&db, 8) + .await + .expect("failed to redefine index"); + + let mut info_res = db + .client + .query("INFO FOR TABLE text_chunk_embedding;") + .await + .expect("info query failed"); + let info: SurrealValue = info_res.take(0).expect("failed to take info result"); + let info_json: serde_json::Value = + serde_json::to_value(info).expect("failed to convert info to json"); + let idx_sql = info_json["Object"]["indexes"]["Object"] + ["idx_embedding_text_chunk_embedding"]["Strand"] + .as_str() + .unwrap_or_default(); + + assert!( + idx_sql.contains("DIMENSION 8"), + "expected index definition to contain new dimension, got: {idx_sql}" + ); + } + + #[tokio::test] + async fn test_redefine_hnsw_index_is_idempotent() { + let db = setup_test_db().await; + + TextChunkEmbedding::redefine_hnsw_index(&db, 4) + .await + .expect("first redefine failed"); + TextChunkEmbedding::redefine_hnsw_index(&db, 4) + .await + .expect("second redefine failed"); + + let mut info_res = db + .client + .query("INFO FOR TABLE text_chunk_embedding;") + .await + .expect("info query failed"); + let info: SurrealValue = info_res.take(0).expect("failed to take info result"); + let info_json: serde_json::Value = + serde_json::to_value(info).expect("failed to convert info to json"); + let idx_sql = info_json["Object"]["indexes"]["Object"] + ["idx_embedding_text_chunk_embedding"]["Strand"] + .as_str() + .unwrap_or_default(); + + assert!( + idx_sql.contains("DIMENSION 4"), + "expected index definition to retain dimension 4, got: {idx_sql}" + ); + } +} diff --git a/common/src/storage/types/text_content.rs b/common/src/storage/types/text_content.rs index 3b7ff2f..ef3f9aa 100644 --- a/common/src/storage/types/text_content.rs +++ b/common/src/storage/types/text_content.rs @@ -146,12 +146,12 @@ impl TextContent { search::highlight('', '', 4) AS highlighted_url, search::highlight('', '', 5) AS highlighted_url_title, ( - search::score(0) + - search::score(1) + - search::score(2) + - search::score(3) + - search::score(4) + - search::score(5) + IF search::score(0) != NONE THEN search::score(0) ELSE 0 END + + IF search::score(1) != NONE THEN search::score(1) ELSE 0 END + + IF search::score(2) != NONE THEN search::score(2) ELSE 0 END + + IF search::score(3) != NONE THEN search::score(3) ELSE 0 END + + IF search::score(4) != NONE THEN search::score(4) ELSE 0 END + + IF search::score(5) != NONE THEN search::score(5) ELSE 0 END ) AS score FROM text_content WHERE diff --git a/common/src/utils/embedding.rs b/common/src/utils/embedding.rs index fcf01cb..88a8a5b 100644 --- a/common/src/utils/embedding.rs +++ b/common/src/utils/embedding.rs @@ -1,10 +1,279 @@ -use async_openai::types::CreateEmbeddingRequestArgs; +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + str::FromStr, + sync::Arc, +}; + +use anyhow::{anyhow, Context, Result}; +use async_openai::{types::CreateEmbeddingRequestArgs, Client}; +use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions}; +use tokio::sync::Mutex; use tracing::debug; use crate::{ error::AppError, storage::{db::SurrealDbClient, types::system_settings::SystemSettings}, }; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EmbeddingBackend { + OpenAI, + FastEmbed, + Hashed, +} + +impl Default for EmbeddingBackend { + fn default() -> Self { + Self::FastEmbed + } +} + +impl std::str::FromStr for EmbeddingBackend { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "openai" => Ok(Self::OpenAI), + "hashed" => Ok(Self::Hashed), + "fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed), + other => Err(anyhow!( + "unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'." + )), + } + } +} + +#[derive(Clone)] +pub struct EmbeddingProvider { + inner: EmbeddingInner, +} + +#[derive(Clone)] +enum EmbeddingInner { + OpenAI { + client: Arc>, + model: String, + dimensions: u32, + }, + Hashed { + dimension: usize, + }, + FastEmbed { + model: Arc>, + model_name: EmbeddingModel, + dimension: usize, + }, +} + +impl EmbeddingProvider { + pub fn backend_label(&self) -> &'static str { + match self.inner { + EmbeddingInner::Hashed { .. } => "hashed", + EmbeddingInner::FastEmbed { .. } => "fastembed", + EmbeddingInner::OpenAI { .. } => "openai", + } + } + + pub fn dimension(&self) -> usize { + match &self.inner { + EmbeddingInner::Hashed { dimension } => *dimension, + EmbeddingInner::FastEmbed { dimension, .. } => *dimension, + EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize, + } + } + + pub fn model_code(&self) -> Option { + match &self.inner { + EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()), + EmbeddingInner::OpenAI { model, .. } => Some(model.clone()), + EmbeddingInner::Hashed { .. } => None, + } + } + + pub async fn embed(&self, text: &str) -> Result> { + match &self.inner { + EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)), + EmbeddingInner::FastEmbed { model, .. } => { + let mut guard = model.lock().await; + let embeddings = guard + .embed(vec![text.to_owned()], None) + .context("generating fastembed vector")?; + embeddings + .into_iter() + .next() + .ok_or_else(|| anyhow!("fastembed returned no embedding for input")) + } + EmbeddingInner::OpenAI { + client, + model, + dimensions, + } => { + let request = CreateEmbeddingRequestArgs::default() + .model(model.clone()) + .input([text]) + .dimensions(*dimensions) + .build()?; + + let response = client.embeddings().create(request).await?; + + let embedding = response + .data + .first() + .ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))? + .embedding + .clone(); + + Ok(embedding) + } + } + } + + pub async fn embed_batch(&self, texts: Vec) -> Result>> { + match &self.inner { + EmbeddingInner::Hashed { dimension } => Ok(texts + .into_iter() + .map(|text| hashed_embedding(&text, *dimension)) + .collect()), + EmbeddingInner::FastEmbed { model, .. } => { + if texts.is_empty() { + return Ok(Vec::new()); + } + let mut guard = model.lock().await; + guard + .embed(texts, None) + .context("generating fastembed batch embeddings") + } + EmbeddingInner::OpenAI { + client, + model, + dimensions, + } => { + if texts.is_empty() { + return Ok(Vec::new()); + } + + let request = CreateEmbeddingRequestArgs::default() + .model(model.clone()) + .input(texts) + .dimensions(*dimensions) + .build()?; + + let response = client.embeddings().create(request).await?; + + let embeddings: Vec> = response + .data + .into_iter() + .map(|item| item.embedding) + .collect(); + + Ok(embeddings) + } + } + } + + pub async fn new_openai( + client: Arc>, + model: String, + dimensions: u32, + ) -> Result { + Ok(EmbeddingProvider { + inner: EmbeddingInner::OpenAI { + client, + model, + dimensions, + }, + }) + } + + pub async fn new_fastembed(model_override: Option) -> Result { + let model_name = if let Some(code) = model_override { + EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))? + } else { + EmbeddingModel::default() + }; + + let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true); + let model_name_for_task = model_name.clone(); + let model_name_code = model_name.to_string(); + + let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> { + let model = + TextEmbedding::try_new(options).context("initialising FastEmbed text model")?; + let info = EmbeddingModel::get_model_info(&model_name_for_task) + .ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?; + Ok((model, info.dim)) + }) + .await + .context("joining FastEmbed initialisation task")??; + + Ok(EmbeddingProvider { + inner: EmbeddingInner::FastEmbed { + model: Arc::new(Mutex::new(model)), + model_name, + dimension, + }, + }) + } + + pub fn new_hashed(dimension: usize) -> Result { + Ok(EmbeddingProvider { + inner: EmbeddingInner::Hashed { + dimension: dimension.max(1), + }, + }) + } +} + +// Helper functions for hashed embeddings +fn hashed_embedding(text: &str, dimension: usize) -> Vec { + let dim = dimension.max(1); + let mut vector = vec![0.0f32; dim]; + if text.is_empty() { + return vector; + } + + let mut token_count = 0f32; + for token in tokens(text) { + token_count += 1.0; + let idx = bucket(&token, dim); + vector[idx] += 1.0; + } + + if token_count == 0.0 { + return vector; + } + + let norm = vector.iter().map(|v| v * v).sum::().sqrt(); + if norm > 0.0 { + for value in &mut vector { + *value /= norm; + } + } + + vector +} + +fn tokens(text: &str) -> impl Iterator + '_ { + text.split(|c: char| !c.is_ascii_alphanumeric()) + .filter(|token| !token.is_empty()) + .map(|token| token.to_ascii_lowercase()) +} + +fn bucket(token: &str, dimension: usize) -> usize { + let mut hasher = DefaultHasher::new(); + token.hash(&mut hasher); + (hasher.finish() as usize) % dimension +} + +// Backward compatibility function +pub async fn generate_embedding_with_provider( + provider: &EmbeddingProvider, + input: &str, +) -> Result, AppError> { + provider.embed(input).await.map_err(AppError::from) +} + /// Generates an embedding vector for the given input text using OpenAI's embedding model. /// /// This function takes a text input and converts it into a numerical vector representation (embedding) diff --git a/eval/src/args.rs b/eval/src/args.rs index d1d132c..3106c12 100644 --- a/eval/src/args.rs +++ b/eval/src/args.rs @@ -8,6 +8,23 @@ use retrieval_pipeline::RetrievalStrategy; use crate::datasets::DatasetKind; +fn workspace_root() -> PathBuf { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + manifest_dir.parent().unwrap_or(&manifest_dir).to_path_buf() +} + +fn default_report_dir() -> PathBuf { + workspace_root().join("eval/reports") +} + +fn default_cache_dir() -> PathBuf { + workspace_root().join("eval/cache") +} + +fn default_ingestion_cache_dir() -> PathBuf { + workspace_root().join("eval/cache/ingested") +} + pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -129,7 +146,7 @@ impl Default for Config { corpus_limit: None, raw_dataset_path: dataset.default_raw_path(), converted_dataset_path: dataset.default_converted_path(), - report_dir: PathBuf::from("eval/reports"), + report_dir: default_report_dir(), k: 5, limit: Some(200), summary_sample: 5, @@ -138,8 +155,8 @@ impl Default for Config { concurrency: 4, embedding_backend: EmbeddingBackend::FastEmbed, embedding_model: None, - cache_dir: PathBuf::from("eval/cache"), - ingestion_cache_dir: PathBuf::from("eval/cache/ingested"), + cache_dir: default_cache_dir(), + ingestion_cache_dir: default_ingestion_cache_dir(), ingestion_batch_size: 5, ingestion_max_retries: 3, refresh_embeddings_only: false, @@ -585,6 +602,13 @@ where } pub fn print_help() { + let report_default = default_report_dir(); + let cache_default = default_cache_dir(); + let ingestion_cache_default = default_ingestion_cache_dir(); + let report_default_display = report_default.display(); + let cache_default_display = cache_default.display(); + let ingestion_cache_default_display = ingestion_cache_default.display(); + println!( "\ eval — dataset conversion, ingestion, and retrieval evaluation CLI @@ -610,7 +634,7 @@ OPTIONS: --corpus-limit Cap the slice corpus size (positives + negatives). Defaults to ~10× --limit, capped at 1000. --raw Path to the raw dataset (defaults per dataset). --converted Path to write/read the converted dataset (defaults per dataset). - --report-dir Directory to write evaluation reports (default: eval/reports). + --report-dir Directory to write evaluation reports (default: {report_default_display}). --k Precision@k cutoff (default: 5). --limit Limit the number of questions evaluated (default: 200, 0 = all). --sample Number of mismatches to surface in the Markdown summary (default: 5). @@ -632,9 +656,9 @@ OPTIONS: --embedding Embedding backend: 'fastembed' (default) or 'hashed'. --embedding-model FastEmbed model code (defaults to crate preset when omitted). - --cache-dir Directory for embedding caches (default: eval/cache). + --cache-dir Directory for embedding caches (default: {cache_default_display}). --ingestion-cache-dir - Directory for ingestion corpora caches (default: eval/cache/ingested). + Directory for ingestion corpora caches (default: {ingestion_cache_default_display}). --ingestion-batch-size Number of paragraphs to ingest concurrently (default: 5). --ingestion-max-retries diff --git a/eval/src/db_helpers.rs b/eval/src/db_helpers.rs index 01fe151..b12b74c 100644 --- a/eval/src/db_helpers.rs +++ b/eval/src/db_helpers.rs @@ -1,187 +1,30 @@ use anyhow::{Context, Result}; -use common::storage::db::SurrealDbClient; +use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes}; +use serde::Deserialize; +use tracing::info; -// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings +// Remove and recreate HNSW indexes for changing embedding lengths, used at beginning if embedding length differs from default system settings. pub async fn change_embedding_length_in_hnsw_indexes( db: &SurrealDbClient, dimension: usize, ) -> Result<()> { - tracing::info!("Changing embedding length in HNSW indexes"); - let query = format!( - "BEGIN TRANSACTION; - REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk; - REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity; - DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim}; - DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim}; - COMMIT TRANSACTION;", - dim = dimension - ); - - db.client - .query(query) - .await - .context("changing HNSW indexes")?; - tracing::info!("HNSW indexes successfully changed"); + // No-op for now; runtime indexes are created after ingestion with the correct dimension. + let _ = (db, dimension); Ok(()) } // Helper functions for index management during namespace reseed pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> { - tracing::info!("Removing ALL indexes before namespace reseed (aggressive approach)"); - - // Remove ALL indexes from ALL tables to ensure no cache access - db.client - .query( - "BEGIN TRANSACTION; - -- HNSW indexes - REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk; - REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity; - - -- FTS indexes on text_content (remove ALL of them) - REMOVE INDEX IF EXISTS text_content_fts_idx ON TABLE text_content; - REMOVE INDEX IF EXISTS text_content_fts_text_idx ON TABLE text_content; - REMOVE INDEX IF EXISTS text_content_fts_category_idx ON TABLE text_content; - REMOVE INDEX IF EXISTS text_content_fts_context_idx ON TABLE text_content; - REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON TABLE text_content; - REMOVE INDEX IF EXISTS text_content_fts_url_idx ON TABLE text_content; - REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON TABLE text_content; - - -- FTS indexes on knowledge_entity - REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity; - REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity; - - -- FTS indexes on text_chunk - REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk; - - COMMIT TRANSACTION;", - ) - .await - .context("removing all indexes before namespace reseed")?; - - tracing::info!("All indexes removed before namespace reseed"); - Ok(()) -} - -async fn create_tokenizer(db: &SurrealDbClient) -> Result<()> { - tracing::info!("Creating FTS analyzers for namespace reseed"); - let res = db - .client - .query( - "BEGIN TRANSACTION; - DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer - TOKENIZERS class - FILTERS lowercase, ascii, snowball(english); - COMMIT TRANSACTION;", - ) - .await - .context("creating FTS analyzers for namespace reseed")?; - - res.check().context("failed to create the tokenizer")?; + let _ = db; + info!("Removing ALL indexes before namespace reseed (no-op placeholder)"); Ok(()) } pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> { - tracing::info!("Recreating ALL indexes after namespace reseed (SEQUENTIAL approach)"); - let total_start = std::time::Instant::now(); - - create_tokenizer(db) + info!("Recreating ALL indexes after namespace reseed via shared runtime helper"); + ensure_runtime_indexes(db, dimension) .await - .context("creating FTS analyzer")?; - - // For now we dont remove these plain indexes, we could if they prove negatively impacting performance - // create_regular_indexes_for_snapshot(db) - // .await - // .context("creating regular indexes for namespace reseed")?; - - let fts_start = std::time::Instant::now(); - create_fts_indexes_for_snapshot(db) - .await - .context("creating FTS indexes for namespace reseed")?; - tracing::info!(duration = ?fts_start.elapsed(), "FTS indexes created"); - - let hnsw_start = std::time::Instant::now(); - create_hnsw_indexes_for_snapshot(db, dimension) - .await - .context("creating HNSW indexes for namespace reseed")?; - tracing::info!(duration = ?hnsw_start.elapsed(), "HNSW indexes created"); - - tracing::info!(duration = ?total_start.elapsed(), "All index groups recreated successfully in sequence"); - Ok(()) -} - -#[allow(dead_code)] // For now we dont do this. We could -async fn create_regular_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> { - tracing::info!("Creating regular indexes for namespace reseed (parallel group 1)"); - let res = db - .client - .query( - "BEGIN TRANSACTION; - DEFINE INDEX text_content_user_id_idx ON text_content FIELDS user_id; - DEFINE INDEX text_content_created_at_idx ON text_content FIELDS created_at; - DEFINE INDEX text_content_category_idx ON text_content FIELDS category; - DEFINE INDEX text_chunk_source_id_idx ON text_chunk FIELDS source_id; - DEFINE INDEX text_chunk_user_id_idx ON text_chunk FIELDS user_id; - DEFINE INDEX knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id; - DEFINE INDEX knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id; - DEFINE INDEX knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type; - DEFINE INDEX knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; - COMMIT TRANSACTION;", - ) - .await - .context("creating regular indexes for namespace reseed")?; - - res.check().context("one of the regular indexes failed")?; - - tracing::info!("Regular indexes for namespace reseed created"); - Ok(()) -} - -async fn create_fts_indexes_for_snapshot(db: &SurrealDbClient) -> Result<()> { - tracing::info!("Creating FTS indexes for namespace reseed (group 2)"); - let res = db.client - .query( - "BEGIN TRANSACTION; - DEFINE INDEX text_content_fts_idx ON TABLE text_content FIELDS text; - DEFINE INDEX knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name - SEARCH ANALYZER app_en_fts_analyzer BM25; - DEFINE INDEX knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description - SEARCH ANALYZER app_en_fts_analyzer BM25; - DEFINE INDEX text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk - SEARCH ANALYZER app_en_fts_analyzer BM25; - COMMIT TRANSACTION;", - ) - .await - .context("sending FTS index creation query")?; - - // This actually surfaces statement-level errors - res.check() - .context("one or more FTS index statements failed")?; - - tracing::info!("FTS indexes for namespace reseed created"); - Ok(()) -} - -async fn create_hnsw_indexes_for_snapshot(db: &SurrealDbClient, dimension: usize) -> Result<()> { - tracing::info!("Creating HNSW indexes for namespace reseed (group 3)"); - let query = format!( - "BEGIN TRANSACTION; - DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {dim}; - DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {dim}; - COMMIT TRANSACTION;", - dim = dimension - ); - - let res = db - .client - .query(query) - .await - .context("creating HNSW indexes for namespace reseed")?; - - res.check() - .context("one or more HNSW index statements failed")?; - - tracing::info!("HNSW indexes for namespace reseed created"); - Ok(()) + .context("creating runtime indexes") } pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &str) -> Result<()> { @@ -207,7 +50,6 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s #[cfg(test)] mod tests { use super::*; - use serde::Deserialize; use uuid::Uuid; #[derive(Debug, Deserialize)] diff --git a/eval/src/eval/mod.rs b/eval/src/eval/mod.rs index 2d8bd38..64133e4 100644 --- a/eval/src/eval/mod.rs +++ b/eval/src/eval/mod.rs @@ -12,7 +12,7 @@ use common::{ error::AppError, storage::{ db::SurrealDbClient, - types::{system_settings::SystemSettings, user::User}, + types::{system_settings::SystemSettings, user::User, StoredObject}, }, }; use retrieval_pipeline::RetrievalTuning; @@ -172,18 +172,26 @@ pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> R info!("Warming HNSW caches with sample queries"); - // Warm up chunk index + // Warm up chunk embedding index - just query the embedding table to load HNSW index let _ = db .client - .query("SELECT * FROM text_chunk WHERE embedding <|1,1|> $embedding LIMIT 5") + .query( + "SELECT chunk_id \ + FROM text_chunk_embedding \ + WHERE embedding <|1,1|> $embedding LIMIT 5", + ) .bind(("embedding", dummy_embedding.clone())) .await .context("warming text chunk HNSW cache")?; - // Warm up entity index + // Warm up entity embedding index let _ = db .client - .query("SELECT * FROM knowledge_entity WHERE embedding <|1,1|> $embedding LIMIT 5") + .query( + "SELECT entity_id \ + FROM knowledge_entity_embedding \ + WHERE embedding <|1,1|> $embedding LIMIT 5", + ) .bind(("embedding", dummy_embedding)) .await .context("warming knowledge entity HNSW cache")?; @@ -206,7 +214,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result { timezone: "UTC".to_string(), }; - if let Some(existing) = db.get_item::(&user.id).await? { + if let Some(existing) = db.get_item::(&user.get_id()).await? { return Ok(existing); } @@ -321,11 +329,11 @@ pub(crate) async fn can_reuse_namespace( } }; - if state.slice_case_count < slice_case_count { + if state.slice_case_count != slice_case_count { info!( requested_cases = slice_case_count, stored_cases = state.slice_case_count, - "Skipping live namespace reuse; ledger grew beyond cached state" + "Skipping live namespace reuse; cached state does not match requested window" ); return Ok(false); } @@ -420,12 +428,12 @@ pub(crate) async fn enforce_system_settings( ) -> Result { let mut updated_settings = settings.clone(); let mut needs_settings_update = false; - let mut embedding_dimension_changed = false; + // let mut embedding_dimension_changed = false; if provider_dimension != settings.embedding_dimensions as usize { updated_settings.embedding_dimensions = provider_dimension as u32; needs_settings_update = true; - embedding_dimension_changed = true; + // embedding_dimension_changed = true; } if let Some(query_override) = config.query_model.as_deref() { if settings.query_model != query_override { @@ -442,16 +450,18 @@ pub(crate) async fn enforce_system_settings( .await .context("updating system settings overrides")?; } - if embedding_dimension_changed { - change_embedding_length_in_hnsw_indexes(db, provider_dimension) - .await - .context("redefining HNSW indexes for new embedding dimension")?; - } + // We dont need to do this, we've changed from default settings already + // if embedding_dimension_changed { + // change_embedding_length_in_hnsw_indexes(db, provider_dimension) + // .await + // .context("redefining HNSW indexes for new embedding dimension")?; + // } Ok(settings) } pub(crate) async fn load_or_init_system_settings( db: &SurrealDbClient, + dimension: usize, ) -> Result<(SystemSettings, bool)> { match SystemSettings::get_current(db).await { Ok(settings) => Ok((settings, false)), @@ -460,7 +470,6 @@ pub(crate) async fn load_or_init_system_settings( db.apply_migrations() .await .context("applying database migrations after missing system settings")?; - tokio::time::sleep(Duration::from_millis(50)).await; let settings = SystemSettings::get_current(db) .await .context("loading system settings after migrations")?; @@ -473,8 +482,8 @@ pub(crate) async fn load_or_init_system_settings( #[cfg(test)] mod tests { use super::*; - use crate::ingest::store::CorpusParagraph; - use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion}; + use crate::ingest::store::{CorpusParagraph, EmbeddedKnowledgeEntity, EmbeddedTextChunk}; + use crate::ingest::{CorpusManifest, CorpusMetadata, CorpusQuestion, MANIFEST_VERSION}; use chrono::Utc; use common::storage::types::text_content::TextContent; @@ -491,9 +500,9 @@ mod tests { None, "user".to_string(), ), - entities: Vec::new(), + entities: Vec::::new(), relationships: Vec::new(), - chunks: Vec::new(), + chunks: Vec::::new(), }, CorpusParagraph { paragraph_id: "p2".to_string(), @@ -506,9 +515,9 @@ mod tests { None, "user".to_string(), ), - entities: Vec::new(), + entities: Vec::::new(), relationships: Vec::new(), - chunks: Vec::new(), + chunks: Vec::::new(), }, ]; let questions = vec![ @@ -541,7 +550,7 @@ mod tests { }, ]; CorpusManifest { - version: 1, + version: MANIFEST_VERSION, metadata: CorpusMetadata { dataset_id: "ds".to_string(), dataset_label: "Dataset".to_string(), diff --git a/eval/src/eval/pipeline/context.rs b/eval/src/eval/pipeline/context.rs index 4a8834d..dad2579 100644 --- a/eval/src/eval/pipeline/context.rs +++ b/eval/src/eval/pipeline/context.rs @@ -5,9 +5,12 @@ use std::{ }; use async_openai::Client; -use common::storage::{ - db::SurrealDbClient, - types::{system_settings::SystemSettings, user::User}, +use common::{ + storage::{ + db::SurrealDbClient, + types::{system_settings::SystemSettings, user::User}, + }, + utils::embedding::EmbeddingProvider, }; use retrieval_pipeline::{ pipeline::{PipelineStageTimings, RetrievalConfig}, @@ -18,7 +21,6 @@ use crate::{ args::Config, cache::EmbeddingCache, datasets::ConvertedDataset, - embedding::EmbeddingProvider, eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase}, ingest, slice, snapshot, }; diff --git a/eval/src/eval/pipeline/stages/prepare_corpus.rs b/eval/src/eval/pipeline/stages/prepare_corpus.rs index 16448f1..9df437c 100644 --- a/eval/src/eval/pipeline/stages/prepare_corpus.rs +++ b/eval/src/eval/pipeline/stages/prepare_corpus.rs @@ -3,7 +3,7 @@ use std::time::Instant; use anyhow::Context; use tracing::info; -use crate::{ingest, slice, snapshot}; +use crate::{eval::can_reuse_namespace, ingest, slice, snapshot}; use super::super::{ context::{EvalStage, EvaluationContext}, @@ -26,19 +26,78 @@ pub(crate) async fn prepare_corpus( let cache_settings = ingest::CorpusCacheConfig::from(config); let embedding_provider = ctx.embedding_provider().clone(); let openai_client = ctx.openai_client(); + let slice = ctx.slice(); + let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit) + .context("selecting slice window for corpus preparation")?; + + let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider()); + let expected_fingerprint = ingest::compute_ingestion_fingerprint( + ctx.dataset(), + slice, + config.converted_dataset_path.as_path(), + )?; + let base_dir = ingest::cached_corpus_dir( + &cache_settings, + ctx.dataset().metadata.id.as_str(), + slice.manifest.slice_id.as_str(), + ); + + if !config.reseed_slice { + let requested_cases = window.cases.len(); + if can_reuse_namespace( + ctx.db(), + &descriptor, + &ctx.namespace, + &ctx.database, + ctx.dataset().metadata.id.as_str(), + slice.manifest.slice_id.as_str(), + expected_fingerprint.as_str(), + requested_cases, + ) + .await? + { + if let Some(manifest) = ingest::load_cached_manifest(&base_dir)? { + info!( + cache = %base_dir.display(), + namespace = ctx.namespace.as_str(), + database = ctx.database.as_str(), + "Namespace already seeded; reusing cached corpus manifest" + ); + let corpus_handle = ingest::corpus_handle_from_manifest(manifest, base_dir); + ctx.corpus_handle = Some(corpus_handle); + ctx.expected_fingerprint = Some(expected_fingerprint); + ctx.ingestion_duration_ms = 0; + ctx.descriptor = Some(descriptor); + + let elapsed = started.elapsed(); + ctx.record_stage_duration(stage, elapsed); + info!( + evaluation_stage = stage.label(), + duration_ms = elapsed.as_millis(), + "completed evaluation stage" + ); + + return machine + .prepare_corpus() + .map_err(|(_, guard)| map_guard_error("prepare_corpus", guard)); + } else { + info!( + cache = %base_dir.display(), + "Namespace reusable but cached manifest missing; regenerating corpus" + ); + } + } + } let eval_user_id = "eval-user".to_string(); let ingestion_timer = Instant::now(); let corpus_handle = { - let slice = ctx.slice(); - let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit) - .context("selecting slice window for corpus preparation")?; ingest::ensure_corpus( ctx.dataset(), slice, &window, &cache_settings, - &embedding_provider, + embedding_provider.clone().into(), openai_client, &eval_user_id, config.converted_dataset_path.as_path(), @@ -64,11 +123,7 @@ pub(crate) async fn prepare_corpus( ctx.corpus_handle = Some(corpus_handle); ctx.expected_fingerprint = Some(expected_fingerprint); ctx.ingestion_duration_ms = ingestion_duration_ms; - ctx.descriptor = Some(snapshot::Descriptor::new( - config, - ctx.slice(), - ctx.embedding_provider(), - )); + ctx.descriptor = Some(descriptor); let elapsed = started.elapsed(); ctx.record_stage_duration(stage, elapsed); diff --git a/eval/src/eval/pipeline/stages/prepare_db.rs b/eval/src/eval/pipeline/stages/prepare_db.rs index 94cee4b..a8666e9 100644 --- a/eval/src/eval/pipeline/stages/prepare_db.rs +++ b/eval/src/eval/pipeline/stages/prepare_db.rs @@ -6,12 +6,12 @@ use tracing::info; use crate::{ args::EmbeddingBackend, cache::EmbeddingCache, - embedding, eval::{ connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code, }, openai, }; +use common::utils::embedding::EmbeddingProvider; use super::super::{ context::{EvalStage, EvaluationContext}, @@ -35,15 +35,22 @@ pub(crate) async fn prepare_db( let config = ctx.config(); let db = connect_eval_db(config, &namespace, &database).await?; - let (mut settings, settings_missing) = load_or_init_system_settings(&db).await?; - let embedding_provider = - embedding::build_provider(config, settings.embedding_dimensions as usize) - .await - .context("building embedding provider")?; let (raw_openai_client, openai_base_url) = openai::build_client_from_env().context("building OpenAI client")?; let openai_client = Arc::new(raw_openai_client); + + // Create embedding provider directly from config (eval only supports FastEmbed and Hashed) + let embedding_provider = match config.embedding_backend { + crate::args::EmbeddingBackend::FastEmbed => { + EmbeddingProvider::new_fastembed(config.embedding_model.clone()) + .await + .context("creating FastEmbed provider")? + } + crate::args::EmbeddingBackend::Hashed => { + EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")? + } + }; let provider_dimension = embedding_provider.dimension(); if provider_dimension == 0 { return Err(anyhow!( @@ -62,6 +69,9 @@ pub(crate) async fn prepare_db( ); info!(openai_base_url = %openai_base_url, "OpenAI client configured"); + let (mut settings, settings_missing) = + load_or_init_system_settings(&db, provider_dimension).await?; + let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed { if let Some(model_code) = embedding_provider.model_code() { let sanitized = sanitize_model_code(&model_code); diff --git a/eval/src/eval/pipeline/stages/prepare_namespace.rs b/eval/src/eval/pipeline/stages/prepare_namespace.rs index 25e336d..78d18f8 100644 --- a/eval/src/eval/pipeline/stages/prepare_namespace.rs +++ b/eval/src/eval/pipeline/stages/prepare_namespace.rs @@ -41,6 +41,22 @@ pub(crate) async fn prepare_namespace( let database = ctx.database.clone(); let embedding_provider = ctx.embedding_provider().clone(); + let corpus_handle = ctx.corpus_handle(); + let base_manifest = &corpus_handle.manifest; + let manifest_for_seed = + if ctx.window_offset == 0 && ctx.window_length >= base_manifest.questions.len() { + base_manifest.clone() + } else { + ingest::window_manifest( + base_manifest, + ctx.window_offset, + ctx.window_length, + ctx.config().negative_multiplier, + ) + .context("selecting manifest window for seeding")? + }; + let requested_cases = manifest_for_seed.questions.len(); + let mut namespace_reused = false; if !config.reseed_slice { namespace_reused = { @@ -53,7 +69,7 @@ pub(crate) async fn prepare_namespace( dataset.metadata.id.as_str(), slice.manifest.slice_id.as_str(), expected_fingerprint.as_str(), - slice.manifest.case_count, + requested_cases, ) .await? }; @@ -79,25 +95,39 @@ pub(crate) async fn prepare_namespace( slice = slice.manifest.slice_id.as_str(), window_offset = ctx.window_offset, window_length = ctx.window_length, - positives = slice.manifest.positive_paragraphs, - negatives = slice.manifest.negative_paragraphs, - total = slice.manifest.total_paragraphs, + positives = manifest_for_seed + .questions + .iter() + .map(|q| q.paragraph_id.as_str()) + .collect::>() + .len(), + negatives = manifest_for_seed.paragraphs.len().saturating_sub( + manifest_for_seed + .questions + .iter() + .map(|q| q.paragraph_id.as_str()) + .collect::>() + .len(), + ), + total = manifest_for_seed.paragraphs.len(), "Seeding ingestion corpus into SurrealDB" ); } let indexes_disabled = remove_all_indexes(ctx.db()).await.is_ok(); + let seed_start = Instant::now(); - ingest::seed_manifest_into_db(ctx.db(), &ctx.corpus_handle().manifest) + ingest::seed_manifest_into_db(ctx.db(), &manifest_for_seed) .await .context("seeding ingestion corpus from manifest")?; namespace_seed_ms = Some(seed_start.elapsed().as_millis() as u128); + + // Recreate indexes AFTER data is loaded (correct bulk loading pattern) if indexes_disabled { - info!("Recreating indexes after namespace reset"); - if let Err(err) = recreate_indexes(ctx.db(), embedding_provider.dimension()).await { - warn!(error = %err, "failed to restore indexes after namespace reset"); - } else { - warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?; - } + info!("Recreating indexes after seeding data"); + recreate_indexes(ctx.db(), embedding_provider.dimension()) + .await + .context("recreating indexes with correct dimension")?; + warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?; } { let slice = ctx.slice(); @@ -108,7 +138,7 @@ pub(crate) async fn prepare_namespace( expected_fingerprint.as_str(), &namespace, &database, - slice.manifest.case_count, + requested_cases, ) .await; } @@ -128,11 +158,10 @@ pub(crate) async fn prepare_namespace( let user = ensure_eval_user(ctx.db()).await?; ctx.eval_user = Some(user); - let corpus_handle = ctx.corpus_handle(); - let total_manifest_questions = corpus_handle.manifest.questions.len(); - let cases = cases_from_manifest(&corpus_handle.manifest); - let include_impossible = corpus_handle.manifest.metadata.include_unanswerable; - let require_verified_chunks = corpus_handle.manifest.metadata.require_verified_chunks; + let total_manifest_questions = manifest_for_seed.questions.len(); + let cases = cases_from_manifest(&manifest_for_seed); + let include_impossible = manifest_for_seed.metadata.include_unanswerable; + let require_verified_chunks = manifest_for_seed.metadata.require_verified_chunks; let filtered = total_manifest_questions.saturating_sub(cases.len()); if filtered > 0 { info!( diff --git a/eval/src/eval/pipeline/stages/run_queries.rs b/eval/src/eval/pipeline/stages/run_queries.rs index d6913a3..d530a10 100644 --- a/eval/src/eval/pipeline/stages/run_queries.rs +++ b/eval/src/eval/pipeline/stages/run_queries.rs @@ -1,6 +1,7 @@ use std::{collections::HashSet, sync::Arc, time::Instant}; use anyhow::Context; +use common::storage::types::StoredObject; use futures::stream::{self, StreamExt}; use tracing::{debug, info}; @@ -174,6 +175,7 @@ pub(crate) async fn run_queries( let outcome = pipeline::run_pipeline_with_embedding_with_diagnostics( &db, &openai_client, + Some(&embedding_provider), query_embedding, &question, &user_id, @@ -187,6 +189,7 @@ pub(crate) async fn run_queries( let outcome = pipeline::run_pipeline_with_embedding_with_metrics( &db, &openai_client, + Some(&embedding_provider), query_embedding, &question, &user_id, @@ -228,9 +231,10 @@ pub(crate) async fn run_queries( } let chunk_id_for_entity = if chunk_id_required { expected_chunk_ids_set.contains(candidate.source_id.as_str()) - || candidate.chunks.iter().any(|chunk| { - expected_chunk_ids_set.contains(chunk.chunk.id.as_str()) - }) + || candidate + .chunks + .iter() + .any(|chunk| expected_chunk_ids_set.contains(&chunk.chunk.get_id())) } else { true }; diff --git a/eval/src/eval/types.rs b/eval/src/eval/types.rs index 0728210..57a6d6b 100644 --- a/eval/src/eval/types.rs +++ b/eval/src/eval/types.rs @@ -1,6 +1,7 @@ use std::collections::HashSet; use chrono::{DateTime, Utc}; +use common::storage::types::StoredObject; use retrieval_pipeline::{ PipelineDiagnostics, PipelineStageTimings, RetrievedChunk, RetrievedEntity, StrategyOutput, }; @@ -164,7 +165,7 @@ impl EvaluationCandidate { fn from_entity(entity: RetrievedEntity) -> Self { let entity_category = Some(format!("{:?}", entity.entity.entity_type)); Self { - entity_id: entity.entity.id.clone(), + entity_id: entity.entity.get_id().to_string(), source_id: entity.entity.source_id.clone(), entity_name: entity.entity.name.clone(), entity_description: Some(entity.entity.description.clone()), @@ -177,7 +178,7 @@ impl EvaluationCandidate { fn from_chunk(chunk: RetrievedChunk) -> Self { let snippet = chunk_snippet(&chunk.chunk.chunk); Self { - entity_id: chunk.chunk.id.clone(), + entity_id: chunk.chunk.get_id().to_string(), source_id: chunk.chunk.source_id.clone(), entity_name: chunk.chunk.source_id.clone(), entity_description: Some(snippet), @@ -301,7 +302,9 @@ pub fn build_stage_latency_breakdown(samples: &[PipelineStageTimings]) -> StageL graph_expansion: compute_latency_stats(&collect_stage(samples, |entry| { entry.graph_expansion_ms() })), - chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| entry.chunk_attach_ms())), + chunk_attach: compute_latency_stats(&collect_stage(samples, |entry| { + entry.chunk_attach_ms() + })), rerank: compute_latency_stats(&collect_stage(samples, |entry| entry.rerank_ms())), assemble: compute_latency_stats(&collect_stage(samples, |entry| entry.assemble_ms())), } @@ -332,11 +335,11 @@ pub fn build_case_diagnostics( let mut chunk_entries = Vec::new(); for chunk in &candidate.chunks { let contains_answer = text_contains_answer(&chunk.chunk.chunk, answers_lower); - let expected_chunk = expected_set.contains(chunk.chunk.id.as_str()); - seen_chunks.insert(chunk.chunk.id.clone()); - attached_chunk_ids.push(chunk.chunk.id.clone()); + let expected_chunk = expected_set.contains(chunk.chunk.get_id()); + seen_chunks.insert(chunk.chunk.get_id().to_string()); + attached_chunk_ids.push(chunk.chunk.get_id().to_string()); chunk_entries.push(ChunkDiagnosticsEntry { - chunk_id: chunk.chunk.id.clone(), + chunk_id: chunk.chunk.get_id().to_string(), score: chunk.score, contains_answer, expected_chunk, diff --git a/eval/src/ingest/config.rs b/eval/src/ingest/config.rs index fd2eed1..c238cb8 100644 --- a/eval/src/ingest/config.rs +++ b/eval/src/ingest/config.rs @@ -59,6 +59,25 @@ impl CorpusEmbeddingProvider for EmbeddingProvider { } } +#[async_trait] +impl CorpusEmbeddingProvider for common::utils::embedding::EmbeddingProvider { + fn backend_label(&self) -> &str { + common::utils::embedding::EmbeddingProvider::backend_label(self) + } + + fn model_code(&self) -> Option { + common::utils::embedding::EmbeddingProvider::model_code(self) + } + + fn dimension(&self) -> usize { + common::utils::embedding::EmbeddingProvider::dimension(self) + } + + async fn embed_batch(&self, texts: Vec) -> Result>> { + common::utils::embedding::EmbeddingProvider::embed_batch(self, texts).await + } +} + impl From<&Config> for CorpusCacheConfig { fn from(config: &Config) -> Self { CorpusCacheConfig::new( diff --git a/eval/src/ingest/mod.rs b/eval/src/ingest/mod.rs index 77315e2..3e8a342 100644 --- a/eval/src/ingest/mod.rs +++ b/eval/src/ingest/mod.rs @@ -2,9 +2,13 @@ mod config; mod orchestrator; pub(crate) mod store; -pub use config::{CorpusCacheConfig, CorpusEmbeddingProvider}; -pub use orchestrator::ensure_corpus; -pub use store::{ - seed_manifest_into_db, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion, - ParagraphShard, ParagraphShardStore, MANIFEST_VERSION, +pub use config::CorpusCacheConfig; +pub use orchestrator::{ + cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus, + load_cached_manifest, +}; +pub use store::{ + seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata, + CorpusQuestion, EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, + ParagraphShardStore, MANIFEST_VERSION, }; diff --git a/eval/src/ingest/orchestrator.rs b/eval/src/ingest/orchestrator.rs index e6ea9ad..7a9cc94 100644 --- a/eval/src/ingest/orchestrator.rs +++ b/eval/src/ingest/orchestrator.rs @@ -2,7 +2,7 @@ use std::{ collections::{HashMap, HashSet}, fs, io::Read, - path::Path, + path::{Path, PathBuf}, sync::Arc, }; @@ -13,10 +13,7 @@ use common::{ storage::{ db::SurrealDbClient, store::{DynStore, StorageManager}, - types::{ - ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, - knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, - }, + types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject}, }, utils::config::{AppConfig, StorageKind}, }; @@ -29,12 +26,14 @@ use uuid::Uuid; use crate::{ datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion}, + db_helpers::change_embedding_length_in_hnsw_indexes, slices::{self, ResolvedSlice, SliceParagraphKind}, }; use crate::ingest::{ - CorpusCacheConfig, CorpusEmbeddingProvider, CorpusHandle, CorpusManifest, CorpusMetadata, - CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION, + CorpusCacheConfig, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion, + EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, ParagraphShardStore, + MANIFEST_VERSION, }; const INGESTION_SPEC_VERSION: u32 = 1; @@ -108,12 +107,12 @@ struct IngestionStats { negative_ingested: usize, } -pub async fn ensure_corpus( +pub async fn ensure_corpus( dataset: &ConvertedDataset, slice: &ResolvedSlice<'_>, window: &slices::SliceWindow<'_>, cache: &CorpusCacheConfig, - embedding: &E, + embedding: Arc, openai: Arc, user_id: &str, converted_path: &Path, @@ -122,10 +121,11 @@ pub async fn ensure_corpus( .with_context(|| format!("computing checksum for {}", converted_path.display()))?; let ingestion_fingerprint = build_ingestion_fingerprint(dataset, slice, &checksum); - let base_dir = cache - .ingestion_cache_dir - .join(dataset.metadata.id.as_str()) - .join(slice.manifest.slice_id.as_str()); + let base_dir = cached_corpus_dir( + cache, + dataset.metadata.id.as_str(), + slice.manifest.slice_id.as_str(), + ); if cache.force_refresh && !cache.refresh_embeddings_only { let _ = fs::remove_dir_all(&base_dir); } @@ -144,11 +144,19 @@ pub async fn ensure_corpus( )); } + let desired_negatives = + ((positive_set.len() as f32) * slice.manifest.negative_multiplier).ceil() as usize; let mut plan = Vec::new(); + let mut negatives_added = 0usize; for (idx, entry) in slice.manifest.paragraphs.iter().enumerate() { let include = match &entry.kind { SliceParagraphKind::Positive { .. } => positive_set.contains(entry.id.as_str()), - SliceParagraphKind::Negative => true, + SliceParagraphKind::Negative => { + negatives_added < desired_negatives && { + negatives_added += 1; + true + } + } }; if include { let paragraph = slice @@ -224,7 +232,7 @@ pub async fn ensure_corpus( let new_shards = ingest_paragraph_batch( dataset, &ingest_requests, - embedding, + embedding.clone(), openai.clone(), user_id, &ingestion_fingerprint, @@ -251,8 +259,7 @@ pub async fn ensure_corpus( .as_mut() .context("shard record missing after ingestion run")?; if cache.refresh_embeddings_only || shard_record.needs_reembed { - reembed_entities(&mut shard_record.shard.entities, embedding).await?; - reembed_chunks(&mut shard_record.shard.chunks, embedding).await?; + // Embeddings are now generated by the pipeline using FastEmbed - no need to re-embed shard_record.shard.ingestion_fingerprint = ingestion_fingerprint.clone(); shard_record.shard.ingested_at = Utc::now(); shard_record.shard.embedding_backend = embedding_backend_label.clone(); @@ -320,7 +327,7 @@ pub async fn ensure_corpus( corpus_questions.push(CorpusQuestion { question_id: case.question.id.clone(), paragraph_id: case.paragraph.id.clone(), - text_content_id: record.shard.text_content.id.clone(), + text_content_id: record.shard.text_content.get_id().to_string(), question_text: case.question.question.clone(), answers: case.question.answers.clone(), is_impossible: case.question.is_impossible, @@ -361,7 +368,7 @@ pub async fn ensure_corpus( let reused_ingestion = ingested_count == 0 && !cache.force_refresh; let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only; - Ok(CorpusHandle { + let handle = CorpusHandle { manifest, path: base_dir, reused_ingestion, @@ -370,64 +377,17 @@ pub async fn ensure_corpus( positive_ingested: stats.positive_ingested, negative_reused: stats.negative_reused, negative_ingested: stats.negative_ingested, - }) + }; + + persist_manifest(&handle).context("persisting corpus manifest")?; + + Ok(handle) } -async fn reembed_entities( - entities: &mut [KnowledgeEntity], - embedding: &E, -) -> Result<()> { - if entities.is_empty() { - return Ok(()); - } - let payloads: Vec = entities.iter().map(entity_embedding_text).collect(); - let vectors = embedding.embed_batch(payloads).await?; - if vectors.len() != entities.len() { - return Err(anyhow!( - "entity embedding batch mismatch (expected {}, got {})", - entities.len(), - vectors.len() - )); - } - for (entity, vector) in entities.iter_mut().zip(vectors.into_iter()) { - entity.embedding = vector; - } - Ok(()) -} - -async fn reembed_chunks( - chunks: &mut [TextChunk], - embedding: &E, -) -> Result<()> { - if chunks.is_empty() { - return Ok(()); - } - let payloads: Vec = chunks.iter().map(|chunk| chunk.chunk.clone()).collect(); - let vectors = embedding.embed_batch(payloads).await?; - if vectors.len() != chunks.len() { - return Err(anyhow!( - "chunk embedding batch mismatch (expected {}, got {})", - chunks.len(), - vectors.len() - )); - } - for (chunk, vector) in chunks.iter_mut().zip(vectors.into_iter()) { - chunk.embedding = vector; - } - Ok(()) -} - -fn entity_embedding_text(entity: &KnowledgeEntity) -> String { - format!( - "name: {}\ndescription: {}\ntype: {:?}", - entity.name, entity.description, entity.entity_type - ) -} - -async fn ingest_paragraph_batch( +async fn ingest_paragraph_batch( dataset: &ConvertedDataset, targets: &[IngestRequest<'_>], - embedding: &E, + embedding: Arc, openai: Arc, user_id: &str, ingestion_fingerprint: &str, @@ -444,12 +404,16 @@ async fn ingest_paragraph_batch( let db = Arc::new( SurrealDbClient::memory(&namespace, "corpus") .await - .context("creating ingestion SurrealDB instance")?, + .context("creating in-memory surrealdb for ingestion")?, ); db.apply_migrations() .await .context("applying migrations for ingestion")?; + change_embedding_length_in_hnsw_indexes(&db, embedding_dimension) + .await + .context("failed setting new hnsw length")?; + let mut app_config = AppConfig::default(); app_config.storage = StorageKind::Memory; let backend: DynStore = Arc::new(InMemory::new()); @@ -461,6 +425,7 @@ async fn ingest_paragraph_batch( app_config, None::>, storage, + embedding.clone(), ) .await?; let pipeline = Arc::new(pipeline); @@ -483,7 +448,6 @@ async fn ingest_paragraph_batch( pipeline_clone.clone(), request, category_clone.clone(), - embedding, user_id, ingestion_fingerprint, backend_clone.clone(), @@ -501,11 +465,10 @@ async fn ingest_paragraph_batch( Ok(shards) } -async fn ingest_single_paragraph( +async fn ingest_single_paragraph( pipeline: Arc, request: IngestRequest<'_>, category: String, - embedding: &E, user_id: &str, ingestion_fingerprint: &str, embedding_backend: String, @@ -524,17 +487,32 @@ async fn ingest_single_paragraph( }; let task = IngestionTask::new(payload, user_id.to_string()); match pipeline.produce_artifacts(&task).await { - Ok(mut artifacts) => { - reembed_entities(&mut artifacts.entities, embedding).await?; - reembed_chunks(&mut artifacts.chunks, embedding).await?; + Ok(artifacts) => { + let entities: Vec = artifacts + .entities + .into_iter() + .map(|e| EmbeddedKnowledgeEntity { + entity: e.entity, + embedding: e.embedding, + }) + .collect(); + let chunks: Vec = artifacts + .chunks + .into_iter() + .map(|c| EmbeddedTextChunk { + chunk: c.chunk, + embedding: c.embedding, + }) + .collect(); + // No need to reembed - pipeline now uses FastEmbed internally let mut shard = ParagraphShard::new( paragraph, request.shard_path, ingestion_fingerprint, artifacts.text_content, - artifacts.entities, + entities, artifacts.relationships, - artifacts.chunks, + chunks, &embedding_backend, embedding_model.clone(), embedding_dimension, @@ -572,7 +550,11 @@ async fn ingest_single_paragraph( .context(format!("running ingestion for paragraph {}", paragraph.id))) } -fn build_ingestion_fingerprint( +pub fn cached_corpus_dir(cache: &CorpusCacheConfig, dataset_id: &str, slice_id: &str) -> PathBuf { + cache.ingestion_cache_dir.join(dataset_id).join(slice_id) +} + +pub fn build_ingestion_fingerprint( dataset: &ConvertedDataset, slice: &ResolvedSlice<'_>, checksum: &str, @@ -592,6 +574,59 @@ fn build_ingestion_fingerprint( ) } +pub fn compute_ingestion_fingerprint( + dataset: &ConvertedDataset, + slice: &ResolvedSlice<'_>, + converted_path: &Path, +) -> Result { + let checksum = compute_file_checksum(converted_path)?; + Ok(build_ingestion_fingerprint(dataset, slice, &checksum)) +} + +pub fn load_cached_manifest(base_dir: &Path) -> Result> { + let path = base_dir.join("manifest.json"); + if !path.exists() { + return Ok(None); + } + let mut file = fs::File::open(&path) + .with_context(|| format!("opening cached manifest {}", path.display()))?; + let mut buf = Vec::new(); + file.read_to_end(&mut buf) + .with_context(|| format!("reading cached manifest {}", path.display()))?; + let manifest: CorpusManifest = serde_json::from_slice(&buf) + .with_context(|| format!("deserialising cached manifest {}", path.display()))?; + Ok(Some(manifest)) +} + +fn persist_manifest(handle: &CorpusHandle) -> Result<()> { + let path = handle.path.join("manifest.json"); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("creating manifest directory {}", parent.display()))?; + } + let tmp_path = path.with_extension("json.tmp"); + let blob = + serde_json::to_vec_pretty(&handle.manifest).context("serialising corpus manifest")?; + fs::write(&tmp_path, &blob) + .with_context(|| format!("writing temporary manifest {}", tmp_path.display()))?; + fs::rename(&tmp_path, &path) + .with_context(|| format!("replacing manifest {}", path.display()))?; + Ok(()) +} + +pub fn corpus_handle_from_manifest(manifest: CorpusManifest, base_dir: PathBuf) -> CorpusHandle { + CorpusHandle { + manifest, + path: base_dir, + reused_ingestion: true, + reused_embeddings: true, + positive_reused: 0, + positive_ingested: 0, + negative_reused: 0, + negative_ingested: 0, + } +} + fn compute_file_checksum(path: &Path) -> Result { let mut file = fs::File::open(path) .with_context(|| format!("opening file {} for checksum", path.display()))?; diff --git a/eval/src/ingest/store.rs b/eval/src/ingest/store.rs index be786dd..13d6f74 100644 --- a/eval/src/ingest/store.rs +++ b/eval/src/ingest/store.rs @@ -1,23 +1,126 @@ -use std::{collections::HashMap, fs, io::BufReader, path::PathBuf}; +use std::{ + collections::{HashMap, HashSet}, + fs, + io::BufReader, + path::PathBuf, +}; use anyhow::{anyhow, Context, Result}; use chrono::{DateTime, Utc}; +use common::storage::types::StoredObject; use common::storage::{ db::SurrealDbClient, types::{ - knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, - text_chunk::TextChunk, text_content::TextContent, + knowledge_entity::KnowledgeEntity, + knowledge_entity_embedding::KnowledgeEntityEmbedding, + knowledge_relationship::{KnowledgeRelationship, RelationshipMetadata}, + text_chunk::TextChunk, + text_chunk_embedding::TextChunkEmbedding, + text_content::TextContent, }, }; +use serde::Deserialize; +use serde::Serialize; +use surrealdb::sql::Thing; use tracing::warn; use crate::datasets::{ConvertedParagraph, ConvertedQuestion}; -pub const MANIFEST_VERSION: u32 = 1; -pub const PARAGRAPH_SHARD_VERSION: u32 = 1; +pub const MANIFEST_VERSION: u32 = 2; +pub const PARAGRAPH_SHARD_VERSION: u32 = 2; +const MANIFEST_BATCH_SIZE: usize = 100; +const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches +const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively +const MAX_BATCHES_PER_REQUEST: usize = 24; +const REQUEST_MAX_BYTES: usize = 800_000; // total payload cap per Surreal query request + +fn current_manifest_version() -> u32 { + MANIFEST_VERSION +} + +fn current_paragraph_shard_version() -> u32 { + PARAGRAPH_SHARD_VERSION +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct EmbeddedKnowledgeEntity { + pub entity: KnowledgeEntity, + pub embedding: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct EmbeddedTextChunk { + pub chunk: TextChunk, + pub embedding: Vec, +} + +#[derive(Debug, Clone, serde::Deserialize)] +struct LegacyKnowledgeEntity { + #[serde(flatten)] + pub entity: KnowledgeEntity, + #[serde(default)] + pub embedding: Vec, +} + +#[derive(Debug, Clone, serde::Deserialize)] +struct LegacyTextChunk { + #[serde(flatten)] + pub chunk: TextChunk, + #[serde(default)] + pub embedding: Vec, +} + +fn deserialize_embedded_entities<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + #[derive(serde::Deserialize)] + #[serde(untagged)] + enum EntityInput { + Embedded(Vec), + Legacy(Vec), + } + + match EntityInput::deserialize(deserializer)? { + EntityInput::Embedded(items) => Ok(items), + EntityInput::Legacy(items) => Ok(items + .into_iter() + .map(|legacy| EmbeddedKnowledgeEntity { + entity: legacy.entity, + embedding: legacy.embedding, + }) + .collect()), + } +} + +fn deserialize_embedded_chunks<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + #[derive(serde::Deserialize)] + #[serde(untagged)] + enum ChunkInput { + Embedded(Vec), + Legacy(Vec), + } + + match ChunkInput::deserialize(deserializer)? { + ChunkInput::Embedded(items) => Ok(items), + ChunkInput::Legacy(items) => Ok(items + .into_iter() + .map(|legacy| EmbeddedTextChunk { + chunk: legacy.chunk, + embedding: legacy.embedding, + }) + .collect()), + } +} #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct CorpusManifest { + #[serde(default = "current_manifest_version")] pub version: u32, pub metadata: CorpusMetadata, pub paragraphs: Vec, @@ -47,9 +150,11 @@ pub struct CorpusParagraph { pub paragraph_id: String, pub title: String, pub text_content: TextContent, - pub entities: Vec, + #[serde(deserialize_with = "deserialize_embedded_entities")] + pub entities: Vec, pub relationships: Vec, - pub chunks: Vec, + #[serde(deserialize_with = "deserialize_embedded_chunks")] + pub chunks: Vec, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -74,8 +179,189 @@ pub struct CorpusHandle { pub negative_ingested: usize, } +pub fn window_manifest( + manifest: &CorpusManifest, + offset: usize, + length: usize, + negative_multiplier: f32, +) -> Result { + let total = manifest.questions.len(); + if total == 0 { + return Err(anyhow!( + "manifest contains no questions; cannot select a window" + )); + } + if offset >= total { + return Err(anyhow!( + "window offset {} exceeds manifest questions ({})", + offset, + total + )); + } + let end = (offset + length).min(total); + let questions = manifest.questions[offset..end].to_vec(); + + let selected_positive_ids: HashSet<_> = + questions.iter().map(|q| q.paragraph_id.clone()).collect(); + let positives_all: HashSet<_> = manifest + .questions + .iter() + .map(|q| q.paragraph_id.as_str()) + .collect(); + let available_negatives = manifest + .paragraphs + .len() + .saturating_sub(positives_all.len()); + let desired_negatives = + ((selected_positive_ids.len() as f32) * negative_multiplier).ceil() as usize; + let desired_negatives = desired_negatives.min(available_negatives); + + let mut paragraphs = Vec::new(); + let mut negative_count = 0usize; + for paragraph in &manifest.paragraphs { + if selected_positive_ids.contains(¶graph.paragraph_id) { + paragraphs.push(paragraph.clone()); + } else if negative_count < desired_negatives { + paragraphs.push(paragraph.clone()); + negative_count += 1; + } + } + + let mut narrowed = manifest.clone(); + narrowed.questions = questions; + narrowed.paragraphs = paragraphs; + narrowed.metadata.paragraph_count = narrowed.paragraphs.len(); + narrowed.metadata.question_count = narrowed.questions.len(); + + Ok(narrowed) +} + +#[derive(Debug, Clone, Serialize)] +struct RelationInsert { + #[serde(rename = "in")] + pub in_: Thing, + #[serde(rename = "out")] + pub out: Thing, + pub id: String, + pub metadata: RelationshipMetadata, +} + +#[derive(Debug)] +struct SizedBatch { + approx_bytes: usize, + items: Vec, +} + +struct ManifestBatches { + text_contents: Vec>, + entities: Vec>, + entity_embeddings: Vec>, + relationships: Vec>, + chunks: Vec>, + chunk_embeddings: Vec>, +} + +fn build_manifest_batches(manifest: &CorpusManifest) -> Result { + let mut text_contents = Vec::new(); + let mut entities = Vec::new(); + let mut entity_embeddings = Vec::new(); + let mut relationships = Vec::new(); + let mut chunks = Vec::new(); + let mut chunk_embeddings = Vec::new(); + + let mut seen_text_content = HashSet::new(); + let mut seen_entities = HashSet::new(); + let mut seen_relationships = HashSet::new(); + let mut seen_chunks = HashSet::new(); + + for paragraph in &manifest.paragraphs { + if seen_text_content.insert(paragraph.text_content.id.clone()) { + text_contents.push(paragraph.text_content.clone()); + } + + for embedded_entity in ¶graph.entities { + if seen_entities.insert(embedded_entity.entity.id.clone()) { + let entity = embedded_entity.entity.clone(); + entities.push(entity.clone()); + entity_embeddings.push(KnowledgeEntityEmbedding::new( + &entity.id, + embedded_entity.embedding.clone(), + entity.user_id.clone(), + )); + } + } + + for relationship in ¶graph.relationships { + if seen_relationships.insert(relationship.id.clone()) { + let table = KnowledgeEntity::table_name(); + let in_id = relationship + .in_ + .strip_prefix(&format!("{table}:")) + .unwrap_or(&relationship.in_); + let out_id = relationship + .out + .strip_prefix(&format!("{table}:")) + .unwrap_or(&relationship.out); + let in_thing = Thing::from((table, in_id)); + let out_thing = Thing::from((table, out_id)); + relationships.push(RelationInsert { + in_: in_thing, + out: out_thing, + id: relationship.id.clone(), + metadata: relationship.metadata.clone(), + }); + } + } + + for embedded_chunk in ¶graph.chunks { + if seen_chunks.insert(embedded_chunk.chunk.id.clone()) { + let chunk = embedded_chunk.chunk.clone(); + chunks.push(chunk.clone()); + chunk_embeddings.push(TextChunkEmbedding::new( + &chunk.id, + chunk.source_id.clone(), + embedded_chunk.embedding.clone(), + chunk.user_id.clone(), + )); + } + } + } + + Ok(ManifestBatches { + text_contents: chunk_items( + &text_contents, + MANIFEST_BATCH_SIZE, + TEXT_CONTENT_MAX_BYTES_PER_BATCH, + ) + .context("chunking text_content payloads")?, + entities: chunk_items(&entities, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH) + .context("chunking knowledge_entity payloads")?, + entity_embeddings: chunk_items( + &entity_embeddings, + MANIFEST_BATCH_SIZE, + MANIFEST_MAX_BYTES_PER_BATCH, + ) + .context("chunking knowledge_entity_embedding payloads")?, + relationships: chunk_items( + &relationships, + MANIFEST_BATCH_SIZE, + MANIFEST_MAX_BYTES_PER_BATCH, + ) + .context("chunking relationship payloads")?, + chunks: chunk_items(&chunks, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH) + .context("chunking text_chunk payloads")?, + chunk_embeddings: chunk_items( + &chunk_embeddings, + MANIFEST_BATCH_SIZE, + MANIFEST_MAX_BYTES_PER_BATCH, + ) + .context("chunking text_chunk_embedding payloads")?, + }) +} + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ParagraphShard { + #[serde(default = "current_paragraph_shard_version")] pub version: u32, pub paragraph_id: String, pub shard_path: String, @@ -83,9 +369,11 @@ pub struct ParagraphShard { pub ingested_at: DateTime, pub title: String, pub text_content: TextContent, - pub entities: Vec, + #[serde(deserialize_with = "deserialize_embedded_entities")] + pub entities: Vec, pub relationships: Vec, - pub chunks: Vec, + #[serde(deserialize_with = "deserialize_embedded_chunks")] + pub chunks: Vec, #[serde(default)] pub question_bindings: HashMap>, #[serde(default)] @@ -126,30 +414,34 @@ impl ParagraphShardStore { let reader = BufReader::new(file); let mut shard: ParagraphShard = serde_json::from_reader(reader) .with_context(|| format!("parsing shard {}", path.display()))?; + + if shard.ingestion_fingerprint != fingerprint { + return Ok(None); + } if shard.version != PARAGRAPH_SHARD_VERSION { warn!( path = %path.display(), version = shard.version, expected = PARAGRAPH_SHARD_VERSION, - "Skipping shard due to version mismatch" + "Upgrading shard to current version" ); - return Ok(None); - } - if shard.ingestion_fingerprint != fingerprint { - return Ok(None); + shard.version = PARAGRAPH_SHARD_VERSION; } shard.shard_path = relative.to_string(); Ok(Some(shard)) } pub fn persist(&self, shard: &ParagraphShard) -> Result<()> { + let mut shard = shard.clone(); + shard.version = PARAGRAPH_SHARD_VERSION; + let path = self.resolve(&shard.shard_path); if let Some(parent) = path.parent() { fs::create_dir_all(parent) .with_context(|| format!("creating shard dir {}", parent.display()))?; } let tmp_path = path.with_extension("json.tmp"); - let body = serde_json::to_vec_pretty(shard).context("serialising paragraph shard")?; + let body = serde_json::to_vec_pretty(&shard).context("serialising paragraph shard")?; fs::write(&tmp_path, &body) .with_context(|| format!("writing shard tmp {}", tmp_path.display()))?; fs::rename(&tmp_path, &path) @@ -164,9 +456,9 @@ impl ParagraphShard { shard_path: String, ingestion_fingerprint: &str, text_content: TextContent, - entities: Vec, + entities: Vec, relationships: Vec, - chunks: Vec, + chunks: Vec, embedding_backend: &str, embedding_model: Option, embedding_dimension: usize, @@ -216,7 +508,7 @@ impl ParagraphShard { fn validate_answers( content: &TextContent, - chunks: &[TextChunk], + chunks: &[EmbeddedTextChunk], question: &ConvertedQuestion, ) -> Result> { if question.is_impossible || question.answers.is_empty() { @@ -236,12 +528,12 @@ fn validate_answers( found_any = true; } for chunk in chunks { - let chunk_text = chunk.chunk.to_ascii_lowercase(); + let chunk_text = chunk.chunk.chunk.to_ascii_lowercase(); let chunk_norm = normalize_answer_text(&chunk_text); if chunk_text.contains(&needle) || (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm)) { - matches.insert(chunk.id.clone()); + matches.insert(chunk.chunk.get_id().to_string()); found_any = true; } } @@ -272,28 +564,492 @@ fn normalize_answer_text(text: &str) -> String { .join(" ") } -pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> { - for paragraph in &manifest.paragraphs { - db.upsert_item(paragraph.text_content.clone()) +fn chunk_items( + items: &[T], + max_items: usize, + max_bytes: usize, +) -> Result>> { + if items.is_empty() { + return Ok(Vec::new()); + } + + let mut batches = Vec::new(); + let mut current = Vec::new(); + let mut current_bytes = 0usize; + + for item in items { + let size = serde_json::to_vec(item) + .map(|buf| buf.len()) + .context("serialising batch item for sizing")?; + + let would_overflow_items = !current.is_empty() && current.len() >= max_items; + let would_overflow_bytes = !current.is_empty() && current_bytes + size > max_bytes; + + if would_overflow_items || would_overflow_bytes { + batches.push(SizedBatch { + approx_bytes: current_bytes.max(1), + items: std::mem::take(&mut current), + }); + current_bytes = 0; + } + + current_bytes += size; + current.push(item.clone()); + } + + if !current.is_empty() { + batches.push(SizedBatch { + approx_bytes: current_bytes.max(1), + items: current, + }); + } + + Ok(batches) +} + +async fn execute_batched_inserts( + db: &SurrealDbClient, + statement: impl AsRef, + prefix: &str, + batches: &[SizedBatch], +) -> Result<()> { + if batches.is_empty() { + return Ok(()); + } + + let mut start = 0; + while start < batches.len() { + let mut group_bytes = 0usize; + let mut group_end = start; + let mut group_count = 0usize; + + while group_end < batches.len() { + let batch_bytes = batches[group_end].approx_bytes.max(1); + if group_count > 0 + && (group_bytes + batch_bytes > REQUEST_MAX_BYTES + || group_count >= MAX_BATCHES_PER_REQUEST) + { + break; + } + group_bytes += batch_bytes; + group_end += 1; + group_count += 1; + } + + let slice = &batches[start..group_end]; + let mut query = db.client.query("BEGIN TRANSACTION;"); + let mut bind_index = 0usize; + for batch in slice { + let name = format!("{prefix}{bind_index}"); + bind_index += 1; + query = query + .query(format!("{} ${};", statement.as_ref(), name)) + .bind((name, batch.items.clone())); + } + let response = query + .query("COMMIT TRANSACTION;") .await - .context("storing text_content from manifest")?; - for entity in ¶graph.entities { - db.upsert_item(entity.clone()) - .await - .context("storing knowledge_entity from manifest")?; - } - for relationship in ¶graph.relationships { - relationship - .store_relationship(db) - .await - .context("storing knowledge_relationship from manifest")?; - } - for chunk in ¶graph.chunks { - db.upsert_item(chunk.clone()) - .await - .context("storing text_chunk from manifest")?; + .context("executing batched insert transaction")?; + if let Err(err) = response.check() { + return Err(anyhow!( + "batched insert failed for statement '{}': {err:?}", + statement.as_ref() + )); } + + start = group_end; } Ok(()) } + +pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> { + let batches = build_manifest_batches(manifest).context("preparing manifest batches")?; + + let result = (|| async { + execute_batched_inserts( + db, + format!("INSERT INTO {}", TextContent::table_name()), + "tc", + &batches.text_contents, + ) + .await?; + + execute_batched_inserts( + db, + format!("INSERT INTO {}", KnowledgeEntity::table_name()), + "ke", + &batches.entities, + ) + .await?; + + execute_batched_inserts( + db, + format!("INSERT INTO {}", TextChunk::table_name()), + "ch", + &batches.chunks, + ) + .await?; + + execute_batched_inserts( + db, + "INSERT RELATION INTO relates_to", + "rel", + &batches.relationships, + ) + .await?; + + execute_batched_inserts( + db, + format!("INSERT INTO {}", KnowledgeEntityEmbedding::table_name()), + "kee", + &batches.entity_embeddings, + ) + .await?; + + execute_batched_inserts( + db, + format!("INSERT INTO {}", TextChunkEmbedding::table_name()), + "tce", + &batches.chunk_embeddings, + ) + .await?; + + Ok(()) + })() + .await; + + if result.is_err() { + // Best-effort cleanup to avoid leaving partial manifest data behind. + let _ = db + .client + .query( + "BEGIN TRANSACTION; + DELETE text_chunk_embedding; + DELETE knowledge_entity_embedding; + DELETE relates_to; + DELETE text_chunk; + DELETE knowledge_entity; + DELETE text_content; + COMMIT TRANSACTION;", + ) + .await; + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::db_helpers::change_embedding_length_in_hnsw_indexes; + use chrono::Utc; + use common::storage::types::knowledge_entity::KnowledgeEntityType; + use uuid::Uuid; + + fn build_manifest() -> CorpusManifest { + let user_id = "user-1".to_string(); + let source_id = "source-1".to_string(); + let now = Utc::now(); + let text_content_id = Uuid::new_v4().to_string(); + + let text_content = TextContent { + id: text_content_id.clone(), + created_at: now, + updated_at: now, + text: "Hello world".to_string(), + file_info: None, + url_info: None, + context: None, + category: "test".to_string(), + user_id: user_id.clone(), + }; + + let entity = KnowledgeEntity { + id: Uuid::new_v4().to_string(), + created_at: now, + updated_at: now, + source_id: source_id.clone(), + name: "Entity".to_string(), + description: "A test entity".to_string(), + entity_type: KnowledgeEntityType::Document, + metadata: None, + user_id: user_id.clone(), + }; + let relationship = KnowledgeRelationship::new( + format!("knowledge_entity:{}", entity.id), + format!("knowledge_entity:{}", entity.id), + user_id.clone(), + source_id.clone(), + "related".to_string(), + ); + + let chunk = TextChunk { + id: Uuid::new_v4().to_string(), + created_at: now, + updated_at: now, + source_id: source_id.clone(), + chunk: "chunk text".to_string(), + user_id: user_id.clone(), + }; + + let paragraph_one = CorpusParagraph { + paragraph_id: "p1".to_string(), + title: "Paragraph 1".to_string(), + text_content: text_content.clone(), + entities: vec![EmbeddedKnowledgeEntity { + entity: entity.clone(), + embedding: vec![0.1, 0.2, 0.3], + }], + relationships: vec![relationship], + chunks: vec![EmbeddedTextChunk { + chunk: chunk.clone(), + embedding: vec![0.3, 0.2, 0.1], + }], + }; + + // Duplicate content/entities should be de-duplicated by the loader. + let paragraph_two = CorpusParagraph { + paragraph_id: "p2".to_string(), + title: "Paragraph 2".to_string(), + text_content: text_content.clone(), + entities: vec![EmbeddedKnowledgeEntity { + entity: entity.clone(), + embedding: vec![0.1, 0.2, 0.3], + }], + relationships: Vec::new(), + chunks: vec![EmbeddedTextChunk { + chunk: chunk.clone(), + embedding: vec![0.3, 0.2, 0.1], + }], + }; + + let question = CorpusQuestion { + question_id: "q1".to_string(), + paragraph_id: paragraph_one.paragraph_id.clone(), + text_content_id: text_content_id, + question_text: "What is this?".to_string(), + answers: vec!["Hello".to_string()], + is_impossible: false, + matching_chunk_ids: vec![chunk.id.clone()], + }; + + CorpusManifest { + version: current_manifest_version(), + metadata: CorpusMetadata { + dataset_id: "dataset".to_string(), + dataset_label: "Dataset".to_string(), + slice_id: "slice".to_string(), + include_unanswerable: false, + require_verified_chunks: false, + ingestion_fingerprint: "fp".to_string(), + embedding_backend: "test".to_string(), + embedding_model: Some("model".to_string()), + embedding_dimension: 3, + converted_checksum: "checksum".to_string(), + generated_at: now, + paragraph_count: 2, + question_count: 1, + }, + paragraphs: vec![paragraph_one, paragraph_two], + questions: vec![question], + } + } + + #[tokio::test] + async fn seeds_manifest_with_transactional_batches() { + let namespace = "test_ns"; + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, &database) + .await + .expect("memory db"); + db.apply_migrations() + .await + .expect("apply migrations for memory db"); + change_embedding_length_in_hnsw_indexes(&db, 3) + .await + .expect("set embedding index dimension for test"); + + let manifest = build_manifest(); + seed_manifest_into_db(&db, &manifest) + .await + .expect("manifest seed should succeed"); + + let text_contents: Vec = db + .client + .query(format!("SELECT * FROM {};", TextContent::table_name())) + .await + .expect("select text_content") + .take(0) + .unwrap_or_default(); + assert_eq!(text_contents.len(), 1); + + let entities: Vec = db + .client + .query(format!("SELECT * FROM {};", KnowledgeEntity::table_name())) + .await + .expect("select knowledge_entity") + .take(0) + .unwrap_or_default(); + assert_eq!(entities.len(), 1); + + let chunks: Vec = db + .client + .query(format!("SELECT * FROM {};", TextChunk::table_name())) + .await + .expect("select text_chunk") + .take(0) + .unwrap_or_default(); + assert_eq!(chunks.len(), 1); + + let relationships: Vec = db + .client + .query("SELECT * FROM relates_to;") + .await + .expect("select relates_to") + .take(0) + .unwrap_or_default(); + assert_eq!(relationships.len(), 1); + + let entity_embeddings: Vec = db + .client + .query(format!( + "SELECT * FROM {};", + KnowledgeEntityEmbedding::table_name() + )) + .await + .expect("select knowledge_entity_embedding") + .take(0) + .unwrap_or_default(); + assert_eq!(entity_embeddings.len(), 1); + + let chunk_embeddings: Vec = db + .client + .query(format!( + "SELECT * FROM {};", + TextChunkEmbedding::table_name() + )) + .await + .expect("select text_chunk_embedding") + .take(0) + .unwrap_or_default(); + assert_eq!(chunk_embeddings.len(), 1); + } + + #[tokio::test] + async fn rolls_back_when_embeddings_mismatch_index_dimension() { + let namespace = "test_ns_rollback"; + let database = Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, &database) + .await + .expect("memory db"); + db.apply_migrations() + .await + .expect("apply migrations for memory db"); + + let manifest = build_manifest(); + let result = seed_manifest_into_db(&db, &manifest).await; + assert!( + result.is_err(), + "expected embedding dimension mismatch to fail" + ); + + let text_contents: Vec = db + .client + .query(format!("SELECT * FROM {};", TextContent::table_name())) + .await + .expect("select text_content") + .take(0) + .unwrap_or_default(); + let entities: Vec = db + .client + .query(format!("SELECT * FROM {};", KnowledgeEntity::table_name())) + .await + .expect("select knowledge_entity") + .take(0) + .unwrap_or_default(); + let chunks: Vec = db + .client + .query(format!("SELECT * FROM {};", TextChunk::table_name())) + .await + .expect("select text_chunk") + .take(0) + .unwrap_or_default(); + let relationships: Vec = db + .client + .query("SELECT * FROM relates_to;") + .await + .expect("select relates_to") + .take(0) + .unwrap_or_default(); + let entity_embeddings: Vec = db + .client + .query(format!( + "SELECT * FROM {};", + KnowledgeEntityEmbedding::table_name() + )) + .await + .expect("select knowledge_entity_embedding") + .take(0) + .unwrap_or_default(); + let chunk_embeddings: Vec = db + .client + .query(format!( + "SELECT * FROM {};", + TextChunkEmbedding::table_name() + )) + .await + .expect("select text_chunk_embedding") + .take(0) + .unwrap_or_default(); + + assert!( + text_contents.is_empty() + && entities.is_empty() + && chunks.is_empty() + && relationships.is_empty() + && entity_embeddings.is_empty() + && chunk_embeddings.is_empty(), + "no rows should be inserted when transaction fails" + ); + } + + #[test] + fn window_manifest_trims_questions_and_negatives() { + let manifest = build_manifest(); + // Add extra negatives to simulate multiplier ~4x + let mut manifest = manifest; + let mut extra_paragraphs = Vec::new(); + for _ in 0..8 { + let mut p = manifest.paragraphs[0].clone(); + p.paragraph_id = Uuid::new_v4().to_string(); + p.entities.clear(); + p.relationships.clear(); + p.chunks.clear(); + extra_paragraphs.push(p); + } + manifest.paragraphs.extend(extra_paragraphs); + manifest.metadata.paragraph_count = manifest.paragraphs.len(); + + let windowed = window_manifest(&manifest, 0, 1, 4.0).expect("window manifest"); + assert_eq!(windowed.questions.len(), 1); + // Expect roughly 4x negatives (bounded by available paragraphs) + assert!( + windowed.paragraphs.len() <= manifest.paragraphs.len(), + "windowed paragraphs should never exceed original" + ); + let positive_set: std::collections::HashSet<_> = windowed + .questions + .iter() + .map(|q| q.paragraph_id.as_str()) + .collect(); + let positives = windowed + .paragraphs + .iter() + .filter(|p| positive_set.contains(p.paragraph_id.as_str())) + .count(); + let negatives = windowed.paragraphs.len().saturating_sub(positives); + assert_eq!(positives, 1); + assert!(negatives >= 1, "should include some negatives"); + } +} diff --git a/eval/src/inspection.rs b/eval/src/inspection.rs index e2fb353..7b0cac0 100644 --- a/eval/src/inspection.rs +++ b/eval/src/inspection.rs @@ -121,13 +121,14 @@ fn build_chunk_lookup(manifest: &ingest::CorpusManifest) -> HashMap() .replace('\n', " "); lookup.insert( - chunk.id.clone(), + chunk.chunk.id.clone(), ChunkEntry { paragraph_title: paragraph.title.clone(), snippet, diff --git a/eval/src/snapshot.rs b/eval/src/snapshot.rs index 18440b9..e655939 100644 --- a/eval/src/snapshot.rs +++ b/eval/src/snapshot.rs @@ -6,7 +6,8 @@ use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use tokio::fs; -use crate::{args::Config, embedding::EmbeddingProvider, slice}; +use crate::{args::Config, slice}; +use common::utils::embedding::EmbeddingProvider; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SnapshotMetadata { diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index 2ee5c6e..fac71e3 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -191,11 +191,10 @@ pub async fn create_knowledge_entity( description.clone(), entity_type, None, - embedding, user.id.clone(), ); - state.db.store_item(new_entity.clone()).await?; + KnowledgeEntity::store_with_embedding(new_entity.clone(), embedding, &state.db).await?; let relationship_type = relationship_type_or_default(form.relationship_type.as_deref()); @@ -285,15 +284,16 @@ pub async fn suggest_knowledge_relationships( }; let config = retrieval_pipeline::RetrievalConfig::for_relationship_suggestion(); - if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) = retrieval_pipeline::retrieve_entities( - &state.db, - &state.openai_client, - &query, - &user.id, - config, - rerank_lease, - ) - .await + if let Ok(retrieval_pipeline::StrategyOutput::Entities(results)) = + retrieval_pipeline::retrieve_entities( + &state.db, + &state.openai_client, + &query, + &user.id, + config, + rerank_lease, + ) + .await { for retrieval_pipeline::RetrievedEntity { entity, score, .. } in results { if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS { diff --git a/ingestion-pipeline/src/pipeline/context.rs b/ingestion-pipeline/src/pipeline/context.rs index 8ccc3b5..825fea3 100644 --- a/ingestion-pipeline/src/pipeline/context.rs +++ b/ingestion-pipeline/src/pipeline/context.rs @@ -18,6 +18,18 @@ use super::enrichment_result::LLMEnrichmentResult; use super::{config::IngestionConfig, services::PipelineServices}; +#[derive(Debug, Clone)] +pub struct EmbeddedKnowledgeEntity { + pub entity: KnowledgeEntity, + pub embedding: Vec, +} + +#[derive(Debug, Clone)] +pub struct EmbeddedTextChunk { + pub chunk: TextChunk, + pub embedding: Vec, +} + pub struct PipelineContext<'a> { pub task: &'a IngestionTask, pub task_id: String, @@ -33,9 +45,9 @@ pub struct PipelineContext<'a> { #[derive(Debug)] pub struct PipelineArtifacts { pub text_content: TextContent, - pub entities: Vec, + pub entities: Vec, pub relationships: Vec, - pub chunks: Vec, + pub chunks: Vec, } impl<'a> PipelineContext<'a> { diff --git a/ingestion-pipeline/src/pipeline/enrichment_result.rs b/ingestion-pipeline/src/pipeline/enrichment_result.rs index e73b28a..3c6193c 100644 --- a/ingestion-pipeline/src/pipeline/enrichment_result.rs +++ b/ingestion-pipeline/src/pipeline/enrichment_result.rs @@ -4,6 +4,7 @@ use chrono::Utc; use futures::stream::{self, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; +use anyhow::Context; use common::{ error::AppError, storage::{ @@ -13,9 +14,10 @@ use common::{ knowledge_relationship::KnowledgeRelationship, }, }, - utils::embedding::generate_embedding, + utils::{embedding::generate_embedding, embedding::EmbeddingProvider}, }; +use crate::pipeline::context::EmbeddedKnowledgeEntity; use crate::utils::graph_mapper::GraphMapper; #[derive(Debug, Serialize, Deserialize, Clone)] @@ -48,7 +50,8 @@ impl LLMEnrichmentResult { openai_client: &async_openai::Client, db_client: &SurrealDbClient, entity_concurrency: usize, - ) -> Result<(Vec, Vec), AppError> { + embedding_provider: Option<&EmbeddingProvider>, + ) -> Result<(Vec, Vec), AppError> { let mapper = Arc::new(self.create_mapper()?); let entities = self @@ -59,6 +62,7 @@ impl LLMEnrichmentResult { openai_client, db_client, entity_concurrency, + embedding_provider, ) .await?; @@ -85,7 +89,8 @@ impl LLMEnrichmentResult { openai_client: &async_openai::Client, db_client: &SurrealDbClient, entity_concurrency: usize, - ) -> Result, AppError> { + embedding_provider: Option<&EmbeddingProvider>, + ) -> Result, AppError> { stream::iter(self.knowledge_entities.iter().cloned().map(|entity| { let mapper = Arc::clone(&mapper); let openai_client = openai_client.clone(); @@ -101,6 +106,7 @@ impl LLMEnrichmentResult { mapper, &openai_client, &db_client, + embedding_provider, ) .await } @@ -141,7 +147,8 @@ async fn create_single_entity( mapper: Arc, openai_client: &async_openai::Client, db_client: &SurrealDbClient, -) -> Result { + embedding_provider: Option<&EmbeddingProvider>, +) -> Result { let assigned_id = mapper.get_id(&llm_entity.key)?.to_string(); let embedding_input = format!( @@ -149,10 +156,17 @@ async fn create_single_entity( llm_entity.name, llm_entity.description, llm_entity.entity_type ); - let embedding = generate_embedding(openai_client, &embedding_input, db_client).await?; + let embedding = if let Some(provider) = embedding_provider { + provider + .embed(&embedding_input) + .await + .context("generating FastEmbed embedding for entity")? + } else { + generate_embedding(openai_client, &embedding_input, db_client).await? + }; let now = Utc::now(); - Ok(KnowledgeEntity { + let entity = KnowledgeEntity { id: assigned_id, created_at: now, updated_at: now, @@ -161,7 +175,8 @@ async fn create_single_entity( entity_type: KnowledgeEntityType::from(llm_entity.entity_type.to_string()), source_id: source_id.to_string(), metadata: None, - embedding, user_id: user_id.into(), - }) + }; + + Ok(EmbeddedKnowledgeEntity { entity, embedding }) } diff --git a/ingestion-pipeline/src/pipeline/mod.rs b/ingestion-pipeline/src/pipeline/mod.rs index 7b8becd..eaacd11 100644 --- a/ingestion-pipeline/src/pipeline/mod.rs +++ b/ingestion-pipeline/src/pipeline/mod.rs @@ -50,6 +50,7 @@ impl IngestionPipeline { config: AppConfig, reranker_pool: Option>, storage: StorageManager, + embedding_provider: Arc, ) -> Result { let services = DefaultPipelineServices::new( db.clone(), @@ -57,6 +58,7 @@ impl IngestionPipeline { config.clone(), reranker_pool, storage, + embedding_provider, ); Self::with_services(db, IngestionConfig::default(), Arc::new(services)) diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index 1b2e119..fbfa756 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -1,5 +1,6 @@ use std::{ops::Range, sync::Arc}; +use anyhow::Context; use async_openai::types::{ ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest, CreateChatCompletionRequestArgs, ResponseFormat, @@ -12,19 +13,18 @@ use common::{ db::SurrealDbClient, store::StorageManager, types::{ - ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity, - knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings, - text_chunk::TextChunk, text_content::TextContent, + ingestion_payload::IngestionPayload, knowledge_relationship::KnowledgeRelationship, + system_settings::SystemSettings, text_chunk::TextChunk, text_content::TextContent, + StoredObject, }, }, - utils::{config::AppConfig, embedding::generate_embedding}, -}; -use retrieval_pipeline::{ - reranking::RerankerPool, retrieved_entities_to_json, RetrievedEntity, + utils::{config::AppConfig, embedding::EmbeddingProvider}, }; +use retrieval_pipeline::{reranking::RerankerPool, retrieved_entities_to_json, RetrievedEntity}; use text_splitter::TextSplitter; use super::{enrichment_result::LLMEnrichmentResult, preparation::to_text_content}; +use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk}; use crate::utils::llm_instructions::{ get_ingress_analysis_schema, INGRESS_ANALYSIS_SYSTEM_MESSAGE, }; @@ -54,13 +54,13 @@ pub trait PipelineServices: Send + Sync { content: &TextContent, analysis: &LLMEnrichmentResult, entity_concurrency: usize, - ) -> Result<(Vec, Vec), AppError>; + ) -> Result<(Vec, Vec), AppError>; async fn prepare_chunks( &self, content: &TextContent, range: Range, - ) -> Result, AppError>; + ) -> Result, AppError>; } pub struct DefaultPipelineServices { @@ -69,6 +69,7 @@ pub struct DefaultPipelineServices { config: AppConfig, reranker_pool: Option>, storage: StorageManager, + embedding_provider: Arc, } impl DefaultPipelineServices { @@ -78,6 +79,7 @@ impl DefaultPipelineServices { config: AppConfig, reranker_pool: Option>, storage: StorageManager, + embedding_provider: Arc, ) -> Self { Self { db, @@ -85,6 +87,7 @@ impl DefaultPipelineServices { config, reranker_pool, storage, + embedding_provider, } } @@ -182,6 +185,7 @@ impl PipelineServices for DefaultPipelineServices { match retrieval_pipeline::retrieve_entities( &self.db, &self.openai_client, + // embedding_provider_ref, &input_text, &content.user_id, config, @@ -218,14 +222,15 @@ impl PipelineServices for DefaultPipelineServices { content: &TextContent, analysis: &LLMEnrichmentResult, entity_concurrency: usize, - ) -> Result<(Vec, Vec), AppError> { + ) -> Result<(Vec, Vec), AppError> { analysis .to_database_entities( - &content.id, + &content.get_id(), &content.user_id, &self.openai_client, &self.db, entity_concurrency, + Some(&*self.embedding_provider), ) .await } @@ -234,7 +239,7 @@ impl PipelineServices for DefaultPipelineServices { &self, content: &TextContent, range: Range, - ) -> Result, AppError> { + ) -> Result, AppError> { let splitter = TextSplitter::new(range.clone()); let chunk_texts: Vec = splitter .chunks(&content.text) @@ -243,13 +248,17 @@ impl PipelineServices for DefaultPipelineServices { let mut chunks = Vec::with_capacity(chunk_texts.len()); for chunk in chunk_texts { - let embedding = generate_embedding(&self.openai_client, &chunk, &self.db).await?; - chunks.push(TextChunk::new( - content.id.clone(), - chunk, + let embedding = self + .embedding_provider + .embed(&chunk) + .await + .context("generating FastEmbed embedding for chunk")?; + let chunk_struct = + TextChunk::new(content.get_id().to_string(), chunk, content.user_id.clone()); + chunks.push(EmbeddedTextChunk { + chunk: chunk_struct, embedding, - content.user_id.clone(), - )); + }); } Ok(chunks) } diff --git a/ingestion-pipeline/src/pipeline/stages/mod.rs b/ingestion-pipeline/src/pipeline/stages/mod.rs index 8f3085d..4891ecd 100644 --- a/ingestion-pipeline/src/pipeline/stages/mod.rs +++ b/ingestion-pipeline/src/pipeline/stages/mod.rs @@ -15,7 +15,7 @@ use tokio::time::{sleep, Duration}; use tracing::{debug, instrument, warn}; use super::{ - context::{PipelineArtifacts, PipelineContext}, + context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk, PipelineArtifacts, PipelineContext}, state::{ContentPrepared, Enriched, IngestionMachine, Persisted, Ready, Retrieved}, }; @@ -177,17 +177,21 @@ fn map_guard_error(event: &str, guard: GuardError) -> AppError { async fn store_graph_entities( db: &SurrealDbClient, tuning: &super::config::IngestionTuning, - entities: Vec, + entities: Vec, relationships: Vec, ) -> Result<(), AppError> { - const STORE_GRAPH_MUTATION: &str = r" - BEGIN TRANSACTION; - LET $entities = $entities; - LET $relationships = $relationships; + // Persist entities with embeddings first. + for embedded in entities { + KnowledgeEntity::store_with_embedding(embedded.entity, embedded.embedding, db).await?; + } - FOR $entity IN $entities { - CREATE type::thing('knowledge_entity', $entity.id) CONTENT $entity; - }; + if relationships.is_empty() { + return Ok(()); + } + + const STORE_RELATIONSHIPS: &str = r" + BEGIN TRANSACTION; + LET $relationships = $relationships; FOR $relationship IN $relationships { LET $in_node = type::thing('knowledge_entity', $relationship.in); @@ -201,7 +205,6 @@ async fn store_graph_entities( COMMIT TRANSACTION; "; - let entities = Arc::new(entities); let relationships = Arc::new(relationships); let mut backoff_ms = tuning.graph_initial_backoff_ms; @@ -209,8 +212,7 @@ async fn store_graph_entities( for attempt in 0..tuning.graph_store_attempts { let result = db .client - .query(STORE_GRAPH_MUTATION) - .bind(("entities", entities.clone())) + .query(STORE_RELATIONSHIPS) .bind(("relationships", relationships.clone())) .await; @@ -240,17 +242,17 @@ async fn store_graph_entities( async fn store_vector_chunks( db: &SurrealDbClient, task_id: &str, - chunks: &[TextChunk], + chunks: &[EmbeddedTextChunk], tuning: &super::config::IngestionTuning, ) -> Result { let chunk_count = chunks.len(); let batch_size = tuning.chunk_insert_concurrency.max(1); - for chunk in chunks { + for embedded in chunks { debug!( task_id = %task_id, - chunk_id = %chunk.id, - chunk_len = chunk.chunk.chars().count(), + chunk_id = %embedded.chunk.id, + chunk_len = embedded.chunk.chunk.chars().count(), "chunk persisted" ); } @@ -270,53 +272,17 @@ fn is_retryable_conflict(error: &surrealdb::Error) -> bool { async fn store_chunk_batch( db: &SurrealDbClient, - batch: &[TextChunk], - tuning: &super::config::IngestionTuning, + batch: &[EmbeddedTextChunk], + _tuning: &super::config::IngestionTuning, ) -> Result<(), AppError> { if batch.is_empty() { return Ok(()); } - const STORE_CHUNKS_MUTATION: &str = r" - BEGIN TRANSACTION; - LET $chunks = $chunks; - - FOR $chunk IN $chunks { - CREATE type::thing('text_chunk', $chunk.id) CONTENT $chunk; - }; - - COMMIT TRANSACTION; - "; - - let chunks = Arc::new(batch.to_vec()); - let mut backoff_ms = tuning.graph_initial_backoff_ms; - - for attempt in 0..tuning.graph_store_attempts { - let result = db - .client - .query(STORE_CHUNKS_MUTATION) - .bind(("chunks", chunks.clone())) - .await; - - match result { - Ok(_) => return Ok(()), - Err(err) => { - if is_retryable_conflict(&err) && attempt + 1 < tuning.graph_store_attempts { - warn!( - attempt = attempt + 1, - "Transient SurrealDB conflict while storing chunks; retrying" - ); - sleep(Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms * 2).min(tuning.graph_max_backoff_ms); - continue; - } - - return Err(AppError::from(err)); - } - } + for embedded in batch { + TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db) + .await?; } - Err(AppError::InternalError( - "Failed to store text chunks after retries".to_string(), - )) + Ok(()) } diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index cb6c206..5a7ff75 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use crate::pipeline::context::{EmbeddedKnowledgeEntity, EmbeddedTextChunk}; use async_trait::async_trait; use chrono::{Duration as ChronoDuration, Utc}; use common::{ @@ -32,7 +33,7 @@ struct MockServices { similar_entities: Vec, analysis: LLMEnrichmentResult, chunk_embedding: Vec, - graph_entities: Vec, + graph_entities: Vec, graph_relationships: Vec, calls: Mutex>, } @@ -54,14 +55,12 @@ impl MockServices { "Previously known context".into(), KnowledgeEntityType::Document, None, - vec![0.1; TEST_EMBEDDING_DIM], user_id.into(), ); let retrieved_chunk = TextChunk::new( retrieved_entity.source_id.clone(), "existing chunk".into(), - vec![0.1; TEST_EMBEDDING_DIM], user_id.into(), ); @@ -76,7 +75,6 @@ impl MockServices { "Entity from enrichment".into(), KnowledgeEntityType::Idea, None, - vec![0.2; TEST_EMBEDDING_DIM], user_id.into(), ); let graph_relationship = KnowledgeRelationship::new( @@ -99,7 +97,10 @@ impl MockServices { }], analysis, chunk_embedding: vec![0.3; TEST_EMBEDDING_DIM], - graph_entities: vec![graph_entity], + graph_entities: vec![EmbeddedKnowledgeEntity { + entity: graph_entity, + embedding: vec![0.2; TEST_EMBEDDING_DIM], + }], graph_relationships: vec![graph_relationship], calls: Mutex::new(Vec::new()), } @@ -142,7 +143,7 @@ impl PipelineServices for MockServices { _content: &TextContent, _analysis: &LLMEnrichmentResult, _entity_concurrency: usize, - ) -> Result<(Vec, Vec), AppError> { + ) -> Result<(Vec, Vec), AppError> { self.record("convert").await; Ok(( self.graph_entities.clone(), @@ -154,14 +155,16 @@ impl PipelineServices for MockServices { &self, content: &TextContent, _range: std::ops::Range, - ) -> Result, AppError> { + ) -> Result, AppError> { self.record("chunk").await; - Ok(vec![TextChunk::new( - content.id.clone(), - "chunk from mock services".into(), - self.chunk_embedding.clone(), - content.user_id.clone(), - )]) + Ok(vec![EmbeddedTextChunk { + chunk: TextChunk::new( + content.id.clone(), + "chunk from mock services".into(), + content.user_id.clone(), + ), + embedding: self.chunk_embedding.clone(), + }]) } } @@ -200,7 +203,7 @@ impl PipelineServices for FailingServices { content: &TextContent, analysis: &LLMEnrichmentResult, entity_concurrency: usize, - ) -> Result<(Vec, Vec), AppError> { + ) -> Result<(Vec, Vec), AppError> { self.inner .convert_analysis(content, analysis, entity_concurrency) .await @@ -210,7 +213,7 @@ impl PipelineServices for FailingServices { &self, content: &TextContent, range: std::ops::Range, - ) -> Result, AppError> { + ) -> Result, AppError> { self.inner.prepare_chunks(content, range).await } } @@ -244,7 +247,7 @@ impl PipelineServices for ValidationServices { _content: &TextContent, _analysis: &LLMEnrichmentResult, _entity_concurrency: usize, - ) -> Result<(Vec, Vec), AppError> { + ) -> Result<(Vec, Vec), AppError> { unreachable!("convert_analysis should not be called after validation failure") } @@ -252,7 +255,7 @@ impl PipelineServices for ValidationServices { &self, _content: &TextContent, _range: std::ops::Range, - ) -> Result, AppError> { + ) -> Result, AppError> { unreachable!("prepare_chunks should not be called after validation failure") } } diff --git a/main/src/main.rs b/main/src/main.rs index 21078d2..442dc00 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -1,7 +1,11 @@ use api_router::{api_routes_v1, api_state::ApiState}; use axum::{extract::FromRef, Router}; use common::{ - storage::db::SurrealDbClient, storage::store::StorageManager, utils::config::get_config, + storage::{ + db::SurrealDbClient, indexes::ensure_runtime_indexes, store::StorageManager, + types::system_settings::SystemSettings, + }, + utils::config::get_config, }; use html_router::{html_routes, html_state::HtmlState}; use ingestion_pipeline::{pipeline::IngestionPipeline, run_worker_loop}; @@ -38,6 +42,8 @@ async fn main() -> Result<(), Box> { // Ensure db is initialized db.apply_migrations().await?; + let settings = SystemSettings::get_current(&db).await?; + ensure_runtime_indexes(&db, settings.embedding_dimensions as usize).await?; let session_store = Arc::new(db.create_session_store().await?); let openai_client = Arc::new(async_openai::Client::with_config( @@ -106,6 +112,9 @@ async fn main() -> Result<(), Box> { .await .unwrap(), ); + let settings = SystemSettings::get_current(&worker_db) + .await + .expect("failed to load system settings"); // Initialize worker components let openai_client = Arc::new(async_openai::Client::with_config( @@ -113,13 +122,21 @@ async fn main() -> Result<(), Box> { .with_api_key(&config.openai_api_key) .with_api_base(&config.openai_base_url), )); + + // Create embedding provider for ingestion + let embedding_provider = Arc::new( + common::utils::embedding::EmbeddingProvider::new_fastembed(None) + .await + .expect("failed to create embedding provider"), + ); let ingestion_pipeline = Arc::new( IngestionPipeline::new( worker_db.clone(), openai_client.clone(), config.clone(), reranker_pool.clone(), - storage.clone(), // Use the global storage manager + storage.clone(), + embedding_provider, ) .await .unwrap(), diff --git a/main/src/worker.rs b/main/src/worker.rs index dde9621..e9d13f8 100644 --- a/main/src/worker.rs +++ b/main/src/worker.rs @@ -37,6 +37,10 @@ async fn main() -> Result<(), Box> { let reranker_pool = RerankerPool::maybe_from_config(&config)?; + // Create embedding provider for ingestion + let embedding_provider = + Arc::new(common::utils::embedding::EmbeddingProvider::new_fastembed(None).await?); + // Create global storage manager let storage = StorageManager::new(&config).await?; @@ -47,6 +51,7 @@ async fn main() -> Result<(), Box> { config, reranker_pool, storage, + embedding_provider, ) .await?, ); diff --git a/retrieval-pipeline/src/fts.rs b/retrieval-pipeline/src/fts.rs index 28ca217..5dd1b4d 100644 --- a/retrieval-pipeline/src/fts.rs +++ b/retrieval-pipeline/src/fts.rs @@ -123,10 +123,6 @@ mod tests { }; use uuid::Uuid; - fn dummy_embedding() -> Vec { - vec![0.0; 1536] - } - #[tokio::test] async fn fts_preserves_single_field_score_for_name() { let namespace = "fts_test_ns"; @@ -146,7 +142,6 @@ mod tests { "completely unrelated description".into(), KnowledgeEntityType::Document, None, - dummy_embedding(), user_id.into(), ); @@ -194,7 +189,6 @@ mod tests { "Detailed notes about async runtimes".into(), KnowledgeEntityType::Document, None, - dummy_embedding(), user_id.into(), ); @@ -239,11 +233,10 @@ mod tests { let chunk = TextChunk::new( "source_chunk".into(), "GraphQL documentation reference".into(), - dummy_embedding(), user_id.into(), ); - db.store_item(chunk.clone()) + TextChunk::store_with_embedding(chunk.clone(), vec![0.0; 1536], &db) .await .expect("failed to insert chunk"); diff --git a/retrieval-pipeline/src/graph.rs b/retrieval-pipeline/src/graph.rs index bf7ae37..494404d 100644 --- a/retrieval-pipeline/src/graph.rs +++ b/retrieval-pipeline/src/graph.rs @@ -171,7 +171,6 @@ mod tests { let source_id3 = "source789".to_string(); let entity_type = KnowledgeEntityType::Document; - let embedding = vec![0.1, 0.2, 0.3]; let user_id = "user123".to_string(); // Entity with source_id1 @@ -181,7 +180,6 @@ mod tests { "Description 1".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -192,7 +190,6 @@ mod tests { "Description 2".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -203,7 +200,6 @@ mod tests { "Description 3".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -214,7 +210,6 @@ mod tests { "Description 4".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -318,7 +313,6 @@ mod tests { // Create some test entities let entity_type = KnowledgeEntityType::Document; - let embedding = vec![0.1, 0.2, 0.3]; let user_id = "user123".to_string(); // Create the central entity we'll query relationships for @@ -328,7 +322,6 @@ mod tests { "Central Description".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -339,7 +332,6 @@ mod tests { "Related Description 1".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -349,7 +341,6 @@ mod tests { "Related Description 2".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); @@ -360,7 +351,6 @@ mod tests { "Unrelated Description".to_string(), entity_type.clone(), None, - embedding.clone(), user_id.clone(), ); diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index f1045e1..167581c 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -5,7 +5,6 @@ pub mod graph; pub mod pipeline; pub mod reranking; pub mod scoring; -pub mod vector; use common::{ error::AppError, @@ -57,6 +56,7 @@ pub async fn retrieve_entities( pipeline::run_pipeline( db_client, openai_client, + None, input_text, user_id, config, @@ -110,10 +110,10 @@ mod tests { db.query( "BEGIN TRANSACTION; - REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk; - DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3; - REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity; - DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3; + REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding; + DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 3; + REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding; + DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 3; COMMIT TRANSACTION;", ) .await @@ -132,20 +132,18 @@ mod tests { "Detailed notes about async runtimes".into(), KnowledgeEntityType::Document, None, - entity_embedding_high(), user_id.into(), ); let chunk = TextChunk::new( entity.source_id.clone(), "Tokio uses cooperative scheduling for fairness.".into(), - chunk_embedding_primary(), user_id.into(), ); - db.store_item(entity.clone()) + KnowledgeEntity::store_with_embedding(entity.clone(), entity_embedding_high(), &db) .await .expect("Failed to store entity"); - db.store_item(chunk.clone()) + TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db) .await .expect("Failed to store chunk"); @@ -153,6 +151,7 @@ mod tests { let results = pipeline::run_pipeline_with_embedding( &db, &openai_client, + None, test_embedding(), "Rust concurrency async tasks", user_id, @@ -193,7 +192,6 @@ mod tests { "Explores async runtimes and scheduling strategies.".into(), KnowledgeEntityType::Document, None, - entity_embedding_high(), user_id.into(), ); let neighbor = KnowledgeEntity::new( @@ -202,34 +200,31 @@ mod tests { "Details on Tokio's cooperative scheduler.".into(), KnowledgeEntityType::Document, None, - entity_embedding_low(), user_id.into(), ); - db.store_item(primary.clone()) + KnowledgeEntity::store_with_embedding(primary.clone(), entity_embedding_high(), &db) .await .expect("Failed to store primary entity"); - db.store_item(neighbor.clone()) + KnowledgeEntity::store_with_embedding(neighbor.clone(), entity_embedding_low(), &db) .await .expect("Failed to store neighbor entity"); let primary_chunk = TextChunk::new( primary.source_id.clone(), "Rust async tasks use Tokio's cooperative scheduler.".into(), - chunk_embedding_primary(), user_id.into(), ); let neighbor_chunk = TextChunk::new( neighbor.source_id.clone(), "Tokio's scheduler manages task fairness across executors.".into(), - chunk_embedding_secondary(), user_id.into(), ); - db.store_item(primary_chunk) + TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db) .await .expect("Failed to store primary chunk"); - db.store_item(neighbor_chunk) + TextChunk::store_with_embedding(neighbor_chunk, chunk_embedding_secondary(), &db) .await .expect("Failed to store neighbor chunk"); @@ -249,6 +244,7 @@ mod tests { let results = pipeline::run_pipeline_with_embedding( &db, &openai_client, + None, test_embedding(), "Rust concurrency async tasks", user_id, @@ -270,6 +266,8 @@ mod tests { } } + println!("{:?}", entities); + let neighbor_entry = neighbor_entry.expect("Graph-enriched neighbor should appear in results"); @@ -293,20 +291,18 @@ mod tests { let chunk_one = TextChunk::new( "src_alpha".into(), "Tokio tasks execute on worker threads managed by the runtime.".into(), - chunk_embedding_primary(), user_id.into(), ); let chunk_two = TextChunk::new( "src_beta".into(), "Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(), - chunk_embedding_secondary(), user_id.into(), ); - db.store_item(chunk_one.clone()) + TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db) .await .expect("Failed to store chunk one"); - db.store_item(chunk_two.clone()) + TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db) .await .expect("Failed to store chunk two"); @@ -315,6 +311,7 @@ mod tests { let results = pipeline::run_pipeline_with_embedding( &db, &openai_client, + None, test_embedding(), "tokio runtime worker behavior", user_id, diff --git a/retrieval-pipeline/src/pipeline/mod.rs b/retrieval-pipeline/src/pipeline/mod.rs index ebc3351..7a0d026 100644 --- a/retrieval-pipeline/src/pipeline/mod.rs +++ b/retrieval-pipeline/src/pipeline/mod.rs @@ -112,6 +112,7 @@ pub struct PipelineRunOutput { pub async fn run_pipeline( db_client: &SurrealDbClient, openai_client: &Client, + embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, input_text: &str, user_id: &str, config: RetrievalConfig, @@ -137,6 +138,7 @@ pub async fn run_pipeline( driver, db_client, openai_client, + embedding_provider, None, input_text, user_id, @@ -153,6 +155,7 @@ pub async fn run_pipeline( driver, db_client, openai_client, + embedding_provider, None, input_text, user_id, @@ -169,6 +172,7 @@ pub async fn run_pipeline( driver, db_client, openai_client, + embedding_provider, None, input_text, user_id, @@ -185,6 +189,7 @@ pub async fn run_pipeline( driver, db_client, openai_client, + embedding_provider, None, input_text, user_id, @@ -201,6 +206,7 @@ pub async fn run_pipeline( pub async fn run_pipeline_with_embedding( db_client: &SurrealDbClient, openai_client: &Client, + embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, query_embedding: Vec, input_text: &str, user_id: &str, @@ -214,6 +220,7 @@ pub async fn run_pipeline_with_embedding( driver, db_client, openai_client, + embedding_provider, Some(query_embedding), input_text, user_id, @@ -230,6 +237,7 @@ pub async fn run_pipeline_with_embedding( driver, db_client, openai_client, + embedding_provider, Some(query_embedding), input_text, user_id, @@ -246,6 +254,7 @@ pub async fn run_pipeline_with_embedding( driver, db_client, openai_client, + embedding_provider, Some(query_embedding), input_text, user_id, @@ -262,6 +271,7 @@ pub async fn run_pipeline_with_embedding( driver, db_client, openai_client, + embedding_provider, Some(query_embedding), input_text, user_id, @@ -283,6 +293,7 @@ pub async fn run_pipeline_with_embedding( pub async fn run_pipeline_with_embedding_with_metrics( db_client: &SurrealDbClient, openai_client: &Client, + embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, query_embedding: Vec, input_text: &str, user_id: &str, @@ -296,6 +307,7 @@ pub async fn run_pipeline_with_embedding_with_metrics( driver, db_client, openai_client, + embedding_provider, Some(query_embedding), input_text, user_id, @@ -316,6 +328,7 @@ pub async fn run_pipeline_with_embedding_with_metrics( driver, db_client, openai_client, + embedding_provider, Some(query_embedding), input_text, user_id, @@ -340,6 +353,7 @@ pub async fn run_pipeline_with_embedding_with_metrics( pub async fn run_pipeline_with_embedding_with_diagnostics( db_client: &SurrealDbClient, openai_client: &Client, + embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, query_embedding: Vec, input_text: &str, user_id: &str, @@ -353,6 +367,7 @@ pub async fn run_pipeline_with_embedding_with_diagnostics( driver, db_client, openai_client, + embedding_provider, Some(query_embedding), input_text, user_id, @@ -373,6 +388,7 @@ pub async fn run_pipeline_with_embedding_with_diagnostics( driver, db_client, openai_client, + embedding_provider, Some(query_embedding), input_text, user_id, @@ -419,6 +435,7 @@ async fn execute_strategy( driver: D, db_client: &SurrealDbClient, openai_client: &Client, + embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>, query_embedding: Option>, input_text: &str, user_id: &str, @@ -430,6 +447,7 @@ async fn execute_strategy( Some(embedding) => PipelineContext::with_embedding( db_client, openai_client, + embedding_provider, embedding, input_text.to_owned(), user_id.to_owned(), @@ -439,6 +457,7 @@ async fn execute_strategy( None => PipelineContext::new( db_client, openai_client, + embedding_provider, input_text.to_owned(), user_id.to_owned(), config, diff --git a/retrieval-pipeline/src/pipeline/stages/mod.rs b/retrieval-pipeline/src/pipeline/stages/mod.rs index 6f94d39..f7f9fed 100644 --- a/retrieval-pipeline/src/pipeline/stages/mod.rs +++ b/retrieval-pipeline/src/pipeline/stages/mod.rs @@ -6,7 +6,7 @@ use common::{ db::SurrealDbClient, types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject}, }, - utils::embedding::generate_embedding, + utils::{embedding::generate_embedding, embedding::EmbeddingProvider}, }; use fastembed::RerankResult; use futures::{stream::FuturesUnordered, StreamExt}; @@ -24,10 +24,6 @@ use crate::{ clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc, FusionWeights, Scored, }, - vector::{ - find_chunk_snippets_by_vector_similarity_with_embedding, - find_items_by_vector_similarity_with_embedding, ChunkSnippet, - }, RetrievedChunk, RetrievedEntity, }; @@ -43,6 +39,7 @@ use super::{ pub struct PipelineContext<'a> { pub db_client: &'a SurrealDbClient, pub openai_client: &'a Client, + pub embedding_provider: Option<&'a EmbeddingProvider>, pub input_text: String, pub user_id: String, pub config: RetrievalConfig, @@ -51,7 +48,7 @@ pub struct PipelineContext<'a> { pub chunk_candidates: HashMap>, pub filtered_entities: Vec>, pub chunk_values: Vec>, - pub revised_chunk_values: Vec>, + pub revised_chunk_values: Vec>, pub reranker: Option, pub diagnostics: Option, pub entity_results: Vec, @@ -63,6 +60,7 @@ impl<'a> PipelineContext<'a> { pub fn new( db_client: &'a SurrealDbClient, openai_client: &'a Client, + embedding_provider: Option<&'a EmbeddingProvider>, input_text: String, user_id: String, config: RetrievalConfig, @@ -71,6 +69,7 @@ impl<'a> PipelineContext<'a> { Self { db_client, openai_client, + embedding_provider, input_text, user_id, config, @@ -91,6 +90,7 @@ impl<'a> PipelineContext<'a> { pub fn with_embedding( db_client: &'a SurrealDbClient, openai_client: &'a Client, + embedding_provider: Option<&'a EmbeddingProvider>, query_embedding: Vec, input_text: String, user_id: String, @@ -100,6 +100,7 @@ impl<'a> PipelineContext<'a> { let mut ctx = Self::new( db_client, openai_client, + embedding_provider, input_text, user_id, config, @@ -299,8 +300,16 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Reusing cached query embedding for hybrid retrieval"); } else { debug!("Generating query embedding for hybrid retrieval"); - let embedding = - generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?; + let embedding = if let Some(provider) = ctx.embedding_provider { + provider.embed(&ctx.input_text).await.map_err(|e| { + AppError::InternalError(format!( + "Failed to generate embedding with provider: {}", + e + )) + })? + } else { + generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await? + }; ctx.query_embedding = Some(embedding); } @@ -315,19 +324,17 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App let weights = FusionWeights::default(); - let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!( - find_items_by_vector_similarity_with_embedding( + let (vector_entity_results, vector_chunk_results, mut fts_entities, mut fts_chunks) = tokio::try_join!( + KnowledgeEntity::vector_search( tuning.entity_vector_take, embedding.clone(), ctx.db_client, - "knowledge_entity", &ctx.user_id, ), - find_items_by_vector_similarity_with_embedding( + TextChunk::vector_search( tuning.chunk_vector_take, embedding, ctx.db_client, - "text_chunk", &ctx.user_id, ), find_items_by_fts( @@ -346,6 +353,15 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App ), )?; + let vector_entities: Vec> = vector_entity_results + .into_iter() + .map(|row| Scored::new(row.entity).with_vector_score(row.score)) + .collect(); + let vector_chunks: Vec> = vector_chunk_results + .into_iter() + .map(|row| Scored::new(row.chunk).with_vector_score(row.score)) + .collect(); + debug!( vector_entities = vector_entities.len(), vector_chunks = vector_chunks.len(), @@ -419,14 +435,15 @@ pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> } for neighbor in neighbors { - if neighbor.id == seed.id { + let neighbor_id = neighbor.id.clone(); + if neighbor_id == seed.id { continue; } let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay); let entry = ctx .entity_candidates - .entry(neighbor.id.clone()) + .entry(neighbor_id.clone()) .or_insert_with(|| Scored::new(neighbor.clone())); entry.item = neighbor; @@ -490,8 +507,6 @@ pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError ctx.filtered_entities = filtered_entities; - let query_embedding = ctx.ensure_embedding()?.clone(); - let mut chunk_results: Vec> = ctx.chunk_candidates.values().cloned().collect(); sort_by_fused_desc(&mut chunk_results); @@ -507,7 +522,6 @@ pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError ctx.db_client, &ctx.user_id, weights, - &query_embedding, ) .await?; @@ -579,13 +593,23 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), debug!("Collecting vector chunk candidates for revised strategy"); let embedding = ctx.ensure_embedding()?.clone(); let tuning = &ctx.config.tuning; - let mut vector_chunks = find_chunk_snippets_by_vector_similarity_with_embedding( + let weights = FusionWeights::default(); + + let mut vector_chunks: Vec> = TextChunk::vector_search( tuning.chunk_vector_take, embedding, ctx.db_client, &ctx.user_id, ) - .await?; + .await? + .into_iter() + .map(|row| { + let mut scored = Scored::new(row.chunk).with_vector_score(row.score); + let fused = fuse_scores(&scored.scores, weights); + scored.update_fused(fused); + scored + }) + .collect(); if ctx.diagnostics_enabled() { ctx.record_collect_candidates(CollectCandidatesStats { @@ -617,7 +641,7 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError return Ok(()); }; - let documents = build_snippet_rerank_documents( + let documents = build_chunk_rerank_documents( &ctx.revised_chunk_values, ctx.config.tuning.rerank_keep_top.max(1), ); @@ -628,11 +652,7 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError match reranker.rerank(&ctx.input_text, documents).await { Ok(results) if !results.is_empty() => { - apply_snippet_rerank_results( - &mut ctx.revised_chunk_values, - &ctx.config.tuning, - results, - ); + apply_chunk_rerank_results(&mut ctx.revised_chunk_values, &ctx.config.tuning, results); } Ok(_) => debug!("Chunk reranker returned no results; retaining original order"), Err(err) => warn!( @@ -649,7 +669,7 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Assembling chunk-only retrieval results"); let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values); let question_terms = extract_keywords(&ctx.input_text); - rank_snippet_chunks_by_combined_score( + rank_chunks_by_combined_score( &mut chunk_values, &question_terms, ctx.config.tuning.lexical_match_weight, @@ -662,12 +682,9 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { ctx.chunk_results = chunk_values .into_iter() - .map(|chunk| { - let text_chunk = snippet_into_text_chunk(chunk.item, &ctx.user_id); - RetrievedChunk { - chunk: text_chunk, - score: chunk.fused, - } + .map(|chunk| RetrievedChunk { + chunk: chunk.item, + score: chunk.fused, }) .collect(); @@ -691,7 +708,6 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Assembling final retrieved entities"); let tuning = &ctx.config.tuning; - let query_embedding = ctx.ensure_embedding()?.clone(); let question_terms = extract_keywords(&ctx.input_text); let mut chunk_by_source: HashMap>> = HashMap::new(); @@ -704,9 +720,8 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { for chunk_list in chunk_by_source.values_mut() { chunk_list.sort_by(|a, b| { - let sim_a = cosine_similarity(&query_embedding, &a.item.embedding); - let sim_b = cosine_similarity(&query_embedding, &b.item.embedding); - sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal) + // No base-table embeddings; order by fused score only. + b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal) }); } @@ -930,7 +945,6 @@ async fn enrich_chunks_from_entities( db_client: &SurrealDbClient, user_id: &str, weights: FusionWeights, - query_embedding: &[f32], ) -> Result<(), AppError> { let mut source_ids: HashSet = HashSet::new(); for entity in entities { @@ -964,16 +978,7 @@ async fn enrich_chunks_from_entities( .copied() .unwrap_or(0.0); - let similarity = cosine_similarity(query_embedding, &chunk.embedding); - - entry.scores.vector = Some( - entry - .scores - .vector - .unwrap_or(0.0) - .max(entity_score * 0.8) - .max(similarity), - ); + entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8)); let fused = fuse_scores(&entry.scores, weights); entry.update_fused(fused); entry.item = chunk; @@ -982,24 +987,6 @@ async fn enrich_chunks_from_entities( Ok(()) } -fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 { - if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() { - return 0.0; - } - let mut dot = 0.0f32; - let mut norm_q = 0.0f32; - let mut norm_e = 0.0f32; - for (q, e) in query.iter().zip(embedding.iter()) { - dot += q * e; - norm_q += q * q; - norm_e += e * e; - } - if norm_q == 0.0 || norm_e == 0.0 { - return 0.0; - } - dot / (norm_q.sqrt() * norm_e.sqrt()) -} - fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec { if ctx.filtered_entities.is_empty() { return Vec::new(); @@ -1050,10 +1037,7 @@ fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usiz .collect() } -fn build_snippet_rerank_documents( - chunks: &[Scored], - max_chunks: usize, -) -> Vec { +fn build_chunk_rerank_documents(chunks: &[Scored], max_chunks: usize) -> Vec { chunks .iter() .take(max_chunks) @@ -1124,8 +1108,8 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec>, +fn apply_chunk_rerank_results( + chunks: &mut Vec>, tuning: &RetrievalTuning, results: Vec, ) { @@ -1133,7 +1117,7 @@ fn apply_snippet_rerank_results( return; } - let mut remaining: Vec>> = + let mut remaining: Vec>> = std::mem::take(chunks).into_iter().map(Some).collect(); let raw_scores: Vec = results.iter().map(|r| r.score).collect(); @@ -1146,7 +1130,7 @@ fn apply_snippet_rerank_results( clamp_unit(tuning.rerank_blend_weight) }; - let mut reranked: Vec> = Vec::with_capacity(remaining.len()); + let mut reranked: Vec> = Vec::with_capacity(remaining.len()); for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) { if let Some(slot) = remaining.get_mut(result.index) { if let Some(mut candidate) = slot.take() { @@ -1217,32 +1201,6 @@ fn extract_keywords(text: &str) -> Vec { terms } -fn rank_snippet_chunks_by_combined_score( - candidates: &mut [Scored], - question_terms: &[String], - lexical_weight: f32, -) { - if lexical_weight > 0.0 && !question_terms.is_empty() { - for candidate in candidates.iter_mut() { - let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk); - let combined = clamp_unit(candidate.fused + lexical_weight * lexical); - candidate.update_fused(combined); - } - } - candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)); -} - -fn snippet_into_text_chunk(snippet: ChunkSnippet, user_id: &str) -> TextChunk { - let mut chunk = TextChunk::new( - snippet.source_id.clone(), - snippet.chunk, - Vec::new(), - user_id.to_owned(), - ); - chunk.id = snippet.id; - chunk -} - fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 { if terms.is_empty() { return 0.0; diff --git a/retrieval-pipeline/src/vector.rs b/retrieval-pipeline/src/vector.rs deleted file mode 100644 index 94a3514..0000000 --- a/retrieval-pipeline/src/vector.rs +++ /dev/null @@ -1,218 +0,0 @@ -use std::collections::HashMap; - -use common::{ - error::AppError, - storage::{ - db::SurrealDbClient, - types::{file_info::deserialize_flexible_id, StoredObject}, - }, - utils::embedding::generate_embedding, -}; -use serde::Deserialize; -use surrealdb::sql::Thing; - -use crate::scoring::{clamp_unit, distance_to_similarity, Scored}; - -/// Compares vectors and retrieves a number of items from the specified table. -/// -/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database, -/// and then deserializes the results into the specified type `T`. -/// -/// # Arguments -/// -/// * `take` - The number of items to retrieve from the database. -/// * `input_text` - The text to generate embeddings for. -/// * `db_client` - The SurrealDB client to use for querying the database. -/// * `table` - The table to query in the database. -/// * `openai_client` - The OpenAI client to use for generating embeddings. -/// * 'user_id`- The user id of the current user. -/// -/// # Returns -/// -/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs. -/// -/// # Type Parameters -/// -/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`. -pub async fn find_items_by_vector_similarity( - take: usize, - input_text: &str, - db_client: &SurrealDbClient, - table: &str, - openai_client: &async_openai::Client, - user_id: &str, -) -> Result>, AppError> -where - T: for<'de> serde::Deserialize<'de> + StoredObject, -{ - // Generate embeddings - let input_embedding = generate_embedding(openai_client, input_text, db_client).await?; - find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id) - .await -} - -#[derive(Debug, Deserialize)] -struct DistanceRow { - #[serde(deserialize_with = "deserialize_flexible_id")] - id: String, - distance: Option, -} - -pub async fn find_items_by_vector_similarity_with_embedding( - take: usize, - query_embedding: Vec, - db_client: &SurrealDbClient, - table: &str, - user_id: &str, -) -> Result>, AppError> -where - T: for<'de> serde::Deserialize<'de> + StoredObject, -{ - let embedding_literal = serde_json::to_string(&query_embedding) - .map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?; - - let closest_query = format!( - "SELECT id, vector::distance::knn() AS distance \ - FROM {table} \ - WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \ - LIMIT $limit", - table = table, - take = take, - embedding = embedding_literal - ); - - let mut response = db_client - .query(closest_query) - .bind(("user_id", user_id.to_owned())) - .bind(("limit", take as i64)) - .await?; - - let distance_rows: Vec = response.take(0)?; - - if distance_rows.is_empty() { - return Ok(Vec::new()); - } - - let ids: Vec = distance_rows.iter().map(|row| row.id.clone()).collect(); - let thing_ids: Vec = ids - .iter() - .map(|id| Thing::from((table, id.as_str()))) - .collect(); - - let mut items_response = db_client - .query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id") - .bind(("table", table.to_owned())) - .bind(("things", thing_ids.clone())) - .bind(("user_id", user_id.to_owned())) - .await?; - - let items: Vec = items_response.take(0)?; - - let mut item_map: HashMap = items - .into_iter() - .map(|item| (item.get_id().to_owned(), item)) - .collect(); - - let mut min_distance = f32::MAX; - let mut max_distance = f32::MIN; - - for row in &distance_rows { - if let Some(distance) = row.distance { - if distance.is_finite() { - if distance < min_distance { - min_distance = distance; - } - if distance > max_distance { - max_distance = distance; - } - } - } - } - - let normalize = min_distance.is_finite() - && max_distance.is_finite() - && (max_distance - min_distance).abs() > f32::EPSILON; - - let mut scored = Vec::with_capacity(distance_rows.len()); - for row in distance_rows { - if let Some(item) = item_map.remove(&row.id) { - let similarity = row - .distance - .map(|distance| { - if normalize { - let span = max_distance - min_distance; - if span.abs() < f32::EPSILON { - 1.0 - } else { - let normalized = 1.0 - ((distance - min_distance) / span); - clamp_unit(normalized) - } - } else { - distance_to_similarity(distance) - } - }) - .unwrap_or_default(); - scored.push(Scored::new(item).with_vector_score(similarity)); - } - } - - Ok(scored) -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ChunkSnippet { - pub id: String, - pub source_id: String, - pub chunk: String, -} - -#[derive(Debug, Deserialize)] -struct ChunkDistanceRow { - distance: Option, - #[serde(deserialize_with = "deserialize_flexible_id")] - pub id: String, - pub source_id: String, - pub chunk: String, -} - -pub async fn find_chunk_snippets_by_vector_similarity_with_embedding( - take: usize, - query_embedding: Vec, - db_client: &SurrealDbClient, - user_id: &str, -) -> Result>, AppError> { - let embedding_literal = serde_json::to_string(&query_embedding) - .map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?; - - let closest_query = format!( - "SELECT id, source_id, chunk, vector::distance::knn() AS distance \ - FROM text_chunk \ - WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \ - LIMIT $limit", - take = take, - embedding = embedding_literal - ); - - let mut response = db_client - .query(closest_query) - .bind(("user_id", user_id.to_owned())) - .bind(("limit", take as i64)) - .await?; - - let rows: Vec = response.take(0)?; - - let mut scored = Vec::with_capacity(rows.len()); - for row in rows { - let similarity = row.distance.map(distance_to_similarity).unwrap_or_default(); - scored.push( - Scored::new(ChunkSnippet { - id: row.id, - source_id: row.source_id, - chunk: row.chunk, - }) - .with_vector_score(similarity), - ); - } - - Ok(scored) -}