mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-25 03:16:26 +02:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bddb603bd9 | |||
| aee9136f1e | |||
| 253ac9819d | |||
| 511e42a078 | |||
| d8e839bf46 | |||
| 588e616baf | |||
| 87e6fa14b2 | |||
| 09e545816e | |||
| 01ef1bcb7a | |||
| 530cd0a8f1 | |||
| 3b20adc50f | |||
| b3d42d2586 | |||
| fb51a8b55f | |||
| adc04d8c6d | |||
| 1013035731 | |||
| cf69cb7b05 | |||
| ead17530bd | |||
| 4e8a58fff1 |
+1
-1
@@ -1,2 +1,2 @@
|
|||||||
[alias]
|
[alias]
|
||||||
eval = "run -p evaluations --"
|
eval = "run -p evaluations --release --"
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
# Git stuff
|
|
||||||
.git/
|
|
||||||
.gitignore
|
|
||||||
.github
|
|
||||||
|
|
||||||
# Node build artifacts
|
|
||||||
**/node_modules/
|
|
||||||
|
|
||||||
# Nix/Devenv environment files
|
|
||||||
.direnv/
|
|
||||||
.devenv/
|
|
||||||
devenv.lock
|
|
||||||
devenv.nix
|
|
||||||
devenv.yaml
|
|
||||||
docker-compose.yml
|
|
||||||
.envrc
|
|
||||||
.devenv.flake.nix
|
|
||||||
flake.lock
|
|
||||||
flake.nix
|
|
||||||
|
|
||||||
# Rust build artifacts (crucial for multi-stage builds)
|
|
||||||
**/target/
|
|
||||||
|
|
||||||
# Runtime data directories
|
|
||||||
data/
|
|
||||||
database/
|
|
||||||
|
|
||||||
# Local environment config (sensitive)
|
|
||||||
.env
|
|
||||||
|
|
||||||
# IDE specific
|
|
||||||
.vscode/
|
|
||||||
.idea/
|
|
||||||
|
|
||||||
# OS specific
|
|
||||||
.DS_Store
|
|
||||||
Thumbs.db
|
|
||||||
|
|
||||||
# Logs / Temporary files
|
|
||||||
*.log
|
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
name: CI
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
actions: write
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check:
|
||||||
|
name: Format, lint, build & test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- uses: DeterminateSystems/determinate-nix-action@v3
|
||||||
|
|
||||||
|
- uses: nix-community/cache-nix-action@v7
|
||||||
|
with:
|
||||||
|
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||||
|
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||||
|
gc-max-store-size-linux: 10G
|
||||||
|
|
||||||
|
- name: Check formatting, clippy lint, unit tests & ort version
|
||||||
|
run: nix flake check --show-trace
|
||||||
@@ -7,7 +7,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '**[0-9]+.[0-9]+.[0-9]+*'
|
- "**[0-9]+.[0-9]+.[0-9]+*"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
plan:
|
plan:
|
||||||
@@ -17,6 +17,7 @@ jobs:
|
|||||||
tag: ${{ !github.event.pull_request && github.ref_name || '' }}
|
tag: ${{ !github.event.pull_request && github.ref_name || '' }}
|
||||||
tag-flag: ${{ !github.event.pull_request && format('--tag={0}', github.ref_name) || '' }}
|
tag-flag: ${{ !github.event.pull_request && format('--tag={0}', github.ref_name) || '' }}
|
||||||
publishing: ${{ !github.event.pull_request }}
|
publishing: ${{ !github.event.pull_request }}
|
||||||
|
ort-version: ${{ steps.ort_version.outputs.value }}
|
||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
@@ -25,13 +26,20 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Install Nix
|
- name: Install Nix
|
||||||
uses: cachix/install-nix-action@v27
|
uses: DeterminateSystems/determinate-nix-action@v3
|
||||||
|
|
||||||
|
- uses: nix-community/cache-nix-action@v7
|
||||||
with:
|
with:
|
||||||
extra_nix_config: |
|
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||||
experimental-features = nix-command flakes
|
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||||
|
gc-max-store-size-linux: 10G
|
||||||
|
|
||||||
|
- name: Read ORT version from flake
|
||||||
|
id: ort_version
|
||||||
|
run: echo "value=$(nix eval .#lib.ortVersion --raw)" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
- name: Verify ort-version matches nixpkgs onnxruntime
|
- name: Verify ort-version matches nixpkgs onnxruntime
|
||||||
run: nix flake check --system x86_64-linux -L
|
run: nix flake check --system x86_64-linux
|
||||||
|
|
||||||
- name: Install dist
|
- name: Install dist
|
||||||
shell: bash
|
shell: bash
|
||||||
@@ -78,7 +86,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Load ONNX Runtime version
|
- name: Load ONNX Runtime version
|
||||||
shell: bash
|
shell: bash
|
||||||
run: echo "ORT_VER=$(tr -d '[:space:]' < ort-version)" >> "$GITHUB_ENV"
|
run: echo "ORT_VER=${{ needs.plan.outputs.ort-version }}" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
- name: Install Rust non-interactively if not already installed
|
- name: Install Rust non-interactively if not already installed
|
||||||
if: ${{ matrix.container }}
|
if: ${{ matrix.container }}
|
||||||
@@ -108,7 +116,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
mkdir -p lib
|
mkdir -p lib
|
||||||
rm -f lib/*
|
rm -f lib/*
|
||||||
|
|
||||||
# Windows PowerShell
|
# Windows PowerShell
|
||||||
- name: Prepare lib dir (Windows)
|
- name: Prepare lib dir (Windows)
|
||||||
if: runner.os == 'Windows'
|
if: runner.os == 'Windows'
|
||||||
@@ -158,7 +166,6 @@ jobs:
|
|||||||
echo "lib/ contents:"
|
echo "lib/ contents:"
|
||||||
ls -l lib || dir lib
|
ls -l lib || dir lib
|
||||||
# ===== END: Injected ORT staging =====
|
# ===== END: Injected ORT staging =====
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
${{ matrix.packages_install }}
|
${{ matrix.packages_install }}
|
||||||
@@ -186,21 +193,31 @@ jobs:
|
|||||||
${{ env.BUILD_MANIFEST_NAME }}
|
${{ env.BUILD_MANIFEST_NAME }}
|
||||||
|
|
||||||
build_and_push_docker_image:
|
build_and_push_docker_image:
|
||||||
name: Build and Push Docker Image
|
name: Build and Push Docker Image (Nix)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [plan]
|
needs: [plan]
|
||||||
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
id-token: write
|
||||||
packages: write
|
packages: write
|
||||||
|
actions: write
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- uses: actions/checkout@v4
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Install Nix
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: DeterminateSystems/determinate-nix-action@v3
|
||||||
|
|
||||||
|
- uses: nix-community/cache-nix-action@v7
|
||||||
|
with:
|
||||||
|
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||||
|
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||||
|
gc-max-store-size-linux: 10G
|
||||||
|
|
||||||
|
- name: Build Docker image with Nix
|
||||||
|
run: nix build .#dockerImage -L --show-trace
|
||||||
|
|
||||||
- name: Log in to GitHub Container Registry
|
- name: Log in to GitHub Container Registry
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
@@ -215,15 +232,16 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
images: ghcr.io/${{ github.repository }}
|
images: ghcr.io/${{ github.repository }}
|
||||||
|
|
||||||
- name: Build and push Docker image
|
- name: Load and push Docker image
|
||||||
uses: docker/build-push-action@v5
|
env:
|
||||||
with:
|
IMAGE_NAME: ghcr.io/${{ github.repository }}
|
||||||
context: .
|
IMAGE_TAG: ${{ needs.plan.outputs.tag }}
|
||||||
push: true
|
run: |
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
docker load < result
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
docker tag "minne:1.0.3" "$IMAGE_NAME:$IMAGE_TAG"
|
||||||
cache-from: type=gha
|
docker tag "minne:1.0.3" "$IMAGE_NAME:latest"
|
||||||
cache-to: type=gha,mode=max
|
docker push "$IMAGE_NAME:$IMAGE_TAG"
|
||||||
|
docker push "$IMAGE_NAME:latest"
|
||||||
|
|
||||||
build-global-artifacts:
|
build-global-artifacts:
|
||||||
needs: [plan, build-local-artifacts]
|
needs: [plan, build-local-artifacts]
|
||||||
|
|||||||
+41
-1
@@ -1,7 +1,31 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
|
- Infra: CI workflow fixes. CI is now a nix flake check which includes compilation, caching and running tests, clippy, fmt, validation for ort version.
|
||||||
|
- Docker-compose: The example now references the ghcr image, this is so we can remove the Dockerfile and reducing maintenance scope.
|
||||||
|
- Refactor: web scraping now uses `servo-fetch` (pure-Rust Servo engine) and PDF rendering uses `pdfium-render` (direct PDFium bindings) — reduces Docker image size by ~300MB, improves startup latency by ~100× for PDF rendering, and provides more stable output
|
||||||
|
- Fix: added `pkgs.libglvnd` to `LD_LIBRARY_PATH` in devenv so Servo engine can find `libEGL.so` at runtime
|
||||||
|
- Fix: updated Dockerfile to add `libegl1 libegl-mesa0 libgles2 libfontconfig1 libfreetype6` runtime dependencies for servo-fetch
|
||||||
|
- Docs: updated architecture, features, and installation docs to reflect the new web processing stack
|
||||||
|
- Fix: added pre-commit hooks to further maintain code consistency.
|
||||||
|
- Security: updated some deps because dependabot told me, good bot.
|
||||||
|
- Security: bump `async-openai` to 0.41.1 (feature-gated types, transcription API rename; removes `backoff` transitive dep)
|
||||||
|
- Refactor: deduplicated test database setup across common/src/storage/.
|
||||||
|
- Refactor: split knowledge-graph.js monolith into focused functions.
|
||||||
|
- Evaluations: simplified crate layout — linear pipeline, sharded-only converted store, in-memory ingestion, `db/` and `cli/` modules; namespace reuse state in corpus manifest (removed `cache/snapshots/`); no legacy JSON/history compatibility (re-run `--warm` after upgrade)
|
||||||
|
- Performance: ingestion skips per-task index rebuild; worker runs scheduled `REBUILD INDEX` (default every 24h via `index_rebuild_interval_secs`, `0` disables)
|
||||||
|
- Performance: ingestion persists all artifacts in a single SurrealDB transaction per task (atomic replace by task id)
|
||||||
|
- Performance: entity embeddings during ingestion use batched `embed_batch`, matching chunk embedding
|
||||||
|
- Fix: ingestion reclaims tasks after a successful persist without re-running the pipeline when `mark_succeeded` failed
|
||||||
|
- Fix: content deletion clears graph relationships via shared `TextContent::clear_ingested_children`
|
||||||
|
- Fix: regression re suggestion of relationships
|
||||||
|
- Internal: extracted duplicate entity+embedding patterns into `HasEmbedding` and `EmbeddingRecord` traits with generic `store_with_embedding`, `delete_by_source_id`, and `vector_search` on `SurrealDbClient`.
|
||||||
|
- Infra: `ort-version` file removed — version inlined in `flake.nix` and `devenv.nix`; `release.yml` reads it via `nix eval .#lib.ortVersion` from the plan job
|
||||||
|
- Infra: `screenshot-graph.webp` and `.dockerignore` deleted — stale artifacts from Dockerfile era
|
||||||
|
|
||||||
## 1.0.3 (2026-06-12)
|
## 1.0.3 (2026-06-12)
|
||||||
|
|
||||||
- Search: filter results by type — knowledge entities, ingested content, or both
|
- Search: filter results by type — knowledge entities, ingested content, or both
|
||||||
- Admin: choose the local FastEmbed model from the admin UI; changes save immediately and apply after restart (re-embeds when the vector dimension changes)
|
- Admin: choose the local FastEmbed model from the admin UI; changes save immediately and apply after restart (re-embeds when the vector dimension changes)
|
||||||
- Performance: pooled FastEmbed workers and batched embedding generation for faster ingestion and search
|
- Performance: pooled FastEmbed workers and batched embedding generation for faster ingestion and search
|
||||||
@@ -11,6 +35,7 @@
|
|||||||
- Fix: API key revocation now correctly clears the stored key
|
- Fix: API key revocation now correctly clears the stored key
|
||||||
|
|
||||||
## 1.0.2 (2026-02-15)
|
## 1.0.2 (2026-02-15)
|
||||||
|
|
||||||
- Fix: edge case where navigation back to a chat page could trigger a new response generation
|
- Fix: edge case where navigation back to a chat page could trigger a new response generation
|
||||||
- Fix: chat references now validate and render more reliably
|
- Fix: chat references now validate and render more reliably
|
||||||
- Fix: improved admin access checks for restricted routes
|
- Fix: improved admin access checks for restricted routes
|
||||||
@@ -19,73 +44,88 @@
|
|||||||
- Security: hardened query handling and ingestion logging to reduce injection and data exposure risk
|
- Security: hardened query handling and ingestion logging to reduce injection and data exposure risk
|
||||||
|
|
||||||
## 1.0.1 (2026-02-11)
|
## 1.0.1 (2026-02-11)
|
||||||
|
|
||||||
- Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments.
|
- Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments.
|
||||||
- Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling.
|
- Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling.
|
||||||
- Fixed edge cases, including content deletion behavior and compatibility for older user records.
|
- Fixed edge cases, including content deletion behavior and compatibility for older user records.
|
||||||
|
|
||||||
## 1.0.0 (2026-01-02)
|
## 1.0.0 (2026-01-02)
|
||||||
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
|
|
||||||
|
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
|
||||||
- Added a benchmarks create for evaluating the retrieval process
|
- Added a benchmarks create for evaluating the retrieval process
|
||||||
- Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms.
|
- Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms.
|
||||||
- Embeddings stored on own table.
|
- Embeddings stored on own table.
|
||||||
- Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details.
|
- Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details.
|
||||||
|
|
||||||
## Version 0.2.7 (2025-12-04)
|
## Version 0.2.7 (2025-12-04)
|
||||||
|
|
||||||
- Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features.
|
- Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features.
|
||||||
- Fix: timezone aware info in scratchpad
|
- Fix: timezone aware info in scratchpad
|
||||||
|
|
||||||
## Version 0.2.6 (2025-10-29)
|
## Version 0.2.6 (2025-10-29)
|
||||||
|
|
||||||
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||||
- Fix: default name for relationships harmonized across application
|
- Fix: default name for relationships harmonized across application
|
||||||
|
|
||||||
## Version 0.2.5 (2025-10-24)
|
## Version 0.2.5 (2025-10-24)
|
||||||
|
|
||||||
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
||||||
- Scratchpad feature, with the feature to convert scratchpads to content.
|
- Scratchpad feature, with the feature to convert scratchpads to content.
|
||||||
- Added knowledge entity search results to the global search
|
- Added knowledge entity search results to the global search
|
||||||
- Backend fixes for improved performance when ingesting and retrieval
|
- Backend fixes for improved performance when ingesting and retrieval
|
||||||
|
|
||||||
## Version 0.2.4 (2025-10-15)
|
## Version 0.2.4 (2025-10-15)
|
||||||
|
|
||||||
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
|
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
|
||||||
- Ingestion task archive
|
- Ingestion task archive
|
||||||
|
|
||||||
## Version 0.2.3 (2025-10-12)
|
## Version 0.2.3 (2025-10-12)
|
||||||
|
|
||||||
- Fix changing vector dimensions on a fresh database (#3)
|
- Fix changing vector dimensions on a fresh database (#3)
|
||||||
|
|
||||||
## Version 0.2.2 (2025-10-07)
|
## Version 0.2.2 (2025-10-07)
|
||||||
|
|
||||||
- Support for ingestion of PDF files
|
- Support for ingestion of PDF files
|
||||||
- Improved ingestion speed
|
- Improved ingestion speed
|
||||||
- Fix deletion of items work as expected
|
- Fix deletion of items work as expected
|
||||||
- Fix enabling GPT-5 use via OpenAI API
|
- Fix enabling GPT-5 use via OpenAI API
|
||||||
|
|
||||||
## Version 0.2.1 (2025-09-24)
|
## Version 0.2.1 (2025-09-24)
|
||||||
|
|
||||||
- Fixed API JSON responses so iOS Shortcuts integrations keep working.
|
- Fixed API JSON responses so iOS Shortcuts integrations keep working.
|
||||||
|
|
||||||
## Version 0.2.0 (2025-09-23)
|
## Version 0.2.0 (2025-09-23)
|
||||||
|
|
||||||
- Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph.
|
- Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph.
|
||||||
- Added pagination for entities and content plus new observability metrics on the dashboard.
|
- Added pagination for entities and content plus new observability metrics on the dashboard.
|
||||||
- Enabled audio ingestion and merged the new storage backend.
|
- Enabled audio ingestion and merged the new storage backend.
|
||||||
- Improved performance, request filtering, and journalctl/systemd compatibility.
|
- Improved performance, request filtering, and journalctl/systemd compatibility.
|
||||||
|
|
||||||
## Version 0.1.4 (2025-07-01)
|
## Version 0.1.4 (2025-07-01)
|
||||||
|
|
||||||
- Added image ingestion with configurable system settings and updated Docker Compose docs.
|
- Added image ingestion with configurable system settings and updated Docker Compose docs.
|
||||||
- Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses.
|
- Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses.
|
||||||
|
|
||||||
## Version 0.1.3 (2025-06-08)
|
## Version 0.1.3 (2025-06-08)
|
||||||
|
|
||||||
- Added support for AI providers beyond OpenAI.
|
- Added support for AI providers beyond OpenAI.
|
||||||
- Made the HTTP port configurable for deployments.
|
- Made the HTTP port configurable for deployments.
|
||||||
- Smoothed graph mapper failures, long content tiles, and refreshed project documentation.
|
- Smoothed graph mapper failures, long content tiles, and refreshed project documentation.
|
||||||
|
|
||||||
## Version 0.1.2 (2025-05-26)
|
## Version 0.1.2 (2025-05-26)
|
||||||
|
|
||||||
- Introduced full-text search across indexed knowledge.
|
- Introduced full-text search across indexed knowledge.
|
||||||
- Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling.
|
- Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling.
|
||||||
- Fixed search result links and SurrealDB vector formatting glitches.
|
- Fixed search result links and SurrealDB vector formatting glitches.
|
||||||
|
|
||||||
## Version 0.1.1 (2025-05-13)
|
## Version 0.1.1 (2025-05-13)
|
||||||
|
|
||||||
- Added streaming feedback to ingestion tasks for clearer progress updates.
|
- Added streaming feedback to ingestion tasks for clearer progress updates.
|
||||||
- Made the data storage path configurable.
|
- Made the data storage path configurable.
|
||||||
- Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes.
|
- Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes.
|
||||||
|
|
||||||
## Version 0.1.0 (2025-05-06)
|
## Version 0.1.0 (2025-05-06)
|
||||||
|
|
||||||
- Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage.
|
- Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage.
|
||||||
- Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts.
|
- Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts.
|
||||||
- Introduced an admin console with analytics, registration and timezone controls, and job monitoring.
|
- Introduced an admin console with analytics, registration and timezone controls, and job monitoring.
|
||||||
|
|||||||
Generated
+6203
-364
File diff suppressed because it is too large
Load Diff
+26
-13
@@ -7,19 +7,24 @@ members = [
|
|||||||
"ingestion-pipeline",
|
"ingestion-pipeline",
|
||||||
"retrieval-pipeline",
|
"retrieval-pipeline",
|
||||||
"json-stream-parser",
|
"json-stream-parser",
|
||||||
"evaluations"
|
"evaluations",
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "3"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
anyhow = "1.0.94"
|
anyhow = "1.0.94"
|
||||||
async-openai = "0.29.3"
|
async-openai = { version = "0.41.1", features = [
|
||||||
|
"chat-completion",
|
||||||
|
"embedding",
|
||||||
|
"audio",
|
||||||
|
"model",
|
||||||
|
] }
|
||||||
async-stream = "0.3.6"
|
async-stream = "0.3.6"
|
||||||
async-trait = "0.1.88"
|
async-trait = "0.1.88"
|
||||||
axum-htmx = "0.7.0"
|
axum-htmx = "0.7.0"
|
||||||
axum_session = "0.16"
|
axum_session = "0.18"
|
||||||
axum_session_auth = "0.16"
|
axum_session_auth = "0.18"
|
||||||
axum_session_surreal = "0.4"
|
axum_session_surreal = "0.6"
|
||||||
axum_typed_multipart = "0.16"
|
axum_typed_multipart = "0.16"
|
||||||
axum = { version = "0.8", features = ["multipart", "macros"] }
|
axum = { version = "0.8", features = ["multipart", "macros"] }
|
||||||
chrono-tz = "0.10.1"
|
chrono-tz = "0.10.1"
|
||||||
@@ -27,7 +32,6 @@ chrono = { version = "0.4.39", features = ["serde"] }
|
|||||||
config = "0.15.4"
|
config = "0.15.4"
|
||||||
dom_smoothie = "0.10.0"
|
dom_smoothie = "0.10.0"
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
headless_chrome = "1.0.17"
|
|
||||||
include_dir = "0.7.4"
|
include_dir = "0.7.4"
|
||||||
mime = "0.3.17"
|
mime = "0.3.17"
|
||||||
mime_guess = "2.0.5"
|
mime_guess = "2.0.5"
|
||||||
@@ -35,12 +39,12 @@ minijinja-autoreload = "2.5.0"
|
|||||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||||
minijinja-embed = { version = "2.8.0" }
|
minijinja-embed = { version = "2.8.0" }
|
||||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||||
reqwest = {version = "0.12.12", features = ["charset", "json"]}
|
reqwest = { version = "0.12.12", features = ["charset", "json"] }
|
||||||
serde_json = "1.0.128"
|
serde_json = "1.0.128"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
sha2 = "0.10.8"
|
sha2 = "0.10.8"
|
||||||
surrealdb-migrations = "2.2.2"
|
surrealdb-migrations = "2.4.0"
|
||||||
surrealdb = { version = "2" }
|
surrealdb = { version = "2.6" }
|
||||||
tempfile = "3.12.0"
|
tempfile = "3.12.0"
|
||||||
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
|
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
|
||||||
tokenizers = { version = "0.20.4", features = ["http"] }
|
tokenizers = { version = "0.20.4", features = ["http"] }
|
||||||
@@ -61,14 +65,24 @@ bytes = "1.7.1"
|
|||||||
state-machines = "0.9"
|
state-machines = "0.9"
|
||||||
pdf-extract = "0.9"
|
pdf-extract = "0.9"
|
||||||
lopdf = "0.32"
|
lopdf = "0.32"
|
||||||
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
|
pdfium-auto = "0.3"
|
||||||
|
pdfium-render = "0.8"
|
||||||
|
servo-fetch = "0.13"
|
||||||
|
tendril = "0.4"
|
||||||
|
image = { version = "0.25", default-features = false, features = ["png"] }
|
||||||
|
fastembed = { version = "5.2.0", default-features = false, features = [
|
||||||
|
"hf-hub-native-tls",
|
||||||
|
"ort-load-dynamic",
|
||||||
|
] }
|
||||||
|
|
||||||
[profile.dist]
|
[profile.dist]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
lto = "thin"
|
lto = "thin"
|
||||||
|
|
||||||
[workspace.lints.rust]
|
[workspace.lints.rust]
|
||||||
unexpected_cfgs = { level = "warn", check-cfg = ["cfg(feature, values(\"inspect\"))"] }
|
unexpected_cfgs = { level = "warn", check-cfg = [
|
||||||
|
"cfg(feature, values(\"inspect\"))",
|
||||||
|
] }
|
||||||
|
|
||||||
[workspace.lints.clippy]
|
[workspace.lints.clippy]
|
||||||
# Performance-focused lints
|
# Performance-focused lints
|
||||||
@@ -118,4 +132,3 @@ needless_raw_string_hashes = "allow"
|
|||||||
multiple_bound_locations = "allow"
|
multiple_bound_locations = "allow"
|
||||||
cargo_common_metadata = "allow"
|
cargo_common_metadata = "allow"
|
||||||
multiple-crate-versions = "allow"
|
multiple-crate-versions = "allow"
|
||||||
|
|
||||||
|
|||||||
-53
@@ -1,53 +0,0 @@
|
|||||||
# === Builder ===
|
|
||||||
FROM rust:1.91.1-bookworm AS builder
|
|
||||||
WORKDIR /usr/src/minne
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Cache deps
|
|
||||||
COPY Cargo.toml Cargo.lock ./
|
|
||||||
RUN mkdir -p api-router common retrieval-pipeline html-router ingestion-pipeline json-stream-parser main worker
|
|
||||||
COPY api-router/Cargo.toml ./api-router/
|
|
||||||
COPY common/Cargo.toml ./common/
|
|
||||||
COPY retrieval-pipeline/Cargo.toml ./retrieval-pipeline/
|
|
||||||
COPY html-router/Cargo.toml ./html-router/
|
|
||||||
COPY ingestion-pipeline/Cargo.toml ./ingestion-pipeline/
|
|
||||||
COPY json-stream-parser/Cargo.toml ./json-stream-parser/
|
|
||||||
COPY main/Cargo.toml ./main/
|
|
||||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker || true
|
|
||||||
|
|
||||||
# Build
|
|
||||||
COPY . .
|
|
||||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker
|
|
||||||
|
|
||||||
# === Runtime ===
|
|
||||||
FROM debian:bookworm-slim
|
|
||||||
|
|
||||||
# Chromium + runtime deps + OpenMP for ORT
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
chromium libnss3 libasound2 libgbm1 libxshmfence1 \
|
|
||||||
ca-certificates fonts-dejavu fonts-noto-color-emoji \
|
|
||||||
libgomp1 libstdc++6 curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# ONNX Runtime (CPU). Version is read from ort-version (override with --build-arg ORT_VERSION=...).
|
|
||||||
COPY ort-version /tmp/ort-version
|
|
||||||
ARG ORT_VERSION
|
|
||||||
RUN ORT_VERSION="${ORT_VERSION:-$(tr -d '[:space:]' < /tmp/ort-version)}" && \
|
|
||||||
mkdir -p /opt/onnxruntime && \
|
|
||||||
curl -fsSL -o /tmp/ort.tgz \
|
|
||||||
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
|
||||||
tar -xzf /tmp/ort.tgz -C /opt/onnxruntime --strip-components=1 && rm /tmp/ort.tgz
|
|
||||||
|
|
||||||
ENV CHROME_BIN=/usr/bin/chromium \
|
|
||||||
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \
|
|
||||||
ORT_DYLIB_PATH=/opt/onnxruntime/lib/libonnxruntime.so
|
|
||||||
|
|
||||||
# Non-root
|
|
||||||
RUN useradd -m appuser
|
|
||||||
USER appuser
|
|
||||||
WORKDIR /home/appuser
|
|
||||||
|
|
||||||
COPY --from=builder /usr/src/minne/target/release/main /usr/local/bin/main
|
|
||||||
EXPOSE 3000
|
|
||||||
CMD ["main"]
|
|
||||||
@@ -121,7 +121,7 @@ fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults t
|
|||||||
- **Frontend:** HTML with HTMX and minimal JavaScript for interactivity
|
- **Frontend:** HTML with HTMX and minimal JavaScript for interactivity
|
||||||
- **Database:** SurrealDB (graph, document, and vector search)
|
- **Database:** SurrealDB (graph, document, and vector search)
|
||||||
- **AI Integration:** OpenAI-compatible API with structured outputs
|
- **AI Integration:** OpenAI-compatible API with structured outputs
|
||||||
- **Web Processing:** Headless Chrome for robust webpage content extraction
|
- **Web Processing:** Embedded Servo engine (servo-fetch) for webpage content extraction + PDFium for PDF rendering
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
@@ -172,7 +172,7 @@ cd minne
|
|||||||
docker compose up -d
|
docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
The included `docker-compose.yml` handles SurrealDB and Chromium dependencies automatically.
|
The included `docker-compose.yml` handles SurrealDB automatically.
|
||||||
|
|
||||||
### 2. Nix
|
### 2. Nix
|
||||||
|
|
||||||
@@ -180,13 +180,13 @@ The included `docker-compose.yml` handles SurrealDB and Chromium dependencies au
|
|||||||
nix run 'github:perstarkse/minne#main'
|
nix run 'github:perstarkse/minne#main'
|
||||||
```
|
```
|
||||||
|
|
||||||
This fetches Minne and all dependencies, including Chromium.
|
This fetches Minne and all dependencies.
|
||||||
|
|
||||||
### 3. Pre-built Binaries
|
### 3. Pre-built Binaries
|
||||||
|
|
||||||
Download binaries for Windows, macOS, and Linux from the [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
Download binaries for Windows, macOS, and Linux from the [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||||
|
|
||||||
**Requirements:** You'll need to provide SurrealDB and Chromium separately.
|
**Requirements:** You'll need to provide SurrealDB separately.
|
||||||
|
|
||||||
### 4. Build from Source
|
### 4. Build from Source
|
||||||
|
|
||||||
@@ -196,7 +196,7 @@ cd minne
|
|||||||
cargo run --release --bin main
|
cargo run --release --bin main
|
||||||
```
|
```
|
||||||
|
|
||||||
**Requirements:** SurrealDB and Chromium must be installed and accessible in your PATH.
|
**Requirements:** SurrealDB must be installed and accessible in your PATH.
|
||||||
|
|
||||||
## Application Architecture
|
## Application Architecture
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -24,7 +24,7 @@ dom_smoothie = { workspace = true }
|
|||||||
axum_session = { workspace = true }
|
axum_session = { workspace = true }
|
||||||
axum_session_auth = { workspace = true }
|
axum_session_auth = { workspace = true }
|
||||||
axum_session_surreal = { workspace = true}
|
axum_session_surreal = { workspace = true}
|
||||||
axum_typed_multipart = { workspace = true}
|
axum_typed_multipart = { workspace = true}
|
||||||
include_dir = { workspace = true }
|
include_dir = { workspace = true }
|
||||||
minijinja = { workspace = true }
|
minijinja = { workspace = true }
|
||||||
minijinja-autoreload = { workspace = true }
|
minijinja-autoreload = { workspace = true }
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
-- Track scheduled runtime index rebuild state on the system_settings singleton.
|
||||||
|
|
||||||
|
DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -201,6 +201,10 @@\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;\n+DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;\n+DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;\n+DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}
|
||||||
@@ -14,3 +14,7 @@ DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;
|
|||||||
DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;
|
DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;
|
||||||
DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;
|
DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;
|
||||||
DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;
|
DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;
|
||||||
|
|||||||
+157
-119
@@ -1,9 +1,11 @@
|
|||||||
use super::types::StoredObject;
|
use super::types::{EmbeddingRecord, HasEmbedding, StoredObject};
|
||||||
use crate::error::AppError;
|
use crate::error::AppError;
|
||||||
use axum_session::{SessionConfig, SessionError, SessionStore};
|
use axum_session::{SessionConfig, SessionError, SessionStore};
|
||||||
use axum_session_surreal::SessionSurrealPool;
|
use axum_session_surreal::SessionSurrealPool;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use include_dir::{include_dir, Dir};
|
use include_dir::{include_dir, Dir};
|
||||||
|
use serde::de::DeserializeOwned;
|
||||||
|
use serde::Serialize;
|
||||||
use std::{ops::Deref, sync::Arc};
|
use std::{ops::Deref, sync::Arc};
|
||||||
use surrealdb::{
|
use surrealdb::{
|
||||||
engine::any::{connect, Any},
|
engine::any::{connect, Any},
|
||||||
@@ -26,20 +28,6 @@ pub trait ProvidesDb {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SurrealDbClient {
|
impl SurrealDbClient {
|
||||||
/// Initialize a new database client.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
|
|
||||||
/// * `username` — Root username for authentication.
|
|
||||||
/// * `password` — Root password for authentication.
|
|
||||||
/// * `namespace` — SurrealDB namespace to use.
|
|
||||||
/// * `database` — SurrealDB database to use.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
|
|
||||||
/// In-memory (`mem://`) connections skip authentication.
|
|
||||||
pub async fn new(
|
pub async fn new(
|
||||||
address: &str,
|
address: &str,
|
||||||
username: &str,
|
username: &str,
|
||||||
@@ -49,30 +37,15 @@ impl SurrealDbClient {
|
|||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
let db = connect(address).await?;
|
let db = connect(address).await?;
|
||||||
|
|
||||||
// Skip sign-in for in-memory engine (no auth support)
|
|
||||||
if !address.starts_with("mem://") {
|
if !address.starts_with("mem://") {
|
||||||
db.signin(Root { username, password }).await?;
|
db.signin(Root { username, password }).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set namespace
|
|
||||||
db.use_ns(namespace).use_db(database).await?;
|
db.use_ns(namespace).use_db(database).await?;
|
||||||
|
|
||||||
Ok(SurrealDbClient { client: db })
|
Ok(SurrealDbClient { client: db })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize a new database client using namespace-level authentication.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `address` — Database connection string.
|
|
||||||
/// * `namespace` — SurrealDB namespace to use (also used for auth).
|
|
||||||
/// * `username` — Namespace username for authentication.
|
|
||||||
/// * `password` — Namespace password for authentication.
|
|
||||||
/// * `database` — SurrealDB database to use.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
|
|
||||||
pub async fn new_with_namespace_user(
|
pub async fn new_with_namespace_user(
|
||||||
address: &str,
|
address: &str,
|
||||||
namespace: &str,
|
namespace: &str,
|
||||||
@@ -91,11 +64,6 @@ impl SurrealDbClient {
|
|||||||
Ok(SurrealDbClient { client: db })
|
Ok(SurrealDbClient { client: db })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an Axum session store backed by SurrealDB.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `SessionError` if the session store configuration or table creation fails.
|
|
||||||
pub async fn create_session_store(
|
pub async fn create_session_store(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||||
@@ -109,15 +77,6 @@ impl SurrealDbClient {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies all pending database migrations found in the embedded MIGRATIONS_DIR.
|
|
||||||
///
|
|
||||||
/// This function should be called during application startup, after connecting to
|
|
||||||
/// the database and selecting the appropriate namespace and database, but before
|
|
||||||
/// the application starts performing operations that rely on the schema.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
|
|
||||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||||
debug!("Applying migrations");
|
debug!("Applying migrations");
|
||||||
MigrationRunner::new(&self.client)
|
MigrationRunner::new(&self.client)
|
||||||
@@ -129,15 +88,6 @@ impl SurrealDbClient {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store an object in SurrealDB.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `item` — The item to store. Must implement `StoredObject`.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `Err` if the database create operation fails.
|
|
||||||
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: StoredObject + Send + Sync + 'static,
|
T: StoredObject + Send + Sync + 'static,
|
||||||
@@ -148,13 +98,6 @@ impl SurrealDbClient {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
|
|
||||||
///
|
|
||||||
/// Useful for idempotent ingestion flows.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `Err` if the database upsert operation fails.
|
|
||||||
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: StoredObject + Send + Sync + 'static,
|
T: StoredObject + Send + Sync + 'static,
|
||||||
@@ -166,11 +109,6 @@ impl SurrealDbClient {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieve all objects from a table.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `Err` if the database select operation fails.
|
|
||||||
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -178,16 +116,6 @@ impl SurrealDbClient {
|
|||||||
self.client.select(T::table_name()).await
|
self.client.select(T::table_name()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieve a single object by its ID.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `id` — The ID of the item to retrieve.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `Err` if the database select operation fails.
|
|
||||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
|
||||||
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -195,16 +123,6 @@ impl SurrealDbClient {
|
|||||||
self.client.select((T::table_name(), id)).await
|
self.client.select((T::table_name(), id)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Delete a single object by its ID.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `id` — The ID of the item to delete.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `Err` if the database delete operation fails.
|
|
||||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
|
||||||
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -212,11 +130,6 @@ impl SurrealDbClient {
|
|||||||
self.client.delete((T::table_name(), id)).await
|
self.client.delete((T::table_name(), id)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Listen to a table for real-time updates via a live query stream.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// Returns `Err` if the database live query subscription fails.
|
|
||||||
pub async fn listen<T>(
|
pub async fn listen<T>(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
||||||
@@ -225,6 +138,156 @@ impl SurrealDbClient {
|
|||||||
{
|
{
|
||||||
self.client.select(T::table_name()).live().await
|
self.client.select(T::table_name()).live().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Atomically store an entity and its embedding vector in a single
|
||||||
|
/// SurrealDB transaction.
|
||||||
|
///
|
||||||
|
/// Creates (or overwrites) the entity row and upserts the linked
|
||||||
|
/// embedding record. The embedding dimension is validated against
|
||||||
|
/// `embedding_dimensions` before the query is issued.
|
||||||
|
pub async fn store_with_embedding<E>(
|
||||||
|
&self,
|
||||||
|
entity: E,
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
embedding_dimensions: usize,
|
||||||
|
) -> Result<(), AppError>
|
||||||
|
where
|
||||||
|
E: HasEmbedding + Serialize + Send + Sync + 'static,
|
||||||
|
<E as HasEmbedding>::Embedding: Serialize + Send + Sync,
|
||||||
|
{
|
||||||
|
E::Embedding::validate_dimension(&embedding, embedding_dimensions)?;
|
||||||
|
|
||||||
|
let entity_id = entity.id().to_string();
|
||||||
|
let emb = <E as HasEmbedding>::Embedding::new(
|
||||||
|
&entity_id,
|
||||||
|
entity.source_id().to_string(),
|
||||||
|
embedding,
|
||||||
|
entity.user_id().to_string(),
|
||||||
|
E::table_name(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let sql = format!(
|
||||||
|
"
|
||||||
|
BEGIN TRANSACTION;
|
||||||
|
CREATE type::thing('{et}', $id) CONTENT $entity;
|
||||||
|
UPSERT type::thing('{emt}', $id) CONTENT $emb;
|
||||||
|
COMMIT TRANSACTION;
|
||||||
|
",
|
||||||
|
et = E::table_name(),
|
||||||
|
emt = <E as HasEmbedding>::Embedding::table_name(),
|
||||||
|
);
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.query(sql)
|
||||||
|
.bind(("id", entity_id))
|
||||||
|
.bind(("entity", entity))
|
||||||
|
.bind(("emb", emb))
|
||||||
|
.await?
|
||||||
|
.check()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete all entity and embedding rows matching a given `source_id`.
|
||||||
|
///
|
||||||
|
/// Runs inside a SurrealDB transaction so that entity and embedding
|
||||||
|
/// deletes are atomic.
|
||||||
|
pub async fn delete_by_source_id<E>(&self, source_id: &str) -> Result<(), AppError>
|
||||||
|
where
|
||||||
|
E: HasEmbedding,
|
||||||
|
E::Embedding: Send + Sync,
|
||||||
|
{
|
||||||
|
self.client
|
||||||
|
.query("BEGIN TRANSACTION;")
|
||||||
|
.query(format!(
|
||||||
|
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||||
|
E::Embedding::table_name()
|
||||||
|
))
|
||||||
|
.query(format!(
|
||||||
|
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||||
|
E::table_name()
|
||||||
|
))
|
||||||
|
.query("COMMIT TRANSACTION;")
|
||||||
|
.bind(("source_id", source_id.to_owned()))
|
||||||
|
.await?
|
||||||
|
.check()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Vector similarity search over entities using HNSW index.
|
||||||
|
///
|
||||||
|
/// Performs a cosine-similarity search against the embedding table,
|
||||||
|
/// fetches the corresponding entity rows server-side via `FETCH`,
|
||||||
|
/// and returns `(entity, score)` pairs ordered by descending
|
||||||
|
/// similarity. Orphaned embeddings (entity deleted but its
|
||||||
|
/// embedding row remains) are logged as a warning and dropped.
|
||||||
|
///
|
||||||
|
/// This is a single round-trip — SurrealDB resolves the link field
|
||||||
|
/// (`entity_id` or `chunk_id`) inside the query engine.
|
||||||
|
pub async fn vector_search<E, Emb>(
|
||||||
|
&self,
|
||||||
|
take: usize,
|
||||||
|
query_embedding: &[f32],
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<Vec<(E, f32)>, AppError>
|
||||||
|
where
|
||||||
|
E: StoredObject + DeserializeOwned + Clone + Send + Sync,
|
||||||
|
Emb: EmbeddingRecord + Send + Sync,
|
||||||
|
{
|
||||||
|
// Generic row that works with both `entity_id` and `chunk_id` link
|
||||||
|
// fields via `#[serde(alias)]`. SurrealDB's `FETCH` resolves the link
|
||||||
|
// server-side so we get the full entity in a single round-trip.
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct FetchRow<Ent> {
|
||||||
|
score: f32,
|
||||||
|
#[serde(alias = "entity_id", alias = "chunk_id")]
|
||||||
|
entity: Option<Ent>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let link_field = Emb::link_field();
|
||||||
|
let sql = format!(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
{link_field},
|
||||||
|
vector::similarity::cosine(embedding, $embedding) AS score
|
||||||
|
FROM {emb_table}
|
||||||
|
WHERE user_id = $user_id
|
||||||
|
AND embedding <|{take},100|> $embedding
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT {take}
|
||||||
|
FETCH {link_field}
|
||||||
|
"#,
|
||||||
|
link_field = link_field,
|
||||||
|
emb_table = Emb::table_name(),
|
||||||
|
take = take,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut response = self
|
||||||
|
.client
|
||||||
|
.query(sql)
|
||||||
|
.bind(("embedding", query_embedding.to_vec()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
response = response.check()?;
|
||||||
|
|
||||||
|
let rows: Vec<FetchRow<E>> = response.take(0)?;
|
||||||
|
|
||||||
|
let mut results = Vec::with_capacity(rows.len());
|
||||||
|
for r in rows {
|
||||||
|
if let Some(entity) = r.entity {
|
||||||
|
results.push((entity, r.score));
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"Vector search hit orphaned {} row with missing {link_field}",
|
||||||
|
Emb::table_name()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for SurrealDbClient {
|
impl Deref for SurrealDbClient {
|
||||||
@@ -237,12 +300,9 @@ impl Deref for SurrealDbClient {
|
|||||||
|
|
||||||
#[cfg(any(test, feature = "test-utils"))]
|
#[cfg(any(test, feature = "test-utils"))]
|
||||||
impl SurrealDbClient {
|
impl SurrealDbClient {
|
||||||
/// Create an in-memory SurrealDB client for testing.
|
|
||||||
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
|
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
|
||||||
let db = connect("mem://").await?;
|
let db = connect("mem://").await?;
|
||||||
|
|
||||||
db.use_ns(namespace).use_db(database).await?;
|
db.use_ns(namespace).use_db(database).await?;
|
||||||
|
|
||||||
Ok(SurrealDbClient { client: db })
|
Ok(SurrealDbClient { client: db })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -253,8 +313,7 @@ mod tests {
|
|||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use crate::test_utils::setup_test_db;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
stored_object!(Dummy, "dummy", {
|
stored_object!(Dummy, "dummy", {
|
||||||
name: String
|
name: String
|
||||||
@@ -262,15 +321,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_initialization_and_crud() -> anyhow::Result<()> {
|
async fn test_initialization_and_crud() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
|
||||||
|
|
||||||
let dummy = Dummy {
|
let dummy = Dummy {
|
||||||
id: "abc".to_string(),
|
id: "abc".to_string(),
|
||||||
@@ -314,15 +365,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
|
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
|
||||||
|
|
||||||
let mut dummy = Dummy {
|
let mut dummy = Dummy {
|
||||||
id: "abc".to_string(),
|
id: "abc".to_string(),
|
||||||
@@ -371,12 +414,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_applying_migrations() -> anyhow::Result<()> {
|
async fn test_applying_migrations() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to build indexes".to_string())?;
|
.with_context(|| "Failed to build indexes".to_string())?;
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
use std::time::Duration;
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{Map, Value};
|
use serde_json::{Map, Value};
|
||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
use crate::{
|
||||||
|
error::AppError,
|
||||||
|
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||||
|
};
|
||||||
|
|
||||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||||
const INDEX_BUILD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
const INDEX_BUILD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
||||||
@@ -204,6 +208,9 @@ pub async fn ensure_runtime(
|
|||||||
|
|
||||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||||
///
|
///
|
||||||
|
/// Uses `DEFINE INDEX OVERWRITE` and is reserved for dimension migrations, re-embed
|
||||||
|
/// flows, and tests. Routine optimization should use [`rebuild_runtime`].
|
||||||
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns `AppError::InternalError` if any index rebuild operation fails.
|
/// Returns `AppError::InternalError` if any index rebuild operation fails.
|
||||||
@@ -211,6 +218,115 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
|||||||
rebuild_inner(db).await.map_err(AppError::internal)
|
rebuild_inner(db).await.map_err(AppError::internal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Rebuilds existing runtime FTS and HNSW indexes in place via SurrealQL `REBUILD INDEX`.
|
||||||
|
///
|
||||||
|
/// SurrealDB maintains ready indexes incrementally on writes; this is for periodic
|
||||||
|
/// optimization (for example a nightly maintainer job), not ingest correctness.
|
||||||
|
/// On SurrealDB 2.6 this runs synchronously (`CONCURRENTLY` is not supported on `REBUILD`).
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::InternalError` if any rebuild operation fails.
|
||||||
|
pub async fn rebuild_runtime(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||||
|
rebuild_runtime_inner(db).await.map_err(AppError::internal)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns whether a scheduled index rebuild is due based on the persisted last-run time.
|
||||||
|
#[must_use]
|
||||||
|
pub fn scheduled_index_rebuild_due(
|
||||||
|
last_run: Option<DateTime<Utc>>,
|
||||||
|
interval_secs: u64,
|
||||||
|
now: DateTime<Utc>,
|
||||||
|
) -> bool {
|
||||||
|
if interval_secs == 0 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(last_run) = last_run else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
let elapsed = now.signed_duration_since(last_run);
|
||||||
|
elapsed.num_seconds() >= i64::try_from(interval_secs).unwrap_or(i64::MAX)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs a scheduled native `REBUILD INDEX` pass when due, using a DB lock so only one
|
||||||
|
/// maintainer rebuilds at a time. Seeds a checkpoint on first run so the initial rebuild
|
||||||
|
/// waits one full interval after worker startup.
|
||||||
|
pub async fn maybe_run_scheduled_index_rebuild(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
worker_id: &str,
|
||||||
|
interval_secs: u64,
|
||||||
|
) {
|
||||||
|
if interval_secs == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = Utc::now();
|
||||||
|
let settings = match SystemSettings::get_current(db).await {
|
||||||
|
Ok(settings) => settings,
|
||||||
|
Err(err) => {
|
||||||
|
warn!(error = %err, "failed to load system settings for index rebuild schedule");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let last_run = settings.last_index_rebuild_at;
|
||||||
|
|
||||||
|
if last_run.is_none() {
|
||||||
|
match SystemSettings::seed_index_rebuild_checkpoint(db).await {
|
||||||
|
Ok(true) => debug!("seeded index rebuild checkpoint; first rebuild deferred"),
|
||||||
|
Ok(false) => {}
|
||||||
|
Err(err) => warn!(error = %err, "failed to seed index rebuild checkpoint"),
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scheduled_index_rebuild_due(last_run, interval_secs, now) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let lock_owner = format!("{worker_id}-index-rebuild");
|
||||||
|
let acquired = match SystemSettings::try_acquire_index_rebuild_lease(db, &lock_owner).await {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => {
|
||||||
|
warn!(error = %err, "failed to acquire index rebuild lease");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !acquired {
|
||||||
|
debug!("another maintainer is rebuilding indexes");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
info!(interval_secs, "starting scheduled runtime index rebuild");
|
||||||
|
let rebuild_result = rebuild_runtime(db).await;
|
||||||
|
|
||||||
|
match rebuild_result {
|
||||||
|
Ok(()) => {
|
||||||
|
if let Err(err) = SystemSettings::record_index_rebuild_completed(db, &lock_owner).await
|
||||||
|
{
|
||||||
|
warn!(error = %err, "failed to persist index rebuild checkpoint");
|
||||||
|
SystemSettings::release_index_rebuild_lease(db, &lock_owner).await;
|
||||||
|
}
|
||||||
|
info!(
|
||||||
|
elapsed_ms = started.elapsed().as_millis(),
|
||||||
|
"scheduled runtime index rebuild completed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
SystemSettings::release_index_rebuild_lease(db, &lock_owner).await;
|
||||||
|
error!(
|
||||||
|
error = %err,
|
||||||
|
elapsed_ms = started.elapsed().as_millis(),
|
||||||
|
"scheduled runtime index rebuild failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any.
|
/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any.
|
||||||
///
|
///
|
||||||
/// Stored embeddings always share this index's dimension because re-embedding rewrites the
|
/// Stored embeddings always share this index's dimension because re-embedding rewrites the
|
||||||
@@ -382,6 +498,45 @@ async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
|
|||||||
try_join_all(hnsw_tasks).await.map(|_| ())
|
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn rebuild_runtime_inner(db: &SurrealDbClient) -> Result<()> {
|
||||||
|
debug!("Rebuilding runtime indexes with REBUILD INDEX");
|
||||||
|
|
||||||
|
for spec in fts_index_specs() {
|
||||||
|
rebuild_existing_index_in_place(db, spec.index_name, spec.table).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||||
|
rebuild_existing_index_in_place(db, spec.index_name, spec.table).await
|
||||||
|
});
|
||||||
|
|
||||||
|
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn rebuild_existing_index_in_place(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
index_name: &str,
|
||||||
|
table: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
if !index_exists(db, table, index_name).await? {
|
||||||
|
debug!(
|
||||||
|
index = index_name,
|
||||||
|
table, "Skipping in-place rebuild because index is missing"
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let query = format!("REBUILD INDEX IF EXISTS {index_name} ON {table};");
|
||||||
|
let res = db
|
||||||
|
.client
|
||||||
|
.query(query)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("rebuilding index {index_name} on table {table}"))?;
|
||||||
|
res.check()
|
||||||
|
.with_context(|| format!("rebuild index {index_name} on table {table} failed"))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn existing_hnsw_dimension(
|
async fn existing_hnsw_dimension(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
spec: &HnswIndexSpec,
|
spec: &HnswIndexSpec,
|
||||||
@@ -906,6 +1061,43 @@ mod tests {
|
|||||||
assert_eq!(extract_dimension(definition), Some(1536));
|
assert_eq!(extract_dimension(definition), Some(1536));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scheduled_index_rebuild_due_respects_interval_and_disabled() {
|
||||||
|
let now = Utc::now();
|
||||||
|
let last = now - chrono::Duration::hours(25);
|
||||||
|
|
||||||
|
assert!(!scheduled_index_rebuild_due(None, 86_400, now));
|
||||||
|
assert!(!scheduled_index_rebuild_due(Some(last), 0, now));
|
||||||
|
assert!(!scheduled_index_rebuild_due(
|
||||||
|
Some(now - chrono::Duration::hours(1)),
|
||||||
|
86_400,
|
||||||
|
now
|
||||||
|
));
|
||||||
|
assert!(scheduled_index_rebuild_due(Some(last), 86_400, now));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn rebuild_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||||
|
let namespace = "indexes_in_place_rebuild";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.context("in-memory db")?;
|
||||||
|
|
||||||
|
db.apply_migrations().await.context("migrations")?;
|
||||||
|
ensure_runtime(&db, 8)
|
||||||
|
.await
|
||||||
|
.context("ensure runtime indexes")?;
|
||||||
|
|
||||||
|
rebuild_runtime(&db)
|
||||||
|
.await
|
||||||
|
.context("first in-place rebuild")?;
|
||||||
|
rebuild_runtime(&db)
|
||||||
|
.await
|
||||||
|
.context("second in-place rebuild")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
|
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||||
let namespace = "indexes_ns";
|
let namespace = "indexes_ns";
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ mod tests {
|
|||||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
|
use crate::test_utils::setup_test_db;
|
||||||
use anyhow::{self};
|
use anyhow::{self};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -120,10 +121,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_analytics_initialization() -> anyhow::Result<()> {
|
async fn test_analytics_initialization() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
// Test initialization of analytics
|
// Test initialization of analytics
|
||||||
let analytics = Analytics::ensure_initialized(&db).await?;
|
let analytics = Analytics::ensure_initialized(&db).await?;
|
||||||
|
|
||||||
@@ -145,10 +143,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_analytics() -> anyhow::Result<()> {
|
async fn test_get_current_analytics() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db).await?;
|
Analytics::ensure_initialized(&db).await?;
|
||||||
|
|
||||||
@@ -165,10 +160,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_increment_visitors() -> anyhow::Result<()> {
|
async fn test_increment_visitors() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db).await?;
|
Analytics::ensure_initialized(&db).await?;
|
||||||
|
|
||||||
@@ -190,10 +182,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_increment_page_loads() -> anyhow::Result<()> {
|
async fn test_increment_page_loads() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db).await?;
|
Analytics::ensure_initialized(&db).await?;
|
||||||
|
|
||||||
@@ -214,11 +203,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_users_amount() -> anyhow::Result<()> {
|
async fn test_get_users_amount() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
// Test with no users
|
// Test with no users
|
||||||
let count = Analytics::get_users_amount(&db).await?;
|
let count = Analytics::get_users_amount(&db).await?;
|
||||||
assert_eq!(count, 0);
|
assert_eq!(count, 0);
|
||||||
@@ -246,10 +231,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_increment_visitors_without_prior_init() -> anyhow::Result<()> {
|
async fn test_increment_visitors_without_prior_init() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
let analytics = Analytics::increment_visitors(&db).await?;
|
let analytics = Analytics::increment_visitors(&db).await?;
|
||||||
assert_eq!(analytics.visitors, 1);
|
assert_eq!(analytics.visitors, 1);
|
||||||
assert_eq!(analytics.page_loads, 0);
|
assert_eq!(analytics.page_loads, 0);
|
||||||
@@ -259,10 +241,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_increment_page_loads_without_prior_init() -> anyhow::Result<()> {
|
async fn test_increment_page_loads_without_prior_init() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
let analytics = Analytics::increment_page_loads(&db).await?;
|
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||||
assert_eq!(analytics.page_loads, 1);
|
assert_eq!(analytics.page_loads, 1);
|
||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
@@ -272,10 +251,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_visitor_and_page_load_increments_are_independent() -> anyhow::Result<()> {
|
async fn test_visitor_and_page_load_increments_are_independent() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
let after_visitors = Analytics::increment_visitors(&db).await?;
|
let after_visitors = Analytics::increment_visitors(&db).await?;
|
||||||
assert_eq!(after_visitors.visitors, 1);
|
assert_eq!(after_visitors.visitors, 1);
|
||||||
assert_eq!(after_visitors.page_loads, 0);
|
assert_eq!(after_visitors.page_loads, 0);
|
||||||
@@ -293,10 +269,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_record_page_view() -> anyhow::Result<()> {
|
async fn test_record_page_view() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
let first_view = Analytics::record_page_view(&db, true).await?;
|
let first_view = Analytics::record_page_view(&db, true).await?;
|
||||||
assert_eq!(first_view.visitors, 1);
|
assert_eq!(first_view.visitors, 1);
|
||||||
assert_eq!(first_view.page_loads, 1);
|
assert_eq!(first_view.page_loads, 1);
|
||||||
@@ -310,11 +283,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
// Don't initialize analytics and try to get it
|
// Don't initialize analytics and try to get it
|
||||||
let result = Analytics::get_current(&db).await;
|
let result = Analytics::get_current(&db).await;
|
||||||
|
|
||||||
|
|||||||
@@ -157,6 +157,7 @@ impl Conversation {
|
|||||||
mod tests {
|
mod tests {
|
||||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use crate::storage::types::message::MessageRole;
|
use crate::storage::types::message::MessageRole;
|
||||||
|
use crate::test_utils::setup_test_db;
|
||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -181,11 +182,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_conversation() -> anyhow::Result<()> {
|
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let title = "Test Conversation";
|
let title = "Test Conversation";
|
||||||
@@ -214,11 +211,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
|
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let result =
|
let result =
|
||||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||||
@@ -234,11 +227,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
|
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let user_id_1 = "user_1";
|
let user_id_1 = "user_1";
|
||||||
let conversation =
|
let conversation =
|
||||||
@@ -264,11 +253,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_success() -> anyhow::Result<()> {
|
async fn test_patch_title_success() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "user_1";
|
let user_id = "user_1";
|
||||||
let original_title = "Original Title";
|
let original_title = "Original Title";
|
||||||
@@ -297,11 +282,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_not_found() -> anyhow::Result<()> {
|
async fn test_patch_title_not_found() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
||||||
|
|
||||||
@@ -316,11 +297,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
|
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let owner_id = "owner";
|
let owner_id = "owner";
|
||||||
let other_user_id = "intruder";
|
let other_user_id = "intruder";
|
||||||
@@ -345,11 +322,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_user_sidebar_conversations_filters_and_orders_by_updated_at_desc() {
|
async fn test_get_user_sidebar_conversations_filters_and_orders_by_updated_at_desc() {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await.expect("setup_test_db");
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
let user_id = "sidebar_user";
|
let user_id = "sidebar_user";
|
||||||
let other_user_id = "other_user";
|
let other_user_id = "other_user";
|
||||||
@@ -398,11 +371,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_sidebar_projection_reflects_patch_title_and_updated_at_reorder() {
|
async fn test_sidebar_projection_reflects_patch_title_and_updated_at_reorder() {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await.expect("setup_test_db");
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
let user_id = "sidebar_patch_user";
|
let user_id = "sidebar_patch_user";
|
||||||
let base = Utc::now();
|
let base = Utc::now();
|
||||||
@@ -440,11 +409,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
|
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let user_id_1 = "user_1";
|
let user_id_1 = "user_1";
|
||||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||||
@@ -527,11 +492,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_sidebar_conversation_deserializes_id_from_db_record() {
|
async fn test_sidebar_conversation_deserializes_id_from_db_record() {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await.expect("setup_test_db");
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
let owner = "sidebar_owner";
|
let owner = "sidebar_owner";
|
||||||
let conversation = Conversation::new(owner.to_string(), "Sidebar title".to_string());
|
let conversation = Conversation::new(owner.to_string(), "Sidebar title".to_string());
|
||||||
@@ -551,9 +512,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_query_filters_by_owner_user_id_in_sql() -> anyhow::Result<()> {
|
async fn test_message_query_filters_by_owner_user_id_in_sql() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
let owner = "owner_user";
|
let owner = "owner_user";
|
||||||
let intruder = "intruder_user";
|
let intruder = "intruder_user";
|
||||||
@@ -590,9 +549,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_orders_messages_by_updated_at() -> anyhow::Result<()> {
|
async fn test_get_complete_conversation_orders_messages_by_updated_at() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
let user_id = "order_user";
|
let user_id = "order_user";
|
||||||
let conversation = Conversation::new(user_id.to_string(), "Ordered".to_string());
|
let conversation = Conversation::new(user_id.to_string(), "Ordered".to_string());
|
||||||
@@ -637,9 +594,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_not_found_when_conversation_deleted() -> anyhow::Result<()> {
|
async fn test_patch_title_not_found_when_conversation_deleted() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
let owner = "owner";
|
let owner = "owner";
|
||||||
let conversation = Conversation::new(owner.to_string(), "To delete".to_string());
|
let conversation = Conversation::new(owner.to_string(), "To delete".to_string());
|
||||||
|
|||||||
@@ -327,6 +327,7 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::store::testing::TestStorageManager;
|
use crate::storage::store::testing::TestStorageManager;
|
||||||
|
use crate::test_utils::setup_test_db;
|
||||||
use axum::http::HeaderMap;
|
use axum::http::HeaderMap;
|
||||||
use axum_typed_multipart::FieldMetadata;
|
use axum_typed_multipart::FieldMetadata;
|
||||||
use std::{io::Write, path::Path};
|
use std::{io::Write, path::Path};
|
||||||
@@ -378,15 +379,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
let content = b"This is a test file for StorageManager operations";
|
let content = b"This is a test file for StorageManager operations";
|
||||||
let file_name = "storage_manager_test.txt";
|
let file_name = "storage_manager_test.txt";
|
||||||
let field_data = create_test_file(content, file_name)?;
|
let field_data = create_test_file(content, file_name)?;
|
||||||
@@ -435,15 +428,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> {
|
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
let content = b"filename sanitization";
|
let content = b"filename sanitization";
|
||||||
let original_name = "Complex name (1).txt";
|
let original_name = "Complex name (1).txt";
|
||||||
let expected_sanitized = "Complex_name__1_.txt";
|
let expected_sanitized = "Complex_name__1_.txt";
|
||||||
@@ -470,15 +455,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> {
|
async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
let content = b"This is a test file for StorageManager duplicate detection";
|
let content = b"This is a test file for StorageManager duplicate detection";
|
||||||
let file_name = "storage_manager_duplicate.txt";
|
let file_name = "storage_manager_duplicate.txt";
|
||||||
let field_data = create_test_file(content, file_name)?;
|
let field_data = create_test_file(content, file_name)?;
|
||||||
@@ -538,15 +515,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_file_creation() -> anyhow::Result<()> {
|
async fn test_file_creation() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let content = b"This is a test file content";
|
let content = b"This is a test file content";
|
||||||
let file_name = "test_file.txt";
|
let file_name = "test_file.txt";
|
||||||
let field_data = create_test_file(content, file_name)?;
|
let field_data = create_test_file(content, file_name)?;
|
||||||
@@ -585,15 +554,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_file_duplicate_detection() -> anyhow::Result<()> {
|
async fn test_file_duplicate_detection() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
// First, store a file with known content
|
// First, store a file with known content
|
||||||
let content = b"This is a test file for duplicate detection";
|
let content = b"This is a test file for duplicate detection";
|
||||||
let file_name = "original.txt";
|
let file_name = "original.txt";
|
||||||
@@ -692,12 +653,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_by_sha_not_found() -> anyhow::Result<()> {
|
async fn test_get_by_sha_not_found() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", "user123", &db).await;
|
let result = FileInfo::get_by_sha("nonexistent_sha_hash", "user123", &db).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
@@ -710,12 +666,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_by_sha_resists_query_injection() {
|
async fn test_get_by_sha_resists_query_injection() {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await.expect("setup test db");
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
let file_info = FileInfo {
|
let file_info = FileInfo {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
@@ -740,15 +691,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_duplicate_detection_is_per_user() -> anyhow::Result<()> {
|
async fn test_duplicate_detection_is_per_user() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
let content = b"shared content across users";
|
let content = b"shared content across users";
|
||||||
let test_storage = TestStorageManager::new_memory()
|
let test_storage = TestStorageManager::new_memory()
|
||||||
.await
|
.await
|
||||||
@@ -783,10 +726,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_by_sha_not_found_for_other_user() -> anyhow::Result<()> {
|
async fn test_get_by_sha_not_found_for_other_user() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
|
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
let sha = "abc123sha";
|
let sha = "abc123sha";
|
||||||
let owner = "owner_user";
|
let owner = "owner_user";
|
||||||
@@ -816,9 +756,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_new_with_storage_missing_file_name() -> anyhow::Result<()> {
|
async fn test_new_with_storage_missing_file_name() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
let test_storage = TestStorageManager::new_memory().await?;
|
let test_storage = TestStorageManager::new_memory().await?;
|
||||||
|
|
||||||
let field_data = create_test_file_without_name(b"data")?;
|
let field_data = create_test_file_without_name(b"data")?;
|
||||||
@@ -832,9 +770,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_new_with_storage_empty_file() -> anyhow::Result<()> {
|
async fn test_new_with_storage_empty_file() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
let test_storage = TestStorageManager::new_memory().await?;
|
let test_storage = TestStorageManager::new_memory().await?;
|
||||||
|
|
||||||
let file_info = FileInfo::new_with_storage(
|
let file_info = FileInfo::new_with_storage(
|
||||||
@@ -856,10 +792,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_duplicate_upload_persists_single_row_per_user_sha() -> anyhow::Result<()> {
|
async fn test_duplicate_upload_persists_single_row_per_user_sha() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
|
||||||
db.apply_migrations().await?;
|
|
||||||
let test_storage = TestStorageManager::new_memory().await?;
|
let test_storage = TestStorageManager::new_memory().await?;
|
||||||
let storage = test_storage.storage();
|
let storage = test_storage.storage();
|
||||||
let user_id = "dedup_user";
|
let user_id = "dedup_user";
|
||||||
@@ -901,12 +834,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_manual_file_info_creation() {
|
async fn test_manual_file_info_creation() {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await.expect("setup test db");
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Create a FileInfo instance directly
|
// Create a FileInfo instance directly
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
let file_info = FileInfo {
|
let file_info = FileInfo {
|
||||||
@@ -939,15 +867,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_id() -> anyhow::Result<()> {
|
async fn test_delete_by_id() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
// Create and persist a test file via FileInfo::new_with_storage
|
// Create and persist a test file via FileInfo::new_with_storage
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let test_storage = TestStorageManager::new_memory()
|
let test_storage = TestStorageManager::new_memory()
|
||||||
@@ -985,12 +905,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_id_not_found() -> anyhow::Result<()> {
|
async fn test_delete_by_id_not_found() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
// Try to delete a file that doesn't exist
|
// Try to delete a file that doesn't exist
|
||||||
let test_storage = TestStorageManager::new_memory()
|
let test_storage = TestStorageManager::new_memory()
|
||||||
.await
|
.await
|
||||||
@@ -1006,12 +921,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_by_id() -> anyhow::Result<()> {
|
async fn test_get_by_id() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
// Create a FileInfo instance directly
|
// Create a FileInfo instance directly
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
let file_id = Uuid::new_v4().to_string();
|
let file_id = Uuid::new_v4().to_string();
|
||||||
@@ -1045,12 +955,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_by_id_not_found() -> anyhow::Result<()> {
|
async fn test_get_by_id_not_found() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
// Try to retrieve a non-existent ID
|
// Try to retrieve a non-existent ID
|
||||||
let non_existent_id = "non-existent-file-id";
|
let non_existent_id = "non-existent-file-id";
|
||||||
let result = FileInfo::get_by_id(non_existent_id, &db).await;
|
let result = FileInfo::get_by_id(non_existent_id, &db).await;
|
||||||
|
|||||||
@@ -630,6 +630,7 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||||
|
use crate::test_utils::setup_test_db;
|
||||||
|
|
||||||
fn create_payload(user_id: &str) -> IngestionPayload {
|
fn create_payload(user_id: &str) -> IngestionPayload {
|
||||||
IngestionPayload::Text {
|
IngestionPayload::Text {
|
||||||
@@ -641,11 +642,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn memory_db() -> anyhow::Result<SurrealDbClient> {
|
async fn memory_db() -> anyhow::Result<SurrealDbClient> {
|
||||||
let namespace = "test_ns";
|
setup_test_db().await
|
||||||
let database = Uuid::new_v4().to_string();
|
|
||||||
SurrealDbClient::memory(namespace, &database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "in-memory surrealdb".to_string())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@@ -4,10 +4,13 @@ use std::fmt::Write;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::db::SurrealDbClient,
|
storage::{
|
||||||
storage::indexes::hnsw_index_overwrite_sql,
|
db::SurrealDbClient,
|
||||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
indexes::hnsw_index_overwrite_sql,
|
||||||
storage::types::system_settings::SystemSettings,
|
types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||||
|
types::system_settings::SystemSettings,
|
||||||
|
types::{EmbeddingRecord, HasEmbedding},
|
||||||
|
},
|
||||||
stored_object,
|
stored_object,
|
||||||
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
|
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
|
||||||
};
|
};
|
||||||
@@ -70,6 +73,18 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
|
|||||||
user_id: String
|
user_id: String
|
||||||
});
|
});
|
||||||
|
|
||||||
|
impl HasEmbedding for KnowledgeEntity {
|
||||||
|
type Embedding = KnowledgeEntityEmbedding;
|
||||||
|
|
||||||
|
fn source_id(&self) -> &str {
|
||||||
|
&self.source_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn user_id(&self) -> &str {
|
||||||
|
&self.user_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl KnowledgeEntity {
|
impl KnowledgeEntity {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
@@ -227,67 +242,22 @@ impl KnowledgeEntity {
|
|||||||
|
|
||||||
pub async fn delete_by_source_id(
|
pub async fn delete_by_source_id(
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
db_client: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
// Delete embeddings first, while we can still look them up via the entity's source_id
|
db.delete_by_source_id::<Self>(source_id).await
|
||||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
|
|
||||||
|
|
||||||
db_client
|
|
||||||
.client
|
|
||||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
|
|
||||||
.bind(("table", Self::table_name()))
|
|
||||||
.bind(("source_id", source_id.to_owned()))
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?
|
|
||||||
.check()
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Atomically store a knowledge entity and its embedding.
|
/// Atomically store one knowledge entity and its embedding (single-record path).
|
||||||
/// Writes the entity to `knowledge_entity` and the embedding to `knowledge_entity_embedding`.
|
///
|
||||||
|
/// Bulk ingestion uses `ingestion_pipeline::persist_artifacts` instead.
|
||||||
pub async fn store_with_embedding(
|
pub async fn store_with_embedding(
|
||||||
entity: KnowledgeEntity,
|
entity: KnowledgeEntity,
|
||||||
embedding: Vec<f32>,
|
embedding: Vec<f32>,
|
||||||
|
embedding_dimensions: usize,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let settings = SystemSettings::get_current(db).await?;
|
db.store_with_embedding(entity, embedding, embedding_dimensions)
|
||||||
KnowledgeEntityEmbedding::validate_dimension(
|
|
||||||
&embedding,
|
|
||||||
settings.embedding_dimensions as usize,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let entity_id = entity.id.clone();
|
|
||||||
let emb = KnowledgeEntityEmbedding::new(
|
|
||||||
&entity_id,
|
|
||||||
entity.source_id.clone(),
|
|
||||||
embedding,
|
|
||||||
entity.user_id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let query = format!(
|
|
||||||
"
|
|
||||||
BEGIN TRANSACTION;
|
|
||||||
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
|
|
||||||
UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb;
|
|
||||||
COMMIT TRANSACTION;
|
|
||||||
",
|
|
||||||
entity_table = Self::table_name(),
|
|
||||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
|
||||||
);
|
|
||||||
|
|
||||||
db.client
|
|
||||||
.query(query)
|
|
||||||
.bind(("entity_id", entity_id))
|
|
||||||
.bind(("entity", entity))
|
|
||||||
.bind(("emb", emb))
|
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?
|
|
||||||
.check()
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
||||||
@@ -297,48 +267,14 @@ impl KnowledgeEntity {
|
|||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
||||||
#[derive(Deserialize)]
|
db.vector_search::<Self, KnowledgeEntityEmbedding>(take, query_embedding, user_id)
|
||||||
struct Row {
|
|
||||||
entity_id: Option<KnowledgeEntity>,
|
|
||||||
score: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
let sql = format!(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
entity_id,
|
|
||||||
vector::similarity::cosine(embedding, $embedding) AS score
|
|
||||||
FROM {emb_table}
|
|
||||||
WHERE user_id = $user_id
|
|
||||||
AND embedding <|{take},100|> $embedding
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT {take}
|
|
||||||
FETCH entity_id;
|
|
||||||
"#,
|
|
||||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
|
||||||
take = take
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut response = db
|
|
||||||
.query(&sql)
|
|
||||||
.bind(("embedding", query_embedding.to_vec()))
|
|
||||||
.bind(("user_id", user_id.to_string()))
|
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?;
|
.map(|results| {
|
||||||
|
results
|
||||||
response = response.check().map_err(AppError::from)?;
|
.into_iter()
|
||||||
|
.map(|(entity, score)| KnowledgeEntitySearchResult { entity, score })
|
||||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
|
.collect()
|
||||||
|
|
||||||
Ok(rows
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|r| {
|
|
||||||
r.entity_id.map(|entity| KnowledgeEntitySearchResult {
|
|
||||||
entity,
|
|
||||||
score: r.score,
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.collect())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn patch(
|
pub async fn patch(
|
||||||
@@ -364,7 +300,13 @@ impl KnowledgeEntity {
|
|||||||
settings.embedding_dimensions as usize,
|
settings.embedding_dimensions as usize,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let emb = KnowledgeEntityEmbedding::new(id, entity.source_id, embedding, entity.user_id);
|
let emb = KnowledgeEntityEmbedding::new(
|
||||||
|
id,
|
||||||
|
entity.source_id,
|
||||||
|
embedding,
|
||||||
|
entity.user_id,
|
||||||
|
Self::table_name(),
|
||||||
|
);
|
||||||
|
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
|
||||||
@@ -554,9 +496,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::indexes::rebuild;
|
use crate::storage::indexes::rebuild;
|
||||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||||
use crate::test_utils::configure_embedding_dimension;
|
use crate::test_utils::{ensure_fts_index, prepare_knowledge_entity_test_db, setup_test_db};
|
||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn embedding_input_text_uses_canonical_type_label() {
|
fn embedding_input_text_uses_canonical_type_label() {
|
||||||
@@ -568,27 +509,6 @@ mod tests {
|
|||||||
assert_eq!(text, "name: Alpha, description: Beta, type: TextSnippet");
|
assert_eq!(text, "name: Alpha, description: Beta, type: TextSnippet");
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> {
|
|
||||||
let snowball_sql = r#"
|
|
||||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name SEARCH ANALYZER app_en_fts_analyzer BM25;
|
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description SEARCH ANALYZER app_en_fts_analyzer BM25;
|
|
||||||
"#;
|
|
||||||
|
|
||||||
if let Err(err) = db.client.query(snowball_sql).await {
|
|
||||||
let fallback_sql = r#"
|
|
||||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name SEARCH ANALYZER app_en_fts_analyzer BM25;
|
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description SEARCH ANALYZER app_en_fts_analyzer BM25;
|
|
||||||
"#;
|
|
||||||
|
|
||||||
db.client
|
|
||||||
.query(fallback_sql)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("define entity fts index fallback: {err}"))?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -675,19 +595,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_knowledge_entity_test_db(5).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 5).await?;
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
|
||||||
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let entity_type = KnowledgeEntityType::Document;
|
let entity_type = KnowledgeEntityType::Document;
|
||||||
@@ -722,13 +630,13 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
|
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), 5, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store entity 1".to_string())?;
|
.with_context(|| "Failed to store entity 1".to_string())?;
|
||||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
|
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), 5, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store entity 2".to_string())?;
|
.with_context(|| "Failed to store entity 2".to_string())?;
|
||||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
|
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), 5, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store different entity".to_string())?;
|
.with_context(|| "Failed to store different entity".to_string())?;
|
||||||
|
|
||||||
@@ -783,21 +691,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_source_id_resists_query_injection() {
|
async fn test_delete_by_source_id_resists_query_injection() {
|
||||||
let namespace = "test_ns";
|
let db = prepare_knowledge_entity_test_db(3)
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.expect("prepare test db");
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.expect("Failed to apply migrations");
|
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 3)
|
|
||||||
.await
|
|
||||||
.expect("configure dim");
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.expect("Failed to redefine index length");
|
|
||||||
|
|
||||||
let user_id = "user123".to_string();
|
let user_id = "user123".to_string();
|
||||||
|
|
||||||
@@ -819,10 +715,10 @@ mod tests {
|
|||||||
user_id,
|
user_id,
|
||||||
);
|
);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db)
|
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], 3, &db)
|
||||||
.await
|
.await
|
||||||
.expect("store entity1");
|
.expect("store entity1");
|
||||||
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db)
|
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], 3, &db)
|
||||||
.await
|
.await
|
||||||
.expect("store entity2");
|
.expect("store entity2");
|
||||||
|
|
||||||
@@ -849,18 +745,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||||
let namespace = "test_ns";
|
let db = prepare_knowledge_entity_test_db(3)
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.expect("prepare test db");
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.expect("Failed to apply migrations");
|
|
||||||
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.expect("Failed to redefine index length");
|
|
||||||
|
|
||||||
let results = KnowledgeEntity::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
let results = KnowledgeEntity::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||||
.await
|
.await
|
||||||
@@ -870,19 +757,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "user".to_string();
|
let user_id = "user".to_string();
|
||||||
let source_id = "src".to_string();
|
let source_id = "src".to_string();
|
||||||
@@ -895,7 +770,7 @@ mod tests {
|
|||||||
user_id.clone(),
|
user_id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store entity with embedding".to_string())?;
|
.with_context(|| "store entity with embedding".to_string())?;
|
||||||
|
|
||||||
@@ -918,7 +793,7 @@ mod tests {
|
|||||||
assert_eq!(stored_embeddings.len(), 1);
|
assert_eq!(stored_embeddings.len(), 1);
|
||||||
|
|
||||||
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
|
let fetched_emb = KnowledgeEntityEmbedding::get_by_record_id(&db, &rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "fetch embedding".to_string())?;
|
.with_context(|| "fetch embedding".to_string())?;
|
||||||
assert!(fetched_emb.is_some());
|
assert!(fetched_emb.is_some());
|
||||||
@@ -938,19 +813,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "user".to_string();
|
let user_id = "user".to_string();
|
||||||
let e1 = KnowledgeEntity::new(
|
let e1 = KnowledgeEntity::new(
|
||||||
@@ -970,10 +833,10 @@ mod tests {
|
|||||||
user_id.clone(),
|
user_id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
|
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store e1".to_string())?;
|
.with_context(|| "store e1".to_string())?;
|
||||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
|
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store e2".to_string())?;
|
.with_context(|| "store e2".to_string())?;
|
||||||
|
|
||||||
@@ -1001,11 +864,11 @@ mod tests {
|
|||||||
|
|
||||||
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
||||||
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
||||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
|
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e1)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get embedding e1".to_string())?
|
.with_context(|| "get embedding e1".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
|
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e2)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get embedding e2".to_string())?
|
.with_context(|| "get embedding e2".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
@@ -1037,19 +900,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns_orphan";
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "user".to_string();
|
let user_id = "user".to_string();
|
||||||
let source_id = "src".to_string();
|
let source_id = "src".to_string();
|
||||||
@@ -1062,7 +913,7 @@ mod tests {
|
|||||||
user_id.clone(),
|
user_id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store entity with embedding".to_string())?;
|
.with_context(|| "store entity with embedding".to_string())?;
|
||||||
|
|
||||||
@@ -1089,15 +940,13 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fts_search_returns_empty_when_no_entities() -> anyhow::Result<()> {
|
async fn test_fts_search_returns_empty_when_no_entities() -> anyhow::Result<()> {
|
||||||
let namespace = "fts_entity_ns_empty";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
ensure_fts_index(
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
&db,
|
||||||
.await
|
"knowledge_entity",
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
&[("name", "name"), ("description", "description")],
|
||||||
db.apply_migrations()
|
)
|
||||||
.await
|
.await?;
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
ensure_entity_fts_indexes(&db).await?;
|
|
||||||
rebuild(&db)
|
rebuild(&db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "rebuild indexes".to_string())?;
|
.with_context(|| "rebuild indexes".to_string())?;
|
||||||
@@ -1112,15 +961,13 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
||||||
let namespace = "fts_entity_ns_single";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
ensure_fts_index(
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
&db,
|
||||||
.await
|
"knowledge_entity",
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
&[("name", "name"), ("description", "description")],
|
||||||
db.apply_migrations()
|
)
|
||||||
.await
|
.await?;
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
ensure_entity_fts_indexes(&db).await?;
|
|
||||||
|
|
||||||
let user_id = "fts_user";
|
let user_id = "fts_user";
|
||||||
let entity = KnowledgeEntity::new(
|
let entity = KnowledgeEntity::new(
|
||||||
@@ -1151,15 +998,13 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
||||||
let namespace = "fts_entity_ns_order";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
ensure_fts_index(
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
&db,
|
||||||
.await
|
"knowledge_entity",
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
&[("name", "name"), ("description", "description")],
|
||||||
db.apply_migrations()
|
)
|
||||||
.await
|
.await?;
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
ensure_entity_fts_indexes(&db).await?;
|
|
||||||
|
|
||||||
let user_id = "fts_user_order";
|
let user_id = "fts_user_order";
|
||||||
let high_score_entity = KnowledgeEntity::new(
|
let high_score_entity = KnowledgeEntity::new(
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use surrealdb::RecordId;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
storage::{db::SurrealDbClient, types::EmbeddingRecord},
|
||||||
stored_object,
|
stored_object,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -17,72 +17,48 @@ stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
|||||||
user_id: String
|
user_id: String
|
||||||
});
|
});
|
||||||
|
|
||||||
impl KnowledgeEntityEmbedding {
|
impl EmbeddingRecord for KnowledgeEntityEmbedding {
|
||||||
/// Recreate the HNSW index with a new embedding dimension.
|
fn link_field() -> &'static str {
|
||||||
pub async fn redefine_hnsw_index(
|
"entity_id"
|
||||||
db: &SurrealDbClient,
|
|
||||||
dimension: usize,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
let query = hnsw_index_redefine_transaction_sql(
|
|
||||||
"idx_embedding_knowledge_entity_embedding",
|
|
||||||
Self::table_name(),
|
|
||||||
dimension,
|
|
||||||
);
|
|
||||||
|
|
||||||
let res = db.client.query(query).await.map_err(AppError::from)?;
|
|
||||||
res.check().map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
fn index_name() -> &'static str {
|
||||||
#[allow(clippy::result_large_err)]
|
"idx_embedding_knowledge_entity_embedding"
|
||||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
|
||||||
if embedding.len() != expected {
|
|
||||||
return Err(AppError::Validation(format!(
|
|
||||||
"embedding dimension mismatch: got {}, expected {expected}",
|
|
||||||
embedding.len()
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new knowledge entity embedding.
|
fn source_id(&self) -> &str {
|
||||||
///
|
&self.source_id
|
||||||
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
|
}
|
||||||
#[must_use]
|
|
||||||
pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
fn user_id(&self) -> &str {
|
||||||
|
&self.user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embedding(&self) -> &[f32] {
|
||||||
|
&self.embedding
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(
|
||||||
|
entity_id: &str,
|
||||||
|
source_id: String,
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
user_id: String,
|
||||||
|
entity_table: &str,
|
||||||
|
) -> Self {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
Self {
|
Self {
|
||||||
id: entity_id.to_owned(),
|
id: entity_id.to_owned(),
|
||||||
created_at: now,
|
created_at: now,
|
||||||
updated_at: now,
|
updated_at: now,
|
||||||
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
entity_id: RecordId::from_table_key(entity_table, entity_id),
|
||||||
embedding,
|
embedding,
|
||||||
source_id,
|
source_id,
|
||||||
user_id,
|
user_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get embedding by entity ID
|
impl KnowledgeEntityEmbedding {
|
||||||
pub async fn get_by_entity_id(
|
|
||||||
entity_id: &RecordId,
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
) -> Result<Option<Self>, AppError> {
|
|
||||||
let query = format!(
|
|
||||||
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
|
|
||||||
Self::table_name()
|
|
||||||
);
|
|
||||||
let mut result = db
|
|
||||||
.client
|
|
||||||
.query(query)
|
|
||||||
.bind(("entity_id", entity_id.clone()))
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
|
||||||
Ok(embeddings.into_iter().next())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get embeddings for multiple entities in batch
|
/// Get embeddings for multiple entities in batch
|
||||||
pub async fn get_by_entity_ids(
|
pub async fn get_by_entity_ids(
|
||||||
entity_ids: &[RecordId],
|
entity_ids: &[RecordId],
|
||||||
@@ -109,44 +85,6 @@ impl KnowledgeEntityEmbedding {
|
|||||||
.map(|e| (e.entity_id.key().to_string(), e.embedding))
|
.map(|e| (e.entity_id.key().to_string(), e.embedding))
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Delete embedding by entity ID
|
|
||||||
pub async fn delete_by_entity_id(
|
|
||||||
entity_id: &RecordId,
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
let query = format!(
|
|
||||||
"DELETE FROM {} WHERE entity_id = $entity_id",
|
|
||||||
Self::table_name()
|
|
||||||
);
|
|
||||||
db.client
|
|
||||||
.query(query)
|
|
||||||
.bind(("entity_id", entity_id.clone()))
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?
|
|
||||||
.check()
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Delete all embeddings with the given denormalized `source_id`.
|
|
||||||
pub async fn delete_by_source_id(
|
|
||||||
source_id: &str,
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
let query = format!(
|
|
||||||
"DELETE FROM {} WHERE source_id = $source_id",
|
|
||||||
Self::table_name()
|
|
||||||
);
|
|
||||||
db.client
|
|
||||||
.query(query)
|
|
||||||
.bind(("source_id", source_id.to_owned()))
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?
|
|
||||||
.check()
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -184,6 +122,7 @@ mod tests {
|
|||||||
"source-1".to_owned(),
|
"source-1".to_owned(),
|
||||||
vec![0.1, 0.2],
|
vec![0.1, 0.2],
|
||||||
"user-1".to_owned(),
|
"user-1".to_owned(),
|
||||||
|
KnowledgeEntity::table_name(),
|
||||||
);
|
);
|
||||||
assert_eq!(emb.id, "entity-abc");
|
assert_eq!(emb.id, "entity-abc");
|
||||||
}
|
}
|
||||||
@@ -205,13 +144,13 @@ mod tests {
|
|||||||
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
|
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
|
||||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
|
||||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
let fetched = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||||
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||||
@@ -234,22 +173,22 @@ mod tests {
|
|||||||
|
|
||||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
|
||||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
let existing = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||||
assert!(existing.is_some());
|
assert!(existing.is_some());
|
||||||
|
|
||||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
KnowledgeEntityEmbedding::delete_by_record_id(&db, &entity_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
||||||
|
|
||||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
let after = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||||
assert!(after.is_none());
|
assert!(after.is_none());
|
||||||
@@ -266,7 +205,7 @@ mod tests {
|
|||||||
|
|
||||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
@@ -277,7 +216,7 @@ mod tests {
|
|||||||
assert!(stored_entity.is_some());
|
assert!(stored_entity.is_some());
|
||||||
|
|
||||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
let stored_embedding = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to fetch embedding".to_string())?;
|
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||||
let stored_embedding =
|
let stored_embedding =
|
||||||
@@ -295,7 +234,7 @@ mod tests {
|
|||||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
|
||||||
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||||
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
|
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], 3, &db).await;
|
||||||
|
|
||||||
assert!(matches!(result, Err(AppError::Validation(_))));
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
|
||||||
@@ -313,15 +252,20 @@ mod tests {
|
|||||||
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
|
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
|
||||||
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
|
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], 3, &db)
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
|
||||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
KnowledgeEntity::store_with_embedding(
|
||||||
|
entity_other.clone(),
|
||||||
|
vec![3.0_f32, 3.1, 3.2],
|
||||||
|
3,
|
||||||
|
&db,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||||
@@ -332,18 +276,18 @@ mod tests {
|
|||||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity1_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get entity1 embedding after delete".to_string())?
|
.with_context(|| "get entity1 embedding after delete".to_string())?
|
||||||
.is_none()
|
.is_none()
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity2_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get entity2 embedding after delete".to_string())?
|
.with_context(|| "get entity2 embedding after delete".to_string())?
|
||||||
.is_none()
|
.is_none()
|
||||||
);
|
);
|
||||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &other_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get other embedding after delete".to_string())?
|
.with_context(|| "get other embedding after delete".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
@@ -403,7 +347,7 @@ mod tests {
|
|||||||
let source_id = "source-fetch";
|
let source_id = "source-fetch";
|
||||||
|
|
||||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
@@ -441,7 +385,7 @@ mod tests {
|
|||||||
let source_id = "source-upsert";
|
let source_id = "source-upsert";
|
||||||
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
|
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "initial store".to_string())?;
|
.with_context(|| "initial store".to_string())?;
|
||||||
|
|
||||||
@@ -450,6 +394,7 @@ mod tests {
|
|||||||
source_id.to_owned(),
|
source_id.to_owned(),
|
||||||
vec![0.0, 1.0, 0.0],
|
vec![0.0, 1.0, 0.0],
|
||||||
user_id.to_owned(),
|
user_id.to_owned(),
|
||||||
|
KnowledgeEntity::table_name(),
|
||||||
);
|
);
|
||||||
db.upsert_item(replacement)
|
db.upsert_item(replacement)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ pub fn format_history(history: &[Message]) -> String {
|
|||||||
mod tests {
|
mod tests {
|
||||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::db::SurrealDbClient;
|
use crate::test_utils::setup_test_db;
|
||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -106,11 +106,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_persistence() -> anyhow::Result<()> {
|
async fn test_message_persistence() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &uuid::Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
let conversation_id = "test_conversation";
|
let conversation_id = "test_conversation";
|
||||||
let message = Message::new(
|
let message = Message::new(
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#![allow(clippy::unsafe_derive_deserialize)]
|
#![allow(clippy::unsafe_derive_deserialize)]
|
||||||
|
#![allow(async_fn_in_trait)]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
pub mod analytics;
|
pub mod analytics;
|
||||||
pub mod conversation;
|
pub mod conversation;
|
||||||
@@ -22,6 +23,135 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
|||||||
fn id(&self) -> &str;
|
fn id(&self) -> &str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An entity that has an associated embedding record for vector search.
|
||||||
|
pub trait HasEmbedding: StoredObject {
|
||||||
|
/// The embedding record type paired with this entity.
|
||||||
|
type Embedding: EmbeddingRecord;
|
||||||
|
|
||||||
|
fn source_id(&self) -> &str;
|
||||||
|
fn user_id(&self) -> &str;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An embedding record linked to a `HasEmbedding` entity.
|
||||||
|
pub trait EmbeddingRecord: StoredObject {
|
||||||
|
/// The field name in the embedding table that links back to the entity
|
||||||
|
/// (e.g. `"entity_id"` or `"chunk_id"`). Used in FETCH and WHERE clauses.
|
||||||
|
fn link_field() -> &'static str;
|
||||||
|
|
||||||
|
/// The HNSW index name (e.g. `"idx_embedding_knowledge_entity_embedding"`).
|
||||||
|
fn index_name() -> &'static str;
|
||||||
|
|
||||||
|
fn source_id(&self) -> &str;
|
||||||
|
fn user_id(&self) -> &str;
|
||||||
|
fn embedding(&self) -> &[f32];
|
||||||
|
|
||||||
|
/// Construct a new embedding record.
|
||||||
|
///
|
||||||
|
/// * `id` – shared record id (same as the entity id).
|
||||||
|
/// * `source_id` – denormalised source id for bulk deletes.
|
||||||
|
/// * `embedding` – the embedding vector.
|
||||||
|
/// * `user_id` – denormalised user id for query scoping.
|
||||||
|
/// * `entity_table` – the entity's table name (used to build the link `RecordId`).
|
||||||
|
fn new(
|
||||||
|
id: &str,
|
||||||
|
source_id: String,
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
user_id: String,
|
||||||
|
entity_table: &str,
|
||||||
|
) -> Self;
|
||||||
|
|
||||||
|
/// Validate that an embedding vector matches the expected dimension.
|
||||||
|
fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), crate::error::AppError>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
if embedding.len() != expected {
|
||||||
|
return Err(crate::error::AppError::Validation(format!(
|
||||||
|
"embedding dimension mismatch: got {}, expected {expected}",
|
||||||
|
embedding.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recreate the HNSW vector index with a new dimension.
|
||||||
|
///
|
||||||
|
/// This drops and recreates the index inside a transaction.
|
||||||
|
async fn redefine_hnsw_index(
|
||||||
|
db: &crate::storage::db::SurrealDbClient,
|
||||||
|
dimension: usize,
|
||||||
|
) -> Result<(), crate::error::AppError>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
let query = crate::storage::indexes::hnsw_index_redefine_transaction_sql(
|
||||||
|
Self::index_name(),
|
||||||
|
Self::table_name(),
|
||||||
|
dimension,
|
||||||
|
);
|
||||||
|
db.client.query(query).await?.check()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch a single embedding record by its link `RecordId`.
|
||||||
|
async fn get_by_record_id(
|
||||||
|
db: &crate::storage::db::SurrealDbClient,
|
||||||
|
rid: &surrealdb::RecordId,
|
||||||
|
) -> Result<Option<Self>, crate::error::AppError>
|
||||||
|
where
|
||||||
|
Self: Sized + serde::de::DeserializeOwned,
|
||||||
|
{
|
||||||
|
let query = format!(
|
||||||
|
"SELECT * FROM {} WHERE {} = $rid LIMIT 1",
|
||||||
|
Self::table_name(),
|
||||||
|
Self::link_field(),
|
||||||
|
);
|
||||||
|
let mut result = db.client.query(query).bind(("rid", rid.clone())).await?;
|
||||||
|
Ok(result.take(0)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete an embedding record by its link `RecordId`.
|
||||||
|
async fn delete_by_record_id(
|
||||||
|
db: &crate::storage::db::SurrealDbClient,
|
||||||
|
rid: &surrealdb::RecordId,
|
||||||
|
) -> Result<(), crate::error::AppError>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
let query = format!(
|
||||||
|
"DELETE FROM {} WHERE {} = $rid",
|
||||||
|
Self::table_name(),
|
||||||
|
Self::link_field(),
|
||||||
|
);
|
||||||
|
db.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("rid", rid.clone()))
|
||||||
|
.await?
|
||||||
|
.check()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete all embedding records with a given `source_id`.
|
||||||
|
async fn delete_by_source_id(
|
||||||
|
source_id: &str,
|
||||||
|
db: &crate::storage::db::SurrealDbClient,
|
||||||
|
) -> Result<(), crate::error::AppError>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
let query = format!(
|
||||||
|
"DELETE FROM {} WHERE source_id = $source_id",
|
||||||
|
Self::table_name(),
|
||||||
|
);
|
||||||
|
db.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("source_id", source_id.to_owned()))
|
||||||
|
.await?
|
||||||
|
.check()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! stored_object {
|
macro_rules! stored_object {
|
||||||
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||||
|
|||||||
@@ -221,19 +221,11 @@ mod tests {
|
|||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::test_utils::setup_test_db;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_scratchpad() -> anyhow::Result<()> {
|
async fn test_create_scratchpad() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
// Create a new scratchpad
|
// Create a new scratchpad
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
@@ -271,15 +263,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_by_user() -> anyhow::Result<()> {
|
async fn test_get_by_user() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
|
|
||||||
@@ -333,15 +317,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_archive_and_restore() -> anyhow::Result<()> {
|
async fn test_archive_and_restore() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
@@ -368,15 +344,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_content() -> anyhow::Result<()> {
|
async fn test_update_content() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
@@ -398,15 +366,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
|
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let owner_id = "owner";
|
let owner_id = "owner";
|
||||||
let other_user = "other_user";
|
let other_user = "other_user";
|
||||||
@@ -428,15 +388,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_scratchpad() -> anyhow::Result<()> {
|
async fn test_delete_scratchpad() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
@@ -461,15 +413,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_unauthorized() -> anyhow::Result<()> {
|
async fn test_delete_unauthorized() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let owner_id = "owner";
|
let owner_id = "owner";
|
||||||
let other_user = "other_user";
|
let other_user = "other_user";
|
||||||
@@ -498,13 +442,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
|
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to create test database".to_string())?;
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "test_user_123";
|
let user_id = "test_user_123";
|
||||||
let scratchpad =
|
let scratchpad =
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use crate::utils::config::EmbeddingBackend;
|
use crate::utils::config::EmbeddingBackend;
|
||||||
use crate::utils::serde_helpers::deserialize_flexible_id;
|
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -22,6 +25,15 @@ pub struct SystemSettings {
|
|||||||
pub image_processing_model: String,
|
pub image_processing_model: String,
|
||||||
pub image_processing_prompt: String,
|
pub image_processing_prompt: String,
|
||||||
pub voice_processing_model: String,
|
pub voice_processing_model: String,
|
||||||
|
/// When the maintainer last completed a scheduled `REBUILD INDEX` pass.
|
||||||
|
#[serde(default)]
|
||||||
|
pub last_index_rebuild_at: Option<DateTime<Utc>>,
|
||||||
|
/// Worker id holding the index-rebuild lease, if any.
|
||||||
|
#[serde(default)]
|
||||||
|
pub index_rebuild_lease_owner: Option<String>,
|
||||||
|
/// Lease expiry for in-flight scheduled index rebuilds.
|
||||||
|
#[serde(default)]
|
||||||
|
pub index_rebuild_lease_expires_at: Option<DateTime<Utc>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Partial update for singleton system settings without cloning unchanged fields.
|
/// Partial update for singleton system settings without cloning unchanged fields.
|
||||||
@@ -100,6 +112,8 @@ impl SystemSettingsPatch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const INDEX_REBUILD_LEASE_TTL: &str = "6h";
|
||||||
|
|
||||||
impl SystemSettings {
|
impl SystemSettings {
|
||||||
pub const RECORD_ID: &'static str = "current";
|
pub const RECORD_ID: &'static str = "current";
|
||||||
|
|
||||||
@@ -227,6 +241,89 @@ impl SystemSettings {
|
|||||||
|
|
||||||
Ok((settings, needs_update))
|
Ok((settings, needs_update))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Seeds the first rebuild checkpoint so the initial scheduled rebuild waits one interval.
|
||||||
|
pub async fn seed_index_rebuild_checkpoint(db: &SurrealDbClient) -> Result<bool, AppError> {
|
||||||
|
let mut response = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"UPDATE type::thing('system_settings', $id) SET last_index_rebuild_at = time::now()
|
||||||
|
WHERE last_index_rebuild_at IS NONE
|
||||||
|
RETURN AFTER;",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
let updated: Option<Self> = response.take(0).map_err(AppError::from)?;
|
||||||
|
Ok(updated.is_some())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Claims the singleton index-rebuild lease when it is free or expired.
|
||||||
|
pub async fn try_acquire_index_rebuild_lease(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
owner: &str,
|
||||||
|
) -> Result<bool, AppError> {
|
||||||
|
let mut response = db
|
||||||
|
.client
|
||||||
|
.query(format!(
|
||||||
|
"UPDATE type::thing('system_settings', $id) SET
|
||||||
|
index_rebuild_lease_owner = $owner,
|
||||||
|
index_rebuild_lease_expires_at = time::now() + {INDEX_REBUILD_LEASE_TTL}
|
||||||
|
WHERE index_rebuild_lease_expires_at IS NONE
|
||||||
|
OR index_rebuild_lease_expires_at < time::now()
|
||||||
|
RETURN AFTER;"
|
||||||
|
))
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
|
.bind(("owner", owner.to_string()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
let updated: Option<Self> = response.take(0).map_err(AppError::from)?;
|
||||||
|
Ok(updated.is_some())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Releases the index-rebuild lease when held by `owner`.
|
||||||
|
pub async fn release_index_rebuild_lease(db: &SurrealDbClient, owner: &str) {
|
||||||
|
let released = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"UPDATE type::thing('system_settings', $id) SET
|
||||||
|
index_rebuild_lease_owner = NONE,
|
||||||
|
index_rebuild_lease_expires_at = NONE
|
||||||
|
WHERE index_rebuild_lease_owner = $owner;",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
|
.bind(("owner", owner.to_string()))
|
||||||
|
.await
|
||||||
|
.and_then(surrealdb::Response::check);
|
||||||
|
|
||||||
|
if let Err(err) = released {
|
||||||
|
warn!(error = %err, "failed to release index rebuild lease");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Records a completed scheduled index rebuild and clears the lease.
|
||||||
|
pub async fn record_index_rebuild_completed(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
owner: &str,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let response = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"UPDATE type::thing('system_settings', $id) SET
|
||||||
|
last_index_rebuild_at = time::now(),
|
||||||
|
index_rebuild_lease_owner = NONE,
|
||||||
|
index_rebuild_lease_expires_at = NONE
|
||||||
|
WHERE index_rebuild_lease_owner = $owner;",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
|
.bind(("owner", owner.to_string()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
response.check().map_err(AppError::from)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -237,6 +334,7 @@ mod tests {
|
|||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::test_utils::setup_test_db;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
async fn get_hnsw_index_dimension(
|
async fn get_hnsw_index_dimension(
|
||||||
@@ -320,17 +418,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_settings_initialization() -> anyhow::Result<()> {
|
async fn test_settings_initialization() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
// Test initialization of system settings
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
let settings = SystemSettings::get_current(&db)
|
let settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to get system settings".to_string())?;
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
@@ -367,19 +455,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_settings() -> anyhow::Result<()> {
|
async fn test_get_current_settings() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
// Initialize settings
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
// Test get_current method
|
|
||||||
let settings = SystemSettings::get_current(&db)
|
let settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to get current settings".to_string())?;
|
.with_context(|| "Failed to get current settings".to_string())?;
|
||||||
@@ -392,17 +469,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_settings() -> anyhow::Result<()> {
|
async fn test_update_settings() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
// Initialize settings
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
// Create updated settings
|
// Create updated settings
|
||||||
let mut updated_settings = SystemSettings::get_current(&db)
|
let mut updated_settings = SystemSettings::get_current(&db)
|
||||||
@@ -435,13 +502,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
|
|
||||||
// Don't initialize settings and try to get them
|
// Don't initialize settings and try to get them
|
||||||
let result = SystemSettings::get_current(&db).await;
|
let result = SystemSettings::get_current(&db).await;
|
||||||
|
|
||||||
@@ -458,12 +519,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_rejects_zero_embedding_dimensions() -> anyhow::Result<()> {
|
async fn test_update_rejects_zero_embedding_dimensions() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let mut invalid_settings = SystemSettings::get_current(&db)
|
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
@@ -477,12 +533,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_updates_without_cloning_full_settings() -> anyhow::Result<()> {
|
async fn test_patch_updates_without_cloning_full_settings() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let updated = SystemSettingsPatch {
|
let updated = SystemSettingsPatch {
|
||||||
registrations_enabled: Some(false),
|
registrations_enabled: Some(false),
|
||||||
@@ -498,12 +549,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_leaves_unmentioned_fields_unchanged() -> anyhow::Result<()> {
|
async fn test_patch_leaves_unmentioned_fields_unchanged() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let original = SystemSettings::get_current(&db)
|
let original = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
@@ -533,12 +579,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_rejects_empty_model_name() -> anyhow::Result<()> {
|
async fn test_update_rejects_empty_model_name() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let mut invalid_settings = SystemSettings::get_current(&db)
|
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
@@ -552,12 +593,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_normalizes_record_id() -> anyhow::Result<()> {
|
async fn test_update_normalizes_record_id() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let mut settings = SystemSettings::get_current(&db)
|
let mut settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
@@ -575,12 +611,7 @@ mod tests {
|
|||||||
async fn test_update_preserves_embedding_backend() -> anyhow::Result<()> {
|
async fn test_update_preserves_embedding_backend() -> anyhow::Result<()> {
|
||||||
use crate::utils::embedding::EmbeddingProvider;
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let provider = EmbeddingProvider::new_hashed(384)
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
@@ -607,12 +638,7 @@ mod tests {
|
|||||||
async fn test_sync_from_embedding_provider_updates_mismatched_settings() -> anyhow::Result<()> {
|
async fn test_sync_from_embedding_provider_updates_mismatched_settings() -> anyhow::Result<()> {
|
||||||
use crate::utils::embedding::EmbeddingProvider;
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let provider = EmbeddingProvider::new_hashed(384)
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
@@ -636,12 +662,7 @@ mod tests {
|
|||||||
async fn test_sync_from_embedding_provider_is_noop_when_already_synced() -> anyhow::Result<()> {
|
async fn test_sync_from_embedding_provider_is_noop_when_already_synced() -> anyhow::Result<()> {
|
||||||
use crate::utils::embedding::EmbeddingProvider;
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let provider = EmbeddingProvider::new_hashed(384)
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
@@ -660,12 +681,7 @@ mod tests {
|
|||||||
async fn test_sync_rejects_provider_dimension_above_u32_max() -> anyhow::Result<()> {
|
async fn test_sync_rejects_provider_dimension_above_u32_max() -> anyhow::Result<()> {
|
||||||
use crate::utils::embedding::EmbeddingProvider;
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
|
||||||
|
|
||||||
let provider = EmbeddingProvider::new_hashed((u32::MAX as usize) + 1)
|
let provider = EmbeddingProvider::new_hashed((u32::MAX as usize) + 1)
|
||||||
.with_context(|| "Failed to create oversized hashed provider".to_string())?;
|
.with_context(|| "Failed to create oversized hashed provider".to_string())?;
|
||||||
@@ -676,14 +692,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
|
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start DB".to_string())?;
|
|
||||||
|
|
||||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Initial migration failed".to_string())?;
|
|
||||||
|
|
||||||
let initial_chunk = TextChunk::new(
|
let initial_chunk = TextChunk::new(
|
||||||
"source1".into(),
|
"source1".into(),
|
||||||
@@ -691,7 +700,7 @@ mod tests {
|
|||||||
"user1".into(),
|
"user1".into(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], 1536, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
|
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
|
||||||
|
|
||||||
@@ -714,14 +723,7 @@ mod tests {
|
|||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
use crate::utils::embedding::EmbeddingProvider;
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
let db = setup_test_db().await?;
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start DB".to_string())?;
|
|
||||||
|
|
||||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "Initial migration failed".to_string())?;
|
|
||||||
|
|
||||||
let mut current_settings = SystemSettings::get_current(&db)
|
let mut current_settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
@@ -802,4 +804,28 @@ mod tests {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn index_rebuild_lease_is_exclusive_on_system_settings() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a").await?,
|
||||||
|
"first lease claim should succeed"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?,
|
||||||
|
"second lease claim should fail while lease is held"
|
||||||
|
);
|
||||||
|
|
||||||
|
SystemSettings::release_index_rebuild_lease(&db, "worker-a").await;
|
||||||
|
|
||||||
|
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?;
|
||||||
|
SystemSettings::record_index_rebuild_completed(&db, "worker-b").await?;
|
||||||
|
|
||||||
|
let settings = SystemSettings::get_current(&db).await?;
|
||||||
|
assert!(settings.last_index_rebuild_at.is_some());
|
||||||
|
assert!(settings.index_rebuild_lease_owner.is_none());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,13 @@ use std::collections::HashMap;
|
|||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
||||||
use crate::storage::indexes::hnsw_index_overwrite_sql;
|
use crate::storage::indexes::hnsw_index_overwrite_sql;
|
||||||
use crate::storage::types::system_settings::SystemSettings;
|
use crate::storage::types::{
|
||||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord, HasEmbedding,
|
||||||
|
};
|
||||||
use crate::utils::embedding::RE_EMBED_BATCH_SIZE;
|
use crate::utils::embedding::RE_EMBED_BATCH_SIZE;
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||||
|
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
stored_object!(TextChunk, "text_chunk", {
|
stored_object!(TextChunk, "text_chunk", {
|
||||||
@@ -25,6 +26,18 @@ pub struct TextChunkSearchResult {
|
|||||||
pub score: f32,
|
pub score: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl HasEmbedding for TextChunk {
|
||||||
|
type Embedding = TextChunkEmbedding;
|
||||||
|
|
||||||
|
fn source_id(&self) -> &str {
|
||||||
|
&self.source_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn user_id(&self) -> &str {
|
||||||
|
&self.user_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl TextChunk {
|
impl TextChunk {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
|
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
|
||||||
@@ -41,123 +54,39 @@ impl TextChunk {
|
|||||||
|
|
||||||
pub async fn delete_by_source_id(
|
pub async fn delete_by_source_id(
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
db_client: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
db_client
|
db.delete_by_source_id::<Self>(source_id).await
|
||||||
.client
|
|
||||||
.query("BEGIN TRANSACTION;")
|
|
||||||
.query(format!(
|
|
||||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
|
||||||
TextChunkEmbedding::table_name()
|
|
||||||
))
|
|
||||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id;")
|
|
||||||
.query("COMMIT TRANSACTION;")
|
|
||||||
.bind(("source_id", source_id.to_owned()))
|
|
||||||
.bind(("table", Self::table_name()))
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?
|
|
||||||
.check()
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Atomically store a text chunk and its embedding.
|
/// Atomically store one text chunk and its embedding (single-record path).
|
||||||
/// Writes the chunk to `text_chunk` and the embedding to `text_chunk_embedding`.
|
///
|
||||||
|
/// Bulk ingestion uses `ingestion_pipeline::persist_artifacts` instead.
|
||||||
pub async fn store_with_embedding(
|
pub async fn store_with_embedding(
|
||||||
chunk: TextChunk,
|
chunk: TextChunk,
|
||||||
embedding: Vec<f32>,
|
embedding: Vec<f32>,
|
||||||
|
embedding_dimensions: usize,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let settings = SystemSettings::get_current(db).await?;
|
db.store_with_embedding(chunk, embedding, embedding_dimensions)
|
||||||
TextChunkEmbedding::validate_dimension(&embedding, settings.embedding_dimensions as usize)?;
|
|
||||||
|
|
||||||
let chunk_id = chunk.id.clone();
|
|
||||||
let emb = TextChunkEmbedding::new(
|
|
||||||
&chunk_id,
|
|
||||||
chunk.source_id.clone(),
|
|
||||||
embedding,
|
|
||||||
chunk.user_id.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let query = format!(
|
|
||||||
"
|
|
||||||
BEGIN TRANSACTION;
|
|
||||||
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
|
|
||||||
UPSERT type::thing('{emb_table}', $chunk_id) CONTENT $emb;
|
|
||||||
COMMIT TRANSACTION;
|
|
||||||
",
|
|
||||||
chunk_table = Self::table_name(),
|
|
||||||
emb_table = TextChunkEmbedding::table_name(),
|
|
||||||
);
|
|
||||||
|
|
||||||
db.client
|
|
||||||
.query(query)
|
|
||||||
.bind(("chunk_id", chunk_id))
|
|
||||||
.bind(("chunk", chunk))
|
|
||||||
.bind(("emb", emb))
|
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?
|
|
||||||
.check()
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
|
/// Vector search over text chunks using the embedding table, fetching full chunk rows and scores.
|
||||||
pub async fn vector_search(
|
pub async fn vector_search(
|
||||||
take: usize,
|
take: usize,
|
||||||
query_embedding: &[f32],
|
query_embedding: &[f32],
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||||
#[allow(clippy::missing_docs_in_private_items)]
|
db.vector_search::<Self, TextChunkEmbedding>(take, query_embedding, user_id)
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct Row {
|
|
||||||
chunk_id: Option<TextChunk>,
|
|
||||||
score: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
let sql = format!(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
chunk_id,
|
|
||||||
embedding,
|
|
||||||
vector::similarity::cosine(embedding, $embedding) AS score
|
|
||||||
FROM {emb_table}
|
|
||||||
WHERE user_id = $user_id
|
|
||||||
AND embedding <|{take},100|> $embedding
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT {take}
|
|
||||||
FETCH chunk_id;
|
|
||||||
"#,
|
|
||||||
emb_table = TextChunkEmbedding::table_name(),
|
|
||||||
take = take
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut response = db
|
|
||||||
.query(&sql)
|
|
||||||
.bind(("embedding", query_embedding.to_vec()))
|
|
||||||
.bind(("user_id", user_id.to_string()))
|
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?;
|
.map(|results| {
|
||||||
|
results
|
||||||
response = response.check().map_err(AppError::from)?;
|
.into_iter()
|
||||||
|
.map(|(chunk, score)| TextChunkSearchResult { chunk, score })
|
||||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
|
.collect()
|
||||||
|
|
||||||
Ok(rows
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|r| {
|
|
||||||
r.chunk_id.map(|chunk| TextChunkSearchResult {
|
|
||||||
chunk,
|
|
||||||
score: r.score,
|
|
||||||
}).or_else(|| {
|
|
||||||
warn!("vector search hit orphaned text_chunk_embedding row with missing chunk");
|
|
||||||
None
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.collect())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Full-text search over text chunks using the BM25 FTS index.
|
/// Full-text search over text chunks using the BM25 FTS index.
|
||||||
@@ -393,29 +322,10 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
use crate::storage::indexes::{ensure_runtime, rebuild};
|
||||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||||
use crate::test_utils::configure_embedding_dimension;
|
use crate::test_utils::{
|
||||||
|
configure_embedding_dimension, ensure_fts_index, prepare_text_chunk_test_db, setup_test_db,
|
||||||
|
};
|
||||||
use surrealdb::RecordId;
|
use surrealdb::RecordId;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
|
|
||||||
let snowball_sql = r#"
|
|
||||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
|
||||||
"#;
|
|
||||||
|
|
||||||
if let Err(err) = db.client.query(snowball_sql).await {
|
|
||||||
let fallback_sql = r#"
|
|
||||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
|
||||||
"#;
|
|
||||||
|
|
||||||
db.client
|
|
||||||
.query(fallback_sql)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("define chunk fts index fallback: {err}"))?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_chunk_creation() -> anyhow::Result<()> {
|
async fn test_text_chunk_creation() -> anyhow::Result<()> {
|
||||||
@@ -434,21 +344,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_text_chunk_test_db(5).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let user_id = "user123".to_string();
|
let user_id = "user123".to_string();
|
||||||
configure_embedding_dimension(&db, 5).await?;
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
|
||||||
.await
|
|
||||||
.with_context(|| "redefine index".to_string())?;
|
|
||||||
|
|
||||||
let chunk1 = TextChunk::new(
|
let chunk1 = TextChunk::new(
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
@@ -466,15 +364,16 @@ mod tests {
|
|||||||
user_id.clone(),
|
user_id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store chunk1".to_string())?;
|
.with_context(|| "store chunk1".to_string())?;
|
||||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store chunk2".to_string())?;
|
.with_context(|| "store chunk2".to_string())?;
|
||||||
TextChunk::store_with_embedding(
|
TextChunk::store_with_embedding(
|
||||||
different_chunk.clone(),
|
different_chunk.clone(),
|
||||||
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
|
5,
|
||||||
&db,
|
&db,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -516,18 +415,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> {
|
async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_text_chunk_test_db(5).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
configure_embedding_dimension(&db, 5).await?;
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
|
||||||
.await
|
|
||||||
.with_context(|| "redefine index".to_string())?;
|
|
||||||
|
|
||||||
let real_source_id = "real_source".to_string();
|
let real_source_id = "real_source".to_string();
|
||||||
let chunk = TextChunk::new(
|
let chunk = TextChunk::new(
|
||||||
@@ -536,7 +424,7 @@ mod tests {
|
|||||||
"user123".to_string(),
|
"user123".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store chunk".to_string())?;
|
.with_context(|| "store chunk".to_string())?;
|
||||||
|
|
||||||
@@ -560,18 +448,9 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_source_id_resists_query_injection() {
|
async fn test_delete_by_source_id_resists_query_injection() {
|
||||||
let namespace = "test_ns";
|
let db = prepare_text_chunk_test_db(5)
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.expect("prepare test db");
|
||||||
db.apply_migrations().await.expect("migrations");
|
|
||||||
configure_embedding_dimension(&db, 5)
|
|
||||||
.await
|
|
||||||
.expect("configure dim");
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
|
||||||
.await
|
|
||||||
.expect("redefine index");
|
|
||||||
|
|
||||||
let chunk1 = TextChunk::new(
|
let chunk1 = TextChunk::new(
|
||||||
"safe_source".to_string(),
|
"safe_source".to_string(),
|
||||||
@@ -584,10 +463,10 @@ mod tests {
|
|||||||
"user123".to_string(),
|
"user123".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||||
.await
|
.await
|
||||||
.expect("store chunk1");
|
.expect("store chunk1");
|
||||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db)
|
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], 5, &db)
|
||||||
.await
|
.await
|
||||||
.expect("store chunk2");
|
.expect("store chunk2");
|
||||||
|
|
||||||
@@ -614,25 +493,12 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> {
|
async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
let source_id = "store-src".to_string();
|
let source_id = "store-src".to_string();
|
||||||
let user_id = "user_store".to_string();
|
let user_id = "user_store".to_string();
|
||||||
let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone());
|
let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone());
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.with_context(|| "redefine index".to_string())?;
|
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store with embedding".to_string())?;
|
.with_context(|| "store with embedding".to_string())?;
|
||||||
|
|
||||||
@@ -645,7 +511,7 @@ mod tests {
|
|||||||
assert_eq!(stored_chunk.user_id, user_id);
|
assert_eq!(stored_chunk.user_id, user_id);
|
||||||
|
|
||||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get embedding".to_string())?
|
.with_context(|| "get embedding".to_string())?
|
||||||
.with_context(|| "expected embedding".to_string())?;
|
.with_context(|| "expected embedding".to_string())?;
|
||||||
@@ -658,14 +524,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> {
|
async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns_runtime";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
let embedding_dimension = 3usize;
|
let embedding_dimension = 3usize;
|
||||||
configure_embedding_dimension(
|
configure_embedding_dimension(
|
||||||
@@ -683,7 +542,7 @@ mod tests {
|
|||||||
"runtime_user".to_string(),
|
"runtime_user".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store with embedding".to_string())?;
|
.with_context(|| "store with embedding".to_string())?;
|
||||||
|
|
||||||
@@ -695,7 +554,7 @@ mod tests {
|
|||||||
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
|
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
|
||||||
|
|
||||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get embedding".to_string())?
|
.with_context(|| "get embedding".to_string())?
|
||||||
.with_context(|| "embedding should exist".to_string())?;
|
.with_context(|| "embedding should exist".to_string())?;
|
||||||
@@ -709,19 +568,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> {
|
async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.with_context(|| "redefine index".to_string())?;
|
|
||||||
|
|
||||||
let results: Vec<TextChunkSearchResult> =
|
let results: Vec<TextChunkSearchResult> =
|
||||||
TextChunk::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
TextChunk::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||||
@@ -733,19 +580,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.with_context(|| "redefine index".to_string())?;
|
|
||||||
|
|
||||||
let source_id = "src".to_string();
|
let source_id = "src".to_string();
|
||||||
let user_id = "user".to_string();
|
let user_id = "user".to_string();
|
||||||
@@ -755,7 +590,7 @@ mod tests {
|
|||||||
user_id.clone(),
|
user_id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store".to_string())?;
|
.with_context(|| "store".to_string())?;
|
||||||
|
|
||||||
@@ -774,28 +609,16 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.with_context(|| "redefine index".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "user".to_string();
|
let user_id = "user".to_string();
|
||||||
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
|
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
|
||||||
let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone());
|
let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone());
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
|
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store chunk1".to_string())?;
|
.with_context(|| "store chunk1".to_string())?;
|
||||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
|
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store chunk2".to_string())?;
|
.with_context(|| "store chunk2".to_string())?;
|
||||||
|
|
||||||
@@ -815,15 +638,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> {
|
async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> {
|
||||||
let namespace = "fts_chunk_ns_empty";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
ensure_chunk_fts_index(&db).await?;
|
|
||||||
rebuild(&db)
|
rebuild(&db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "rebuild indexes".to_string())?;
|
.with_context(|| "rebuild indexes".to_string())?;
|
||||||
@@ -838,15 +654,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
||||||
let namespace = "fts_chunk_ns_single";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
ensure_chunk_fts_index(&db).await?;
|
|
||||||
|
|
||||||
let user_id = "fts_user";
|
let user_id = "fts_user";
|
||||||
let chunk = TextChunk::new(
|
let chunk = TextChunk::new(
|
||||||
@@ -874,15 +683,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
||||||
let namespace = "fts_chunk_ns_order";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
ensure_chunk_fts_index(&db).await?;
|
|
||||||
|
|
||||||
let user_id = "fts_user_order";
|
let user_id = "fts_user_order";
|
||||||
let high_score_chunk = TextChunk::new(
|
let high_score_chunk = TextChunk::new(
|
||||||
@@ -936,19 +738,12 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns_dim";
|
let db = setup_test_db().await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
configure_embedding_dimension(&db, 3).await?;
|
||||||
|
|
||||||
let chunk = TextChunk::new("src".to_string(), "body".to_string(), "user".to_string());
|
let chunk = TextChunk::new("src".to_string(), "body".to_string(), "user".to_string());
|
||||||
|
|
||||||
let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], &db)
|
let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], 3, &db)
|
||||||
.await
|
.await
|
||||||
.expect_err("expected dimension validation failure");
|
.expect_err("expected dimension validation failure");
|
||||||
assert!(matches!(err, AppError::Validation(_)));
|
assert!(matches!(err, AppError::Validation(_)));
|
||||||
@@ -958,18 +753,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns_orphan_chunk";
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.with_context(|| "migrations".to_string())?;
|
|
||||||
configure_embedding_dimension(&db, 3).await?;
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.with_context(|| "redefine index".to_string())?;
|
|
||||||
|
|
||||||
let user_id = "user".to_string();
|
let user_id = "user".to_string();
|
||||||
let chunk = TextChunk::new(
|
let chunk = TextChunk::new(
|
||||||
@@ -978,7 +762,7 @@ mod tests {
|
|||||||
user_id.clone(),
|
user_id.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "store chunk with embedding".to_string())?;
|
.with_context(|| "store chunk with embedding".to_string())?;
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
use surrealdb::RecordId;
|
use surrealdb::RecordId;
|
||||||
|
|
||||||
use crate::storage::types::text_chunk::TextChunk;
|
use crate::{storage::types::EmbeddingRecord, stored_object};
|
||||||
use crate::{
|
|
||||||
error::AppError,
|
#[cfg(test)]
|
||||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
use crate::error::AppError;
|
||||||
stored_object,
|
|
||||||
};
|
|
||||||
|
|
||||||
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||||
/// Record link to the owning text_chunk
|
/// Record link to the owning text_chunk
|
||||||
@@ -18,123 +16,46 @@ stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
|||||||
user_id: String
|
user_id: String
|
||||||
});
|
});
|
||||||
|
|
||||||
impl TextChunkEmbedding {
|
impl EmbeddingRecord for TextChunkEmbedding {
|
||||||
/// Recreate the HNSW index with a new embedding dimension.
|
fn link_field() -> &'static str {
|
||||||
///
|
"chunk_id"
|
||||||
/// This is useful when the embedding length changes; Surreal requires the
|
|
||||||
/// index definition to be recreated with the updated dimension.
|
|
||||||
pub async fn redefine_hnsw_index(
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
dimension: usize,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
let query = hnsw_index_redefine_transaction_sql(
|
|
||||||
"idx_embedding_text_chunk_embedding",
|
|
||||||
Self::table_name(),
|
|
||||||
dimension,
|
|
||||||
);
|
|
||||||
|
|
||||||
let res = db.client.query(query).await.map_err(AppError::from)?;
|
|
||||||
res.check().map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
fn index_name() -> &'static str {
|
||||||
#[allow(clippy::result_large_err)]
|
"idx_embedding_text_chunk_embedding"
|
||||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
|
||||||
if embedding.len() != expected {
|
|
||||||
return Err(AppError::Validation(format!(
|
|
||||||
"embedding dimension mismatch: got {}, expected {expected}",
|
|
||||||
embedding.len()
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new text chunk embedding.
|
fn source_id(&self) -> &str {
|
||||||
///
|
&self.source_id
|
||||||
/// The embedding record id equals `chunk_id` so each chunk has at most one embedding row.
|
}
|
||||||
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid".
|
|
||||||
#[must_use]
|
fn user_id(&self) -> &str {
|
||||||
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
&self.user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embedding(&self) -> &[f32] {
|
||||||
|
&self.embedding
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(
|
||||||
|
chunk_id: &str,
|
||||||
|
source_id: String,
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
user_id: String,
|
||||||
|
entity_table: &str,
|
||||||
|
) -> Self {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
id: chunk_id.to_owned(),
|
id: chunk_id.to_owned(),
|
||||||
created_at: now,
|
created_at: now,
|
||||||
updated_at: now,
|
updated_at: now,
|
||||||
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
chunk_id: RecordId::from_table_key(entity_table, chunk_id),
|
||||||
source_id,
|
source_id,
|
||||||
embedding,
|
embedding,
|
||||||
user_id,
|
user_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a single embedding by its chunk RecordId
|
|
||||||
pub async fn get_by_chunk_id(
|
|
||||||
chunk_id: &RecordId,
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
) -> Result<Option<Self>, AppError> {
|
|
||||||
let query = format!(
|
|
||||||
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
|
|
||||||
Self::table_name()
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut result = db
|
|
||||||
.client
|
|
||||||
.query(query)
|
|
||||||
.bind(("chunk_id", chunk_id.clone()))
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
|
|
||||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(embeddings.into_iter().next())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Delete embeddings for a given chunk RecordId
|
|
||||||
pub async fn delete_by_chunk_id(
|
|
||||||
chunk_id: &RecordId,
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
let query = format!(
|
|
||||||
"DELETE FROM {} WHERE chunk_id = $chunk_id",
|
|
||||||
Self::table_name()
|
|
||||||
);
|
|
||||||
|
|
||||||
db.client
|
|
||||||
.query(query)
|
|
||||||
.bind(("chunk_id", chunk_id.clone()))
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?
|
|
||||||
.check()
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Delete all embeddings that belong to chunks with a given `source_id`
|
|
||||||
///
|
|
||||||
/// This uses the denormalized `source_id` on the embedding table.
|
|
||||||
pub async fn delete_by_source_id(
|
|
||||||
source_id: &str,
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
) -> Result<(), AppError> {
|
|
||||||
let query = format!(
|
|
||||||
"DELETE FROM {} WHERE source_id = $source_id",
|
|
||||||
Self::table_name()
|
|
||||||
);
|
|
||||||
|
|
||||||
db.client
|
|
||||||
.query(query)
|
|
||||||
.bind(("source_id", source_id.to_owned()))
|
|
||||||
.await
|
|
||||||
.map_err(AppError::from)?
|
|
||||||
.check()
|
|
||||||
.map_err(AppError::from)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -144,8 +65,31 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::db::SurrealDbClient;
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
use crate::storage::types::text_chunk::TextChunk;
|
||||||
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
||||||
use surrealdb::Value as SurrealValue;
|
|
||||||
|
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||||
|
let mut info_res = db
|
||||||
|
.client
|
||||||
|
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||||
|
.await
|
||||||
|
.with_context(|| "info query failed".to_string())?;
|
||||||
|
let info: surrealdb::Value = info_res
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "failed to take info result".to_string())?;
|
||||||
|
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||||
|
.with_context(|| "failed to convert info to json".to_string())?;
|
||||||
|
let idx_sql = info_json
|
||||||
|
.get("Object")
|
||||||
|
.and_then(|v| v.get("indexes"))
|
||||||
|
.and_then(|v| v.get("Object"))
|
||||||
|
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||||
|
.and_then(|v| v.get("Strand"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or_default()
|
||||||
|
.to_string();
|
||||||
|
Ok(idx_sql)
|
||||||
|
}
|
||||||
|
|
||||||
async fn create_text_chunk_with_id(
|
async fn create_text_chunk_with_id(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
@@ -169,29 +113,6 @@ mod tests {
|
|||||||
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
|
||||||
let mut info_res = db
|
|
||||||
.client
|
|
||||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
|
||||||
.await
|
|
||||||
.with_context(|| "info query failed".to_string())?;
|
|
||||||
let info: SurrealValue = info_res
|
|
||||||
.take(0)
|
|
||||||
.with_context(|| "failed to take info result".to_string())?;
|
|
||||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
|
||||||
.with_context(|| "failed to convert info to json".to_string())?;
|
|
||||||
let idx_sql = info_json
|
|
||||||
.get("Object")
|
|
||||||
.and_then(|v| v.get("indexes"))
|
|
||||||
.and_then(|v| v.get("Object"))
|
|
||||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
|
||||||
.and_then(|v| v.get("Strand"))
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.unwrap_or_default()
|
|
||||||
.to_string();
|
|
||||||
Ok(idx_sql)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn new_uses_chunk_id_as_record_id() {
|
fn new_uses_chunk_id_as_record_id() {
|
||||||
let emb = TextChunkEmbedding::new(
|
let emb = TextChunkEmbedding::new(
|
||||||
@@ -199,6 +120,7 @@ mod tests {
|
|||||||
"source-1".to_owned(),
|
"source-1".to_owned(),
|
||||||
vec![0.1, 0.2],
|
vec![0.1, 0.2],
|
||||||
"user-1".to_owned(),
|
"user-1".to_owned(),
|
||||||
|
TextChunk::table_name(),
|
||||||
);
|
);
|
||||||
assert_eq!(emb.id, "chunk-abc");
|
assert_eq!(emb.id, "chunk-abc");
|
||||||
}
|
}
|
||||||
@@ -226,13 +148,14 @@ mod tests {
|
|||||||
source_id.to_string(),
|
source_id.to_string(),
|
||||||
embedding_vec.clone(),
|
embedding_vec.clone(),
|
||||||
user_id.to_string(),
|
user_id.to_string(),
|
||||||
|
TextChunk::table_name(),
|
||||||
);
|
);
|
||||||
|
|
||||||
db.upsert_item(emb)
|
db.upsert_item(emb)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store embedding".to_string())?;
|
.with_context(|| "Failed to store embedding".to_string())?;
|
||||||
|
|
||||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
let fetched = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
||||||
.with_context(|| "Expected an embedding to be found".to_string())?;
|
.with_context(|| "Expected an embedding to be found".to_string())?;
|
||||||
@@ -259,22 +182,23 @@ mod tests {
|
|||||||
source_id.to_string(),
|
source_id.to_string(),
|
||||||
vec![0.4_f32, 0.5, 0.6],
|
vec![0.4_f32, 0.5, 0.6],
|
||||||
user_id.to_string(),
|
user_id.to_string(),
|
||||||
|
TextChunk::table_name(),
|
||||||
);
|
);
|
||||||
|
|
||||||
db.upsert_item(emb)
|
db.upsert_item(emb)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store embedding".to_string())?;
|
.with_context(|| "Failed to store embedding".to_string())?;
|
||||||
|
|
||||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
let existing = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||||
|
|
||||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
TextChunkEmbedding::delete_by_record_id(&db, &chunk_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
||||||
|
|
||||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
let after = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||||
assert!(after.is_none(), "Embedding should have been deleted");
|
assert!(after.is_none(), "Embedding should have been deleted");
|
||||||
@@ -299,21 +223,27 @@ mod tests {
|
|||||||
("chunk-s2", source_id, vec![0.2]),
|
("chunk-s2", source_id, vec![0.2]),
|
||||||
("chunk-other", other_source, vec![0.3]),
|
("chunk-other", other_source, vec![0.3]),
|
||||||
] {
|
] {
|
||||||
let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
|
let emb = TextChunkEmbedding::new(
|
||||||
|
key,
|
||||||
|
src.to_string(),
|
||||||
|
vec,
|
||||||
|
user_id.to_string(),
|
||||||
|
TextChunk::table_name(),
|
||||||
|
);
|
||||||
db.upsert_item(emb)
|
db.upsert_item(emb)
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("store embedding for {key}"))?;
|
.with_context(|| format!("store embedding for {key}"))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get chunk1".to_string())?
|
.with_context(|| "get chunk1".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get chunk2".to_string())?
|
.with_context(|| "get chunk2".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "get chunk_other".to_string())?
|
.with_context(|| "get chunk_other".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
@@ -322,15 +252,15 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||||
|
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "check chunk1".to_string())?
|
.with_context(|| "check chunk1".to_string())?
|
||||||
.is_none());
|
.is_none());
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "check chunk2".to_string())?
|
.with_context(|| "check chunk2".to_string())?
|
||||||
.is_none());
|
.is_none());
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "check chunk_other".to_string())?
|
.with_context(|| "check chunk_other".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
@@ -352,6 +282,7 @@ mod tests {
|
|||||||
source_id.to_owned(),
|
source_id.to_owned(),
|
||||||
vec![1.0_f32, 0.0, 0.0],
|
vec![1.0_f32, 0.0, 0.0],
|
||||||
user_id.to_owned(),
|
user_id.to_owned(),
|
||||||
|
TextChunk::table_name(),
|
||||||
);
|
);
|
||||||
db.upsert_item(initial)
|
db.upsert_item(initial)
|
||||||
.await
|
.await
|
||||||
@@ -362,6 +293,7 @@ mod tests {
|
|||||||
source_id.to_owned(),
|
source_id.to_owned(),
|
||||||
vec![0.0, 1.0, 0.0],
|
vec![0.0, 1.0, 0.0],
|
||||||
user_id.to_owned(),
|
user_id.to_owned(),
|
||||||
|
TextChunk::table_name(),
|
||||||
);
|
);
|
||||||
db.upsert_item(replacement)
|
db.upsert_item(replacement)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -96,6 +96,41 @@ impl TextContent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// SurrealQL deletes for ingested child rows keyed by `source_id` (no transaction wrapper).
|
||||||
|
///
|
||||||
|
/// Used inside larger transactions (e.g. ingestion `persist_artifacts`) and mirrored by
|
||||||
|
/// [`Self::clear_ingested_children`].
|
||||||
|
pub const CLEAR_INGESTED_CHILD_ROWS_SURQL: &'static str = r"
|
||||||
|
DELETE relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id;
|
||||||
|
DELETE text_chunk_embedding WHERE source_id = $source_id;
|
||||||
|
DELETE text_chunk WHERE source_id = $source_id;
|
||||||
|
DELETE knowledge_entity_embedding WHERE source_id = $source_id;
|
||||||
|
DELETE knowledge_entity WHERE source_id = $source_id;
|
||||||
|
";
|
||||||
|
|
||||||
|
/// Removes chunks, embeddings, entities, and relationships for one ingested document snapshot.
|
||||||
|
pub async fn clear_ingested_children(
|
||||||
|
source_id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let query = format!(
|
||||||
|
"BEGIN TRANSACTION;\n{} COMMIT TRANSACTION;",
|
||||||
|
Self::CLEAR_INGESTED_CHILD_ROWS_SURQL
|
||||||
|
);
|
||||||
|
|
||||||
|
db.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("source_id", source_id.to_string()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn patch(
|
pub async fn patch(
|
||||||
id: &str,
|
id: &str,
|
||||||
context: &str,
|
context: &str,
|
||||||
@@ -364,7 +399,14 @@ mod tests {
|
|||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::test_utils::setup_test_db_with_runtime_indexes;
|
use crate::{
|
||||||
|
storage::types::{
|
||||||
|
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||||
|
knowledge_relationship::KnowledgeRelationship,
|
||||||
|
text_chunk::TextChunk,
|
||||||
|
},
|
||||||
|
test_utils::{setup_test_db, setup_test_db_with_runtime_indexes},
|
||||||
|
};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_creation() -> anyhow::Result<()> {
|
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||||
@@ -638,4 +680,81 @@ mod tests {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn clear_ingested_children_removes_chunks_entities_and_relationships(
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
let user_id = "clear-user";
|
||||||
|
let source_id = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
|
let entity_a = KnowledgeEntity::new(
|
||||||
|
source_id.clone(),
|
||||||
|
"entity-a".to_string(),
|
||||||
|
"desc-a".to_string(),
|
||||||
|
KnowledgeEntityType::Idea,
|
||||||
|
None,
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
let entity_b = KnowledgeEntity::new(
|
||||||
|
source_id.clone(),
|
||||||
|
"entity-b".to_string(),
|
||||||
|
"desc-b".to_string(),
|
||||||
|
KnowledgeEntityType::Idea,
|
||||||
|
None,
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
KnowledgeEntity::store_with_embedding(entity_a.clone(), vec![0.1; 3], 3, &db)
|
||||||
|
.await
|
||||||
|
.context("store entity a")?;
|
||||||
|
KnowledgeEntity::store_with_embedding(entity_b.clone(), vec![0.2; 3], 3, &db)
|
||||||
|
.await
|
||||||
|
.context("store entity b")?;
|
||||||
|
|
||||||
|
let chunk = TextChunk::new(source_id.clone(), "chunk".to_string(), user_id.to_string());
|
||||||
|
TextChunk::store_with_embedding(chunk, vec![0.3; 3], 3, &db)
|
||||||
|
.await
|
||||||
|
.context("store chunk")?;
|
||||||
|
|
||||||
|
KnowledgeRelationship::new(
|
||||||
|
entity_a.id.clone(),
|
||||||
|
entity_b.id,
|
||||||
|
user_id.to_string(),
|
||||||
|
source_id.clone(),
|
||||||
|
"relates_to".to_string(),
|
||||||
|
)
|
||||||
|
.store_relationship(&db)
|
||||||
|
.await
|
||||||
|
.context("store relationship")?;
|
||||||
|
|
||||||
|
TextContent::clear_ingested_children(&source_id, user_id, &db)
|
||||||
|
.await
|
||||||
|
.context("clear ingested children")?;
|
||||||
|
|
||||||
|
let chunks: Vec<TextChunk> = db
|
||||||
|
.client
|
||||||
|
.query("SELECT * FROM text_chunk WHERE source_id = $source_id;")
|
||||||
|
.bind(("source_id", source_id.clone()))
|
||||||
|
.await?
|
||||||
|
.take(0)?;
|
||||||
|
assert!(chunks.is_empty());
|
||||||
|
|
||||||
|
let entities: Vec<KnowledgeEntity> = db
|
||||||
|
.client
|
||||||
|
.query("SELECT * FROM knowledge_entity WHERE source_id = $source_id;")
|
||||||
|
.bind(("source_id", source_id.clone()))
|
||||||
|
.await?
|
||||||
|
.take(0)?;
|
||||||
|
assert!(entities.is_empty());
|
||||||
|
|
||||||
|
let relationships: Vec<KnowledgeRelationship> = db
|
||||||
|
.client
|
||||||
|
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id;")
|
||||||
|
.bind(("source_id", source_id))
|
||||||
|
.await?
|
||||||
|
.take(0)?;
|
||||||
|
assert!(relationships.is_empty());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use crate::storage::{
|
|||||||
indexes::{ensure_runtime, rebuild},
|
indexes::{ensure_runtime, rebuild},
|
||||||
types::{
|
types::{
|
||||||
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
|
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
|
||||||
text_chunk_embedding::TextChunkEmbedding,
|
text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -91,3 +91,47 @@ pub async fn setup_test_db_with_runtime_indexes() -> Result<SurrealDbClient> {
|
|||||||
rebuild(&db).await?;
|
rebuild(&db).await?;
|
||||||
Ok(db)
|
Ok(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Ensures an FTS analyzer and BM25 indexes exist for a table.
|
||||||
|
///
|
||||||
|
/// Attempts snowball(english) tokenizer first; falls back to basic
|
||||||
|
/// lowercase+ascii when the platform lacks the snowball extension.
|
||||||
|
///
|
||||||
|
/// `indexes` is a slice of `(field_name, index_id_suffix)` pairs —
|
||||||
|
/// e.g. `&[("chunk", "chunk")]` produces index
|
||||||
|
/// `text_chunk_fts_chunk_idx` on column `chunk` of `text_chunk`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the fallback definition fails. The initial
|
||||||
|
/// snowball attempt is allowed to fail silently.
|
||||||
|
pub async fn ensure_fts_index(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
table: &str,
|
||||||
|
indexes: &[(&str, &str)],
|
||||||
|
) -> Result<()> {
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
let mut define_indexes = String::new();
|
||||||
|
for (field, suffix) in indexes {
|
||||||
|
let _ = writeln!(
|
||||||
|
define_indexes,
|
||||||
|
"DEFINE INDEX IF NOT EXISTS {table}_fts_{suffix}_idx ON TABLE {table} FIELDS {field} SEARCH ANALYZER app_en_fts_analyzer BM25;"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let snowball_sql = format!(
|
||||||
|
"DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);\n{define_indexes}"
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Err(err) = db.client.query(&snowball_sql).await {
|
||||||
|
let fallback_sql = format!(
|
||||||
|
"DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;\n{define_indexes}"
|
||||||
|
);
|
||||||
|
db.client
|
||||||
|
.query(&fallback_sql)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("define fts index fallback for {table}: {err}"))?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|||||||
@@ -135,6 +135,9 @@ pub struct AppConfig {
|
|||||||
pub ingest_max_context_bytes: usize,
|
pub ingest_max_context_bytes: usize,
|
||||||
#[serde(default = "default_ingest_max_category_bytes")]
|
#[serde(default = "default_ingest_max_category_bytes")]
|
||||||
pub ingest_max_category_bytes: usize,
|
pub ingest_max_category_bytes: usize,
|
||||||
|
/// Seconds between scheduled `REBUILD INDEX` maintainer runs (`0` disables).
|
||||||
|
#[serde(default = "default_index_rebuild_interval_secs")]
|
||||||
|
pub index_rebuild_interval_secs: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Default data directory for persisted assets.
|
/// Default data directory for persisted assets.
|
||||||
@@ -172,6 +175,10 @@ fn default_ingest_max_category_bytes() -> usize {
|
|||||||
128
|
128
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_index_rebuild_interval_secs() -> u64 {
|
||||||
|
86_400
|
||||||
|
}
|
||||||
|
|
||||||
static ORT_PATH_INIT: Once = Once::new();
|
static ORT_PATH_INIT: Once = Once::new();
|
||||||
|
|
||||||
/// Sets `ORT_DYLIB_PATH` once per process when a bundled ONNX runtime library is found.
|
/// Sets `ORT_DYLIB_PATH` once per process when a bundled ONNX runtime library is found.
|
||||||
@@ -238,6 +245,7 @@ impl Default for AppConfig {
|
|||||||
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
||||||
ingest_max_context_bytes: default_ingest_max_context_bytes(),
|
ingest_max_context_bytes: default_ingest_max_context_bytes(),
|
||||||
ingest_max_category_bytes: default_ingest_max_category_bytes(),
|
ingest_max_category_bytes: default_ingest_max_category_bytes(),
|
||||||
|
index_rebuild_interval_secs: default_index_rebuild_interval_secs(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use std::{
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
use async_openai::{types::embeddings::CreateEmbeddingRequestArgs, Client};
|
||||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||||
|
|
||||||
|
|||||||
+14
-14
@@ -3,10 +3,10 @@
|
|||||||
"devenv": {
|
"devenv": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"dir": "src/modules",
|
"dir": "src/modules",
|
||||||
"lastModified": 1771066302,
|
"lastModified": 1781800860,
|
||||||
"owner": "cachix",
|
"owner": "cachix",
|
||||||
"repo": "devenv",
|
"repo": "devenv",
|
||||||
"rev": "1b355dec9bddbaddbe4966d6fc30d7aa3af8575b",
|
"rev": "d59d872d80876d9eeb3e214d3b088bc4a14a9c4f",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -22,10 +22,10 @@
|
|||||||
"rust-analyzer-src": "rust-analyzer-src"
|
"rust-analyzer-src": "rust-analyzer-src"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1771052630,
|
"lastModified": 1781779700,
|
||||||
"owner": "nix-community",
|
"owner": "nix-community",
|
||||||
"repo": "fenix",
|
"repo": "fenix",
|
||||||
"rev": "d0555da98576b8611c25df0c208e51e9a182d95f",
|
"rev": "ad30e585c7a2917325943c2b19511f5a249eff53",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -58,10 +58,10 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1770726378,
|
"lastModified": 1781733627,
|
||||||
"owner": "cachix",
|
"owner": "cachix",
|
||||||
"repo": "git-hooks.nix",
|
"repo": "git-hooks.nix",
|
||||||
"rev": "5eaaedde414f6eb1aea8b8525c466dc37bba95ae",
|
"rev": "3bbec39bc90eadfa031e6f3b77272f3f60803e39",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -92,10 +92,10 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1771008912,
|
"lastModified": 1781577229,
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "a82ccc39b39b621151d6732718e3e250109076fa",
|
"rev": "567a49d1913ce81ac6e9582e3553dd90a955875f",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -107,10 +107,10 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs_2": {
|
"nixpkgs_2": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1770843696,
|
"lastModified": 1781607440,
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "2343bbb58f99267223bc2aac4fc9ea301a155a16",
|
"rev": "3e41b24abd260e8f71dbe2f5737d24122f972158",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -135,10 +135,10 @@
|
|||||||
"rust-analyzer-src": {
|
"rust-analyzer-src": {
|
||||||
"flake": false,
|
"flake": false,
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1771007332,
|
"lastModified": 1781714865,
|
||||||
"owner": "rust-lang",
|
"owner": "rust-lang",
|
||||||
"repo": "rust-analyzer",
|
"repo": "rust-analyzer",
|
||||||
"rev": "bbc84d335fbbd9b3099d3e40c7469ee57dbd1873",
|
"rev": "abb1301c3c14a40645bb2588b1cc858fe374b527",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -155,10 +155,10 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1771038269,
|
"lastModified": 1781850613,
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "d7a86c8a4df49002446737603a3e0d7ef91a9637",
|
"rev": "4baecb43a008cd004e5220a777e1724bd8d43e43",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|||||||
+32
-5
@@ -4,17 +4,32 @@
|
|||||||
config,
|
config,
|
||||||
inputs,
|
inputs,
|
||||||
...
|
...
|
||||||
}:
|
}: let
|
||||||
let
|
ortVersion = "1.23.2";
|
||||||
ortVersion = lib.removeSuffix "\n" (builtins.readFile "${toString ./.}/ort-version");
|
|
||||||
_ortVersionCheck =
|
_ortVersionCheck =
|
||||||
if pkgs.onnxruntime.version == ortVersion
|
if pkgs.onnxruntime.version == ortVersion
|
||||||
then null
|
then null
|
||||||
else
|
else throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ortVersion in flake.nix (${ortVersion})";
|
||||||
throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ort-version (${ortVersion})";
|
|
||||||
in {
|
in {
|
||||||
|
devenv.warnOnNewVersion = false;
|
||||||
|
|
||||||
cachix.enable = false;
|
cachix.enable = false;
|
||||||
|
|
||||||
|
git-hooks.install.enable = true;
|
||||||
|
git-hooks.hooks = {
|
||||||
|
rustfmt.enable = true;
|
||||||
|
clippy = {
|
||||||
|
enable = true;
|
||||||
|
settings.allFeatures = true;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
# Use pinned Rust toolchain from languages.rust for git-hooks wrappers
|
||||||
|
# (git-hooks.nix defaults to nixpkgs's cargo/clippy/rustfmt, ignoring the pin)
|
||||||
|
git-hooks.tools.cargo = lib.mkDefault config.languages.rust.toolchain.cargo;
|
||||||
|
git-hooks.tools.clippy = lib.mkDefault config.languages.rust.toolchain.clippy;
|
||||||
|
git-hooks.tools.rustfmt = lib.mkDefault config.languages.rust.toolchain.rustfmt;
|
||||||
|
|
||||||
packages = [
|
packages = [
|
||||||
pkgs.openssl
|
pkgs.openssl
|
||||||
pkgs.nodejs
|
pkgs.nodejs
|
||||||
@@ -26,6 +41,14 @@ in {
|
|||||||
pkgs.onnxruntime
|
pkgs.onnxruntime
|
||||||
pkgs.cargo-watch
|
pkgs.cargo-watch
|
||||||
pkgs.tailwindcss_4
|
pkgs.tailwindcss_4
|
||||||
|
pkgs.python3
|
||||||
|
pkgs.fontconfig
|
||||||
|
pkgs.fontconfig.dev
|
||||||
|
pkgs.libGL
|
||||||
|
pkgs.libGLU
|
||||||
|
pkgs.libclang
|
||||||
|
pkgs.wayland
|
||||||
|
pkgs.libxkbcommon
|
||||||
];
|
];
|
||||||
|
|
||||||
languages.rust = {
|
languages.rust = {
|
||||||
@@ -38,6 +61,10 @@ in {
|
|||||||
};
|
};
|
||||||
|
|
||||||
env = {
|
env = {
|
||||||
|
# tikv-jemalloc-sys configure flags: -O0 + -Werror triggers glibc _FORTIFY_SOURCE warning
|
||||||
|
NIX_CFLAGS_COMPILE = "-Wno-error=cpp";
|
||||||
|
LIBCLANG_PATH = "${pkgs.libclang.lib}/lib";
|
||||||
|
LD_LIBRARY_PATH = "${pkgs.wayland}/lib:${pkgs.libxkbcommon}/lib:${pkgs.pipewire}/lib:${pkgs.libglvnd}/lib";
|
||||||
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
|
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
|
||||||
S3_ENDPOINT = "http://127.0.0.1:19000";
|
S3_ENDPOINT = "http://127.0.0.1:19000";
|
||||||
S3_BUCKET = "minne-tests";
|
S3_BUCKET = "minne-tests";
|
||||||
|
|||||||
@@ -9,3 +9,6 @@ inputs:
|
|||||||
nixpkgs:
|
nixpkgs:
|
||||||
follows: nixpkgs
|
follows: nixpkgs
|
||||||
allowUnfree: true
|
allowUnfree: true
|
||||||
|
nixpkgs:
|
||||||
|
permittedInsecurePackages:
|
||||||
|
- "minio-2025-10-15T17-29-55Z"
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
services:
|
services:
|
||||||
minne:
|
minne:
|
||||||
build: .
|
image: ghcr.io/perstarkse/minne:latest
|
||||||
container_name: minne_app
|
container_name: minne_app
|
||||||
ports:
|
ports:
|
||||||
- "3000:3000"
|
- "3000:3000"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
| Frontend | HTML + HTMX + minimal JS |
|
| Frontend | HTML + HTMX + minimal JS |
|
||||||
| Database | SurrealDB (graph, document, vector) |
|
| Database | SurrealDB (graph, document, vector) |
|
||||||
| AI | OpenAI-compatible API |
|
| AI | OpenAI-compatible API |
|
||||||
| Web Processing | Headless Chromium |
|
| Web Processing | Servo engine (servo-fetch) + PDFium |
|
||||||
|
|
||||||
## Crate Structure
|
## Crate Structure
|
||||||
|
|
||||||
|
|||||||
+3
-1
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
Minne automatically processes saved content:
|
Minne automatically processes saved content:
|
||||||
|
|
||||||
1. **Web scraping** extracts readable text from URLs (via headless Chrome)
|
1. **Web scraping** extracts readable text from URLs (via embedded Servo engine)
|
||||||
2. **Text analysis** identifies key concepts and relationships
|
2. **Text analysis** identifies key concepts and relationships
|
||||||
3. **Graph creation** builds connections between related content
|
3. **Graph creation** builds connections between related content
|
||||||
4. **Embedding generation** enables semantic search
|
4. **Embedding generation** enables semantic search
|
||||||
@@ -43,6 +43,7 @@ Optional **reranking** can rescore fused chunk lists with a cross-encoder model;
|
|||||||
When enabled, retrieval results are rescored with a cross-encoder model for improved relevance. Powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs).
|
When enabled, retrieval results are rescored with a cross-encoder model for improved relevance. Powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs).
|
||||||
|
|
||||||
**Trade-offs:**
|
**Trade-offs:**
|
||||||
|
|
||||||
- Downloads ~1.1 GB of model data
|
- Downloads ~1.1 GB of model data
|
||||||
- Adds latency per query
|
- Adds latency per query
|
||||||
- Potentially improves answer quality, see [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/)
|
- Potentially improves answer quality, see [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/)
|
||||||
@@ -52,6 +53,7 @@ Enable via `RERANKING_ENABLED=true`. See [Configuration](./configuration.md).
|
|||||||
## Multi-Format Ingestion
|
## Multi-Format Ingestion
|
||||||
|
|
||||||
Supported content types:
|
Supported content types:
|
||||||
|
|
||||||
- Plain text and notes
|
- Plain text and notes
|
||||||
- URLs (web pages)
|
- URLs (web pages)
|
||||||
- PDF documents
|
- PDF documents
|
||||||
|
|||||||
@@ -12,13 +12,13 @@ cd minne
|
|||||||
docker compose up -d
|
docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
The included `docker-compose.yml` handles SurrealDB and Chromium automatically.
|
The included `docker-compose.yml` handles SurrealDB automatically.
|
||||||
|
|
||||||
**Required:** Set your `OPENAI_API_KEY` in `docker-compose.yml` before starting.
|
**Required:** Set your `OPENAI_API_KEY` in `docker-compose.yml` before starting.
|
||||||
|
|
||||||
## Nix
|
## Nix
|
||||||
|
|
||||||
Run Minne directly with Nix (includes Chromium):
|
Run Minne directly with Nix:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
nix run 'github:perstarkse/minne#main'
|
nix run 'github:perstarkse/minne#main'
|
||||||
@@ -31,8 +31,9 @@ Configure via environment variables or a `config.yaml` file. See [Configuration]
|
|||||||
Download binaries for Windows, macOS, and Linux from [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
Download binaries for Windows, macOS, and Linux from [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||||
|
|
||||||
**Requirements:**
|
**Requirements:**
|
||||||
|
|
||||||
- SurrealDB instance (local or remote)
|
- SurrealDB instance (local or remote)
|
||||||
- Chromium (for web scraping)
|
- `libEGL` + `libfontconfig` (for servo-fetch web scraping)
|
||||||
|
|
||||||
## Build from Source
|
## Build from Source
|
||||||
|
|
||||||
@@ -45,9 +46,10 @@ cargo build --release --bin main
|
|||||||
The binary will be at `target/release/main`.
|
The binary will be at `target/release/main`.
|
||||||
|
|
||||||
**Requirements:**
|
**Requirements:**
|
||||||
|
|
||||||
- Rust toolchain
|
- Rust toolchain
|
||||||
- SurrealDB accessible at configured address
|
- SurrealDB accessible at configured address
|
||||||
- Chromium in PATH
|
- `libEGL` + `libfontconfig` for servo-fetch (web scraping) — bundled in Nix and Docker images
|
||||||
|
|
||||||
## Process Modes
|
## Process Modes
|
||||||
|
|
||||||
|
|||||||
@@ -30,8 +30,6 @@ serde_json = { workspace = true }
|
|||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
once_cell = "1.19"
|
once_cell = "1.19"
|
||||||
serde_yaml = "0.9"
|
serde_yaml = "0.9"
|
||||||
criterion = "0.5"
|
|
||||||
state-machines = { workspace = true }
|
|
||||||
clap = { version = "4.4", features = ["derive", "env"] }
|
clap = { version = "4.4", features = ["derive", "env"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|||||||
+71
-181
@@ -1,212 +1,102 @@
|
|||||||
# Evaluations
|
# Evaluations
|
||||||
|
|
||||||
The `evaluations` crate provides a retrieval evaluation framework for benchmarking Minne's information retrieval pipeline against standard datasets.
|
The `evaluations` crate benchmarks Minne's retrieval pipeline against standard datasets.
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run SQuAD v2.0 evaluation (vector-only, recommended)
|
# One-time prep (convert, slice ledger, corpus cache, DB seed)
|
||||||
cargo run --package evaluations -- --ingest-chunks-only
|
cargo eval --warm --dataset beir --slice beir-mix-600
|
||||||
|
|
||||||
# Run a specific dataset
|
# Check readiness
|
||||||
cargo run --package evaluations -- --dataset fiqa --ingest-chunks-only
|
cargo eval --status --dataset beir --slice beir-mix-600
|
||||||
|
|
||||||
# Convert dataset only (no evaluation)
|
# Run benchmark (steady state after warm)
|
||||||
cargo run --package evaluations -- --convert-only
|
cargo eval --dataset beir --slice beir-mix-600 --require-ready
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Default dataset is `beir`. When `--slice` is omitted, the first catalog slice for the dataset is applied automatically (e.g. `beir-mix-600`).
|
||||||
|
|
||||||
|
Chunk-only ingestion is the default. Pass `--include-entities` to opt into entity extraction during ingestion (requires `OPENAI_API_KEY`).
|
||||||
|
|
||||||
|
### Custom slice sizes
|
||||||
|
|
||||||
|
`--slice` is a ledger id, not only a catalog name. You can use any id; `--limit` controls how many questions the ledger contains:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 200-case BEIR mix (default --limit is 200)
|
||||||
|
cargo eval --warm --dataset beir --slice beir-mix-200
|
||||||
|
cargo eval --dataset beir --slice beir-mix-200 --require-ready
|
||||||
|
```
|
||||||
|
|
||||||
|
The catalog slice `beir-mix-600` in `manifest.yaml` is a preset with `limit: 600` and `negative_multiplier: 9.0`.
|
||||||
|
|
||||||
|
### BEIR mix layout
|
||||||
|
|
||||||
|
`beir` is a **virtual mix** across eight subset datasets (FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR). There is no monolithic `beir-minne/` store.
|
||||||
|
|
||||||
|
1. Build an in-memory qrels-world mix from raw subset data
|
||||||
|
2. Resolve the slice ledger (`cache/slices/beir/<slice-id>.json`)
|
||||||
|
3. Materialize only ledger paragraph ids into per-subset stores (`fever-minne/`, `fiqa-minne/`, …)
|
||||||
|
4. Ingest the slice corpus and seed SurrealDB
|
||||||
|
|
||||||
|
Conversion is **qrels-closed**: only documents that appear in qrels are exported, not the full BEIR corpus.
|
||||||
|
|
||||||
|
Chunk-only mode may evaluate fewer cases than the slice ledger size when some questions are impossible or lack verifiable answer chunks.
|
||||||
|
|
||||||
|
Reports include a **Retrieved Context Volume** section: total characters and estimated tokens across all chunks returned per query (`~chars/4`, comparable across `--chunk-result-cap` sweeps). Use this to compare the cost of raising `--chunk-result-cap`.
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
### 1. SurrealDB
|
### SurrealDB
|
||||||
|
|
||||||
Start a SurrealDB instance before running evaluations:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose up -d surrealdb
|
docker-compose up -d surrealdb
|
||||||
```
|
```
|
||||||
|
|
||||||
Or using the default endpoint configuration:
|
### Raw datasets
|
||||||
|
|
||||||
```bash
|
Place raw datasets under `evaluations/data/raw/`. See [manifest.yaml](./manifest.yaml) for paths.
|
||||||
surreal start --user root_user --pass root_password
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Download Raw Datasets
|
BEIR subsets live in sibling directories (`data/raw/fever`, `data/raw/fiqa`, …). The `data/raw/beir` entry is a virtual catalog placeholder; warm uses the subset paths.
|
||||||
|
|
||||||
Raw datasets must be downloaded manually and placed in `evaluations/data/raw/`. See [Dataset Sources](#dataset-sources) below for links and formats.
|
## Directory structure
|
||||||
|
|
||||||
## Directory Structure
|
|
||||||
|
|
||||||
```
|
```
|
||||||
evaluations/
|
evaluations/
|
||||||
├── data/
|
├── data/
|
||||||
│ ├── raw/ # Downloaded raw datasets (manual)
|
│ ├── raw/ # Downloaded datasets (manual)
|
||||||
│ │ ├── squad/ # SQuAD v2.0
|
│ │ ├── fever/ # BEIR subset raw dirs (corpus.jsonl, queries.jsonl, qrels/)
|
||||||
│ │ ├── nq-dev/ # Natural Questions
|
│ │ ├── fiqa/
|
||||||
│ │ ├── fiqa/ # BEIR: FiQA-2018
|
│ │ └── …
|
||||||
│ │ ├── fever/ # BEIR: FEVER
|
│ └── converted/ # Sharded stores (auto-generated)
|
||||||
│ │ ├── hotpotqa/ # BEIR: HotpotQA
|
│ ├── fever-minne/ # per-BEIR-subset stores
|
||||||
│ │ └── ... # Other BEIR subsets
|
│ ├── fiqa-minne/
|
||||||
│ └── converted/ # Auto-generated (Minne JSON format)
|
│ └── … # BEIR mix loads from subset stores (no monolithic beir-minne/)
|
||||||
├── cache/ # Ingestion and embedding caches
|
├── cache/
|
||||||
├── reports/ # Evaluation output (JSON + Markdown)
|
│ ├── slices/ # Slice ledgers
|
||||||
├── manifest.yaml # Dataset and slice definitions
|
│ └── ingested/ # Corpus ingestion caches (manifest includes namespace seed)
|
||||||
└── src/ # Evaluation source code
|
├── reports/ # JSON + Markdown output from benchmark runs
|
||||||
|
├── manifest.yaml
|
||||||
|
└── src/
|
||||||
```
|
```
|
||||||
|
|
||||||
## Dataset Sources
|
**After upgrading:** delete old monolithic `*-minne.json` files, any legacy `beir-minne/` merged store, `cache/snapshots/` directories, and stale `reports/history/` artifacts, then re-run `--warm`.
|
||||||
|
|
||||||
### SQuAD v2.0
|
## Common flags
|
||||||
|
|
||||||
Download and place at `data/raw/squad/dev-v2.0.json`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
mkdir -p evaluations/data/raw/squad
|
|
||||||
curl -L https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json \
|
|
||||||
-o evaluations/data/raw/squad/dev-v2.0.json
|
|
||||||
```
|
|
||||||
|
|
||||||
### Natural Questions (NQ)
|
|
||||||
|
|
||||||
Download and place at `data/raw/nq-dev/dev-all.jsonl`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
mkdir -p evaluations/data/raw/nq-dev
|
|
||||||
# Download from Google's Natural Questions page or HuggingFace
|
|
||||||
# File: dev-all.jsonl (simplified JSONL format)
|
|
||||||
```
|
|
||||||
|
|
||||||
Source: [Google Natural Questions](https://ai.google.com/research/NaturalQuestions)
|
|
||||||
|
|
||||||
### BEIR Datasets
|
|
||||||
|
|
||||||
All BEIR datasets follow the same format structure:
|
|
||||||
|
|
||||||
```
|
|
||||||
data/raw/<dataset>/
|
|
||||||
├── corpus.jsonl # Document corpus
|
|
||||||
├── queries.jsonl # Query set
|
|
||||||
└── qrels/
|
|
||||||
└── test.tsv # Relevance judgments (or dev.tsv)
|
|
||||||
```
|
|
||||||
|
|
||||||
Download datasets from the [BEIR Benchmark repository](https://github.com/beir-cellar/beir). Each dataset zip extracts to the required directory structure.
|
|
||||||
|
|
||||||
| Dataset | Directory |
|
|
||||||
|------------|---------------|
|
|
||||||
| FEVER | `fever/` |
|
|
||||||
| FiQA-2018 | `fiqa/` |
|
|
||||||
| HotpotQA | `hotpotqa/` |
|
|
||||||
| NFCorpus | `nfcorpus/` |
|
|
||||||
| Quora | `quora/` |
|
|
||||||
| TREC-COVID | `trec-covid/` |
|
|
||||||
| SciFact | `scifact/` |
|
|
||||||
| NQ (BEIR) | `nq/` |
|
|
||||||
|
|
||||||
Example download:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd evaluations/data/raw
|
|
||||||
curl -L https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip -o fiqa.zip
|
|
||||||
unzip fiqa.zip && rm fiqa.zip
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dataset Conversion
|
|
||||||
|
|
||||||
Raw datasets are automatically converted to Minne's internal JSON format on first run. To force reconversion:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --package evaluations -- --force-convert
|
|
||||||
```
|
|
||||||
|
|
||||||
Converted files are saved to `data/converted/` and cached for subsequent runs.
|
|
||||||
|
|
||||||
## CLI Reference
|
|
||||||
|
|
||||||
### Common Options
|
|
||||||
|
|
||||||
| Flag | Description | Default |
|
| Flag | Description | Default |
|
||||||
|------|-------------|---------|
|
|------|-------------|---------|
|
||||||
| `--dataset <NAME>` | Dataset to evaluate | `squad-v2` |
|
| `--dataset` | Dataset to evaluate | `beir` |
|
||||||
| `--limit <N>` | Max questions to evaluate (0 = all) | `200` |
|
| `--slice` | Slice ledger id (catalog or custom) | first catalog slice |
|
||||||
| `--k <N>` | Precision@k cutoff | `5` |
|
| `--limit` | Max questions in the slice ledger | `200` |
|
||||||
| `--slice <ID>` | Use a predefined slice from manifest | — |
|
| `--warm` | Prepare without running queries | — |
|
||||||
| `--rerank` | Enable FastEmbed reranking stage | disabled |
|
| `--status` | Print readiness | — |
|
||||||
| `--embedding-backend <BE>` | `fastembed` or `hashed` | `fastembed` |
|
| `--require-ready` | Fail if not warmed | — |
|
||||||
| `--ingest-chunks-only` | Skip entity extraction, ingest only text chunks | disabled |
|
| `--include-entities` | Entity extraction during ingestion | off |
|
||||||
|
| `--force-convert` | Rebuild converted store | — |
|
||||||
|
| `--chunk-result-cap` | Max chunks returned per query (raise with `--k`) | `5` |
|
||||||
|
| `--perf-log-console` | Print per-stage timings after a run | off |
|
||||||
|
| `--label` | Label stored in JSON/Markdown reports | — |
|
||||||
|
|
||||||
> [!TIP]
|
See [REFACTOR.md](./REFACTOR.md) for architecture notes.
|
||||||
> Use `--ingest-chunks-only` when evaluating vector-only retrieval strategies. This skips the LLM-based entity extraction and graph generation, significantly speeding up ingestion while focusing on pure chunk-based vector search.
|
|
||||||
|
|
||||||
### Available Datasets
|
|
||||||
|
|
||||||
```
|
|
||||||
squad-v2, natural-questions, beir, fever, fiqa, hotpotqa,
|
|
||||||
nfcorpus, quora, trec-covid, scifact, nq-beir
|
|
||||||
```
|
|
||||||
|
|
||||||
### Database Configuration
|
|
||||||
|
|
||||||
| Flag | Environment | Default |
|
|
||||||
|------|-------------|---------|
|
|
||||||
| `--db-endpoint` | `EVAL_DB_ENDPOINT` | `ws://127.0.0.1:8000` |
|
|
||||||
| `--db-username` | `EVAL_DB_USERNAME` | `root_user` |
|
|
||||||
| `--db-password` | `EVAL_DB_PASSWORD` | `root_password` |
|
|
||||||
| `--db-namespace` | `EVAL_DB_NAMESPACE` | auto-generated |
|
|
||||||
| `--db-database` | `EVAL_DB_DATABASE` | auto-generated |
|
|
||||||
|
|
||||||
### Example Runs
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Vector-only evaluation (recommended for benchmarking)
|
|
||||||
cargo run --package evaluations -- \
|
|
||||||
--dataset fiqa \
|
|
||||||
--ingest-chunks-only \
|
|
||||||
--limit 200
|
|
||||||
|
|
||||||
# Full FiQA evaluation with reranking
|
|
||||||
cargo run --package evaluations -- \
|
|
||||||
--dataset fiqa \
|
|
||||||
--ingest-chunks-only \
|
|
||||||
--limit 500 \
|
|
||||||
--rerank \
|
|
||||||
--k 10
|
|
||||||
|
|
||||||
# Use a predefined slice for reproducibility
|
|
||||||
cargo run --package evaluations -- --slice fiqa-test-200 --ingest-chunks-only
|
|
||||||
|
|
||||||
# Run the mixed BEIR benchmark
|
|
||||||
cargo run --package evaluations -- --dataset beir --slice beir-mix-600 --ingest-chunks-only
|
|
||||||
```
|
|
||||||
|
|
||||||
## Slices
|
|
||||||
|
|
||||||
Slices are predefined, reproducible subsets defined in `manifest.yaml`. Each slice specifies:
|
|
||||||
|
|
||||||
- **limit**: Number of questions
|
|
||||||
- **corpus_limit**: Maximum corpus size
|
|
||||||
- **seed**: Fixed RNG seed for reproducibility
|
|
||||||
|
|
||||||
View available slices in [manifest.yaml](./manifest.yaml).
|
|
||||||
|
|
||||||
## Reports
|
|
||||||
|
|
||||||
Evaluations generate reports in `reports/`:
|
|
||||||
|
|
||||||
- **JSON**: Full structured results (`*-report.json`)
|
|
||||||
- **Markdown**: Human-readable summary with sample mismatches (`*-report.md`)
|
|
||||||
- **History**: Timestamped run history (`history/`)
|
|
||||||
|
|
||||||
## Performance Tuning
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Log per-stage performance timings
|
|
||||||
cargo run --package evaluations -- --perf-log-console
|
|
||||||
|
|
||||||
# Save telemetry to file
|
|
||||||
cargo run --package evaluations -- --perf-log-json ./perf.json
|
|
||||||
```
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
See [../LICENSE](../LICENSE).
|
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
# Evaluations crate refactor plan
|
||||||
|
|
||||||
|
This document records the architecture review and the simplification work applied to the
|
||||||
|
`evaluations` crate. **No backwards compatibility** is maintained for converted JSON layouts,
|
||||||
|
legacy report history, or old cache artifact formats.
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Smaller, linear pipeline (no state machine ceremony)
|
||||||
|
- Sharded converted store for **all** datasets (memory-efficient partial loading)
|
||||||
|
- Slice-first loading when a catalog slice is selected
|
||||||
|
- In-memory SurrealDB for ingestion (no ephemeral server namespaces)
|
||||||
|
- Single DB lifecycle module (`db/`)
|
||||||
|
- CLI helpers under `cli/`
|
||||||
|
|
||||||
|
## Primary workflow
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# One-time prep (converts raw data if needed, builds slice ledger, corpus cache, DB seed)
|
||||||
|
cargo eval --warm --dataset beir --slice beir-mix-600
|
||||||
|
|
||||||
|
# Check readiness
|
||||||
|
cargo eval --status --dataset beir --slice beir-mix-600
|
||||||
|
|
||||||
|
# Steady-state benchmark
|
||||||
|
cargo eval --dataset beir --slice beir-mix-600 --require-ready
|
||||||
|
```
|
||||||
|
|
||||||
|
Default dataset is `beir`. Chunk-only ingestion is the default; pass `--include-entities` to
|
||||||
|
opt into entity extraction (requires `OPENAI_API_KEY`). Slice tuning such as
|
||||||
|
`negative_multiplier` lives in `manifest.yaml` (e.g. `beir-mix-600` uses `9.0`).
|
||||||
|
|
||||||
|
## Cache layers (after refactor)
|
||||||
|
|
||||||
|
| Layer | Location | Purpose |
|
||||||
|
|-------|----------|---------|
|
||||||
|
| Converted store | `data/converted/<name>/` | Sharded paragraphs + question catalog |
|
||||||
|
| Slice ledger | `cache/slices/<dataset>/<slice-id>.json` | Deterministic questions + paragraph set |
|
||||||
|
| Corpus cache | `cache/ingested/<dataset>/<slice-id>/` | Ingestion paragraph shards, manifest, and namespace reuse seed |
|
||||||
|
|
||||||
|
Namespace reuse state lives in the corpus manifest (`metadata.namespace_seed`), not a separate
|
||||||
|
`snapshots/` tree. After upgrading, delete old `*-minne.json` monolithic files, any
|
||||||
|
`cache/snapshots/` directories, and re-run `--warm`.
|
||||||
|
|
||||||
|
## Phases applied
|
||||||
|
|
||||||
|
### Phase 0 — dead code
|
||||||
|
|
||||||
|
- Removed unused `criterion` dependency
|
||||||
|
- Removed unused `EmbeddingCache`
|
||||||
|
- Updated README for current CLI
|
||||||
|
|
||||||
|
### Phase 1 — structure
|
||||||
|
|
||||||
|
- Flattened pipeline to linear `async fn` stages
|
||||||
|
- Removed `eval.rs` hub; imports go to owning modules
|
||||||
|
- Merged `namespace.rs`, `db_helpers.rs` → `db/`; dropped standalone `snapshot.rs`
|
||||||
|
- Moved `status.rs` → `cli/status.rs`
|
||||||
|
- Fixed catalog slice bootstrap (build ledger when explicit slice manifest is missing)
|
||||||
|
|
||||||
|
### Phase 2 — no legacy paths
|
||||||
|
|
||||||
|
- All datasets use sharded converted store only
|
||||||
|
- Removed legacy JSON layout and migration
|
||||||
|
- Removed legacy report history format
|
||||||
|
- Auto-apply first catalog slice when `--slice` omitted
|
||||||
|
- Namespace seed folded into corpus manifest (removed `cache/snapshots/`)
|
||||||
|
|
||||||
|
### Phase 3 — performance
|
||||||
|
|
||||||
|
- Ingestion always uses in-memory SurrealDB
|
||||||
|
- Slice-first partial load when ledger is complete
|
||||||
|
- Default catalog slice for dataset when `--slice` not passed
|
||||||
|
- Split `slice/` into `mod.rs`, `build.rs`, and `beir.rs`
|
||||||
|
|
||||||
|
### Phase 4 — BEIR mix slice-first
|
||||||
|
|
||||||
|
- `beir` is a virtual mix: slice ledger references prefixed ids (`fever-…`, `fiqa-…`, …)
|
||||||
|
- Conversion is **qrels-closed** per subset (only documents appearing in qrels, not full corpus)
|
||||||
|
- Slice ledger is resolved for the requested `--slice` (catalog preset or custom id + `--limit`)
|
||||||
|
- Only ledger paragraph ids are materialized into per-subset stores (`fever-minne/`, `fiqa-minne/`, …)
|
||||||
|
- No monolithic `beir-minne/` merged store
|
||||||
|
- Raw BEIR data lives in per-subset dirs under `data/raw/`; `data/raw/beir` is a catalog placeholder
|
||||||
|
|
||||||
|
## Do not re-introduce
|
||||||
|
|
||||||
|
- Monolithic `*-minne.json` converted files
|
||||||
|
- Monolithic `beir-minne/` merged converted store (use per-subset stores + virtual mix loader)
|
||||||
|
- `state-machines` pipeline for this linear flow
|
||||||
|
- `eval.rs` re-export hub
|
||||||
|
- Legacy history migration in reports
|
||||||
|
- Ephemeral `ingest_eval_*` namespaces on the shared SurrealDB server
|
||||||
|
- Separate `cache/snapshots/` namespace state files
|
||||||
|
|
||||||
|
## Open follow-ups
|
||||||
|
|
||||||
|
- Generate `DatasetKind` from `manifest.yaml` at build time
|
||||||
|
- Split `report.rs` when touching reporting again
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
default_dataset: squad-v2
|
default_dataset: beir
|
||||||
datasets:
|
datasets:
|
||||||
- id: squad-v2
|
- id: squad-v2
|
||||||
label: "SQuAD v2.0"
|
label: "SQuAD v2.0"
|
||||||
@@ -45,6 +45,7 @@ datasets:
|
|||||||
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR"
|
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR"
|
||||||
limit: 600
|
limit: 600
|
||||||
corpus_limit: 6000
|
corpus_limit: 6000
|
||||||
|
negative_multiplier: 9.0
|
||||||
seed: 0x5eed2025
|
seed: 0x5eed2025
|
||||||
- id: fever
|
- id: fever
|
||||||
label: "FEVER (BEIR)"
|
label: "FEVER (BEIR)"
|
||||||
|
|||||||
+66
-18
@@ -137,9 +137,9 @@ pub struct IngestConfig {
|
|||||||
#[arg(long, default_value_t = 50)]
|
#[arg(long, default_value_t = 50)]
|
||||||
pub ingest_chunk_overlap_tokens: usize,
|
pub ingest_chunk_overlap_tokens: usize,
|
||||||
|
|
||||||
/// Run ingestion in chunk-only mode (skip analyzer/graph generation)
|
/// Include entity extraction and graph generation during ingestion (uses LLM tokens)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub ingest_chunks_only: bool,
|
pub include_entities: bool,
|
||||||
|
|
||||||
/// Number of paragraphs to ingest concurrently
|
/// Number of paragraphs to ingest concurrently
|
||||||
#[arg(long, default_value_t = 10)]
|
#[arg(long, default_value_t = 10)]
|
||||||
@@ -159,6 +159,7 @@ pub struct IngestConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Args)]
|
#[derive(Debug, Clone, Args)]
|
||||||
|
#[allow(clippy::struct_field_names)]
|
||||||
pub struct DatabaseArgs {
|
pub struct DatabaseArgs {
|
||||||
/// `SurrealDB` server endpoint
|
/// `SurrealDB` server endpoint
|
||||||
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
|
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
|
||||||
@@ -179,10 +180,6 @@ pub struct DatabaseArgs {
|
|||||||
/// Override the database used on the `SurrealDB` server
|
/// Override the database used on the `SurrealDB` server
|
||||||
#[arg(long, env = "EVAL_DB_DATABASE")]
|
#[arg(long, env = "EVAL_DB_DATABASE")]
|
||||||
pub db_database: Option<String>,
|
pub db_database: Option<String>,
|
||||||
|
|
||||||
/// Path to inspect DB state
|
|
||||||
#[arg(long)]
|
|
||||||
pub inspect_db_state: Option<PathBuf>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@@ -233,10 +230,6 @@ pub struct Config {
|
|||||||
#[arg(long, default_value_t = 5)]
|
#[arg(long, default_value_t = 5)]
|
||||||
pub sample: usize,
|
pub sample: usize,
|
||||||
|
|
||||||
/// Disable context cropping when converting datasets (ingest entire documents)
|
|
||||||
#[arg(long)]
|
|
||||||
pub full_context: bool,
|
|
||||||
|
|
||||||
#[command(flatten)]
|
#[command(flatten)]
|
||||||
pub retrieval: RetrievalSettings,
|
pub retrieval: RetrievalSettings,
|
||||||
|
|
||||||
@@ -322,6 +315,18 @@ pub struct Config {
|
|||||||
#[command(flatten)]
|
#[command(flatten)]
|
||||||
pub database: DatabaseArgs,
|
pub database: DatabaseArgs,
|
||||||
|
|
||||||
|
/// Require warmed corpus/namespace before running queries
|
||||||
|
#[arg(long)]
|
||||||
|
pub require_ready: bool,
|
||||||
|
|
||||||
|
/// Prepare converted data, slice, corpus, and namespace without running queries
|
||||||
|
#[arg(long, conflicts_with = "status")]
|
||||||
|
pub warm: bool,
|
||||||
|
|
||||||
|
/// Print readiness of converted data, slice, corpus, and namespace
|
||||||
|
#[arg(long, conflicts_with = "warm")]
|
||||||
|
pub status: bool,
|
||||||
|
|
||||||
// Computed fields (not arguments)
|
// Computed fields (not arguments)
|
||||||
#[arg(skip)]
|
#[arg(skip)]
|
||||||
pub raw_dataset_path: PathBuf,
|
pub raw_dataset_path: PathBuf,
|
||||||
@@ -334,11 +339,6 @@ pub struct Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
#[allow(clippy::unused_self)]
|
|
||||||
pub fn context_token_limit(&self) -> Option<usize> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub fn finalize(&mut self) -> Result<()> {
|
pub fn finalize(&mut self) -> Result<()> {
|
||||||
// Handle dataset paths
|
// Handle dataset paths
|
||||||
@@ -367,9 +367,7 @@ impl Config {
|
|||||||
// Handle retrieval settings
|
// Handle retrieval settings
|
||||||
self.retrieval.require_verified_chunks = !self.llm_mode;
|
self.retrieval.require_verified_chunks = !self.llm_mode;
|
||||||
|
|
||||||
if self.dataset == DatasetKind::Beir {
|
self.apply_catalog_slice_defaults()?;
|
||||||
self.negative_multiplier = 9.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validations
|
// Validations
|
||||||
if self.ingest.ingest_chunk_min_tokens == 0
|
if self.ingest.ingest_chunk_min_tokens == 0
|
||||||
@@ -477,6 +475,56 @@ impl Config {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn apply_catalog_slice_defaults(&mut self) -> Result<()> {
|
||||||
|
let catalog = crate::datasets::catalog()?;
|
||||||
|
let entry = catalog.dataset(self.dataset.id())?;
|
||||||
|
|
||||||
|
if self.slice.is_none() {
|
||||||
|
if let Some(default_slice) = entry.slices.first() {
|
||||||
|
self.slice = Some(default_slice.id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(slice_id) = self.slice.as_deref() else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok((_, slice)) = catalog.slice(slice_id) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
if slice.dataset_id != self.dataset.id() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(limit) = slice.limit {
|
||||||
|
if self.limit_arg == 200 {
|
||||||
|
self.limit_arg = limit;
|
||||||
|
self.limit = Some(limit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.corpus_limit.is_none() {
|
||||||
|
self.corpus_limit = slice.corpus_limit;
|
||||||
|
}
|
||||||
|
if let Some(seed) = slice.seed {
|
||||||
|
self.slice_seed = seed;
|
||||||
|
}
|
||||||
|
if let Some(include_unanswerable) = slice.include_unanswerable {
|
||||||
|
self.llm_mode = include_unanswerable;
|
||||||
|
self.retrieval.require_verified_chunks = !include_unanswerable;
|
||||||
|
}
|
||||||
|
if let Some(multiplier) = slice.negative_multiplier {
|
||||||
|
if negative_multiplier_is_default(self.negative_multiplier) {
|
||||||
|
self.negative_multiplier = multiplier;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn negative_multiplier_is_default(value: f32) -> bool {
|
||||||
|
(value - crate::slice::DEFAULT_NEGATIVE_MULTIPLIER).abs() < f32::EPSILON
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ParsedArgs {
|
pub struct ParsedArgs {
|
||||||
|
|||||||
@@ -1,88 +0,0 @@
|
|||||||
use std::{
|
|
||||||
collections::HashMap,
|
|
||||||
path::Path,
|
|
||||||
sync::{
|
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
|
|
||||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
|
||||||
struct EmbeddingCacheData {
|
|
||||||
entities: HashMap<String, Vec<f32>>,
|
|
||||||
chunks: HashMap<String, Vec<f32>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct EmbeddingCache {
|
|
||||||
path: Arc<Path>,
|
|
||||||
data: Arc<Mutex<EmbeddingCacheData>>,
|
|
||||||
dirty: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
impl EmbeddingCache {
|
|
||||||
pub async fn load(path: impl AsRef<Path>) -> Result<Self> {
|
|
||||||
let path = path.as_ref().to_path_buf();
|
|
||||||
let data = if path.exists() {
|
|
||||||
let raw = tokio::fs::read(&path)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("reading embedding cache {}", path.display()))?;
|
|
||||||
serde_json::from_slice(&raw)
|
|
||||||
.with_context(|| format!("parsing embedding cache {}", path.display()))?
|
|
||||||
} else {
|
|
||||||
EmbeddingCacheData::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
path: Arc::from(path.as_path()),
|
|
||||||
data: Arc::new(Mutex::new(data)),
|
|
||||||
dirty: Arc::new(AtomicBool::new(false)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_entity(&self, id: &str) -> Option<Vec<f32>> {
|
|
||||||
let guard = self.data.lock().await;
|
|
||||||
guard.entities.get(id).cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn insert_entity(&self, id: String, embedding: Vec<f32>) {
|
|
||||||
let mut guard = self.data.lock().await;
|
|
||||||
guard.entities.insert(id, embedding);
|
|
||||||
self.dirty.store(true, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_chunk(&self, id: &str) -> Option<Vec<f32>> {
|
|
||||||
let guard = self.data.lock().await;
|
|
||||||
guard.chunks.get(id).cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn insert_chunk(&self, id: String, embedding: Vec<f32>) {
|
|
||||||
let mut guard = self.data.lock().await;
|
|
||||||
guard.chunks.insert(id, embedding);
|
|
||||||
self.dirty.store(true, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn persist(&self) -> Result<()> {
|
|
||||||
if !self.dirty.load(Ordering::Relaxed) {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let guard = self.data.lock().await;
|
|
||||||
let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?;
|
|
||||||
if let Some(parent) = self.path.parent() {
|
|
||||||
tokio::fs::create_dir_all(parent)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("creating cache directory {}", parent.display()))?;
|
|
||||||
}
|
|
||||||
tokio::fs::write(&*self.path, body)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("writing embedding cache {}", self.path.display()))?;
|
|
||||||
self.dirty.store(false, Ordering::Relaxed);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -156,6 +156,7 @@ mod tests {
|
|||||||
chunk_min_tokens: 1,
|
chunk_min_tokens: 1,
|
||||||
chunk_max_tokens: 10,
|
chunk_max_tokens: 10,
|
||||||
chunk_only: false,
|
chunk_only: false,
|
||||||
|
namespace_seed: None,
|
||||||
},
|
},
|
||||||
paragraphs,
|
paragraphs,
|
||||||
questions,
|
questions,
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod status;
|
||||||
|
|
||||||
|
pub use status::{collect_status, ensure_query_ready, print_status, warm};
|
||||||
@@ -0,0 +1,311 @@
|
|||||||
|
#![allow(clippy::module_name_repetitions)]
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
args::Config,
|
||||||
|
corpus::{self, CorpusCacheConfig},
|
||||||
|
datasets::{
|
||||||
|
beir_subset_store_summary, beir_subset_stores_ready, content_checksum_for_layout,
|
||||||
|
detect_layout, mix_content_checksum, store_dir_for, ConvertedLayout, DatasetKind,
|
||||||
|
},
|
||||||
|
db::{connect_eval_db, default_database, default_namespace, namespace_has_corpus},
|
||||||
|
slice::{self, ledger_target},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct EvalStatus {
|
||||||
|
pub dataset: String,
|
||||||
|
pub slice: Option<String>,
|
||||||
|
pub converted: ConvertedStatus,
|
||||||
|
pub slice_ledger: SliceLedgerStatus,
|
||||||
|
pub corpus_cache: CorpusCacheStatus,
|
||||||
|
pub namespace: NamespaceStatus,
|
||||||
|
pub query_ready: bool,
|
||||||
|
pub notes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct ConvertedStatus {
|
||||||
|
pub layout: String,
|
||||||
|
pub path: String,
|
||||||
|
pub ready: bool,
|
||||||
|
pub partial_load_eligible: bool,
|
||||||
|
pub checksum: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct SliceLedgerStatus {
|
||||||
|
pub ready: bool,
|
||||||
|
pub path: Option<String>,
|
||||||
|
pub cases: Option<usize>,
|
||||||
|
pub positives: Option<usize>,
|
||||||
|
pub negatives: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct CorpusCacheStatus {
|
||||||
|
pub ready: bool,
|
||||||
|
pub path: Option<String>,
|
||||||
|
pub manifest_present: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct NamespaceStatus {
|
||||||
|
pub namespace: String,
|
||||||
|
pub database: String,
|
||||||
|
pub seeded: bool,
|
||||||
|
pub namespace_seed_recorded: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
pub async fn collect_status(config: &Config) -> Result<EvalStatus> {
|
||||||
|
let mut notes = Vec::new();
|
||||||
|
let is_beir_mix = config.dataset == DatasetKind::Beir;
|
||||||
|
let converted_path = &config.converted_dataset_path;
|
||||||
|
let layout = if is_beir_mix {
|
||||||
|
ConvertedLayout::Missing
|
||||||
|
} else {
|
||||||
|
detect_layout(converted_path)
|
||||||
|
};
|
||||||
|
let layout_label = if is_beir_mix {
|
||||||
|
"beir-mix-subset-stores"
|
||||||
|
} else {
|
||||||
|
match layout {
|
||||||
|
ConvertedLayout::ShardedStore => "sharded-store",
|
||||||
|
ConvertedLayout::Missing => "missing",
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let store_dir = store_dir_for(converted_path);
|
||||||
|
let display_path = if is_beir_mix {
|
||||||
|
beir_subset_store_summary()?
|
||||||
|
.into_iter()
|
||||||
|
.map(|(subset, paragraphs, questions)| {
|
||||||
|
format!("{subset}-minne ({paragraphs} paragraphs, {questions} questions)")
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("; ")
|
||||||
|
} else {
|
||||||
|
store_dir.display().to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let manifest_path = slice::cached_manifest_path(config);
|
||||||
|
let slice_config = slice::slice_config_with_limit(config, ledger_target(config));
|
||||||
|
let slice_manifest = manifest_path
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|path| slice::read_manifest_if_exists(path).ok().flatten());
|
||||||
|
|
||||||
|
let slice_ledger = SliceLedgerStatus {
|
||||||
|
ready: slice_manifest
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|manifest| slice::manifest_is_complete(manifest, &slice_config)),
|
||||||
|
path: manifest_path
|
||||||
|
.as_ref()
|
||||||
|
.map(|path| path.display().to_string()),
|
||||||
|
cases: slice_manifest.as_ref().map(|manifest| manifest.case_count),
|
||||||
|
positives: slice_manifest
|
||||||
|
.as_ref()
|
||||||
|
.map(|manifest| manifest.positive_paragraphs),
|
||||||
|
negatives: slice_manifest
|
||||||
|
.as_ref()
|
||||||
|
.map(|manifest| manifest.negative_paragraphs),
|
||||||
|
};
|
||||||
|
|
||||||
|
let beir_paragraph_ids = slice_manifest.as_ref().map(|manifest| {
|
||||||
|
manifest
|
||||||
|
.paragraphs
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.id.clone())
|
||||||
|
.collect::<std::collections::HashSet<_>>()
|
||||||
|
});
|
||||||
|
|
||||||
|
let converted_ready = if is_beir_mix {
|
||||||
|
slice_ledger.ready
|
||||||
|
&& beir_paragraph_ids
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|ids| beir_subset_stores_ready(ids).unwrap_or(false))
|
||||||
|
} else {
|
||||||
|
layout == ConvertedLayout::ShardedStore
|
||||||
|
};
|
||||||
|
|
||||||
|
let checksum = if is_beir_mix {
|
||||||
|
beir_paragraph_ids
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|ids| mix_content_checksum(ids).ok())
|
||||||
|
} else if layout == ConvertedLayout::ShardedStore {
|
||||||
|
content_checksum_for_layout(converted_path).ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let partial_load_eligible = slice_ledger.ready && config.slice.is_some();
|
||||||
|
|
||||||
|
let corpus_cache = if let Some(manifest) = slice_manifest.as_ref() {
|
||||||
|
let cache_settings = CorpusCacheConfig::from(config);
|
||||||
|
let base_dir = corpus::cached_corpus_dir(
|
||||||
|
&cache_settings,
|
||||||
|
config.dataset.id(),
|
||||||
|
manifest.slice_id.as_str(),
|
||||||
|
);
|
||||||
|
let manifest_present = corpus::load_cached_manifest(&base_dir)?.is_some();
|
||||||
|
CorpusCacheStatus {
|
||||||
|
ready: manifest_present,
|
||||||
|
path: Some(base_dir.display().to_string()),
|
||||||
|
manifest_present,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
CorpusCacheStatus {
|
||||||
|
ready: false,
|
||||||
|
path: None,
|
||||||
|
manifest_present: false,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let namespace = config.database.db_namespace.clone().unwrap_or_else(|| {
|
||||||
|
default_namespace(config.dataset.id(), config.limit, config.slice.as_deref())
|
||||||
|
});
|
||||||
|
let database = config
|
||||||
|
.database
|
||||||
|
.db_database
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(default_database);
|
||||||
|
|
||||||
|
let namespace_seed = corpus_cache.path.as_ref().and_then(|path| {
|
||||||
|
corpus::load_cached_manifest(Path::new(path))
|
||||||
|
.ok()
|
||||||
|
.flatten()
|
||||||
|
.and_then(|manifest| manifest.metadata.namespace_seed)
|
||||||
|
});
|
||||||
|
|
||||||
|
let (seeded, namespace_seed_recorded) =
|
||||||
|
match connect_eval_db(config, &namespace, &database).await {
|
||||||
|
Ok(db) => {
|
||||||
|
let has_corpus = namespace_has_corpus(&db).await.unwrap_or(false);
|
||||||
|
(has_corpus, namespace_seed.is_some())
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
notes.push(format!("SurrealDB unavailable: {err}"));
|
||||||
|
(false, false)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let query_ready = converted_ready
|
||||||
|
&& slice_ledger.ready
|
||||||
|
&& corpus_cache.ready
|
||||||
|
&& seeded
|
||||||
|
&& namespace_seed_recorded;
|
||||||
|
|
||||||
|
if !query_ready {
|
||||||
|
notes.push("Run `cargo eval --warm --slice <id>` to prepare corpus and namespace.".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(EvalStatus {
|
||||||
|
dataset: config.dataset.id().to_string(),
|
||||||
|
slice: config.slice.clone(),
|
||||||
|
converted: ConvertedStatus {
|
||||||
|
layout: layout_label.to_string(),
|
||||||
|
path: display_path,
|
||||||
|
ready: converted_ready,
|
||||||
|
partial_load_eligible,
|
||||||
|
checksum,
|
||||||
|
},
|
||||||
|
slice_ledger,
|
||||||
|
corpus_cache,
|
||||||
|
namespace: NamespaceStatus {
|
||||||
|
namespace,
|
||||||
|
database,
|
||||||
|
seeded,
|
||||||
|
namespace_seed_recorded,
|
||||||
|
},
|
||||||
|
query_ready,
|
||||||
|
notes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn print_status(status: &EvalStatus) {
|
||||||
|
println!("Evaluation status for dataset `{}`", status.dataset);
|
||||||
|
if let Some(slice) = &status.slice {
|
||||||
|
println!("Slice: {slice}");
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Converted: {} ({})",
|
||||||
|
if status.converted.ready {
|
||||||
|
"ready"
|
||||||
|
} else {
|
||||||
|
"missing"
|
||||||
|
},
|
||||||
|
status.converted.layout
|
||||||
|
);
|
||||||
|
println!("Converted path: {}", status.converted.path);
|
||||||
|
if status.converted.partial_load_eligible {
|
||||||
|
println!("Slice-first loading: eligible");
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Slice ledger: {}",
|
||||||
|
if status.slice_ledger.ready {
|
||||||
|
format!(
|
||||||
|
"ready ({} cases, {} positives, {} negatives)",
|
||||||
|
status.slice_ledger.cases.unwrap_or(0),
|
||||||
|
status.slice_ledger.positives.unwrap_or(0),
|
||||||
|
status.slice_ledger.negatives.unwrap_or(0)
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
"missing or incomplete".to_string()
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if let Some(path) = &status.slice_ledger.path {
|
||||||
|
println!("Slice ledger path: {path}");
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Corpus cache: {}",
|
||||||
|
if status.corpus_cache.ready {
|
||||||
|
"ready"
|
||||||
|
} else {
|
||||||
|
"missing"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if let Some(path) = &status.corpus_cache.path {
|
||||||
|
println!("Corpus cache path: {path}");
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Namespace `{}` / `{}`: seeded={}, namespace_seed_recorded={}",
|
||||||
|
status.namespace.namespace,
|
||||||
|
status.namespace.database,
|
||||||
|
status.namespace.seeded,
|
||||||
|
status.namespace.namespace_seed_recorded
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"Query-ready: {}",
|
||||||
|
if status.query_ready { "yes" } else { "no" }
|
||||||
|
);
|
||||||
|
for note in &status.notes {
|
||||||
|
println!("Note: {note}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn warm(config: &Config) -> Result<()> {
|
||||||
|
let loaded =
|
||||||
|
crate::datasets::prepare_dataset(config.dataset, config).context("preparing dataset")?;
|
||||||
|
crate::pipeline::warm_evaluation(&loaded.dataset, config, &loaded.content_checksum)
|
||||||
|
.await
|
||||||
|
.context("warming evaluation corpus and namespace")?;
|
||||||
|
let status = collect_status(config).await?;
|
||||||
|
print_status(&status);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ensure_query_ready(config: &Config) -> Result<()> {
|
||||||
|
let status = collect_status(config).await?;
|
||||||
|
if status.query_ready {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
print_status(&status);
|
||||||
|
anyhow::bail!(
|
||||||
|
"evaluation is not query-ready; run `cargo eval --warm --slice {}` first",
|
||||||
|
config.slice.as_deref().unwrap_or("<slice-id>")
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
#![allow(clippy::arithmetic_side_effects)]
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use common::storage::types::StoredObject;
|
||||||
|
|
||||||
|
use crate::types::EvaluationCandidate;
|
||||||
|
|
||||||
|
const TOKENIZER_LABEL: &str = "estimated (~chars/4; ingestion uses bert-base-cased)";
|
||||||
|
|
||||||
|
#[allow(clippy::struct_field_names)]
|
||||||
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct RetrievedContextStats {
|
||||||
|
pub chunk_count: usize,
|
||||||
|
pub char_count: usize,
|
||||||
|
pub token_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct RetrievalContextStats {
|
||||||
|
pub tokenizer: String,
|
||||||
|
pub queries: usize,
|
||||||
|
pub total_chunks: usize,
|
||||||
|
pub total_chars: usize,
|
||||||
|
pub total_tokens: usize,
|
||||||
|
pub avg_chunks_per_query: f64,
|
||||||
|
pub avg_chars_per_query: f64,
|
||||||
|
pub avg_tokens_per_query: f64,
|
||||||
|
pub p50_tokens_per_query: usize,
|
||||||
|
pub p95_tokens_per_query: usize,
|
||||||
|
pub max_tokens_per_query: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stats_for_candidates(candidates: &[EvaluationCandidate]) -> RetrievedContextStats {
|
||||||
|
let mut seen_chunk_ids = std::collections::HashSet::new();
|
||||||
|
let mut stats = RetrievedContextStats::default();
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
for chunk in &candidate.chunks {
|
||||||
|
let chunk_id = chunk.chunk.id().to_string();
|
||||||
|
if !seen_chunk_ids.insert(chunk_id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let text = chunk.chunk.chunk.as_str();
|
||||||
|
stats.chunk_count += 1;
|
||||||
|
stats.char_count += text.chars().count();
|
||||||
|
stats.token_count += estimate_ingestion_tokens(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::cast_precision_loss)]
|
||||||
|
pub fn aggregate_context_stats(per_query: &[RetrievedContextStats]) -> RetrievalContextStats {
|
||||||
|
let queries = per_query.len();
|
||||||
|
if queries == 0 {
|
||||||
|
return RetrievalContextStats {
|
||||||
|
tokenizer: TOKENIZER_LABEL.to_string(),
|
||||||
|
queries: 0,
|
||||||
|
total_chunks: 0,
|
||||||
|
total_chars: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
avg_chunks_per_query: 0.0,
|
||||||
|
avg_chars_per_query: 0.0,
|
||||||
|
avg_tokens_per_query: 0.0,
|
||||||
|
p50_tokens_per_query: 0,
|
||||||
|
p95_tokens_per_query: 0,
|
||||||
|
max_tokens_per_query: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_chunks: usize = per_query.iter().map(|stats| stats.chunk_count).sum();
|
||||||
|
let total_chars: usize = per_query.iter().map(|stats| stats.char_count).sum();
|
||||||
|
let total_tokens: usize = per_query.iter().map(|stats| stats.token_count).sum();
|
||||||
|
let mut tokens_per_query: Vec<usize> =
|
||||||
|
per_query.iter().map(|stats| stats.token_count).collect();
|
||||||
|
tokens_per_query.sort_unstable();
|
||||||
|
let max_tokens_per_query = *tokens_per_query.last().unwrap_or(&0);
|
||||||
|
|
||||||
|
let total_chunks_f = total_chunks as f64;
|
||||||
|
let total_chars_f = total_chars as f64;
|
||||||
|
let total_tokens_f = total_tokens as f64;
|
||||||
|
let queries_f = queries as f64;
|
||||||
|
let avg_chunks_per_query = total_chunks_f / queries_f;
|
||||||
|
let avg_chars_per_query = total_chars_f / queries_f;
|
||||||
|
let avg_tokens_per_query = total_tokens_f / queries_f;
|
||||||
|
|
||||||
|
RetrievalContextStats {
|
||||||
|
tokenizer: TOKENIZER_LABEL.to_string(),
|
||||||
|
queries,
|
||||||
|
total_chunks,
|
||||||
|
total_chars,
|
||||||
|
total_tokens,
|
||||||
|
avg_chunks_per_query,
|
||||||
|
avg_chars_per_query,
|
||||||
|
avg_tokens_per_query,
|
||||||
|
p50_tokens_per_query: percentile_usize(&tokens_per_query, 0.50),
|
||||||
|
p95_tokens_per_query: percentile_usize(&tokens_per_query, 0.95),
|
||||||
|
max_tokens_per_query,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_ingestion_tokens(text: &str) -> usize {
|
||||||
|
let chars = text.chars().count();
|
||||||
|
if chars == 0 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
chars.div_ceil(4)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(
|
||||||
|
clippy::cast_precision_loss,
|
||||||
|
clippy::cast_sign_loss,
|
||||||
|
clippy::cast_possible_truncation,
|
||||||
|
clippy::indexing_slicing,
|
||||||
|
clippy::arithmetic_side_effects
|
||||||
|
)]
|
||||||
|
fn percentile_usize(sorted: &[usize], fraction: f64) -> usize {
|
||||||
|
if sorted.is_empty() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
let clamped = fraction.clamp(0.0, 1.0);
|
||||||
|
let index = ((sorted.len() - 1) as f64 * clamped).round() as usize;
|
||||||
|
sorted[index.min(sorted.len() - 1)]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use common::storage::types::text_chunk::TextChunk;
|
||||||
|
use retrieval_pipeline::RetrievedChunk;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deduplicates_chunks_when_counting_context() {
|
||||||
|
let shared = Arc::new(TextChunk::new(
|
||||||
|
"src".into(),
|
||||||
|
"hello world".into(),
|
||||||
|
"user".into(),
|
||||||
|
));
|
||||||
|
let candidates = vec![
|
||||||
|
EvaluationCandidate {
|
||||||
|
entity_id: "a".into(),
|
||||||
|
source_id: "src".into(),
|
||||||
|
entity_name: "A".into(),
|
||||||
|
entity_description: None,
|
||||||
|
entity_category: None,
|
||||||
|
score: 1.0,
|
||||||
|
chunks: vec![RetrievedChunk {
|
||||||
|
chunk: Arc::clone(&shared),
|
||||||
|
score: 1.0,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
EvaluationCandidate {
|
||||||
|
entity_id: "b".into(),
|
||||||
|
source_id: "src".into(),
|
||||||
|
entity_name: "B".into(),
|
||||||
|
entity_description: None,
|
||||||
|
entity_category: None,
|
||||||
|
score: 0.9,
|
||||||
|
chunks: vec![RetrievedChunk {
|
||||||
|
chunk: shared,
|
||||||
|
score: 0.9,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
let stats = stats_for_candidates(&candidates);
|
||||||
|
assert_eq!(stats.chunk_count, 1);
|
||||||
|
assert_eq!(stats.char_count, "hello world".chars().count());
|
||||||
|
assert_eq!(stats.token_count, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn aggregates_per_query_token_totals() {
|
||||||
|
let per_query = vec![
|
||||||
|
RetrievedContextStats {
|
||||||
|
chunk_count: 2,
|
||||||
|
char_count: 100,
|
||||||
|
token_count: 40,
|
||||||
|
},
|
||||||
|
RetrievedContextStats {
|
||||||
|
chunk_count: 5,
|
||||||
|
char_count: 250,
|
||||||
|
token_count: 100,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
let aggregate = aggregate_context_stats(&per_query);
|
||||||
|
assert_eq!(aggregate.queries, 2);
|
||||||
|
assert_eq!(aggregate.total_chunks, 7);
|
||||||
|
assert_eq!(aggregate.total_tokens, 140);
|
||||||
|
assert_eq!(aggregate.max_tokens_per_query, 100);
|
||||||
|
assert!((aggregate.avg_tokens_per_query - 70.0).abs() < f64::EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,32 +11,14 @@ pub struct CorpusCacheConfig {
|
|||||||
pub ingestion_max_retries: usize,
|
pub ingestion_max_retries: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CorpusCacheConfig {
|
impl From<&Config> for CorpusCacheConfig {
|
||||||
pub fn new(
|
fn from(config: &Config) -> Self {
|
||||||
ingestion_cache_dir: impl Into<PathBuf>,
|
|
||||||
force_refresh: bool,
|
|
||||||
refresh_embeddings_only: bool,
|
|
||||||
ingestion_batch_size: usize,
|
|
||||||
ingestion_max_retries: usize,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
ingestion_cache_dir: ingestion_cache_dir.into(),
|
ingestion_cache_dir: config.ingest.ingestion_cache_dir.clone(),
|
||||||
force_refresh,
|
force_refresh: config.force_convert || config.ingest.slice_reset_ingestion,
|
||||||
refresh_embeddings_only,
|
refresh_embeddings_only: config.ingest.refresh_embeddings_only,
|
||||||
ingestion_batch_size,
|
ingestion_batch_size: config.ingest.ingestion_batch_size,
|
||||||
ingestion_max_retries,
|
ingestion_max_retries: config.ingest.ingestion_max_retries,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<&Config> for CorpusCacheConfig {
|
|
||||||
fn from(config: &Config) -> Self {
|
|
||||||
CorpusCacheConfig::new(
|
|
||||||
config.ingest.ingestion_cache_dir.clone(),
|
|
||||||
config.force_convert || config.ingest.slice_reset_ingestion,
|
|
||||||
config.ingest.refresh_embeddings_only,
|
|
||||||
config.ingest.ingestion_batch_size,
|
|
||||||
config.ingest.ingestion_max_retries,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ pub(crate) mod store;
|
|||||||
pub use config::CorpusCacheConfig;
|
pub use config::CorpusCacheConfig;
|
||||||
pub use orchestrator::{
|
pub use orchestrator::{
|
||||||
cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus,
|
cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus,
|
||||||
load_cached_manifest,
|
load_cached_manifest, persist_corpus_manifest,
|
||||||
};
|
};
|
||||||
pub use store::{
|
pub use store::{
|
||||||
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
|
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
|
||||||
CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
CorpusQuestion, NamespaceSeedRecord, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
|
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
|
||||||
@@ -20,6 +20,6 @@ pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline
|
|||||||
chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
chunk_only: config.ingest.ingest_chunks_only,
|
chunk_only: !config.ingest.include_entities,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ use std::{
|
|||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
#[cfg(not(test))]
|
|
||||||
use common::utils::config::get_config;
|
|
||||||
use common::{
|
use common::{
|
||||||
storage::{
|
storage::{
|
||||||
db::SurrealDbClient,
|
db::SurrealDbClient,
|
||||||
@@ -125,10 +123,14 @@ pub async fn ensure_corpus(
|
|||||||
openai: Arc<OpenAIClient>,
|
openai: Arc<OpenAIClient>,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
converted_path: &Path,
|
converted_path: &Path,
|
||||||
|
precomputed_checksum: Option<&str>,
|
||||||
ingestion_config: IngestionConfig,
|
ingestion_config: IngestionConfig,
|
||||||
) -> Result<CorpusHandle> {
|
) -> Result<CorpusHandle> {
|
||||||
let checksum = compute_file_checksum(converted_path)
|
let checksum = match precomputed_checksum {
|
||||||
.with_context(|| format!("computing checksum for {}", converted_path.display()))?;
|
Some(value) => value.to_string(),
|
||||||
|
None => crate::datasets::content_checksum_for_layout(converted_path)
|
||||||
|
.with_context(|| format!("computing checksum for {}", converted_path.display()))?,
|
||||||
|
};
|
||||||
let ingestion_fingerprint =
|
let ingestion_fingerprint =
|
||||||
build_ingestion_fingerprint(dataset, slice, &checksum, &ingestion_config);
|
build_ingestion_fingerprint(dataset, slice, &checksum, &ingestion_config);
|
||||||
|
|
||||||
@@ -381,6 +383,7 @@ pub async fn ensure_corpus(
|
|||||||
chunk_min_tokens: ingestion_config.tuning.chunk_min_tokens,
|
chunk_min_tokens: ingestion_config.tuning.chunk_min_tokens,
|
||||||
chunk_max_tokens: ingestion_config.tuning.chunk_max_tokens,
|
chunk_max_tokens: ingestion_config.tuning.chunk_max_tokens,
|
||||||
chunk_only: ingestion_config.chunk_only,
|
chunk_only: ingestion_config.chunk_only,
|
||||||
|
namespace_seed: None,
|
||||||
},
|
},
|
||||||
paragraphs: corpus_paragraphs,
|
paragraphs: corpus_paragraphs,
|
||||||
questions: corpus_questions,
|
questions: corpus_questions,
|
||||||
@@ -415,7 +418,7 @@ pub async fn ensure_corpus(
|
|||||||
negative_ingested: stats.negative_ingested,
|
negative_ingested: stats.negative_ingested,
|
||||||
};
|
};
|
||||||
|
|
||||||
persist_manifest(&handle).context("persisting corpus manifest")?;
|
persist_corpus_manifest(&handle).context("persisting corpus manifest")?;
|
||||||
|
|
||||||
Ok(handle)
|
Ok(handle)
|
||||||
}
|
}
|
||||||
@@ -501,7 +504,6 @@ async fn ingest_paragraph_batch(
|
|||||||
Ok(shards)
|
Ok(shards)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||||
let db = SurrealDbClient::memory(namespace, "corpus")
|
let db = SurrealDbClient::memory(namespace, "corpus")
|
||||||
.await
|
.await
|
||||||
@@ -509,21 +511,6 @@ async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
|||||||
Ok(Arc::new(db))
|
Ok(Arc::new(db))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(test))]
|
|
||||||
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
|
||||||
let config = get_config().context("loading app config for ingestion database")?;
|
|
||||||
let db = SurrealDbClient::new(
|
|
||||||
&config.surrealdb_address,
|
|
||||||
&config.surrealdb_username,
|
|
||||||
&config.surrealdb_password,
|
|
||||||
namespace,
|
|
||||||
"corpus",
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.context("creating surrealdb database for ingestion")?;
|
|
||||||
Ok(Arc::new(db))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn ingest_single_paragraph(
|
async fn ingest_single_paragraph(
|
||||||
pipeline: Arc<IngestionPipeline>,
|
pipeline: Arc<IngestionPipeline>,
|
||||||
@@ -631,8 +618,12 @@ pub fn compute_ingestion_fingerprint(
|
|||||||
slice: &ResolvedSlice<'_>,
|
slice: &ResolvedSlice<'_>,
|
||||||
converted_path: &Path,
|
converted_path: &Path,
|
||||||
ingestion_config: &IngestionConfig,
|
ingestion_config: &IngestionConfig,
|
||||||
|
precomputed_checksum: Option<&str>,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let checksum = compute_file_checksum(converted_path)?;
|
let checksum = match precomputed_checksum {
|
||||||
|
Some(value) => value.to_string(),
|
||||||
|
None => crate::datasets::content_checksum_for_layout(converted_path)?,
|
||||||
|
};
|
||||||
Ok(build_ingestion_fingerprint(
|
Ok(build_ingestion_fingerprint(
|
||||||
dataset,
|
dataset,
|
||||||
slice,
|
slice,
|
||||||
@@ -641,7 +632,7 @@ pub fn compute_ingestion_fingerprint(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
|
pub fn load_cached_manifest(base_dir: &std::path::Path) -> Result<Option<CorpusManifest>> {
|
||||||
let path = base_dir.join("manifest.json");
|
let path = base_dir.join("manifest.json");
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
@@ -656,7 +647,7 @@ pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
|
|||||||
Ok(Some(manifest))
|
Ok(Some(manifest))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn persist_manifest(handle: &CorpusHandle) -> Result<()> {
|
pub fn persist_corpus_manifest(handle: &CorpusHandle) -> Result<()> {
|
||||||
let path = handle.path.join("manifest.json");
|
let path = handle.path.join("manifest.json");
|
||||||
if let Some(parent) = path.parent() {
|
if let Some(parent) = path.parent() {
|
||||||
fs::create_dir_all(parent)
|
fs::create_dir_all(parent)
|
||||||
@@ -685,24 +676,6 @@ pub fn corpus_handle_from_manifest(manifest: CorpusManifest, base_dir: PathBuf)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::indexing_slicing)]
|
|
||||||
fn compute_file_checksum(path: &Path) -> Result<String> {
|
|
||||||
let mut file = fs::File::open(path)
|
|
||||||
.with_context(|| format!("opening file {} for checksum", path.display()))?;
|
|
||||||
let mut hasher = Sha256::new();
|
|
||||||
let mut buffer = [0u8; 8192];
|
|
||||||
loop {
|
|
||||||
let read = file
|
|
||||||
.read(&mut buffer)
|
|
||||||
.with_context(|| format!("reading {} for checksum", path.display()))?;
|
|
||||||
if read == 0 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
hasher.update(&buffer[..read]);
|
|
||||||
}
|
|
||||||
Ok(format!("{:x}", hasher.finalize()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -728,11 +701,7 @@ mod tests {
|
|||||||
|
|
||||||
ConvertedDataset {
|
ConvertedDataset {
|
||||||
generated_at: Utc::now(),
|
generated_at: Utc::now(),
|
||||||
metadata: crate::datasets::DatasetMetadata::for_kind(
|
metadata: crate::datasets::DatasetMetadata::for_kind(DatasetKind::default(), false),
|
||||||
DatasetKind::default(),
|
|
||||||
false,
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
source: "src".to_string(),
|
source: "src".to_string(),
|
||||||
paragraphs: vec![paragraph],
|
paragraphs: vec![paragraph],
|
||||||
}
|
}
|
||||||
|
|||||||
+36
-287
@@ -7,33 +7,21 @@ use std::{
|
|||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use common::storage::types::StoredObject;
|
|
||||||
use common::storage::{
|
use common::storage::{
|
||||||
db::SurrealDbClient,
|
db::SurrealDbClient,
|
||||||
types::{
|
types::{
|
||||||
knowledge_entity::KnowledgeEntity,
|
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||||
knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
text_chunk::TextChunk, text_content::TextContent, StoredObject,
|
||||||
knowledge_relationship::{KnowledgeRelationship, RelationshipMetadata},
|
|
||||||
text_chunk::TextChunk,
|
|
||||||
text_chunk_embedding::TextChunkEmbedding,
|
|
||||||
text_content::TextContent,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
use ingestion_pipeline::{persist_artifacts, IngestionTuning, PipelineArtifacts};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
|
||||||
use surrealdb::sql::Thing;
|
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
||||||
|
|
||||||
pub const MANIFEST_VERSION: u32 = 3;
|
pub const MANIFEST_VERSION: u32 = 3;
|
||||||
pub const PARAGRAPH_SHARD_VERSION: u32 = 3;
|
pub const PARAGRAPH_SHARD_VERSION: u32 = 3;
|
||||||
const MANIFEST_BATCH_SIZE: usize = 100;
|
|
||||||
const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches
|
|
||||||
const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively
|
|
||||||
const MAX_BATCHES_PER_REQUEST: usize = 24;
|
|
||||||
const REQUEST_MAX_BYTES: usize = 800_000; // total payload cap per Surreal query request
|
|
||||||
|
|
||||||
fn current_manifest_version() -> u32 {
|
fn current_manifest_version() -> u32 {
|
||||||
MANIFEST_VERSION
|
MANIFEST_VERSION
|
||||||
}
|
}
|
||||||
@@ -51,7 +39,7 @@ fn default_chunk_max_tokens() -> usize {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn default_chunk_only() -> bool {
|
fn default_chunk_only() -> bool {
|
||||||
false
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reuse the pipeline's canonical embedded-artifact types so the on-disk corpus
|
// Reuse the pipeline's canonical embedded-artifact types so the on-disk corpus
|
||||||
@@ -131,6 +119,14 @@ pub struct CorpusManifest {
|
|||||||
pub questions: Vec<CorpusQuestion>,
|
pub questions: Vec<CorpusQuestion>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct NamespaceSeedRecord {
|
||||||
|
pub namespace: String,
|
||||||
|
pub database: String,
|
||||||
|
pub slice_case_count: usize,
|
||||||
|
pub seeded_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
pub struct CorpusMetadata {
|
pub struct CorpusMetadata {
|
||||||
pub dataset_id: String,
|
pub dataset_id: String,
|
||||||
@@ -153,6 +149,8 @@ pub struct CorpusMetadata {
|
|||||||
pub chunk_max_tokens: usize,
|
pub chunk_max_tokens: usize,
|
||||||
#[serde(default = "default_chunk_only")]
|
#[serde(default = "default_chunk_only")]
|
||||||
pub chunk_only: bool,
|
pub chunk_only: bool,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub namespace_seed: Option<NamespaceSeedRecord>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
@@ -251,130 +249,6 @@ pub fn window_manifest(
|
|||||||
Ok(narrowed)
|
Ok(narrowed)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
struct RelationInsert {
|
|
||||||
#[serde(rename = "in")]
|
|
||||||
pub in_: Thing,
|
|
||||||
#[serde(rename = "out")]
|
|
||||||
pub out: Thing,
|
|
||||||
pub id: String,
|
|
||||||
pub metadata: RelationshipMetadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct SizedBatch<T> {
|
|
||||||
approx_bytes: usize,
|
|
||||||
items: Vec<T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ManifestBatches {
|
|
||||||
text_contents: Vec<SizedBatch<TextContent>>,
|
|
||||||
entities: Vec<SizedBatch<KnowledgeEntity>>,
|
|
||||||
entity_embeddings: Vec<SizedBatch<KnowledgeEntityEmbedding>>,
|
|
||||||
relationships: Vec<SizedBatch<RelationInsert>>,
|
|
||||||
chunks: Vec<SizedBatch<TextChunk>>,
|
|
||||||
chunk_embeddings: Vec<SizedBatch<TextChunkEmbedding>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches> {
|
|
||||||
let mut text_contents = Vec::new();
|
|
||||||
let mut entities = Vec::new();
|
|
||||||
let mut entity_embeddings = Vec::new();
|
|
||||||
let mut relationships = Vec::new();
|
|
||||||
let mut chunks = Vec::new();
|
|
||||||
let mut chunk_embeddings = Vec::new();
|
|
||||||
|
|
||||||
let mut seen_text_content = HashSet::new();
|
|
||||||
let mut seen_entities = HashSet::new();
|
|
||||||
let mut seen_relationships = HashSet::new();
|
|
||||||
let mut seen_chunks = HashSet::new();
|
|
||||||
|
|
||||||
for paragraph in &manifest.paragraphs {
|
|
||||||
if seen_text_content.insert(paragraph.text_content.id.clone()) {
|
|
||||||
text_contents.push(paragraph.text_content.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
for embedded_entity in ¶graph.entities {
|
|
||||||
if seen_entities.insert(embedded_entity.entity.id.clone()) {
|
|
||||||
let entity = embedded_entity.entity.clone();
|
|
||||||
entities.push(entity.clone());
|
|
||||||
entity_embeddings.push(KnowledgeEntityEmbedding::new(
|
|
||||||
&entity.id,
|
|
||||||
entity.source_id.clone(),
|
|
||||||
embedded_entity.embedding.clone(),
|
|
||||||
entity.user_id.clone(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for relationship in ¶graph.relationships {
|
|
||||||
if seen_relationships.insert(relationship.id.clone()) {
|
|
||||||
let table = KnowledgeEntity::table_name();
|
|
||||||
let in_id = relationship
|
|
||||||
.in_
|
|
||||||
.strip_prefix(&format!("{table}:"))
|
|
||||||
.unwrap_or(&relationship.in_);
|
|
||||||
let out_id = relationship
|
|
||||||
.out
|
|
||||||
.strip_prefix(&format!("{table}:"))
|
|
||||||
.unwrap_or(&relationship.out);
|
|
||||||
let in_thing = Thing::from((table, in_id));
|
|
||||||
let out_thing = Thing::from((table, out_id));
|
|
||||||
relationships.push(RelationInsert {
|
|
||||||
in_: in_thing,
|
|
||||||
out: out_thing,
|
|
||||||
id: relationship.id.clone(),
|
|
||||||
metadata: relationship.metadata.clone(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for embedded_chunk in ¶graph.chunks {
|
|
||||||
if seen_chunks.insert(embedded_chunk.chunk.id.clone()) {
|
|
||||||
let chunk = embedded_chunk.chunk.clone();
|
|
||||||
chunks.push(chunk.clone());
|
|
||||||
chunk_embeddings.push(TextChunkEmbedding::new(
|
|
||||||
&chunk.id,
|
|
||||||
chunk.source_id.clone(),
|
|
||||||
embedded_chunk.embedding.clone(),
|
|
||||||
chunk.user_id.clone(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ManifestBatches {
|
|
||||||
text_contents: chunk_items(
|
|
||||||
&text_contents,
|
|
||||||
MANIFEST_BATCH_SIZE,
|
|
||||||
TEXT_CONTENT_MAX_BYTES_PER_BATCH,
|
|
||||||
)
|
|
||||||
.context("chunking text_content payloads")?,
|
|
||||||
entities: chunk_items(&entities, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
|
||||||
.context("chunking knowledge_entity payloads")?,
|
|
||||||
entity_embeddings: chunk_items(
|
|
||||||
&entity_embeddings,
|
|
||||||
MANIFEST_BATCH_SIZE,
|
|
||||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
|
||||||
)
|
|
||||||
.context("chunking knowledge_entity_embedding payloads")?,
|
|
||||||
relationships: chunk_items(
|
|
||||||
&relationships,
|
|
||||||
MANIFEST_BATCH_SIZE,
|
|
||||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
|
||||||
)
|
|
||||||
.context("chunking relationship payloads")?,
|
|
||||||
chunks: chunk_items(&chunks, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
|
||||||
.context("chunking text_chunk payloads")?,
|
|
||||||
chunk_embeddings: chunk_items(
|
|
||||||
&chunk_embeddings,
|
|
||||||
MANIFEST_BATCH_SIZE,
|
|
||||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
|
||||||
)
|
|
||||||
.context("chunking text_chunk_embedding payloads")?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
pub struct ParagraphShard {
|
pub struct ParagraphShard {
|
||||||
#[serde(default = "current_paragraph_shard_version")]
|
#[serde(default = "current_paragraph_shard_version")]
|
||||||
@@ -599,157 +473,28 @@ fn normalize_answer_text(text: &str) -> String {
|
|||||||
.join(" ")
|
.join(" ")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
|
||||||
fn chunk_items<T: Clone + Serialize>(
|
|
||||||
items: &[T],
|
|
||||||
max_items: usize,
|
|
||||||
max_bytes: usize,
|
|
||||||
) -> Result<Vec<SizedBatch<T>>> {
|
|
||||||
if items.is_empty() {
|
|
||||||
return Ok(Vec::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut batches = Vec::new();
|
|
||||||
let mut current = Vec::new();
|
|
||||||
let mut current_bytes = 0usize;
|
|
||||||
|
|
||||||
for item in items {
|
|
||||||
let size = serde_json::to_vec(item)
|
|
||||||
.map(|buf| buf.len())
|
|
||||||
.context("serialising batch item for sizing")?;
|
|
||||||
|
|
||||||
let would_overflow_items = !current.is_empty() && current.len() >= max_items;
|
|
||||||
let would_overflow_bytes = !current.is_empty() && current_bytes + size > max_bytes;
|
|
||||||
|
|
||||||
if would_overflow_items || would_overflow_bytes {
|
|
||||||
batches.push(SizedBatch {
|
|
||||||
approx_bytes: current_bytes.max(1),
|
|
||||||
items: std::mem::take(&mut current),
|
|
||||||
});
|
|
||||||
current_bytes = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
current_bytes += size;
|
|
||||||
current.push(item.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
if !current.is_empty() {
|
|
||||||
batches.push(SizedBatch {
|
|
||||||
approx_bytes: current_bytes.max(1),
|
|
||||||
items: current,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(batches)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
|
||||||
async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
statement: impl AsRef<str>,
|
|
||||||
prefix: &str,
|
|
||||||
batches: &[SizedBatch<T>],
|
|
||||||
) -> Result<()> {
|
|
||||||
if batches.is_empty() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut start = 0;
|
|
||||||
while start < batches.len() {
|
|
||||||
let mut group_bytes = 0usize;
|
|
||||||
let mut group_end = start;
|
|
||||||
let mut group_count = 0usize;
|
|
||||||
|
|
||||||
while group_end < batches.len() {
|
|
||||||
let batch_bytes = batches[group_end].approx_bytes.max(1);
|
|
||||||
if group_count > 0
|
|
||||||
&& (group_bytes + batch_bytes > REQUEST_MAX_BYTES
|
|
||||||
|| group_count >= MAX_BATCHES_PER_REQUEST)
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
group_bytes += batch_bytes;
|
|
||||||
group_end += 1;
|
|
||||||
group_count += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
let slice = &batches[start..group_end];
|
|
||||||
let mut query = db.client.query("BEGIN TRANSACTION;");
|
|
||||||
for (bind_index, batch) in slice.iter().enumerate() {
|
|
||||||
let name = format!("{prefix}{bind_index}");
|
|
||||||
query = query
|
|
||||||
.query(format!("{} ${};", statement.as_ref(), name))
|
|
||||||
.bind((name, batch.items.clone()));
|
|
||||||
}
|
|
||||||
let response = query
|
|
||||||
.query("COMMIT TRANSACTION;")
|
|
||||||
.await
|
|
||||||
.context("executing batched insert transaction")?;
|
|
||||||
if let Err(err) = response.check() {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"batched insert failed for statement '{}': {err:?}",
|
|
||||||
statement.as_ref()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
start = group_end;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
||||||
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
|
let tuning = IngestionTuning::default();
|
||||||
|
let embedding_dimensions = manifest.metadata.embedding_dimension;
|
||||||
|
let mut seen_text_content = HashSet::new();
|
||||||
|
|
||||||
let result = async {
|
let result = async {
|
||||||
execute_batched_inserts(
|
for paragraph in &manifest.paragraphs {
|
||||||
db,
|
if !seen_text_content.insert(paragraph.text_content.id.clone()) {
|
||||||
format!("INSERT INTO {}", TextContent::table_name()),
|
continue;
|
||||||
"tc",
|
}
|
||||||
&batches.text_contents,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
execute_batched_inserts(
|
let artifacts = PipelineArtifacts {
|
||||||
db,
|
text_content: paragraph.text_content.clone(),
|
||||||
format!("INSERT INTO {}", KnowledgeEntity::table_name()),
|
entities: paragraph.entities.clone(),
|
||||||
"ke",
|
relationships: paragraph.relationships.clone(),
|
||||||
&batches.entities,
|
chunks: paragraph.chunks.clone(),
|
||||||
)
|
};
|
||||||
.await?;
|
|
||||||
|
|
||||||
execute_batched_inserts(
|
|
||||||
db,
|
|
||||||
format!("INSERT INTO {}", TextChunk::table_name()),
|
|
||||||
"ch",
|
|
||||||
&batches.chunks,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
execute_batched_inserts(
|
|
||||||
db,
|
|
||||||
"INSERT RELATION INTO relates_to",
|
|
||||||
"rel",
|
|
||||||
&batches.relationships,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
execute_batched_inserts(
|
|
||||||
db,
|
|
||||||
format!("INSERT INTO {}", KnowledgeEntityEmbedding::table_name()),
|
|
||||||
"kee",
|
|
||||||
&batches.entity_embeddings,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
execute_batched_inserts(
|
|
||||||
db,
|
|
||||||
format!("INSERT INTO {}", TextChunkEmbedding::table_name()),
|
|
||||||
"tce",
|
|
||||||
&batches.chunk_embeddings,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
|
persist_artifacts(db, &tuning, embedding_dimensions, artifacts)
|
||||||
|
.await
|
||||||
|
.map_err(|err| anyhow!("persist manifest paragraph: {err}"))?;
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
@@ -778,7 +523,10 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
use common::storage::types::{
|
||||||
|
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||||
|
text_chunk::TextChunk,
|
||||||
|
};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
@@ -888,6 +636,7 @@ mod tests {
|
|||||||
chunk_min_tokens: 1,
|
chunk_min_tokens: 1,
|
||||||
chunk_max_tokens: 10,
|
chunk_max_tokens: 10,
|
||||||
chunk_only: false,
|
chunk_only: false,
|
||||||
|
namespace_seed: None,
|
||||||
},
|
},
|
||||||
paragraphs: vec![paragraph_one, paragraph_two],
|
paragraphs: vec![paragraph_one, paragraph_two],
|
||||||
questions: vec![question],
|
questions: vec![question],
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeMap, HashMap},
|
collections::{BTreeMap, HashMap, HashSet},
|
||||||
fs::File,
|
fs::File,
|
||||||
io::{BufRead, BufReader},
|
io::{BufRead, BufReader},
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
@@ -47,20 +47,71 @@ struct QrelEntry {
|
|||||||
score: i32,
|
score: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert only documents that appear in qrels (the BEIR evaluation closed world).
|
||||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
||||||
pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<ConvertedParagraph>> {
|
pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<ConvertedParagraph>> {
|
||||||
|
convert_beir_documents(raw_dir, dataset, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a subset of qrels-world documents. `doc_ids` use corpus ids (unprefixed).
|
||||||
|
#[allow(
|
||||||
|
clippy::too_many_lines,
|
||||||
|
clippy::arithmetic_side_effects,
|
||||||
|
clippy::indexing_slicing
|
||||||
|
)]
|
||||||
|
pub fn convert_beir_documents(
|
||||||
|
raw_dir: &Path,
|
||||||
|
dataset: DatasetKind,
|
||||||
|
doc_ids: Option<&HashSet<String>>,
|
||||||
|
) -> Result<Vec<ConvertedParagraph>> {
|
||||||
let corpus_path = raw_dir.join("corpus.jsonl");
|
let corpus_path = raw_dir.join("corpus.jsonl");
|
||||||
let queries_path = raw_dir.join("queries.jsonl");
|
let queries_path = raw_dir.join("queries.jsonl");
|
||||||
let qrels_path = resolve_qrels_path(raw_dir)?;
|
let qrels_path = resolve_qrels_path(raw_dir)?;
|
||||||
|
|
||||||
let corpus = load_corpus(&corpus_path)?;
|
|
||||||
let queries = load_queries(&queries_path)?;
|
let queries = load_queries(&queries_path)?;
|
||||||
let qrels = load_qrels(&qrels_path)?;
|
let qrels = load_qrels(&qrels_path)?;
|
||||||
|
|
||||||
let mut paragraphs = Vec::with_capacity(corpus.len());
|
let mut qrels_doc_ids = HashSet::new();
|
||||||
|
for entries in qrels.values() {
|
||||||
|
for entry in entries {
|
||||||
|
qrels_doc_ids.insert(entry.doc_id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let target_doc_ids: HashSet<String> = match doc_ids {
|
||||||
|
Some(ids) => ids
|
||||||
|
.iter()
|
||||||
|
.filter(|id| qrels_doc_ids.contains(*id))
|
||||||
|
.cloned()
|
||||||
|
.collect(),
|
||||||
|
None => qrels_doc_ids.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if target_doc_ids.is_empty() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"no qrels documents to convert for {} at {}",
|
||||||
|
dataset.id(),
|
||||||
|
raw_dir.display()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let corpus = load_corpus_filtered(&corpus_path, &target_doc_ids)?;
|
||||||
|
|
||||||
|
let mut doc_ids_sorted: Vec<String> = target_doc_ids.into_iter().collect();
|
||||||
|
doc_ids_sorted.sort();
|
||||||
|
|
||||||
|
let mut paragraphs = Vec::with_capacity(doc_ids_sorted.len());
|
||||||
let mut paragraph_index = HashMap::new();
|
let mut paragraph_index = HashMap::new();
|
||||||
|
|
||||||
for (doc_id, entry) in &corpus {
|
for doc_id in &doc_ids_sorted {
|
||||||
|
let Some(entry) = corpus.get(doc_id) else {
|
||||||
|
warn!(
|
||||||
|
doc_id = %doc_id,
|
||||||
|
dataset = %dataset.id(),
|
||||||
|
"Skipping qrels document missing from corpus"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
let paragraph_id = format!("{}-{doc_id}", dataset.source_prefix());
|
let paragraph_id = format!("{}-{doc_id}", dataset.source_prefix());
|
||||||
let paragraph = ConvertedParagraph {
|
let paragraph = ConvertedParagraph {
|
||||||
id: paragraph_id.clone(),
|
id: paragraph_id.clone(),
|
||||||
@@ -87,6 +138,12 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
|||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Some(filter) = doc_ids {
|
||||||
|
if !filter.contains(&best.doc_id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let Some(¶graph_slot) = paragraph_index.get(&best.doc_id) else {
|
let Some(¶graph_slot) = paragraph_index.get(&best.doc_id) else {
|
||||||
missing_docs += 1;
|
missing_docs += 1;
|
||||||
warn!(
|
warn!(
|
||||||
@@ -106,7 +163,6 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
|||||||
);
|
);
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
let answers = vec![snippet];
|
|
||||||
|
|
||||||
let question_id = format!("{}-{query_id}", dataset.source_prefix());
|
let question_id = format!("{}-{query_id}", dataset.source_prefix());
|
||||||
paragraphs[paragraph_slot]
|
paragraphs[paragraph_slot]
|
||||||
@@ -114,7 +170,7 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
|||||||
.push(ConvertedQuestion {
|
.push(ConvertedQuestion {
|
||||||
id: question_id,
|
id: question_id,
|
||||||
question: query.text.clone(),
|
question: query.text.clone(),
|
||||||
answers,
|
answers: vec![snippet],
|
||||||
is_impossible: false,
|
is_impossible: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -122,13 +178,21 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
|||||||
if missing_queries + missing_docs + skipped_answers > 0 {
|
if missing_queries + missing_docs + skipped_answers > 0 {
|
||||||
warn!(
|
warn!(
|
||||||
missing_queries,
|
missing_queries,
|
||||||
missing_docs, skipped_answers, "Skipped some BEIR qrels entries during conversion"
|
missing_docs,
|
||||||
|
skipped_answers,
|
||||||
|
dataset = %dataset.id(),
|
||||||
|
"Skipped some BEIR qrels entries during conversion"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(paragraphs)
|
Ok(paragraphs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn corpus_doc_id(paragraph_id: &str, dataset: DatasetKind) -> Option<String> {
|
||||||
|
let prefix = format!("{}-", dataset.source_prefix());
|
||||||
|
paragraph_id.strip_prefix(&prefix).map(str::to_string)
|
||||||
|
}
|
||||||
|
|
||||||
fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
|
fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
|
||||||
let qrels_dir = raw_dir.join("qrels");
|
let qrels_dir = raw_dir.join("qrels");
|
||||||
let candidates = ["test.tsv", "dev.tsv", "train.tsv"];
|
let candidates = ["test.tsv", "dev.tsv", "train.tsv"];
|
||||||
@@ -148,7 +212,10 @@ fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::arithmetic_side_effects)]
|
#[allow(clippy::arithmetic_side_effects)]
|
||||||
fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
|
fn load_corpus_filtered(
|
||||||
|
path: &Path,
|
||||||
|
doc_ids: &HashSet<String>,
|
||||||
|
) -> Result<BTreeMap<String, BeirParagraph>> {
|
||||||
let file =
|
let file =
|
||||||
File::open(path).with_context(|| format!("opening BEIR corpus at {}", path.display()))?;
|
File::open(path).with_context(|| format!("opening BEIR corpus at {}", path.display()))?;
|
||||||
let reader = BufReader::new(file);
|
let reader = BufReader::new(file);
|
||||||
@@ -167,6 +234,9 @@ fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
|
|||||||
path.display()
|
path.display()
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
if !doc_ids.contains(&corpus_row.id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let title = corpus_row.title.unwrap_or_else(|| corpus_row.id.clone());
|
let title = corpus_row.title.unwrap_or_else(|| corpus_row.id.clone());
|
||||||
let text = corpus_row.text.unwrap_or_default();
|
let text = corpus_row.text.unwrap_or_default();
|
||||||
let context = build_context(&title, &text);
|
let context = build_context(&title, &text);
|
||||||
@@ -296,10 +366,8 @@ mod tests {
|
|||||||
use std::fs;
|
use std::fs;
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
#[test]
|
#[allow(clippy::unwrap_used)]
|
||||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
fn write_fixture(dir: &tempfile::TempDir) {
|
||||||
fn converts_basic_beir_layout() {
|
|
||||||
let dir = tempdir().unwrap();
|
|
||||||
let corpus = r#"
|
let corpus = r#"
|
||||||
{"_id":"d1","title":"Doc 1","text":"Doc one has some text for testing."}
|
{"_id":"d1","title":"Doc 1","text":"Doc one has some text for testing."}
|
||||||
{"_id":"d2","title":"Doc 2","text":"Second document content."}
|
{"_id":"d2","title":"Doc 2","text":"Second document content."}
|
||||||
@@ -313,24 +381,34 @@ mod tests {
|
|||||||
fs::write(dir.path().join("queries.jsonl"), queries.trim()).unwrap();
|
fs::write(dir.path().join("queries.jsonl"), queries.trim()).unwrap();
|
||||||
fs::create_dir_all(dir.path().join("qrels")).unwrap();
|
fs::create_dir_all(dir.path().join("qrels")).unwrap();
|
||||||
fs::write(dir.path().join("qrels/test.tsv"), qrels).unwrap();
|
fs::write(dir.path().join("qrels/test.tsv"), qrels).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||||
|
fn converts_qrels_world_only() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
write_fixture(&dir);
|
||||||
|
|
||||||
let paragraphs = convert_beir(dir.path(), DatasetKind::Fever).unwrap();
|
let paragraphs = convert_beir(dir.path(), DatasetKind::Fever).unwrap();
|
||||||
|
|
||||||
assert_eq!(paragraphs.len(), 2);
|
assert_eq!(paragraphs.len(), 1);
|
||||||
let doc_one = paragraphs
|
let doc_one = ¶graphs[0];
|
||||||
.iter()
|
assert_eq!(doc_one.id, "fever-d1");
|
||||||
.find(|p| p.id == "fever-d1")
|
|
||||||
.expect("missing paragraph for d1");
|
|
||||||
assert_eq!(doc_one.questions.len(), 1);
|
assert_eq!(doc_one.questions.len(), 1);
|
||||||
let question = &doc_one.questions[0];
|
assert_eq!(doc_one.questions[0].id, "fever-q1");
|
||||||
assert_eq!(question.id, "fever-q1");
|
}
|
||||||
assert!(!question.answers.is_empty());
|
|
||||||
assert!(doc_one.context.contains(&question.answers[0]));
|
|
||||||
|
|
||||||
let doc_two = paragraphs
|
#[test]
|
||||||
.iter()
|
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||||
.find(|p| p.id == "fever-d2")
|
fn converts_filtered_doc_ids() {
|
||||||
.expect("missing paragraph for d2");
|
let dir = tempdir().unwrap();
|
||||||
assert!(doc_two.questions.is_empty());
|
write_fixture(&dir);
|
||||||
|
|
||||||
|
let mut ids = HashSet::new();
|
||||||
|
ids.insert("d1".to_string());
|
||||||
|
let paragraphs =
|
||||||
|
convert_beir_documents(dir.path(), DatasetKind::Fever, Some(&ids)).unwrap();
|
||||||
|
assert_eq!(paragraphs.len(), 1);
|
||||||
|
assert_eq!(paragraphs[0].id, "fever-d1");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,258 @@
|
|||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Context, Result};
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
beir,
|
||||||
|
checksum::hash_file,
|
||||||
|
store::{
|
||||||
|
self, build_dataset_from_catalog, paragraph_path, read_meta, store_dir_for,
|
||||||
|
upsert_sharded_paragraphs, write_sharded,
|
||||||
|
},
|
||||||
|
ConvertedDataset, DatasetKind, DatasetMetadata, BEIR_DATASETS,
|
||||||
|
};
|
||||||
|
use crate::{args::Config, slice};
|
||||||
|
|
||||||
|
pub fn subset_for_paragraph_id(paragraph_id: &str) -> Option<DatasetKind> {
|
||||||
|
let mut kinds: Vec<DatasetKind> = BEIR_DATASETS.to_vec();
|
||||||
|
kinds.sort_by_key(|kind| std::cmp::Reverse(kind.source_prefix().len()));
|
||||||
|
for kind in kinds {
|
||||||
|
let prefix = format!("{}-", kind.source_prefix());
|
||||||
|
if paragraph_id.starts_with(&prefix) {
|
||||||
|
return Some(kind);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_beir_mix_qrels_dataset(include_unanswerable: bool) -> Result<ConvertedDataset> {
|
||||||
|
if include_unanswerable {
|
||||||
|
tracing::warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut paragraphs = Vec::new();
|
||||||
|
for subset in BEIR_DATASETS {
|
||||||
|
let entry = super::dataset_entry_for_kind(subset)?;
|
||||||
|
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
|
||||||
|
paragraphs.extend(subset_paragraphs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ConvertedDataset {
|
||||||
|
generated_at: super::base_timestamp(),
|
||||||
|
metadata: DatasetMetadata::for_kind(DatasetKind::Beir, include_unanswerable),
|
||||||
|
source: "beir-mix".to_string(),
|
||||||
|
paragraphs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prepare_beir_mix(config: &Config) -> Result<super::loader::LoadedDataset> {
|
||||||
|
let virtual_ds = build_beir_mix_qrels_dataset(config.llm_mode)?;
|
||||||
|
let slice_config = slice::slice_config_with_limit(config, slice::ledger_target(config));
|
||||||
|
let resolved = slice::resolve_slice(&virtual_ds, &slice_config)
|
||||||
|
.context("resolving BEIR mix slice ledger (check --slice and --limit match your intent)")?;
|
||||||
|
|
||||||
|
let unique: HashSet<String> = resolved
|
||||||
|
.manifest
|
||||||
|
.paragraphs
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.id.clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
materialize_subset_stores(&unique, config.force_convert)?;
|
||||||
|
|
||||||
|
let dataset = load_beir_mix_from_subsets(&unique)?;
|
||||||
|
let checksum = mix_content_checksum(&unique)?;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
slice = resolved.manifest.slice_id.as_str(),
|
||||||
|
paragraphs = unique.len(),
|
||||||
|
checksum = %checksum,
|
||||||
|
"Prepared BEIR mix from per-subset converted stores"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(super::loader::LoadedDataset {
|
||||||
|
dataset,
|
||||||
|
content_checksum: checksum,
|
||||||
|
partial: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn materialize_subset_stores(paragraph_ids: &HashSet<String>, force: bool) -> Result<()> {
|
||||||
|
let mut by_subset: HashMap<DatasetKind, Vec<String>> = HashMap::new();
|
||||||
|
for paragraph_id in paragraph_ids {
|
||||||
|
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||||
|
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||||
|
})?;
|
||||||
|
by_subset
|
||||||
|
.entry(kind)
|
||||||
|
.or_default()
|
||||||
|
.push(paragraph_id.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (kind, ids) in by_subset {
|
||||||
|
let entry = super::dataset_entry_for_kind(kind)?;
|
||||||
|
let store_dir = store_dir_for(&entry.converted_path);
|
||||||
|
let existing = if store_dir.join("meta.json").is_file() {
|
||||||
|
store::load_paragraph_ids_set(&store_dir)?
|
||||||
|
} else {
|
||||||
|
HashSet::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
let missing: Vec<String> = if force {
|
||||||
|
ids
|
||||||
|
} else {
|
||||||
|
ids.into_iter()
|
||||||
|
.filter(|paragraph_id| !existing.contains(paragraph_id))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
if missing.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let corpus_ids: HashSet<String> = missing
|
||||||
|
.iter()
|
||||||
|
.filter_map(|paragraph_id| beir::corpus_doc_id(paragraph_id, kind))
|
||||||
|
.collect();
|
||||||
|
let paragraphs = beir::convert_beir_documents(&entry.raw_path, kind, Some(&corpus_ids))?;
|
||||||
|
|
||||||
|
if store_dir.join("meta.json").is_file() {
|
||||||
|
upsert_sharded_paragraphs(&store_dir, ¶graphs)?;
|
||||||
|
} else {
|
||||||
|
let question_count = paragraphs
|
||||||
|
.iter()
|
||||||
|
.map(|paragraph| paragraph.questions.len())
|
||||||
|
.sum::<usize>();
|
||||||
|
let dataset = ConvertedDataset {
|
||||||
|
generated_at: super::base_timestamp(),
|
||||||
|
metadata: DatasetMetadata::for_kind(kind, false),
|
||||||
|
source: entry.raw_path.display().to_string(),
|
||||||
|
paragraphs,
|
||||||
|
};
|
||||||
|
write_sharded(&dataset, &store_dir)?;
|
||||||
|
info!(
|
||||||
|
subset = kind.id(),
|
||||||
|
store = %store_dir.display(),
|
||||||
|
paragraphs = dataset.paragraphs.len(),
|
||||||
|
questions = question_count,
|
||||||
|
"Created subset converted store for BEIR mix"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_beir_mix_from_subsets(paragraph_ids: &HashSet<String>) -> Result<ConvertedDataset> {
|
||||||
|
let mut by_subset: HashMap<DatasetKind, HashSet<String>> = HashMap::new();
|
||||||
|
for paragraph_id in paragraph_ids {
|
||||||
|
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||||
|
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||||
|
})?;
|
||||||
|
by_subset
|
||||||
|
.entry(kind)
|
||||||
|
.or_default()
|
||||||
|
.insert(paragraph_id.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut paragraphs = Vec::with_capacity(paragraph_ids.len());
|
||||||
|
for (kind, subset_ids) in by_subset {
|
||||||
|
let entry = super::dataset_entry_for_kind(kind)?;
|
||||||
|
let store_dir = store_dir_for(&entry.converted_path);
|
||||||
|
let partial = build_dataset_from_catalog(&store_dir, &subset_ids)?;
|
||||||
|
paragraphs.extend(partial.paragraphs);
|
||||||
|
}
|
||||||
|
|
||||||
|
paragraphs.sort_by(|left, right| left.id.cmp(&right.id));
|
||||||
|
|
||||||
|
Ok(ConvertedDataset {
|
||||||
|
generated_at: super::base_timestamp(),
|
||||||
|
metadata: DatasetMetadata::for_kind(DatasetKind::Beir, false),
|
||||||
|
source: "beir-mix".to_string(),
|
||||||
|
paragraphs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mix_content_checksum(paragraph_ids: &HashSet<String>) -> Result<String> {
|
||||||
|
let mut ids: Vec<String> = paragraph_ids.iter().cloned().collect();
|
||||||
|
ids.sort();
|
||||||
|
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
for paragraph_id in ids {
|
||||||
|
let kind = subset_for_paragraph_id(¶graph_id)
|
||||||
|
.ok_or_else(|| anyhow!("unknown BEIR subset for paragraph '{paragraph_id}'"))?;
|
||||||
|
let entry = super::dataset_entry_for_kind(kind)?;
|
||||||
|
let store_dir = store_dir_for(&entry.converted_path);
|
||||||
|
let path = paragraph_path(&store_dir, ¶graph_id);
|
||||||
|
if !path.is_file() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"missing converted paragraph {} at {}",
|
||||||
|
paragraph_id,
|
||||||
|
path.display()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
hasher.update(paragraph_id.as_bytes());
|
||||||
|
hasher.update([0]);
|
||||||
|
hasher.update(hash_file(&path)?.as_bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(format!("{:x}", hasher.finalize()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn beir_subset_stores_ready(paragraph_ids: &HashSet<String>) -> Result<bool> {
|
||||||
|
for paragraph_id in paragraph_ids {
|
||||||
|
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||||
|
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||||
|
})?;
|
||||||
|
let entry = super::dataset_entry_for_kind(kind)?;
|
||||||
|
let store_dir = store_dir_for(&entry.converted_path);
|
||||||
|
if !store_dir.join("meta.json").is_file() {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
if !paragraph_path(&store_dir, paragraph_id).is_file() {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn beir_subset_store_summary() -> Result<Vec<(String, usize, usize)>> {
|
||||||
|
let mut summary = Vec::new();
|
||||||
|
for kind in BEIR_DATASETS {
|
||||||
|
let entry = super::dataset_entry_for_kind(kind)?;
|
||||||
|
let store_dir = store_dir_for(&entry.converted_path);
|
||||||
|
if store_dir.join("meta.json").is_file() {
|
||||||
|
let meta = read_meta(&store_dir)?;
|
||||||
|
summary.push((
|
||||||
|
kind.id().to_string(),
|
||||||
|
meta.paragraph_count,
|
||||||
|
meta.question_count,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn routes_prefixed_paragraph_ids() {
|
||||||
|
assert_eq!(
|
||||||
|
subset_for_paragraph_id("fever-doc-1"),
|
||||||
|
Some(DatasetKind::Fever)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
subset_for_paragraph_id("nq-beir-doc-1"),
|
||||||
|
Some(DatasetKind::NqBeir)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
subset_for_paragraph_id("trec-covid-doc-1"),
|
||||||
|
Some(DatasetKind::TrecCovid)
|
||||||
|
);
|
||||||
|
assert!(subset_for_paragraph_id("unknown-doc").is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,219 @@
|
|||||||
|
use std::{
|
||||||
|
fs::{self, File},
|
||||||
|
io::Read,
|
||||||
|
path::Path,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
const SIDECAR_VERSION: u32 = 1;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChecksumSidecar {
|
||||||
|
pub version: u32,
|
||||||
|
pub sha256: String,
|
||||||
|
pub size_bytes: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
pub modified_unix_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChecksumSidecar {
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn sidecar_path(content_path: &Path) -> PathBuf {
|
||||||
|
content_path.with_extension("sha256")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn is_valid_for(&self, content_path: &Path) -> bool {
|
||||||
|
if self.version != SIDECAR_VERSION {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
let Ok(metadata) = fs::metadata(content_path) else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
if metadata.len() != self.size_bytes {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if self.modified_unix_secs != 0 {
|
||||||
|
let Ok(modified) = metadata.modified() else {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
let Ok(secs) = modified.duration_since(std::time::UNIX_EPOCH) else {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
if secs.as_secs() != self.modified_unix_secs {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::indexing_slicing)]
|
||||||
|
pub fn hash_file(path: &Path) -> Result<String> {
|
||||||
|
let mut file = File::open(path)
|
||||||
|
.with_context(|| format!("opening file {} for checksum", path.display()))?;
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
let mut buffer = vec![0u8; 65_536];
|
||||||
|
loop {
|
||||||
|
let read = file
|
||||||
|
.read(&mut buffer)
|
||||||
|
.with_context(|| format!("reading {} for checksum", path.display()))?;
|
||||||
|
if read == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
hasher.update(&buffer[..read]);
|
||||||
|
}
|
||||||
|
Ok(format!("{:x}", hasher.finalize()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_sidecar(path: &Path) -> Result<Option<ChecksumSidecar>> {
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
let raw = fs::read_to_string(path)
|
||||||
|
.with_context(|| format!("reading checksum sidecar {}", path.display()))?;
|
||||||
|
let sidecar: ChecksumSidecar = serde_json::from_str(&raw)
|
||||||
|
.with_context(|| format!("parsing checksum sidecar {}", path.display()))?;
|
||||||
|
Ok(Some(sidecar))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn write_sidecar(content_path: &Path, sha256: &str) -> Result<()> {
|
||||||
|
let metadata = fs::metadata(content_path)
|
||||||
|
.with_context(|| format!("reading metadata for {}", content_path.display()))?;
|
||||||
|
let modified_unix_secs = metadata
|
||||||
|
.modified()
|
||||||
|
.ok()
|
||||||
|
.and_then(|time| time.duration_since(std::time::UNIX_EPOCH).ok())
|
||||||
|
.map_or(0, |duration| duration.as_secs());
|
||||||
|
let sidecar = ChecksumSidecar {
|
||||||
|
version: SIDECAR_VERSION,
|
||||||
|
sha256: sha256.to_string(),
|
||||||
|
size_bytes: metadata.len(),
|
||||||
|
modified_unix_secs,
|
||||||
|
};
|
||||||
|
let path = ChecksumSidecar::sidecar_path(content_path);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)
|
||||||
|
.with_context(|| format!("creating checksum sidecar directory {}", parent.display()))?;
|
||||||
|
}
|
||||||
|
let blob = serde_json::to_vec_pretty(&sidecar).context("serialising checksum sidecar")?;
|
||||||
|
fs::write(&path, blob)
|
||||||
|
.with_context(|| format!("writing checksum sidecar {}", path.display()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn content_checksum(content_path: &Path) -> Result<String> {
|
||||||
|
let sidecar_path = ChecksumSidecar::sidecar_path(content_path);
|
||||||
|
if let Some(sidecar) = read_sidecar(&sidecar_path)? {
|
||||||
|
if sidecar.is_valid_for(content_path) {
|
||||||
|
return Ok(sidecar.sha256);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let sha256 = hash_file(content_path)?;
|
||||||
|
write_sidecar(content_path, &sha256)?;
|
||||||
|
Ok(sha256)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_aggregate_checksum(store_dir: &Path) -> Result<String> {
|
||||||
|
let marker = store_dir.join("checksum.sha256");
|
||||||
|
let meta = store_dir.join("meta.json");
|
||||||
|
if marker.is_file() && meta.is_file() {
|
||||||
|
if let (Ok(marker_meta), Ok(meta_meta)) = (marker.metadata(), meta.metadata()) {
|
||||||
|
if marker_meta
|
||||||
|
.modified()
|
||||||
|
.ok()
|
||||||
|
.zip(meta_meta.modified().ok())
|
||||||
|
.is_some_and(|(marker_modified, meta_modified)| marker_modified >= meta_modified)
|
||||||
|
{
|
||||||
|
if let Some(sidecar) = read_sidecar(&marker)? {
|
||||||
|
return Ok(sidecar.sha256);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut entries = Vec::new();
|
||||||
|
collect_store_files(store_dir, store_dir, &mut entries)?;
|
||||||
|
entries.sort();
|
||||||
|
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
for relative in &entries {
|
||||||
|
let path = store_dir.join(relative);
|
||||||
|
if path == marker {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
hasher.update(relative.as_bytes());
|
||||||
|
hasher.update([0]);
|
||||||
|
let file_hash = hash_file(&path)?;
|
||||||
|
hasher.update(file_hash.as_bytes());
|
||||||
|
}
|
||||||
|
let digest = format!("{:x}", hasher.finalize());
|
||||||
|
|
||||||
|
let sidecar = ChecksumSidecar {
|
||||||
|
version: SIDECAR_VERSION,
|
||||||
|
sha256: digest.clone(),
|
||||||
|
size_bytes: entries.len() as u64,
|
||||||
|
modified_unix_secs: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.map_or(0, |duration| duration.as_secs()),
|
||||||
|
};
|
||||||
|
if let Some(parent) = marker.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
fs::write(&marker, serde_json::to_vec_pretty(&sidecar)?)?;
|
||||||
|
Ok(digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collect_store_files(base: &Path, current: &Path, entries: &mut Vec<String>) -> Result<()> {
|
||||||
|
for entry in fs::read_dir(current)? {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
if path
|
||||||
|
.file_name()
|
||||||
|
.is_some_and(|name| name == "checksum.sha256")
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if path.is_dir() {
|
||||||
|
collect_store_files(base, &path, entries)?;
|
||||||
|
} else if path.is_file() {
|
||||||
|
let relative = path
|
||||||
|
.strip_prefix(base)
|
||||||
|
.unwrap_or(&path)
|
||||||
|
.to_string_lossy()
|
||||||
|
.replace('\\', "/");
|
||||||
|
entries.push(relative);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sidecar_round_trip() -> Result<()> {
|
||||||
|
let dir = tempdir()?;
|
||||||
|
let file = dir.path().join("sample.json");
|
||||||
|
fs::write(&file, br#"{"hello":"world"}"#)?;
|
||||||
|
|
||||||
|
let first = content_checksum(&file)?;
|
||||||
|
let second = content_checksum(&file)?;
|
||||||
|
assert_eq!(first, second);
|
||||||
|
|
||||||
|
fs::write(&file, br#"{"hello":"world!"}"#)?;
|
||||||
|
let third = content_checksum(&file)?;
|
||||||
|
assert_ne!(first, third);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,195 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
catalog,
|
||||||
|
store::{
|
||||||
|
self, build_dataset_from_catalog, detect_layout, read_meta, store_dir_for, write_sharded,
|
||||||
|
ConvertedLayout,
|
||||||
|
},
|
||||||
|
ConvertedDataset, DatasetKind,
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
args::Config,
|
||||||
|
slice::{self, SliceConfig},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LoadedDataset {
|
||||||
|
pub dataset: ConvertedDataset,
|
||||||
|
pub content_checksum: String,
|
||||||
|
pub partial: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prepare_dataset(dataset_kind: DatasetKind, config: &Config) -> Result<LoadedDataset> {
|
||||||
|
if dataset_kind == DatasetKind::Beir {
|
||||||
|
return super::beir_mix::prepare_beir_mix(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
let converted_path = &config.converted_dataset_path;
|
||||||
|
let layout = detect_layout(converted_path);
|
||||||
|
let store_dir = store_dir_for(converted_path);
|
||||||
|
|
||||||
|
if layout == ConvertedLayout::Missing || config.force_convert {
|
||||||
|
return convert_and_load(dataset_kind, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_from_store(dataset_kind, config, &store_dir, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_and_load(dataset_kind: DatasetKind, config: &Config) -> Result<LoadedDataset> {
|
||||||
|
let dataset = super::convert(
|
||||||
|
config.raw_dataset_path.as_path(),
|
||||||
|
dataset_kind,
|
||||||
|
config.llm_mode,
|
||||||
|
)
|
||||||
|
.with_context(|| format!("converting {} dataset", dataset_kind.label()))?;
|
||||||
|
|
||||||
|
let store_dir = store_dir_for(&config.converted_dataset_path);
|
||||||
|
write_sharded(&dataset, &store_dir)?;
|
||||||
|
prebuild_catalog_slices(&dataset, config)?;
|
||||||
|
let checksum = crate::datasets::store_aggregate_checksum(&store_dir)?;
|
||||||
|
|
||||||
|
Ok(LoadedDataset {
|
||||||
|
dataset,
|
||||||
|
content_checksum: checksum,
|
||||||
|
partial: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_from_store(
|
||||||
|
dataset_kind: DatasetKind,
|
||||||
|
config: &Config,
|
||||||
|
store_dir: &std::path::Path,
|
||||||
|
allow_partial: bool,
|
||||||
|
) -> Result<LoadedDataset> {
|
||||||
|
let checksum = crate::datasets::store_aggregate_checksum(store_dir)?;
|
||||||
|
let meta = read_meta(store_dir)?;
|
||||||
|
validate_metadata_fields(&meta.metadata, dataset_kind, config)?;
|
||||||
|
|
||||||
|
if allow_partial {
|
||||||
|
if let Some(paragraph_ids) = slice_paragraph_ids_for_fast_path(config)? {
|
||||||
|
let unique: HashSet<String> = paragraph_ids.into_iter().collect();
|
||||||
|
info!(
|
||||||
|
paragraphs = unique.len(),
|
||||||
|
store = %store_dir.display(),
|
||||||
|
"Loading slice-addressed paragraphs from sharded converted store"
|
||||||
|
);
|
||||||
|
let dataset = build_dataset_from_catalog(store_dir, &unique)?;
|
||||||
|
return Ok(LoadedDataset {
|
||||||
|
dataset,
|
||||||
|
content_checksum: checksum,
|
||||||
|
partial: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
store = %store_dir.display(),
|
||||||
|
paragraphs = meta.paragraph_count,
|
||||||
|
"Loading full sharded converted store"
|
||||||
|
);
|
||||||
|
let dataset = store::load_sharded_full(store_dir)?;
|
||||||
|
Ok(LoadedDataset {
|
||||||
|
dataset,
|
||||||
|
content_checksum: checksum,
|
||||||
|
partial: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn slice_paragraph_ids_for_fast_path(config: &Config) -> Result<Option<Vec<String>>> {
|
||||||
|
let Some(manifest_path) = slice::cached_manifest_path(config) else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
let Some(manifest) = slice::read_manifest_if_exists(&manifest_path)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
let slice_config = slice::slice_config_with_limit(config, slice::ledger_target(config));
|
||||||
|
if !slice::manifest_is_complete(&manifest, &slice_config) {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
Ok(Some(
|
||||||
|
manifest
|
||||||
|
.paragraphs
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.id.clone())
|
||||||
|
.collect(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_metadata_fields(
|
||||||
|
metadata: &super::DatasetMetadata,
|
||||||
|
dataset_kind: DatasetKind,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<()> {
|
||||||
|
if metadata.id != dataset_kind.id() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"converted dataset targets '{}', expected '{}'",
|
||||||
|
metadata.id,
|
||||||
|
dataset_kind.id()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if metadata.include_unanswerable != config.llm_mode {
|
||||||
|
anyhow::bail!(
|
||||||
|
"converted dataset include_unanswerable mismatch (expected {}, found {})",
|
||||||
|
config.llm_mode,
|
||||||
|
metadata.include_unanswerable
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prebuild_catalog_slices(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||||
|
let catalog = catalog()?;
|
||||||
|
let entry = catalog.dataset(dataset.metadata.id.as_str())?;
|
||||||
|
if entry.slices.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
dataset = dataset.metadata.id.as_str(),
|
||||||
|
slices = entry.slices.len(),
|
||||||
|
"Prebuilding catalog slice ledgers"
|
||||||
|
);
|
||||||
|
|
||||||
|
for slice_entry in &entry.slices {
|
||||||
|
let slice_config = slice_config_for_catalog_entry(config, slice_entry);
|
||||||
|
match slice::resolve_slice(dataset, &slice_config) {
|
||||||
|
Ok(resolved) => info!(
|
||||||
|
slice = resolved.manifest.slice_id.as_str(),
|
||||||
|
cases = resolved.manifest.case_count,
|
||||||
|
positives = resolved.manifest.positive_paragraphs,
|
||||||
|
negatives = resolved.manifest.negative_paragraphs,
|
||||||
|
"Prebuilt catalog slice ledger"
|
||||||
|
),
|
||||||
|
Err(err) => tracing::warn!(
|
||||||
|
slice = slice_entry.id.as_str(),
|
||||||
|
error = %err,
|
||||||
|
"Failed to prebuild catalog slice ledger"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn slice_config_for_catalog_entry<'a>(
|
||||||
|
config: &'a Config,
|
||||||
|
slice_entry: &'a super::SliceEntry,
|
||||||
|
) -> SliceConfig<'a> {
|
||||||
|
SliceConfig {
|
||||||
|
cache_dir: config.cache_dir.as_path(),
|
||||||
|
force_convert: config.force_convert,
|
||||||
|
explicit_slice: Some(slice_entry.id.as_str()),
|
||||||
|
limit: slice_entry.limit,
|
||||||
|
corpus_limit: slice_entry.corpus_limit,
|
||||||
|
slice_seed: slice_entry.seed.unwrap_or(config.slice_seed),
|
||||||
|
llm_mode: slice_entry.include_unanswerable.unwrap_or(config.llm_mode),
|
||||||
|
negative_multiplier: slice_entry
|
||||||
|
.negative_multiplier
|
||||||
|
.unwrap_or(config.negative_multiplier),
|
||||||
|
require_verified_chunks: config.retrieval.require_verified_chunks,
|
||||||
|
}
|
||||||
|
}
|
||||||
+36
-143
@@ -1,6 +1,10 @@
|
|||||||
mod beir;
|
mod beir;
|
||||||
|
mod beir_mix;
|
||||||
|
mod checksum;
|
||||||
|
mod loader;
|
||||||
mod nq;
|
mod nq;
|
||||||
mod squad;
|
mod squad;
|
||||||
|
mod store;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeMap, HashMap},
|
collections::{BTreeMap, HashMap},
|
||||||
@@ -20,38 +24,31 @@ const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"
|
|||||||
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(dead_code)]
|
|
||||||
pub struct DatasetCatalog {
|
pub struct DatasetCatalog {
|
||||||
datasets: BTreeMap<String, DatasetEntry>,
|
datasets: BTreeMap<String, DatasetEntry>,
|
||||||
slices: HashMap<String, SliceLocation>,
|
slices: HashMap<String, SliceLocation>,
|
||||||
default_dataset: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(dead_code)]
|
|
||||||
pub struct DatasetEntry {
|
pub struct DatasetEntry {
|
||||||
pub metadata: DatasetMetadata,
|
pub metadata: DatasetMetadata,
|
||||||
pub raw_path: PathBuf,
|
pub raw_path: PathBuf,
|
||||||
pub converted_path: PathBuf,
|
pub converted_path: PathBuf,
|
||||||
pub include_unanswerable: bool,
|
|
||||||
pub slices: Vec<SliceEntry>,
|
pub slices: Vec<SliceEntry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(dead_code)]
|
|
||||||
pub struct SliceEntry {
|
pub struct SliceEntry {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub dataset_id: String,
|
pub dataset_id: String,
|
||||||
pub label: String,
|
|
||||||
pub description: Option<String>,
|
|
||||||
pub limit: Option<usize>,
|
pub limit: Option<usize>,
|
||||||
pub corpus_limit: Option<usize>,
|
pub corpus_limit: Option<usize>,
|
||||||
pub include_unanswerable: Option<bool>,
|
pub include_unanswerable: Option<bool>,
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
pub negative_multiplier: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(dead_code)]
|
|
||||||
struct SliceLocation {
|
struct SliceLocation {
|
||||||
dataset_id: String,
|
dataset_id: String,
|
||||||
slice_index: usize,
|
slice_index: usize,
|
||||||
@@ -59,7 +56,6 @@ struct SliceLocation {
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ManifestFile {
|
struct ManifestFile {
|
||||||
default_dataset: Option<String>,
|
|
||||||
datasets: Vec<ManifestDataset>,
|
datasets: Vec<ManifestDataset>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,6 +77,7 @@ struct ManifestDataset {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[allow(dead_code)]
|
||||||
struct ManifestSlice {
|
struct ManifestSlice {
|
||||||
id: String,
|
id: String,
|
||||||
label: String,
|
label: String,
|
||||||
@@ -94,6 +91,8 @@ struct ManifestSlice {
|
|||||||
include_unanswerable: Option<bool>,
|
include_unanswerable: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
seed: Option<u64>,
|
seed: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
negative_multiplier: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DatasetCatalog {
|
impl DatasetCatalog {
|
||||||
@@ -111,18 +110,19 @@ impl DatasetCatalog {
|
|||||||
let raw_path = resolve_path(root, &dataset.raw);
|
let raw_path = resolve_path(root, &dataset.raw);
|
||||||
let converted_path = resolve_path(root, &dataset.converted);
|
let converted_path = resolve_path(root, &dataset.converted);
|
||||||
|
|
||||||
if !raw_path.exists() {
|
if !raw_path.exists() && dataset.id != "beir" {
|
||||||
bail!(
|
bail!(
|
||||||
"dataset '{}' raw file missing at {}",
|
"dataset '{}' raw file missing at {}",
|
||||||
dataset.id,
|
dataset.id,
|
||||||
raw_path.display()
|
raw_path.display()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if !converted_path.exists() {
|
let store_dir = store::store_dir_for(&converted_path);
|
||||||
|
if !converted_path.exists() && !store_dir.join("meta.json").is_file() {
|
||||||
warn!(
|
warn!(
|
||||||
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
|
"dataset '{}' converted store missing at {}; the next conversion run will regenerate it",
|
||||||
dataset.id,
|
dataset.id,
|
||||||
converted_path.display()
|
store_dir.display()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +139,6 @@ impl DatasetCatalog {
|
|||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| dataset.id.clone()),
|
.unwrap_or_else(|| dataset.id.clone()),
|
||||||
include_unanswerable: dataset.include_unanswerable,
|
include_unanswerable: dataset.include_unanswerable,
|
||||||
context_token_limit: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
|
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
|
||||||
@@ -154,12 +153,11 @@ impl DatasetCatalog {
|
|||||||
entry_slices.push(SliceEntry {
|
entry_slices.push(SliceEntry {
|
||||||
id: manifest_slice.id.clone(),
|
id: manifest_slice.id.clone(),
|
||||||
dataset_id: dataset.id.clone(),
|
dataset_id: dataset.id.clone(),
|
||||||
label: manifest_slice.label,
|
|
||||||
description: manifest_slice.description,
|
|
||||||
limit: manifest_slice.limit,
|
limit: manifest_slice.limit,
|
||||||
corpus_limit: manifest_slice.corpus_limit,
|
corpus_limit: manifest_slice.corpus_limit,
|
||||||
include_unanswerable: manifest_slice.include_unanswerable,
|
include_unanswerable: manifest_slice.include_unanswerable,
|
||||||
seed: manifest_slice.seed,
|
seed: manifest_slice.seed,
|
||||||
|
negative_multiplier: manifest_slice.negative_multiplier,
|
||||||
});
|
});
|
||||||
slices.insert(
|
slices.insert(
|
||||||
manifest_slice.id,
|
manifest_slice.id,
|
||||||
@@ -176,22 +174,16 @@ impl DatasetCatalog {
|
|||||||
metadata,
|
metadata,
|
||||||
raw_path,
|
raw_path,
|
||||||
converted_path,
|
converted_path,
|
||||||
include_unanswerable: dataset.include_unanswerable,
|
|
||||||
slices: entry_slices,
|
slices: entry_slices,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let default_dataset = manifest
|
if datasets.is_empty() {
|
||||||
.default_dataset
|
bail!("dataset manifest does not include any datasets");
|
||||||
.or_else(|| datasets.keys().next().cloned())
|
}
|
||||||
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self { datasets, slices })
|
||||||
datasets,
|
|
||||||
slices,
|
|
||||||
default_dataset,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn global() -> Result<&'static Self> {
|
pub fn global() -> Result<&'static Self> {
|
||||||
@@ -204,12 +196,6 @@ impl DatasetCatalog {
|
|||||||
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
|
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
|
|
||||||
self.dataset(&self.default_dataset)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
|
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
|
||||||
let location = self
|
let location = self
|
||||||
.slices
|
.slices
|
||||||
@@ -236,20 +222,27 @@ fn resolve_path(root: &Path, value: &str) -> PathBuf {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub use beir_mix::{beir_subset_store_summary, beir_subset_stores_ready, mix_content_checksum};
|
||||||
|
pub use checksum::store_aggregate_checksum;
|
||||||
|
pub use loader::{prebuild_catalog_slices, prepare_dataset};
|
||||||
|
pub use store::{
|
||||||
|
content_checksum_for_layout, detect_layout, store_dir_for, write_sharded, ConvertedLayout,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn catalog() -> Result<&'static DatasetCatalog> {
|
pub fn catalog() -> Result<&'static DatasetCatalog> {
|
||||||
DatasetCatalog::global()
|
DatasetCatalog::global()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
pub(crate) fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||||
let catalog = catalog()?;
|
let catalog = catalog()?;
|
||||||
catalog.dataset(kind.id())
|
catalog.dataset(kind.id())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, ValueEnum, Default)]
|
||||||
pub enum DatasetKind {
|
pub enum DatasetKind {
|
||||||
#[default]
|
|
||||||
SquadV2,
|
SquadV2,
|
||||||
NaturalQuestions,
|
NaturalQuestions,
|
||||||
|
#[default]
|
||||||
Beir,
|
Beir,
|
||||||
#[value(name = "fever")]
|
#[value(name = "fever")]
|
||||||
Fever,
|
Fever,
|
||||||
@@ -416,16 +409,10 @@ pub struct DatasetMetadata {
|
|||||||
pub source_prefix: String,
|
pub source_prefix: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub include_unanswerable: bool,
|
pub include_unanswerable: bool,
|
||||||
#[serde(default)]
|
|
||||||
pub context_token_limit: Option<usize>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DatasetMetadata {
|
impl DatasetMetadata {
|
||||||
pub fn for_kind(
|
pub fn for_kind(kind: DatasetKind, include_unanswerable: bool) -> Self {
|
||||||
kind: DatasetKind,
|
|
||||||
include_unanswerable: bool,
|
|
||||||
context_token_limit: Option<usize>,
|
|
||||||
) -> Self {
|
|
||||||
if let Ok(entry) = dataset_entry_for_kind(kind) {
|
if let Ok(entry) = dataset_entry_for_kind(kind) {
|
||||||
return Self {
|
return Self {
|
||||||
id: entry.metadata.id.clone(),
|
id: entry.metadata.id.clone(),
|
||||||
@@ -434,7 +421,6 @@ impl DatasetMetadata {
|
|||||||
entity_suffix: entry.metadata.entity_suffix.clone(),
|
entity_suffix: entry.metadata.entity_suffix.clone(),
|
||||||
source_prefix: entry.metadata.source_prefix.clone(),
|
source_prefix: entry.metadata.source_prefix.clone(),
|
||||||
include_unanswerable,
|
include_unanswerable,
|
||||||
context_token_limit,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -445,13 +431,12 @@ impl DatasetMetadata {
|
|||||||
entity_suffix: kind.entity_suffix().to_string(),
|
entity_suffix: kind.entity_suffix().to_string(),
|
||||||
source_prefix: kind.source_prefix().to_string(),
|
source_prefix: kind.source_prefix().to_string(),
|
||||||
include_unanswerable,
|
include_unanswerable,
|
||||||
context_token_limit,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_metadata() -> DatasetMetadata {
|
fn default_metadata() -> DatasetMetadata {
|
||||||
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
|
DatasetMetadata::for_kind(DatasetKind::default(), false)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -483,14 +468,15 @@ pub fn convert(
|
|||||||
raw_path: &Path,
|
raw_path: &Path,
|
||||||
dataset: DatasetKind,
|
dataset: DatasetKind,
|
||||||
include_unanswerable: bool,
|
include_unanswerable: bool,
|
||||||
context_token_limit: Option<usize>,
|
|
||||||
) -> Result<ConvertedDataset> {
|
) -> Result<ConvertedDataset> {
|
||||||
let paragraphs = match dataset {
|
let paragraphs = match dataset {
|
||||||
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
|
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
|
||||||
DatasetKind::NaturalQuestions => {
|
DatasetKind::NaturalQuestions => nq::convert_nq(raw_path, include_unanswerable)?,
|
||||||
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
|
DatasetKind::Beir => {
|
||||||
|
bail!(
|
||||||
|
"BEIR mix is prepared via slice-first subset stores; use prepare_beir_mix instead of convert"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
DatasetKind::Beir => convert_beir_mix(include_unanswerable, context_token_limit)?,
|
|
||||||
DatasetKind::Fever
|
DatasetKind::Fever
|
||||||
| DatasetKind::Fiqa
|
| DatasetKind::Fiqa
|
||||||
| DatasetKind::HotpotQa
|
| DatasetKind::HotpotQa
|
||||||
@@ -501,11 +487,6 @@ pub fn convert(
|
|||||||
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
|
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let metadata_limit = match dataset {
|
|
||||||
DatasetKind::NaturalQuestions => None,
|
|
||||||
_ => context_token_limit,
|
|
||||||
};
|
|
||||||
|
|
||||||
let generated_at = match dataset {
|
let generated_at = match dataset {
|
||||||
DatasetKind::Beir
|
DatasetKind::Beir
|
||||||
| DatasetKind::Fever
|
| DatasetKind::Fever
|
||||||
@@ -526,100 +507,12 @@ pub fn convert(
|
|||||||
|
|
||||||
Ok(ConvertedDataset {
|
Ok(ConvertedDataset {
|
||||||
generated_at,
|
generated_at,
|
||||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable),
|
||||||
source: source_label,
|
source: source_label,
|
||||||
paragraphs,
|
paragraphs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_beir_mix(
|
|
||||||
include_unanswerable: bool,
|
|
||||||
_context_token_limit: Option<usize>,
|
|
||||||
) -> Result<Vec<ConvertedParagraph>> {
|
|
||||||
if include_unanswerable {
|
|
||||||
warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut paragraphs = Vec::new();
|
|
||||||
for subset in BEIR_DATASETS {
|
|
||||||
let entry = dataset_entry_for_kind(subset)?;
|
|
||||||
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
|
|
||||||
paragraphs.extend(subset_paragraphs);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(paragraphs)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ensure_parent(path: &Path) -> Result<()> {
|
|
||||||
if let Some(parent) = path.parent() {
|
|
||||||
fs::create_dir_all(parent)
|
|
||||||
.with_context(|| format!("creating parent directory for {}", path.display()))?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
|
|
||||||
ensure_parent(converted_path)?;
|
|
||||||
let json =
|
|
||||||
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
|
|
||||||
fs::write(converted_path, json)
|
|
||||||
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
|
|
||||||
let raw = fs::read_to_string(converted_path)
|
|
||||||
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
|
|
||||||
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
|
|
||||||
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
|
|
||||||
if dataset.metadata.id.trim().is_empty() {
|
|
||||||
dataset.metadata = default_metadata();
|
|
||||||
}
|
|
||||||
if dataset.source.is_empty() {
|
|
||||||
dataset.source = converted_path.display().to_string();
|
|
||||||
}
|
|
||||||
Ok(dataset)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ensure_converted(
|
|
||||||
dataset_kind: DatasetKind,
|
|
||||||
raw_path: &Path,
|
|
||||||
converted_path: &Path,
|
|
||||||
force: bool,
|
|
||||||
include_unanswerable: bool,
|
|
||||||
context_token_limit: Option<usize>,
|
|
||||||
) -> Result<ConvertedDataset> {
|
|
||||||
if force || !converted_path.exists() {
|
|
||||||
let dataset = convert(
|
|
||||||
raw_path,
|
|
||||||
dataset_kind,
|
|
||||||
include_unanswerable,
|
|
||||||
context_token_limit,
|
|
||||||
)?;
|
|
||||||
write_converted(&dataset, converted_path)?;
|
|
||||||
return Ok(dataset);
|
|
||||||
}
|
|
||||||
|
|
||||||
match read_converted(converted_path) {
|
|
||||||
Ok(dataset)
|
|
||||||
if dataset.metadata.id == dataset_kind.id()
|
|
||||||
&& dataset.metadata.include_unanswerable == include_unanswerable
|
|
||||||
&& dataset.metadata.context_token_limit == context_token_limit =>
|
|
||||||
{
|
|
||||||
Ok(dataset)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let dataset = convert(
|
|
||||||
raw_path,
|
|
||||||
dataset_kind,
|
|
||||||
include_unanswerable,
|
|
||||||
context_token_limit,
|
|
||||||
)?;
|
|
||||||
write_converted(&dataset, converted_path)?;
|
|
||||||
Ok(dataset)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn base_timestamp() -> DateTime<Utc> {
|
pub fn base_timestamp() -> DateTime<Utc> {
|
||||||
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
|
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,11 +16,7 @@ use super::{ConvertedParagraph, ConvertedQuestion};
|
|||||||
clippy::arithmetic_side_effects,
|
clippy::arithmetic_side_effects,
|
||||||
clippy::cast_sign_loss
|
clippy::cast_sign_loss
|
||||||
)]
|
)]
|
||||||
pub fn convert_nq(
|
pub fn convert_nq(raw_path: &Path, include_unanswerable: bool) -> Result<Vec<ConvertedParagraph>> {
|
||||||
raw_path: &Path,
|
|
||||||
include_unanswerable: bool,
|
|
||||||
_context_token_limit: Option<usize>,
|
|
||||||
) -> Result<Vec<ConvertedParagraph>> {
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct NqExample {
|
struct NqExample {
|
||||||
|
|||||||
@@ -0,0 +1,412 @@
|
|||||||
|
use std::{
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
fs::{self, File, OpenOptions},
|
||||||
|
io::{BufRead, BufReader, Write},
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Context, Result};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
checksum::store_aggregate_checksum, ConvertedDataset, ConvertedParagraph, ConvertedQuestion,
|
||||||
|
DatasetMetadata,
|
||||||
|
};
|
||||||
|
use crate::slice;
|
||||||
|
|
||||||
|
pub const SHARDED_STORE_VERSION: u32 = 1;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ShardedMeta {
|
||||||
|
pub version: u32,
|
||||||
|
pub generated_at: DateTime<Utc>,
|
||||||
|
pub metadata: DatasetMetadata,
|
||||||
|
pub source: String,
|
||||||
|
pub paragraph_count: usize,
|
||||||
|
pub question_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub(crate) struct QuestionRecord {
|
||||||
|
paragraph_id: String,
|
||||||
|
#[serde(flatten)]
|
||||||
|
question: ConvertedQuestion,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct QuestionCatalog {
|
||||||
|
pub entries: Vec<QuestionRecord>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ConvertedLayout {
|
||||||
|
ShardedStore,
|
||||||
|
Missing,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_dir_for(converted_path: &Path) -> PathBuf {
|
||||||
|
converted_path
|
||||||
|
.parent()
|
||||||
|
.unwrap_or_else(|| Path::new("."))
|
||||||
|
.join(converted_path.file_stem().map_or_else(
|
||||||
|
|| "dataset".to_string(),
|
||||||
|
|stem| stem.to_string_lossy().into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn detect_layout(converted_path: &Path) -> ConvertedLayout {
|
||||||
|
let store_dir = store_dir_for(converted_path);
|
||||||
|
if store_dir.join("meta.json").is_file() {
|
||||||
|
ConvertedLayout::ShardedStore
|
||||||
|
} else {
|
||||||
|
ConvertedLayout::Missing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn paragraph_file_name(paragraph_id: &str) -> String {
|
||||||
|
format!("{}.json", slice::paragraph_storage_key(paragraph_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn paragraph_path(store_dir: &Path, paragraph_id: &str) -> PathBuf {
|
||||||
|
store_dir
|
||||||
|
.join("paragraphs")
|
||||||
|
.join(paragraph_file_name(paragraph_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_sharded(dataset: &ConvertedDataset, store_dir: &Path) -> Result<String> {
|
||||||
|
if store_dir.exists() {
|
||||||
|
fs::remove_dir_all(store_dir)
|
||||||
|
.with_context(|| format!("clearing sharded store {}", store_dir.display()))?;
|
||||||
|
}
|
||||||
|
fs::create_dir_all(store_dir.join("paragraphs"))
|
||||||
|
.with_context(|| format!("creating sharded store {}", store_dir.display()))?;
|
||||||
|
|
||||||
|
let question_count = dataset
|
||||||
|
.paragraphs
|
||||||
|
.iter()
|
||||||
|
.map(|paragraph| paragraph.questions.len())
|
||||||
|
.sum::<usize>();
|
||||||
|
|
||||||
|
let meta = ShardedMeta {
|
||||||
|
version: SHARDED_STORE_VERSION,
|
||||||
|
generated_at: dataset.generated_at,
|
||||||
|
metadata: dataset.metadata.clone(),
|
||||||
|
source: dataset.source.clone(),
|
||||||
|
paragraph_count: dataset.paragraphs.len(),
|
||||||
|
question_count,
|
||||||
|
};
|
||||||
|
let meta_path = store_dir.join("meta.json");
|
||||||
|
fs::write(
|
||||||
|
&meta_path,
|
||||||
|
serde_json::to_vec_pretty(&meta).context("serialising sharded store metadata")?,
|
||||||
|
)
|
||||||
|
.with_context(|| format!("writing sharded metadata {}", meta_path.display()))?;
|
||||||
|
|
||||||
|
let mut questions_file = File::create(store_dir.join("questions.jsonl"))
|
||||||
|
.context("creating questions.jsonl for sharded store")?;
|
||||||
|
let mut paragraph_ids_file = File::create(store_dir.join("paragraph_ids.jsonl"))
|
||||||
|
.context("creating paragraph_ids.jsonl for sharded store")?;
|
||||||
|
|
||||||
|
for paragraph in &dataset.paragraphs {
|
||||||
|
writeln!(paragraph_ids_file, "{}", paragraph.id)
|
||||||
|
.context("writing paragraph id to paragraph_ids.jsonl")?;
|
||||||
|
for question in ¶graph.questions {
|
||||||
|
let record = QuestionRecord {
|
||||||
|
paragraph_id: paragraph.id.clone(),
|
||||||
|
question: question.clone(),
|
||||||
|
};
|
||||||
|
serde_json::to_writer(&mut questions_file, &record)
|
||||||
|
.context("writing question record to questions.jsonl")?;
|
||||||
|
questions_file.write_all(b"\n")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = paragraph_path(store_dir, ¶graph.id);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
fs::write(
|
||||||
|
&path,
|
||||||
|
serde_json::to_vec(paragraph).context("serialising sharded paragraph")?,
|
||||||
|
)
|
||||||
|
.with_context(|| format!("writing sharded paragraph {}", path.display()))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let digest = store_aggregate_checksum(store_dir)?;
|
||||||
|
info!(
|
||||||
|
store = %store_dir.display(),
|
||||||
|
paragraphs = dataset.paragraphs.len(),
|
||||||
|
questions = question_count,
|
||||||
|
checksum = %digest,
|
||||||
|
"Wrote sharded converted dataset"
|
||||||
|
);
|
||||||
|
Ok(digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_meta(store_dir: &Path) -> Result<ShardedMeta> {
|
||||||
|
let path = store_dir.join("meta.json");
|
||||||
|
let raw = fs::read_to_string(&path)
|
||||||
|
.with_context(|| format!("reading sharded metadata {}", path.display()))?;
|
||||||
|
serde_json::from_str(&raw)
|
||||||
|
.with_context(|| format!("parsing sharded metadata {}", path.display()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn content_checksum_for_layout(converted_path: &Path) -> Result<String> {
|
||||||
|
match detect_layout(converted_path) {
|
||||||
|
ConvertedLayout::ShardedStore => {
|
||||||
|
crate::datasets::store_aggregate_checksum(&store_dir_for(converted_path))
|
||||||
|
}
|
||||||
|
ConvertedLayout::Missing => Err(anyhow!(
|
||||||
|
"converted dataset missing at {}",
|
||||||
|
converted_path.display()
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_paragraph(store_dir: &Path, paragraph_id: &str) -> Result<ConvertedParagraph> {
|
||||||
|
let path = paragraph_path(store_dir, paragraph_id);
|
||||||
|
let raw =
|
||||||
|
fs::read(&path).with_context(|| format!("reading sharded paragraph {}", path.display()))?;
|
||||||
|
serde_json::from_slice(&raw)
|
||||||
|
.with_context(|| format!("parsing sharded paragraph {}", path.display()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_paragraphs(store_dir: &Path, paragraph_ids: &[String]) -> Result<Vec<ConvertedParagraph>> {
|
||||||
|
paragraph_ids
|
||||||
|
.iter()
|
||||||
|
.map(|paragraph_id| load_paragraph(store_dir, paragraph_id))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_sharded_partial(
|
||||||
|
store_dir: &Path,
|
||||||
|
paragraph_ids: &[String],
|
||||||
|
) -> Result<ConvertedDataset> {
|
||||||
|
let meta = read_meta(store_dir)?;
|
||||||
|
let mut paragraphs = load_paragraphs(store_dir, paragraph_ids)?;
|
||||||
|
paragraphs.sort_by(|left, right| left.id.cmp(&right.id));
|
||||||
|
Ok(ConvertedDataset {
|
||||||
|
generated_at: meta.generated_at,
|
||||||
|
metadata: meta.metadata,
|
||||||
|
source: meta.source,
|
||||||
|
paragraphs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_sharded_full(store_dir: &Path) -> Result<ConvertedDataset> {
|
||||||
|
let meta = read_meta(store_dir)?;
|
||||||
|
let ids = load_paragraph_ids(store_dir)?;
|
||||||
|
let paragraphs = load_paragraphs(store_dir, &ids)?;
|
||||||
|
Ok(ConvertedDataset {
|
||||||
|
generated_at: meta.generated_at,
|
||||||
|
metadata: meta.metadata,
|
||||||
|
source: meta.source,
|
||||||
|
paragraphs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_paragraph_ids_set(store_dir: &Path) -> Result<HashSet<String>> {
|
||||||
|
Ok(load_paragraph_ids(store_dir)?.into_iter().collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::arithmetic_side_effects)]
|
||||||
|
pub fn upsert_sharded_paragraphs(
|
||||||
|
store_dir: &Path,
|
||||||
|
paragraphs: &[ConvertedParagraph],
|
||||||
|
) -> Result<()> {
|
||||||
|
if paragraphs.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
if !store_dir.join("meta.json").is_file() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"cannot upsert into missing sharded store at {}",
|
||||||
|
store_dir.display()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
fs::create_dir_all(store_dir.join("paragraphs"))
|
||||||
|
.with_context(|| format!("creating paragraphs directory in {}", store_dir.display()))?;
|
||||||
|
|
||||||
|
let existing = load_paragraph_ids_set(store_dir)?;
|
||||||
|
let questions_path = store_dir.join("questions.jsonl");
|
||||||
|
let mut questions_file = OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.open(&questions_path)
|
||||||
|
.with_context(|| format!("opening question catalog {}", questions_path.display()))?;
|
||||||
|
|
||||||
|
let mut ids_file = None;
|
||||||
|
let mut new_paragraphs = 0usize;
|
||||||
|
let mut new_questions = 0usize;
|
||||||
|
|
||||||
|
for paragraph in paragraphs {
|
||||||
|
let is_new = !existing.contains(¶graph.id);
|
||||||
|
let path = paragraph_path(store_dir, ¶graph.id);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
fs::write(
|
||||||
|
&path,
|
||||||
|
serde_json::to_vec(paragraph).context("serialising sharded paragraph")?,
|
||||||
|
)
|
||||||
|
.with_context(|| format!("writing sharded paragraph {}", path.display()))?;
|
||||||
|
|
||||||
|
if is_new {
|
||||||
|
if ids_file.is_none() {
|
||||||
|
ids_file = Some(
|
||||||
|
OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.open(store_dir.join("paragraph_ids.jsonl"))
|
||||||
|
.context("opening paragraph_ids.jsonl for append")?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if let Some(file) = ids_file.as_mut() {
|
||||||
|
writeln!(file, "{}", paragraph.id).context("appending paragraph id")?;
|
||||||
|
}
|
||||||
|
new_paragraphs += 1;
|
||||||
|
|
||||||
|
for question in ¶graph.questions {
|
||||||
|
let record = QuestionRecord {
|
||||||
|
paragraph_id: paragraph.id.clone(),
|
||||||
|
question: question.clone(),
|
||||||
|
};
|
||||||
|
serde_json::to_writer(&mut questions_file, &record)
|
||||||
|
.context("writing question record to questions.jsonl")?;
|
||||||
|
questions_file.write_all(b"\n")?;
|
||||||
|
new_questions += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if new_paragraphs > 0 || new_questions > 0 {
|
||||||
|
let meta = read_meta(store_dir)?;
|
||||||
|
let updated = ShardedMeta {
|
||||||
|
paragraph_count: meta.paragraph_count + new_paragraphs,
|
||||||
|
question_count: meta.question_count + new_questions,
|
||||||
|
..meta
|
||||||
|
};
|
||||||
|
fs::write(
|
||||||
|
store_dir.join("meta.json"),
|
||||||
|
serde_json::to_vec_pretty(&updated).context("serialising updated sharded metadata")?,
|
||||||
|
)?;
|
||||||
|
store_aggregate_checksum(store_dir)?;
|
||||||
|
info!(
|
||||||
|
store = %store_dir.display(),
|
||||||
|
new_paragraphs,
|
||||||
|
new_questions,
|
||||||
|
"Upserted paragraphs into sharded converted store"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_paragraph_ids(store_dir: &Path) -> Result<Vec<String>> {
|
||||||
|
let path = store_dir.join("paragraph_ids.jsonl");
|
||||||
|
let file = File::open(&path)
|
||||||
|
.with_context(|| format!("opening paragraph id index {}", path.display()))?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
reader
|
||||||
|
.lines()
|
||||||
|
.map(|line| {
|
||||||
|
line.context("reading paragraph id index line")
|
||||||
|
.and_then(|value| {
|
||||||
|
let trimmed = value.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
Err(anyhow!("empty paragraph id in index"))
|
||||||
|
} else {
|
||||||
|
Ok(trimmed.to_string())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_question_catalog(store_dir: &Path) -> Result<QuestionCatalog> {
|
||||||
|
let path = store_dir.join("questions.jsonl");
|
||||||
|
let file = File::open(&path)
|
||||||
|
.with_context(|| format!("opening question catalog {}", path.display()))?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let mut entries = Vec::new();
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line.context("reading question catalog line")?;
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let record: QuestionRecord =
|
||||||
|
serde_json::from_str(&line).context("parsing question catalog record")?;
|
||||||
|
entries.push(record);
|
||||||
|
}
|
||||||
|
Ok(QuestionCatalog { entries })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_dataset_from_catalog(
|
||||||
|
store_dir: &Path,
|
||||||
|
paragraph_ids: &HashSet<String>,
|
||||||
|
) -> Result<ConvertedDataset> {
|
||||||
|
let catalog = load_question_catalog(store_dir)?;
|
||||||
|
let mut questions_by_paragraph: HashMap<String, Vec<ConvertedQuestion>> = HashMap::new();
|
||||||
|
for entry in catalog.entries {
|
||||||
|
if paragraph_ids.contains(&entry.paragraph_id) {
|
||||||
|
questions_by_paragraph
|
||||||
|
.entry(entry.paragraph_id.clone())
|
||||||
|
.or_default()
|
||||||
|
.push(entry.question);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut dataset = load_sharded_partial(
|
||||||
|
store_dir,
|
||||||
|
¶graph_ids.iter().cloned().collect::<Vec<_>>(),
|
||||||
|
)?;
|
||||||
|
for paragraph in &mut dataset.paragraphs {
|
||||||
|
if let Some(questions) = questions_by_paragraph.remove(¶graph.id) {
|
||||||
|
paragraph.questions = questions;
|
||||||
|
} else {
|
||||||
|
paragraph.questions.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(dataset)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::datasets::{DatasetKind, DatasetMetadata};
|
||||||
|
|
||||||
|
fn sample_dataset() -> ConvertedDataset {
|
||||||
|
ConvertedDataset {
|
||||||
|
generated_at: Utc::now(),
|
||||||
|
metadata: DatasetMetadata::for_kind(DatasetKind::SquadV2, false),
|
||||||
|
source: "test".to_string(),
|
||||||
|
paragraphs: vec![ConvertedParagraph {
|
||||||
|
id: "p1".to_string(),
|
||||||
|
title: "Title".to_string(),
|
||||||
|
context: "Body".to_string(),
|
||||||
|
questions: vec![ConvertedQuestion {
|
||||||
|
id: "q1".to_string(),
|
||||||
|
question: "Question?".to_string(),
|
||||||
|
answers: vec!["Answer".to_string()],
|
||||||
|
is_impossible: false,
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[allow(clippy::indexing_slicing)]
|
||||||
|
fn sharded_round_trip() -> Result<()> {
|
||||||
|
let dir = tempfile::tempdir()?;
|
||||||
|
let store_dir = dir.path().join("sample");
|
||||||
|
let dataset = sample_dataset();
|
||||||
|
write_sharded(&dataset, &store_dir)?;
|
||||||
|
|
||||||
|
let loaded = load_sharded_full(&store_dir)?;
|
||||||
|
assert_eq!(loaded.paragraphs.len(), 1);
|
||||||
|
assert_eq!(loaded.paragraphs[0].questions[0].id, "q1");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,22 +1,22 @@
|
|||||||
//! Database namespace management utilities.
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use common::storage::{
|
use common::{
|
||||||
db::SurrealDbClient,
|
storage::{
|
||||||
types::user::{Theme, User},
|
db::SurrealDbClient,
|
||||||
types::StoredObject,
|
types::user::{Theme, User},
|
||||||
|
types::StoredObject,
|
||||||
|
},
|
||||||
|
utils::embedding::EmbeddingProvider,
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
args::Config,
|
args::Config,
|
||||||
|
corpus::{self, CorpusHandle, CorpusManifest, NamespaceSeedRecord},
|
||||||
datasets,
|
datasets,
|
||||||
snapshot::{self, DbSnapshotState},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Connect to the evaluation database with fallback auth strategies.
|
|
||||||
pub(crate) async fn connect_eval_db(
|
pub(crate) async fn connect_eval_db(
|
||||||
config: &Config,
|
config: &Config,
|
||||||
namespace: &str,
|
namespace: &str,
|
||||||
@@ -73,7 +73,6 @@ pub(crate) async fn connect_eval_db(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if the namespace contains any corpus data.
|
|
||||||
pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
|
pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct CountRow {
|
struct CountRow {
|
||||||
@@ -89,41 +88,51 @@ pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
|
|||||||
Ok(rows.first().map_or(0, |row| row.count) > 0)
|
Ok(rows.first().map_or(0, |row| row.count) > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Determine if we can reuse an existing namespace based on cached state.
|
fn manifest_matches_runtime(
|
||||||
|
manifest: &CorpusManifest,
|
||||||
|
embedding_provider: &EmbeddingProvider,
|
||||||
|
ingestion_fingerprint: &str,
|
||||||
|
) -> bool {
|
||||||
|
let metadata = &manifest.metadata;
|
||||||
|
metadata.ingestion_fingerprint == ingestion_fingerprint
|
||||||
|
&& metadata.embedding_backend == embedding_provider.backend_label()
|
||||||
|
&& metadata.embedding_model == embedding_provider.model_code()
|
||||||
|
&& metadata.embedding_dimension == embedding_provider.dimension()
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) async fn can_reuse_namespace(
|
pub(crate) async fn can_reuse_namespace(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
descriptor: &snapshot::Descriptor,
|
manifest: &CorpusManifest,
|
||||||
|
embedding_provider: &EmbeddingProvider,
|
||||||
namespace: &str,
|
namespace: &str,
|
||||||
database: &str,
|
database: &str,
|
||||||
dataset_id: &str,
|
|
||||||
slice_id: &str,
|
|
||||||
ingestion_fingerprint: &str,
|
ingestion_fingerprint: &str,
|
||||||
slice_case_count: usize,
|
slice_case_count: usize,
|
||||||
) -> Result<bool> {
|
) -> Result<bool> {
|
||||||
let Some(state) = descriptor.load_db_state().await? else {
|
if !manifest_matches_runtime(manifest, embedding_provider, ingestion_fingerprint) {
|
||||||
info!("No namespace state recorded; reseeding corpus from cached shards");
|
info!("Corpus manifest metadata mismatch; rebuilding namespace from cached shards");
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(seed) = manifest.metadata.namespace_seed.as_ref() else {
|
||||||
|
info!("No namespace seed recorded in corpus manifest; reseeding");
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
if state.slice_case_count != slice_case_count {
|
if seed.slice_case_count != slice_case_count {
|
||||||
info!(
|
info!(
|
||||||
requested_cases = slice_case_count,
|
requested_cases = slice_case_count,
|
||||||
stored_cases = state.slice_case_count,
|
stored_cases = seed.slice_case_count,
|
||||||
"Skipping live namespace reuse; cached state does not match requested window"
|
"Skipping namespace reuse; case window mismatch"
|
||||||
);
|
);
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.dataset_id != dataset_id
|
if seed.namespace != namespace || seed.database != database {
|
||||||
|| state.slice_id != slice_id
|
|
||||||
|| state.ingestion_fingerprint != ingestion_fingerprint
|
|
||||||
|| state.namespace.as_deref() != Some(namespace)
|
|
||||||
|| state.database.as_deref() != Some(database)
|
|
||||||
{
|
|
||||||
info!(
|
info!(
|
||||||
namespace,
|
namespace,
|
||||||
database, "Cached namespace metadata mismatch; rebuilding corpus from ingestion cache"
|
database, "Corpus manifest namespace metadata mismatch; reseeding"
|
||||||
);
|
);
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
@@ -140,28 +149,20 @@ pub(crate) async fn can_reuse_namespace(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Record the current namespace state to allow future reuse checks.
|
pub(crate) async fn record_namespace_seed(
|
||||||
pub(crate) async fn record_namespace_state(
|
handle: &mut CorpusHandle,
|
||||||
descriptor: &snapshot::Descriptor,
|
|
||||||
dataset_id: &str,
|
|
||||||
slice_id: &str,
|
|
||||||
ingestion_fingerprint: &str,
|
|
||||||
namespace: &str,
|
namespace: &str,
|
||||||
database: &str,
|
database: &str,
|
||||||
slice_case_count: usize,
|
slice_case_count: usize,
|
||||||
) {
|
) {
|
||||||
let state = DbSnapshotState {
|
handle.manifest.metadata.namespace_seed = Some(NamespaceSeedRecord {
|
||||||
dataset_id: dataset_id.to_string(),
|
namespace: namespace.to_string(),
|
||||||
slice_id: slice_id.to_string(),
|
database: database.to_string(),
|
||||||
ingestion_fingerprint: ingestion_fingerprint.to_string(),
|
|
||||||
snapshot_hash: descriptor.metadata_hash().to_string(),
|
|
||||||
updated_at: Utc::now(),
|
|
||||||
namespace: Some(namespace.to_string()),
|
|
||||||
database: Some(database.to_string()),
|
|
||||||
slice_case_count,
|
slice_case_count,
|
||||||
};
|
seeded_at: Utc::now(),
|
||||||
if let Err(err) = descriptor.store_db_state(&state).await {
|
});
|
||||||
warn!(error = %err, "Failed to record namespace state");
|
if let Err(err) = corpus::persist_corpus_manifest(handle) {
|
||||||
|
warn!(error = %err, "Failed to record namespace seed in corpus manifest");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,8 +186,17 @@ fn sanitize_identifier(input: &str) -> String {
|
|||||||
cleaned
|
cleaned
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a default namespace name based on dataset and limit.
|
pub(crate) fn default_namespace(
|
||||||
pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> String {
|
dataset_id: &str,
|
||||||
|
limit: Option<usize>,
|
||||||
|
slice_id: Option<&str>,
|
||||||
|
) -> String {
|
||||||
|
if let Some(slice_id) = slice_id {
|
||||||
|
let sanitized = sanitize_identifier(slice_id);
|
||||||
|
if !sanitized.is_empty() {
|
||||||
|
return format!("eval_{sanitized}");
|
||||||
|
}
|
||||||
|
}
|
||||||
let dataset_component = sanitize_identifier(dataset_id);
|
let dataset_component = sanitize_identifier(dataset_id);
|
||||||
let limit_component = match limit {
|
let limit_component = match limit {
|
||||||
Some(value) if value > 0 => format!("limit{value}"),
|
Some(value) if value > 0 => format!("limit{value}"),
|
||||||
@@ -195,12 +205,10 @@ pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> Strin
|
|||||||
format!("eval_{dataset_component}_{limit_component}")
|
format!("eval_{dataset_component}_{limit_component}")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate the default database name for evaluations.
|
|
||||||
pub(crate) fn default_database() -> String {
|
pub(crate) fn default_database() -> String {
|
||||||
"retrieval_eval".to_string()
|
"retrieval_eval".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Ensure the evaluation user exists in the database.
|
|
||||||
pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
||||||
let timestamp = datasets::base_timestamp();
|
let timestamp = datasets::base_timestamp();
|
||||||
let user = User {
|
let user = User {
|
||||||
@@ -225,3 +233,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
|||||||
.context("storing evaluation user")?;
|
.context("storing evaluation user")?;
|
||||||
Ok(user)
|
Ok(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sanitize_model_code(code: &str) -> String {
|
||||||
|
sanitize_identifier(code)
|
||||||
|
}
|
||||||
@@ -2,13 +2,6 @@ use anyhow::{Context, Result};
|
|||||||
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime};
|
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
// Helper functions for index management during namespace reseed
|
|
||||||
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
|
|
||||||
let _ = db;
|
|
||||||
info!("Removing ALL indexes before namespace reseed (no-op placeholder)");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||||
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
|
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
|
||||||
ensure_runtime(db, dimension)
|
ensure_runtime(db, dimension)
|
||||||
@@ -34,14 +27,39 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Test helper to force index dimension change
|
#[allow(clippy::cast_precision_loss)]
|
||||||
// #[allow(dead_code)]
|
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||||
// pub async fn change_embedding_length_in_hnsw_indexes(
|
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
|
||||||
// db: &SurrealDbClient,
|
|
||||||
// dimension: usize,
|
info!("Warming HNSW caches with sample queries");
|
||||||
// ) -> Result<()> {
|
|
||||||
// recreate_indexes(db, dimension).await
|
let _ = db
|
||||||
// }
|
.client
|
||||||
|
.query(
|
||||||
|
r#"SELECT chunk_id
|
||||||
|
FROM text_chunk_embedding
|
||||||
|
WHERE embedding <|1,1|> $embedding
|
||||||
|
LIMIT 5"#,
|
||||||
|
)
|
||||||
|
.bind(("embedding", dummy_embedding.clone()))
|
||||||
|
.await
|
||||||
|
.context("warming text chunk HNSW cache")?;
|
||||||
|
|
||||||
|
let _ = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
r#"SELECT entity_id
|
||||||
|
FROM knowledge_entity_embedding
|
||||||
|
WHERE embedding <|1,1|> $embedding
|
||||||
|
LIMIT 5"#,
|
||||||
|
)
|
||||||
|
.bind(("embedding", dummy_embedding))
|
||||||
|
.await
|
||||||
|
.context("warming knowledge entity HNSW cache")?;
|
||||||
|
|
||||||
|
info!("HNSW cache warming completed");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
mod connect;
|
||||||
|
mod lifecycle;
|
||||||
|
|
||||||
|
pub(crate) use connect::{
|
||||||
|
can_reuse_namespace, connect_eval_db, default_database, default_namespace, ensure_eval_user,
|
||||||
|
namespace_has_corpus, record_namespace_seed, sanitize_model_code,
|
||||||
|
};
|
||||||
|
pub(crate) use lifecycle::warm_hnsw_cache;
|
||||||
|
pub use lifecycle::{recreate_indexes, reset_namespace};
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
//! Evaluation utilities module - re-exports from focused submodules.
|
|
||||||
|
|
||||||
// Re-export types from the root types module
|
|
||||||
pub use crate::types::*;
|
|
||||||
|
|
||||||
// Re-export from focused modules at crate root (crate-internal only)
|
|
||||||
pub(crate) use crate::cases::{cases_from_manifest, SeededCase};
|
|
||||||
pub(crate) use crate::namespace::{
|
|
||||||
can_reuse_namespace, connect_eval_db, default_database, default_namespace, ensure_eval_user,
|
|
||||||
record_namespace_state,
|
|
||||||
};
|
|
||||||
pub(crate) use crate::settings::{enforce_system_settings, load_or_init_system_settings};
|
|
||||||
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use common::storage::db::SurrealDbClient;
|
|
||||||
use tokio::io::AsyncWriteExt;
|
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
args::{self, Config},
|
|
||||||
datasets::ConvertedDataset,
|
|
||||||
slice::{self},
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Grow the slice ledger to contain the target number of cases.
|
|
||||||
pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
|
||||||
let ledger_limit = ledger_target(config);
|
|
||||||
let slice_settings = slice::slice_config_with_limit(config, ledger_limit);
|
|
||||||
let slice =
|
|
||||||
slice::resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
|
|
||||||
info!(
|
|
||||||
slice = slice.manifest.slice_id.as_str(),
|
|
||||||
cases = slice.manifest.case_count,
|
|
||||||
positives = slice.manifest.positive_paragraphs,
|
|
||||||
negatives = slice.manifest.negative_paragraphs,
|
|
||||||
total_paragraphs = slice.manifest.total_paragraphs,
|
|
||||||
"Slice ledger ready"
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
|
|
||||||
slice.manifest.slice_id,
|
|
||||||
slice.manifest.case_count,
|
|
||||||
slice.manifest.positive_paragraphs,
|
|
||||||
slice.manifest.negative_paragraphs
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn ledger_target(config: &Config) -> Option<usize> {
|
|
||||||
match (config.slice_grow, config.limit) {
|
|
||||||
(Some(grow), Some(limit)) => Some(limit.max(grow)),
|
|
||||||
(Some(grow), None) => Some(grow),
|
|
||||||
(None, limit) => limit,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
|
|
||||||
args::ensure_parent(path)?;
|
|
||||||
let mut file = tokio::fs::File::create(path)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
|
|
||||||
for case in cases {
|
|
||||||
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
|
|
||||||
file.write_all(&line).await?;
|
|
||||||
file.write_all(b"\n").await?;
|
|
||||||
}
|
|
||||||
file.flush().await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::cast_precision_loss)]
|
|
||||||
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
|
||||||
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
|
|
||||||
|
|
||||||
info!("Warming HNSW caches with sample queries");
|
|
||||||
|
|
||||||
// Warm up chunk embedding index - just query the embedding table to load HNSW index
|
|
||||||
let _ = db
|
|
||||||
.client
|
|
||||||
.query(
|
|
||||||
r#"SELECT chunk_id
|
|
||||||
FROM text_chunk_embedding
|
|
||||||
WHERE embedding <|1,1|> $embedding
|
|
||||||
LIMIT 5"#,
|
|
||||||
)
|
|
||||||
.bind(("embedding", dummy_embedding.clone()))
|
|
||||||
.await
|
|
||||||
.context("warming text chunk HNSW cache")?;
|
|
||||||
|
|
||||||
// Warm up entity embedding index
|
|
||||||
let _ = db
|
|
||||||
.client
|
|
||||||
.query(
|
|
||||||
r#"SELECT entity_id
|
|
||||||
FROM knowledge_entity_embedding
|
|
||||||
WHERE embedding <|1,1|> $embedding
|
|
||||||
LIMIT 5"#,
|
|
||||||
)
|
|
||||||
.bind(("embedding", dummy_embedding))
|
|
||||||
.await
|
|
||||||
.context("warming knowledge entity HNSW cache")?;
|
|
||||||
|
|
||||||
info!("HNSW cache warming completed");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
use chrono::{DateTime, SecondsFormat, Utc};
|
|
||||||
|
|
||||||
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
|
|
||||||
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn sanitize_model_code(code: &str) -> String {
|
|
||||||
code.chars()
|
|
||||||
.map(|ch| {
|
|
||||||
if ch.is_ascii_alphanumeric() {
|
|
||||||
ch.to_ascii_lowercase()
|
|
||||||
} else {
|
|
||||||
'_'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-export run_evaluation from the pipeline module at crate root
|
|
||||||
pub use crate::pipeline::run_evaluation;
|
|
||||||
@@ -1,13 +1,9 @@
|
|||||||
use std::{
|
use std::{collections::HashMap, fs, path::Path};
|
||||||
collections::HashMap,
|
|
||||||
fs,
|
|
||||||
path::{Path, PathBuf},
|
|
||||||
};
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
|
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
|
||||||
|
|
||||||
use crate::{args::Config, corpus, eval::connect_eval_db, snapshot::DbSnapshotState};
|
use crate::{args::Config, corpus, db::connect_eval_db};
|
||||||
|
|
||||||
pub async fn inspect_question(config: &Config) -> Result<()> {
|
pub async fn inspect_question(config: &Config) -> Result<()> {
|
||||||
let question_id = config
|
let question_id = config
|
||||||
@@ -64,39 +60,26 @@ pub async fn inspect_question(config: &Config) -> Result<()> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let db_state_path = config
|
if let Some(seed) = manifest.metadata.namespace_seed.as_ref() {
|
||||||
.database
|
let ns = seed.namespace.as_str();
|
||||||
.inspect_db_state
|
let db_name = seed.database.as_str();
|
||||||
.clone()
|
match connect_eval_db(config, ns, db_name).await {
|
||||||
.unwrap_or_else(|| default_state_path(config, &manifest));
|
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
|
||||||
if let Some(state) = load_db_state(&db_state_path)? {
|
MissingChunks::None => println!(
|
||||||
if let (Some(ns), Some(db_name)) = (state.namespace.as_deref(), state.database.as_deref()) {
|
"All matching_chunk_ids exist in namespace '{ns}', database '{db_name}'"
|
||||||
match connect_eval_db(config, ns, db_name).await {
|
),
|
||||||
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
|
MissingChunks::Missing(list) => {
|
||||||
MissingChunks::None => println!(
|
println!("Missing chunks in namespace '{ns}', database '{db_name}': {list:?}");
|
||||||
"All matching_chunk_ids exist in namespace '{ns}', database '{db_name}'"
|
|
||||||
),
|
|
||||||
MissingChunks::Missing(list) => println!(
|
|
||||||
"Missing chunks in namespace '{ns}', database '{db_name}': {list:?}"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
println!(
|
|
||||||
"Failed to connect to SurrealDB namespace '{ns}' / database '{db_name}': {err}"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
println!(
|
||||||
|
"Failed to connect to SurrealDB namespace '{ns}' / database '{db_name}': {err}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
println!(
|
|
||||||
"State file {} is missing namespace/database fields; skipping live DB validation",
|
|
||||||
db_state_path.display()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
println!(
|
println!("Corpus manifest has no namespace seed; skipping live DB validation");
|
||||||
"State file {} not found; skipping live DB validation",
|
|
||||||
db_state_path.display()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -137,25 +120,6 @@ fn build_chunk_lookup(manifest: &corpus::CorpusManifest) -> HashMap<String, Chun
|
|||||||
lookup
|
lookup
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_state_path(config: &Config, manifest: &corpus::CorpusManifest) -> PathBuf {
|
|
||||||
config
|
|
||||||
.cache_dir
|
|
||||||
.join("snapshots")
|
|
||||||
.join(&manifest.metadata.dataset_id)
|
|
||||||
.join(&manifest.metadata.slice_id)
|
|
||||||
.join("db/state.json")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_db_state(path: &Path) -> Result<Option<DbSnapshotState>> {
|
|
||||||
if !path.exists() {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
let bytes = fs::read(path).with_context(|| format!("reading db state {}", path.display()))?;
|
|
||||||
let state = serde_json::from_slice(&bytes)
|
|
||||||
.with_context(|| format!("parsing db state {}", path.display()))?;
|
|
||||||
Ok(Some(state))
|
|
||||||
}
|
|
||||||
|
|
||||||
enum MissingChunks {
|
enum MissingChunks {
|
||||||
None,
|
None,
|
||||||
Missing(Vec<String>),
|
Missing(Vec<String>),
|
||||||
|
|||||||
+49
-46
@@ -1,19 +1,17 @@
|
|||||||
mod args;
|
mod args;
|
||||||
mod cache;
|
|
||||||
mod cases;
|
mod cases;
|
||||||
|
mod cli;
|
||||||
|
mod context_stats;
|
||||||
mod corpus;
|
mod corpus;
|
||||||
mod datasets;
|
mod datasets;
|
||||||
mod db_helpers;
|
mod db;
|
||||||
mod eval;
|
|
||||||
mod inspection;
|
mod inspection;
|
||||||
mod namespace;
|
|
||||||
mod openai;
|
mod openai;
|
||||||
mod perf;
|
mod perf;
|
||||||
mod pipeline;
|
mod pipeline;
|
||||||
mod report;
|
mod report;
|
||||||
mod settings;
|
mod settings;
|
||||||
mod slice;
|
mod slice;
|
||||||
mod snapshot;
|
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
@@ -24,7 +22,6 @@ use tracing_subscriber::{fmt, EnvFilter};
|
|||||||
/// Configure `SurrealDB` environment variables for optimal performance
|
/// Configure `SurrealDB` environment variables for optimal performance
|
||||||
#[allow(clippy::arithmetic_side_effects, clippy::unwrap_used)]
|
#[allow(clippy::arithmetic_side_effects, clippy::unwrap_used)]
|
||||||
fn configure_surrealdb_performance(cpu_count: usize) {
|
fn configure_surrealdb_performance(cpu_count: usize) {
|
||||||
// Set environment variables only if they're not already set
|
|
||||||
let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE")
|
let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE")
|
||||||
.unwrap_or_else(|_| (cpu_count * 2).to_string());
|
.unwrap_or_else(|_| (cpu_count * 2).to_string());
|
||||||
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
|
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
|
||||||
@@ -62,12 +59,11 @@ fn configure_surrealdb_performance(cpu_count: usize) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
// Create an explicit multi-threaded runtime with optimized configuration
|
|
||||||
let runtime = Builder::new_multi_thread()
|
let runtime = Builder::new_multi_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.worker_threads(std::thread::available_parallelism()?.get())
|
.worker_threads(std::thread::available_parallelism()?.get())
|
||||||
.max_blocking_threads(std::thread::available_parallelism()?.get())
|
.max_blocking_threads(std::thread::available_parallelism()?.get())
|
||||||
.thread_stack_size(10 * 1024 * 1024) // 10MiB stack size
|
.thread_stack_size(10 * 1024 * 1024)
|
||||||
.thread_name("eval-retrieval-worker")
|
.thread_name("eval-retrieval-worker")
|
||||||
.build()
|
.build()
|
||||||
.context("failed to create tokio runtime")?;
|
.context("failed to create tokio runtime")?;
|
||||||
@@ -77,7 +73,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
async fn async_main() -> anyhow::Result<()> {
|
async fn async_main() -> anyhow::Result<()> {
|
||||||
// Log runtime configuration
|
|
||||||
let cpu_count = std::thread::available_parallelism()?.get();
|
let cpu_count = std::thread::available_parallelism()?.get();
|
||||||
info!(
|
info!(
|
||||||
cpu_cores = cpu_count,
|
cpu_cores = cpu_count,
|
||||||
@@ -87,7 +82,6 @@ async fn async_main() -> anyhow::Result<()> {
|
|||||||
"Started multi-threaded tokio runtime"
|
"Started multi-threaded tokio runtime"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Configure SurrealDB environment variables for better performance
|
|
||||||
configure_surrealdb_performance(cpu_count);
|
configure_surrealdb_performance(cpu_count);
|
||||||
|
|
||||||
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
|
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
|
||||||
@@ -97,13 +91,22 @@ async fn async_main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let parsed = args::parse()?;
|
let parsed = args::parse()?;
|
||||||
|
|
||||||
// Clap handles help automatically, so we don't need to check for it manually
|
|
||||||
|
|
||||||
if parsed.config.inspect_question.is_some() {
|
if parsed.config.inspect_question.is_some() {
|
||||||
inspection::inspect_question(&parsed.config).await?;
|
inspection::inspect_question(&parsed.config).await?;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if parsed.config.status {
|
||||||
|
let status = cli::collect_status(&parsed.config).await?;
|
||||||
|
cli::print_status(&status);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.config.warm {
|
||||||
|
cli::warm(&parsed.config).await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
let dataset_kind = parsed.config.dataset;
|
let dataset_kind = parsed.config.dataset;
|
||||||
|
|
||||||
if parsed.config.convert_only {
|
if parsed.config.convert_only {
|
||||||
@@ -115,7 +118,6 @@ async fn async_main() -> anyhow::Result<()> {
|
|||||||
parsed.config.raw_dataset_path.as_path(),
|
parsed.config.raw_dataset_path.as_path(),
|
||||||
dataset_kind,
|
dataset_kind,
|
||||||
parsed.config.llm_mode,
|
parsed.config.llm_mode,
|
||||||
parsed.config.context_token_limit(),
|
|
||||||
)
|
)
|
||||||
.with_context(|| {
|
.with_context(|| {
|
||||||
format!(
|
format!(
|
||||||
@@ -124,56 +126,52 @@ async fn async_main() -> anyhow::Result<()> {
|
|||||||
parsed.config.raw_dataset_path.display()
|
parsed.config.raw_dataset_path.display()
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
crate::datasets::write_converted(&dataset, parsed.config.converted_dataset_path.as_path())
|
let store_dir = datasets::store_dir_for(&parsed.config.converted_dataset_path);
|
||||||
.with_context(|| {
|
datasets::write_sharded(&dataset, &store_dir)?;
|
||||||
format!(
|
datasets::prebuild_catalog_slices(&dataset, &parsed.config)?;
|
||||||
"writing converted dataset to {}",
|
println!("Converted dataset written under {}", store_dir.display());
|
||||||
parsed.config.converted_dataset_path.display()
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
println!(
|
|
||||||
"Converted dataset written to {}",
|
|
||||||
parsed.config.converted_dataset_path.display()
|
|
||||||
);
|
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if parsed.config.require_ready {
|
||||||
|
cli::ensure_query_ready(&parsed.config).await?;
|
||||||
|
}
|
||||||
|
|
||||||
info!(dataset = dataset_kind.id(), "Preparing converted dataset");
|
info!(dataset = dataset_kind.id(), "Preparing converted dataset");
|
||||||
let dataset = crate::datasets::ensure_converted(
|
let loaded =
|
||||||
dataset_kind,
|
crate::datasets::prepare_dataset(dataset_kind, &parsed.config).with_context(|| {
|
||||||
parsed.config.raw_dataset_path.as_path(),
|
format!(
|
||||||
parsed.config.converted_dataset_path.as_path(),
|
"preparing converted dataset at {}",
|
||||||
parsed.config.force_convert,
|
parsed.config.converted_dataset_path.display()
|
||||||
parsed.config.llm_mode,
|
)
|
||||||
parsed.config.context_token_limit(),
|
})?;
|
||||||
)
|
|
||||||
.with_context(|| {
|
|
||||||
format!(
|
|
||||||
"preparing converted dataset at {}",
|
|
||||||
parsed.config.converted_dataset_path.display()
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
questions = dataset
|
questions = loaded
|
||||||
|
.dataset
|
||||||
.paragraphs
|
.paragraphs
|
||||||
.iter()
|
.iter()
|
||||||
.map(|p| p.questions.len())
|
.map(|p| p.questions.len())
|
||||||
.sum::<usize>(),
|
.sum::<usize>(),
|
||||||
paragraphs = dataset.paragraphs.len(),
|
paragraphs = loaded.dataset.paragraphs.len(),
|
||||||
dataset = dataset.metadata.id.as_str(),
|
partial = loaded.partial,
|
||||||
|
dataset = loaded.dataset.metadata.id.as_str(),
|
||||||
"Dataset ready"
|
"Dataset ready"
|
||||||
);
|
);
|
||||||
|
|
||||||
if parsed.config.slice_grow.is_some() {
|
if parsed.config.slice_grow.is_some() {
|
||||||
eval::grow_slice(&dataset, &parsed.config).context("growing slice ledger")?;
|
slice::grow_slice(&loaded.dataset, &parsed.config).context("growing slice ledger")?;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Running retrieval evaluation");
|
info!("Running retrieval evaluation");
|
||||||
let summary = eval::run_evaluation(&dataset, &parsed.config)
|
let summary = pipeline::run_evaluation(
|
||||||
.await
|
&loaded.dataset,
|
||||||
.context("running retrieval evaluation")?;
|
&parsed.config,
|
||||||
|
Some(loaded.content_checksum.as_str()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.context("running retrieval evaluation")?;
|
||||||
|
|
||||||
let report = report::write_reports(
|
let report = report::write_reports(
|
||||||
&summary,
|
&summary,
|
||||||
@@ -226,12 +224,17 @@ async fn async_main() -> anyhow::Result<()> {
|
|||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
println!(
|
println!(
|
||||||
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
|
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) | Retrieved context: {chunks} chunks, {tokens} tokens ({tokenizer}, avg {avg_tokens:.0}/query, p95 {p95}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
|
||||||
summary.dataset_label,
|
summary.dataset_label,
|
||||||
k = summary.k,
|
k = summary.k,
|
||||||
precision = summary.precision,
|
precision = summary.precision,
|
||||||
correct = summary.correct,
|
correct = summary.correct,
|
||||||
retrieval_total = summary.retrieval_cases,
|
retrieval_total = summary.retrieval_cases,
|
||||||
|
chunks = summary.retrieved_context.total_chunks,
|
||||||
|
tokens = summary.retrieved_context.total_tokens,
|
||||||
|
tokenizer = summary.retrieved_context.tokenizer,
|
||||||
|
avg_tokens = summary.retrieved_context.avg_tokens_per_query,
|
||||||
|
p95 = summary.retrieved_context.p95_tokens_per_query,
|
||||||
json = report.paths.json.display(),
|
json = report.paths.json.display(),
|
||||||
md = report.paths.markdown.display(),
|
md = report.paths.markdown.display(),
|
||||||
history = report.history_path.display(),
|
history = report.history_path.display(),
|
||||||
|
|||||||
@@ -1,9 +1,24 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use async_openai::{config::OpenAIConfig, Client};
|
use async_openai::{config::OpenAIConfig, Client};
|
||||||
|
|
||||||
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
|
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
|
||||||
|
|
||||||
pub fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
|
pub fn ingestion_openai_client(
|
||||||
|
include_entities: bool,
|
||||||
|
) -> Result<(Arc<Client<OpenAIConfig>>, Option<String>)> {
|
||||||
|
if include_entities {
|
||||||
|
let (client, base_url) = build_client_from_env().context(
|
||||||
|
"OPENAI_API_KEY must be set when --include-entities is enabled (entity extraction uses OpenAI)",
|
||||||
|
)?;
|
||||||
|
Ok((Arc::new(client), Some(base_url)))
|
||||||
|
} else {
|
||||||
|
Ok((Arc::new(Client::with_config(OpenAIConfig::default())), None))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
|
||||||
let api_key = std::env::var("OPENAI_API_KEY")
|
let api_key = std::env::var("OPENAI_API_KEY")
|
||||||
.context("OPENAI_API_KEY must be set to run retrieval evaluations")?;
|
.context("OPENAI_API_KEY must be set to run retrieval evaluations")?;
|
||||||
let base_url =
|
let base_url =
|
||||||
|
|||||||
+11
-7
@@ -7,8 +7,8 @@ use anyhow::{Context, Result};
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
args,
|
args,
|
||||||
eval::EvaluationSummary,
|
|
||||||
report::{self, EvaluationReport},
|
report::{self, EvaluationReport},
|
||||||
|
types::EvaluationSummary,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn mirror_perf_outputs(
|
pub fn mirror_perf_outputs(
|
||||||
@@ -91,23 +91,26 @@ fn format_duration(value: Option<u128>) -> String {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::eval::{EvaluationStageTimings, PerformanceTimings};
|
use crate::types::{
|
||||||
|
EvaluationStageTimings, LatencyStats, PerformanceTimings, StageLatency,
|
||||||
|
StageLatencyBreakdown,
|
||||||
|
};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
fn sample_latency() -> crate::eval::LatencyStats {
|
fn sample_latency() -> LatencyStats {
|
||||||
crate::eval::LatencyStats {
|
LatencyStats {
|
||||||
avg: 10.0,
|
avg: 10.0,
|
||||||
p50: 8,
|
p50: 8,
|
||||||
p95: 15,
|
p95: 15,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
|
fn sample_stage_latency() -> StageLatencyBreakdown {
|
||||||
crate::eval::StageLatencyBreakdown {
|
StageLatencyBreakdown {
|
||||||
stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
|
stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|stage| crate::eval::StageLatency {
|
.map(|stage| StageLatency {
|
||||||
stage: stage.to_string(),
|
stage: stage.to_string(),
|
||||||
stats: sample_latency(),
|
stats: sample_latency(),
|
||||||
})
|
})
|
||||||
@@ -206,6 +209,7 @@ mod tests {
|
|||||||
chunk_vector_take: 20,
|
chunk_vector_take: 20,
|
||||||
chunk_fts_take: 20,
|
chunk_fts_take: 20,
|
||||||
max_chunks_per_entity: 4,
|
max_chunks_per_entity: 4,
|
||||||
|
retrieved_context: crate::context_stats::aggregate_context_stats(&[]),
|
||||||
cases: Vec::new(),
|
cases: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,11 +20,11 @@ use retrieval_pipeline::{
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
args::Config,
|
args::Config,
|
||||||
cache::EmbeddingCache,
|
cases::SeededCase,
|
||||||
corpus,
|
corpus,
|
||||||
datasets::ConvertedDataset,
|
datasets::ConvertedDataset,
|
||||||
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
|
slice,
|
||||||
slice, snapshot,
|
types::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[allow(clippy::struct_excessive_bools)]
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
@@ -41,12 +41,10 @@ pub(super) struct EvaluationContext<'a> {
|
|||||||
pub namespace: String,
|
pub namespace: String,
|
||||||
pub database: String,
|
pub database: String,
|
||||||
pub db: Option<SurrealDbClient>,
|
pub db: Option<SurrealDbClient>,
|
||||||
pub descriptor: Option<snapshot::Descriptor>,
|
|
||||||
pub settings: Option<SystemSettings>,
|
pub settings: Option<SystemSettings>,
|
||||||
pub settings_missing: bool,
|
pub settings_missing: bool,
|
||||||
pub must_reapply_settings: bool,
|
pub must_reapply_settings: bool,
|
||||||
pub embedding_provider: Option<EmbeddingProvider>,
|
pub embedding_provider: Option<EmbeddingProvider>,
|
||||||
pub embedding_cache: Option<EmbeddingCache>,
|
|
||||||
pub openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
pub openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||||
pub openai_base_url: Option<String>,
|
pub openai_base_url: Option<String>,
|
||||||
pub expected_fingerprint: Option<String>,
|
pub expected_fingerprint: Option<String>,
|
||||||
@@ -67,13 +65,19 @@ pub(super) struct EvaluationContext<'a> {
|
|||||||
pub summary: Option<EvaluationSummary>,
|
pub summary: Option<EvaluationSummary>,
|
||||||
pub diagnostics_path: Option<PathBuf>,
|
pub diagnostics_path: Option<PathBuf>,
|
||||||
pub diagnostics_enabled: bool,
|
pub diagnostics_enabled: bool,
|
||||||
|
pub content_checksum: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> EvaluationContext<'a> {
|
impl<'a> EvaluationContext<'a> {
|
||||||
pub fn new(dataset: &'a ConvertedDataset, config: &'a Config) -> Self {
|
pub fn new(
|
||||||
|
dataset: &'a ConvertedDataset,
|
||||||
|
config: &'a Config,
|
||||||
|
content_checksum: Option<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
dataset,
|
dataset,
|
||||||
config,
|
config,
|
||||||
|
content_checksum,
|
||||||
stage_timings: EvaluationStageTimings::default(),
|
stage_timings: EvaluationStageTimings::default(),
|
||||||
ledger_limit: None,
|
ledger_limit: None,
|
||||||
slice_settings: None,
|
slice_settings: None,
|
||||||
@@ -84,12 +88,10 @@ impl<'a> EvaluationContext<'a> {
|
|||||||
namespace: String::new(),
|
namespace: String::new(),
|
||||||
database: String::new(),
|
database: String::new(),
|
||||||
db: None,
|
db: None,
|
||||||
descriptor: None,
|
|
||||||
settings: None,
|
settings: None,
|
||||||
settings_missing: false,
|
settings_missing: false,
|
||||||
must_reapply_settings: false,
|
must_reapply_settings: false,
|
||||||
embedding_provider: None,
|
embedding_provider: None,
|
||||||
embedding_cache: None,
|
|
||||||
openai_client: None,
|
openai_client: None,
|
||||||
openai_base_url: None,
|
openai_base_url: None,
|
||||||
expected_fingerprint: None,
|
expected_fingerprint: None,
|
||||||
@@ -133,12 +135,6 @@ impl<'a> EvaluationContext<'a> {
|
|||||||
.ok_or_else(|| anyhow!("database connection missing"))
|
.ok_or_else(|| anyhow!("database connection missing"))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn descriptor(&self) -> Result<&snapshot::Descriptor> {
|
|
||||||
self.descriptor
|
|
||||||
.as_ref()
|
|
||||||
.ok_or_else(|| anyhow!("snapshot descriptor unavailable"))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn embedding_provider(&self) -> Result<&EmbeddingProvider> {
|
pub fn embedding_provider(&self) -> Result<&EmbeddingProvider> {
|
||||||
self.embedding_provider
|
self.embedding_provider
|
||||||
.as_ref()
|
.as_ref()
|
||||||
@@ -159,6 +155,10 @@ impl<'a> EvaluationContext<'a> {
|
|||||||
.ok_or_else(|| anyhow!("corpus handle missing"))
|
.ok_or_else(|| anyhow!("corpus handle missing"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn content_checksum(&self) -> Option<&str> {
|
||||||
|
self.content_checksum.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn evaluation_user(&self) -> Result<&User> {
|
pub fn evaluation_user(&self) -> Result<&User> {
|
||||||
self.eval_user
|
self.eval_user
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
|
use crate::{args, types::CaseDiagnostics};
|
||||||
|
|
||||||
|
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
|
||||||
|
args::ensure_parent(path)?;
|
||||||
|
let mut file = tokio::fs::File::create(path)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
|
||||||
|
for case in cases {
|
||||||
|
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
|
||||||
|
file.write_all(&line).await?;
|
||||||
|
file.write_all(b"\n").await?;
|
||||||
|
}
|
||||||
|
file.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
mod context;
|
mod context;
|
||||||
|
mod diagnostics;
|
||||||
mod stages;
|
mod stages;
|
||||||
mod state;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
||||||
@@ -8,20 +8,40 @@ use crate::{args::Config, datasets::ConvertedDataset, types::EvaluationSummary};
|
|||||||
|
|
||||||
use context::EvaluationContext;
|
use context::EvaluationContext;
|
||||||
|
|
||||||
|
async fn run_through_namespace<'a>(
|
||||||
|
dataset: &'a ConvertedDataset,
|
||||||
|
config: &'a Config,
|
||||||
|
content_checksum: Option<String>,
|
||||||
|
) -> Result<EvaluationContext<'a>> {
|
||||||
|
let mut ctx = EvaluationContext::new(dataset, config, content_checksum);
|
||||||
|
stages::prepare_slice(&mut ctx).await?;
|
||||||
|
stages::prepare_db(&mut ctx).await?;
|
||||||
|
stages::prepare_corpus(&mut ctx).await?;
|
||||||
|
stages::prepare_namespace(&mut ctx).await?;
|
||||||
|
Ok(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn warm_evaluation(
|
||||||
|
dataset: &ConvertedDataset,
|
||||||
|
config: &Config,
|
||||||
|
content_checksum: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
let _ctx = run_through_namespace(dataset, config, Some(content_checksum.to_string())).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn run_evaluation(
|
pub async fn run_evaluation(
|
||||||
dataset: &ConvertedDataset,
|
dataset: &ConvertedDataset,
|
||||||
config: &Config,
|
config: &Config,
|
||||||
|
content_checksum: Option<&str>,
|
||||||
) -> Result<EvaluationSummary> {
|
) -> Result<EvaluationSummary> {
|
||||||
let mut ctx = EvaluationContext::new(dataset, config);
|
let mut ctx = EvaluationContext::new(dataset, config, content_checksum.map(str::to_string));
|
||||||
let machine = state::ready();
|
stages::prepare_slice(&mut ctx).await?;
|
||||||
|
stages::prepare_db(&mut ctx).await?;
|
||||||
let machine = stages::prepare_slice(machine, &mut ctx).await?;
|
stages::prepare_corpus(&mut ctx).await?;
|
||||||
let machine = stages::prepare_db(machine, &mut ctx).await?;
|
stages::prepare_namespace(&mut ctx).await?;
|
||||||
let machine = stages::prepare_corpus(machine, &mut ctx).await?;
|
stages::run_queries(&mut ctx).await?;
|
||||||
let machine = stages::prepare_namespace(machine, &mut ctx).await?;
|
stages::summarize(&mut ctx).await?;
|
||||||
let machine = stages::run_queries(machine, &mut ctx).await?;
|
stages::finalize(&mut ctx).await?;
|
||||||
let machine = stages::summarize(machine, &mut ctx).await?;
|
|
||||||
let _ = stages::finalize(machine, &mut ctx).await?;
|
|
||||||
|
|
||||||
ctx.into_summary()
|
ctx.into_summary()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,18 +3,12 @@ use std::time::Instant;
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::eval::write_chunk_diagnostics;
|
|
||||||
|
|
||||||
use super::super::{
|
use super::super::{
|
||||||
context::{EvalStage, EvaluationContext},
|
context::{EvalStage, EvaluationContext},
|
||||||
state::{Completed, EvaluationMachine, Summarized},
|
diagnostics::write_chunk_diagnostics,
|
||||||
};
|
};
|
||||||
use super::{map_guard_error, StageResult};
|
|
||||||
|
|
||||||
pub(crate) async fn finalize(
|
pub(crate) async fn finalize(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||||
machine: EvaluationMachine<(), Summarized>,
|
|
||||||
ctx: &mut EvaluationContext<'_>,
|
|
||||||
) -> StageResult<Completed> {
|
|
||||||
let stage = EvalStage::Finalize;
|
let stage = EvalStage::Finalize;
|
||||||
info!(
|
info!(
|
||||||
evaluation_stage = stage.label(),
|
evaluation_stage = stage.label(),
|
||||||
@@ -22,13 +16,6 @@ pub(crate) async fn finalize(
|
|||||||
);
|
);
|
||||||
let started = Instant::now();
|
let started = Instant::now();
|
||||||
|
|
||||||
if let Some(cache) = ctx.embedding_cache.as_ref() {
|
|
||||||
cache
|
|
||||||
.persist()
|
|
||||||
.await
|
|
||||||
.context("persisting embedding cache")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(path) = ctx.diagnostics_path.as_ref() {
|
if let Some(path) = ctx.diagnostics_path.as_ref() {
|
||||||
if ctx.diagnostics_enabled {
|
if ctx.diagnostics_enabled {
|
||||||
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
|
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
|
||||||
@@ -53,7 +40,5 @@ pub(crate) async fn finalize(
|
|||||||
"completed evaluation stage"
|
"completed evaluation stage"
|
||||||
);
|
);
|
||||||
|
|
||||||
machine
|
Ok(())
|
||||||
.finalize()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("finalize", &guard))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,14 +13,3 @@ pub(crate) use prepare_namespace::prepare_namespace;
|
|||||||
pub(crate) use prepare_slice::prepare_slice;
|
pub(crate) use prepare_slice::prepare_slice;
|
||||||
pub(crate) use run_queries::run_queries;
|
pub(crate) use run_queries::run_queries;
|
||||||
pub(crate) use summarize::summarize;
|
pub(crate) use summarize::summarize;
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use state_machines::core::GuardError;
|
|
||||||
|
|
||||||
use super::state::EvaluationMachine;
|
|
||||||
|
|
||||||
fn map_guard_error(event: &str, guard: &GuardError) -> anyhow::Error {
|
|
||||||
anyhow::anyhow!("invalid evaluation pipeline transition during {event}: {guard:?}")
|
|
||||||
}
|
|
||||||
|
|
||||||
type StageResult<S> = Result<EvaluationMachine<(), S>>;
|
|
||||||
|
|||||||
@@ -3,19 +3,12 @@ use std::time::Instant;
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{corpus, eval::can_reuse_namespace, slice, snapshot};
|
use crate::{corpus, db::can_reuse_namespace, slice};
|
||||||
|
|
||||||
use super::super::{
|
use super::super::context::{EvalStage, EvaluationContext};
|
||||||
context::{EvalStage, EvaluationContext},
|
|
||||||
state::{CorpusReady, DbReady, EvaluationMachine},
|
|
||||||
};
|
|
||||||
use super::{map_guard_error, StageResult};
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub(crate) async fn prepare_corpus(
|
pub(crate) async fn prepare_corpus(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||||
machine: EvaluationMachine<(), DbReady>,
|
|
||||||
ctx: &mut EvaluationContext<'_>,
|
|
||||||
) -> StageResult<CorpusReady> {
|
|
||||||
let stage = EvalStage::PrepareCorpus;
|
let stage = EvalStage::PrepareCorpus;
|
||||||
info!(
|
info!(
|
||||||
evaluation_stage = stage.label(),
|
evaluation_stage = stage.label(),
|
||||||
@@ -31,13 +24,13 @@ pub(crate) async fn prepare_corpus(
|
|||||||
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
|
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
|
||||||
.context("selecting slice window for corpus preparation")?;
|
.context("selecting slice window for corpus preparation")?;
|
||||||
|
|
||||||
let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider()?);
|
|
||||||
let ingestion_config = corpus::make_ingestion_config(config);
|
let ingestion_config = corpus::make_ingestion_config(config);
|
||||||
let expected_fingerprint = corpus::compute_ingestion_fingerprint(
|
let expected_fingerprint = corpus::compute_ingestion_fingerprint(
|
||||||
ctx.dataset(),
|
ctx.dataset(),
|
||||||
slice,
|
slice,
|
||||||
config.converted_dataset_path.as_path(),
|
config.converted_dataset_path.as_path(),
|
||||||
&ingestion_config,
|
&ingestion_config,
|
||||||
|
ctx.content_checksum(),
|
||||||
)?;
|
)?;
|
||||||
let base_dir = corpus::cached_corpus_dir(
|
let base_dir = corpus::cached_corpus_dir(
|
||||||
&cache_settings,
|
&cache_settings,
|
||||||
@@ -47,19 +40,18 @@ pub(crate) async fn prepare_corpus(
|
|||||||
|
|
||||||
if !config.reseed_slice {
|
if !config.reseed_slice {
|
||||||
let requested_cases = window.cases.len();
|
let requested_cases = window.cases.len();
|
||||||
if can_reuse_namespace(
|
if let Some(manifest) = corpus::load_cached_manifest(&base_dir)? {
|
||||||
ctx.db()?,
|
if can_reuse_namespace(
|
||||||
&descriptor,
|
ctx.db()?,
|
||||||
&ctx.namespace,
|
&manifest,
|
||||||
&ctx.database,
|
&embedding_provider,
|
||||||
ctx.dataset().metadata.id.as_str(),
|
&ctx.namespace,
|
||||||
slice.manifest.slice_id.as_str(),
|
&ctx.database,
|
||||||
expected_fingerprint.as_str(),
|
expected_fingerprint.as_str(),
|
||||||
requested_cases,
|
requested_cases,
|
||||||
)
|
)
|
||||||
.await?
|
.await?
|
||||||
{
|
{
|
||||||
if let Some(manifest) = corpus::load_cached_manifest(&base_dir)? {
|
|
||||||
info!(
|
info!(
|
||||||
cache = %base_dir.display(),
|
cache = %base_dir.display(),
|
||||||
namespace = ctx.namespace.as_str(),
|
namespace = ctx.namespace.as_str(),
|
||||||
@@ -70,7 +62,6 @@ pub(crate) async fn prepare_corpus(
|
|||||||
ctx.corpus_handle = Some(corpus_handle);
|
ctx.corpus_handle = Some(corpus_handle);
|
||||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||||
ctx.ingestion_duration_ms = 0;
|
ctx.ingestion_duration_ms = 0;
|
||||||
ctx.descriptor = Some(descriptor);
|
|
||||||
|
|
||||||
let elapsed = started.elapsed();
|
let elapsed = started.elapsed();
|
||||||
ctx.record_stage_duration(stage, elapsed);
|
ctx.record_stage_duration(stage, elapsed);
|
||||||
@@ -80,14 +71,8 @@ pub(crate) async fn prepare_corpus(
|
|||||||
"completed evaluation stage"
|
"completed evaluation stage"
|
||||||
);
|
);
|
||||||
|
|
||||||
return machine
|
return Ok(());
|
||||||
.prepare_corpus()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("prepare_corpus", &guard));
|
|
||||||
}
|
}
|
||||||
info!(
|
|
||||||
cache = %base_dir.display(),
|
|
||||||
"Namespace reusable but cached manifest missing; regenerating corpus"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,6 +88,7 @@ pub(crate) async fn prepare_corpus(
|
|||||||
openai_client,
|
openai_client,
|
||||||
&eval_user_id,
|
&eval_user_id,
|
||||||
config.converted_dataset_path.as_path(),
|
config.converted_dataset_path.as_path(),
|
||||||
|
ctx.content_checksum(),
|
||||||
ingestion_config.clone(),
|
ingestion_config.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -126,7 +112,6 @@ pub(crate) async fn prepare_corpus(
|
|||||||
ctx.corpus_handle = Some(corpus_handle);
|
ctx.corpus_handle = Some(corpus_handle);
|
||||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||||
ctx.ingestion_duration_ms = ingestion_duration_ms;
|
ctx.ingestion_duration_ms = ingestion_duration_ms;
|
||||||
ctx.descriptor = Some(descriptor);
|
|
||||||
|
|
||||||
let elapsed = started.elapsed();
|
let elapsed = started.elapsed();
|
||||||
ctx.record_stage_duration(stage, elapsed);
|
ctx.record_stage_duration(stage, elapsed);
|
||||||
@@ -136,7 +121,5 @@ pub(crate) async fn prepare_corpus(
|
|||||||
"completed evaluation stage"
|
"completed evaluation stage"
|
||||||
);
|
);
|
||||||
|
|
||||||
machine
|
Ok(())
|
||||||
.prepare_corpus()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("prepare_corpus", &guard))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,28 +1,19 @@
|
|||||||
use std::{sync::Arc, time::Instant};
|
use std::time::Instant;
|
||||||
|
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
args::EmbeddingBackend,
|
args::EmbeddingBackend,
|
||||||
cache::EmbeddingCache,
|
db::{connect_eval_db, sanitize_model_code},
|
||||||
eval::{
|
|
||||||
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
|
|
||||||
},
|
|
||||||
openai,
|
openai,
|
||||||
|
settings::{enforce_system_settings, load_or_init_system_settings},
|
||||||
};
|
};
|
||||||
use common::utils::embedding::{default_embedding_pool_size, EmbeddingProvider};
|
use common::utils::embedding::{default_embedding_pool_size, EmbeddingProvider};
|
||||||
|
|
||||||
use super::super::{
|
use super::super::context::{EvalStage, EvaluationContext};
|
||||||
context::{EvalStage, EvaluationContext},
|
|
||||||
state::{DbReady, EvaluationMachine, SlicePrepared},
|
|
||||||
};
|
|
||||||
use super::{map_guard_error, StageResult};
|
|
||||||
|
|
||||||
pub(crate) async fn prepare_db(
|
pub(crate) async fn prepare_db(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||||
machine: EvaluationMachine<(), SlicePrepared>,
|
|
||||||
ctx: &mut EvaluationContext<'_>,
|
|
||||||
) -> StageResult<DbReady> {
|
|
||||||
let stage = EvalStage::PrepareDb;
|
let stage = EvalStage::PrepareDb;
|
||||||
info!(
|
info!(
|
||||||
evaluation_stage = stage.label(),
|
evaluation_stage = stage.label(),
|
||||||
@@ -36,19 +27,18 @@ pub(crate) async fn prepare_db(
|
|||||||
|
|
||||||
let db = connect_eval_db(config, &namespace, &database).await?;
|
let db = connect_eval_db(config, &namespace, &database).await?;
|
||||||
|
|
||||||
let (raw_openai_client, openai_base_url) =
|
let (openai_client, openai_base_url) =
|
||||||
openai::build_client_from_env().context("building OpenAI client")?;
|
openai::ingestion_openai_client(config.ingest.include_entities)
|
||||||
let openai_client = Arc::new(raw_openai_client);
|
.context("building OpenAI client for ingestion")?;
|
||||||
|
|
||||||
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
|
|
||||||
let embedding_provider = match config.embedding_backend {
|
let embedding_provider = match config.embedding_backend {
|
||||||
crate::args::EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed(
|
EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed(
|
||||||
config.embedding_model.clone(),
|
config.embedding_model.clone(),
|
||||||
default_embedding_pool_size(),
|
default_embedding_pool_size(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.context("creating FastEmbed provider")?,
|
.context("creating FastEmbed provider")?,
|
||||||
crate::args::EmbeddingBackend::Hashed => {
|
EmbeddingBackend::Hashed => {
|
||||||
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
|
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -68,12 +58,14 @@ pub(crate) async fn prepare_db(
|
|||||||
dimension = provider_dimension,
|
dimension = provider_dimension,
|
||||||
"Embedding provider initialised"
|
"Embedding provider initialised"
|
||||||
);
|
);
|
||||||
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
|
if let Some(base_url) = &openai_base_url {
|
||||||
|
info!(openai_base_url = %base_url, "OpenAI client configured for entity ingestion");
|
||||||
|
}
|
||||||
|
|
||||||
let (mut settings, settings_missing) =
|
let (mut settings, settings_missing) =
|
||||||
load_or_init_system_settings(&db, provider_dimension).await?;
|
load_or_init_system_settings(&db, provider_dimension).await?;
|
||||||
|
|
||||||
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
|
if config.embedding_backend == EmbeddingBackend::FastEmbed {
|
||||||
if let Some(model_code) = embedding_provider.model_code() {
|
if let Some(model_code) = embedding_provider.model_code() {
|
||||||
let sanitized = sanitize_model_code(&model_code);
|
let sanitized = sanitize_model_code(&model_code);
|
||||||
let path = config.cache_dir.join(format!("{sanitized}.json"));
|
let path = config.cache_dir.join(format!("{sanitized}.json"));
|
||||||
@@ -83,15 +75,8 @@ pub(crate) async fn prepare_db(
|
|||||||
.with_context(|| format!("removing stale cache {}", path.display()))
|
.with_context(|| format!("removing stale cache {}", path.display()))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
let cache = EmbeddingCache::load(&path).await?;
|
|
||||||
info!(path = %path.display(), "Embedding cache ready");
|
|
||||||
Some(cache)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let must_reapply_settings = settings_missing;
|
let must_reapply_settings = settings_missing;
|
||||||
let defer_initial_enforce = settings_missing && !config.reseed_slice;
|
let defer_initial_enforce = settings_missing && !config.reseed_slice;
|
||||||
@@ -104,9 +89,8 @@ pub(crate) async fn prepare_db(
|
|||||||
ctx.must_reapply_settings = must_reapply_settings;
|
ctx.must_reapply_settings = must_reapply_settings;
|
||||||
ctx.settings = Some(settings);
|
ctx.settings = Some(settings);
|
||||||
ctx.embedding_provider = Some(embedding_provider);
|
ctx.embedding_provider = Some(embedding_provider);
|
||||||
ctx.embedding_cache = embedding_cache;
|
|
||||||
ctx.openai_client = Some(openai_client);
|
ctx.openai_client = Some(openai_client);
|
||||||
ctx.openai_base_url = Some(openai_base_url);
|
ctx.openai_base_url = openai_base_url;
|
||||||
|
|
||||||
let elapsed = started.elapsed();
|
let elapsed = started.elapsed();
|
||||||
ctx.record_stage_duration(stage, elapsed);
|
ctx.record_stage_duration(stage, elapsed);
|
||||||
@@ -116,7 +100,5 @@ pub(crate) async fn prepare_db(
|
|||||||
"completed evaluation stage"
|
"completed evaluation stage"
|
||||||
);
|
);
|
||||||
|
|
||||||
machine
|
Ok(())
|
||||||
.prepare_db()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("prepare_db", &guard))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,25 +5,19 @@ use common::storage::types::system_settings::SystemSettings;
|
|||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
cases::cases_from_manifest,
|
||||||
corpus,
|
corpus,
|
||||||
db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace},
|
db::{
|
||||||
eval::{
|
can_reuse_namespace, ensure_eval_user, record_namespace_seed, recreate_indexes,
|
||||||
can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user,
|
reset_namespace, warm_hnsw_cache,
|
||||||
record_namespace_state, warm_hnsw_cache,
|
|
||||||
},
|
},
|
||||||
|
settings::enforce_system_settings,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::super::{
|
use super::super::context::{EvalStage, EvaluationContext};
|
||||||
context::{EvalStage, EvaluationContext},
|
|
||||||
state::{CorpusReady, EvaluationMachine, NamespaceReady},
|
|
||||||
};
|
|
||||||
use super::{map_guard_error, StageResult};
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub(crate) async fn prepare_namespace(
|
pub(crate) async fn prepare_namespace(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||||
machine: EvaluationMachine<(), CorpusReady>,
|
|
||||||
ctx: &mut EvaluationContext<'_>,
|
|
||||||
) -> StageResult<NamespaceReady> {
|
|
||||||
let stage = EvalStage::PrepareNamespace;
|
let stage = EvalStage::PrepareNamespace;
|
||||||
info!(
|
info!(
|
||||||
evaluation_stage = stage.label(),
|
evaluation_stage = stage.label(),
|
||||||
@@ -32,7 +26,6 @@ pub(crate) async fn prepare_namespace(
|
|||||||
let started = Instant::now();
|
let started = Instant::now();
|
||||||
|
|
||||||
let config = ctx.config();
|
let config = ctx.config();
|
||||||
let dataset = ctx.dataset();
|
|
||||||
let expected_fingerprint = ctx
|
let expected_fingerprint = ctx
|
||||||
.expected_fingerprint
|
.expected_fingerprint
|
||||||
.as_deref()
|
.as_deref()
|
||||||
@@ -60,20 +53,16 @@ pub(crate) async fn prepare_namespace(
|
|||||||
|
|
||||||
let mut namespace_reused = false;
|
let mut namespace_reused = false;
|
||||||
if !config.reseed_slice {
|
if !config.reseed_slice {
|
||||||
namespace_reused = {
|
namespace_reused = can_reuse_namespace(
|
||||||
let slice = ctx.slice()?;
|
ctx.db()?,
|
||||||
can_reuse_namespace(
|
base_manifest,
|
||||||
ctx.db()?,
|
&embedding_provider,
|
||||||
ctx.descriptor()?,
|
&namespace,
|
||||||
&namespace,
|
&database,
|
||||||
&database,
|
expected_fingerprint.as_str(),
|
||||||
dataset.metadata.id.as_str(),
|
requested_cases,
|
||||||
slice.manifest.slice_id.as_str(),
|
)
|
||||||
expected_fingerprint.as_str(),
|
.await?;
|
||||||
requested_cases,
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut namespace_seed_ms = None;
|
let mut namespace_seed_ms = None;
|
||||||
@@ -114,34 +103,20 @@ pub(crate) async fn prepare_namespace(
|
|||||||
"Seeding ingestion corpus into SurrealDB"
|
"Seeding ingestion corpus into SurrealDB"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let indexes_disabled = remove_all_indexes(ctx.db()?).await.is_ok();
|
|
||||||
|
|
||||||
let seed_start = Instant::now();
|
let seed_start = Instant::now();
|
||||||
corpus::seed_manifest_into_db(ctx.db()?, &manifest_for_seed)
|
corpus::seed_manifest_into_db(ctx.db()?, &manifest_for_seed)
|
||||||
.await
|
.await
|
||||||
.context("seeding ingestion corpus from manifest")?;
|
.context("seeding ingestion corpus from manifest")?;
|
||||||
namespace_seed_ms = Some(seed_start.elapsed().as_millis());
|
namespace_seed_ms = Some(seed_start.elapsed().as_millis());
|
||||||
|
|
||||||
// Recreate indexes AFTER data is loaded (correct bulk loading pattern)
|
info!("Recreating indexes after seeding data");
|
||||||
if indexes_disabled {
|
recreate_indexes(ctx.db()?, embedding_provider.dimension())
|
||||||
info!("Recreating indexes after seeding data");
|
.await
|
||||||
recreate_indexes(ctx.db()?, embedding_provider.dimension())
|
.context("recreating indexes with correct dimension")?;
|
||||||
.await
|
warm_hnsw_cache(ctx.db()?, embedding_provider.dimension()).await?;
|
||||||
.context("recreating indexes with correct dimension")?;
|
|
||||||
warm_hnsw_cache(ctx.db()?, embedding_provider.dimension()).await?;
|
if let Some(handle) = ctx.corpus_handle.as_mut() {
|
||||||
}
|
record_namespace_seed(handle, &namespace, &database, requested_cases).await;
|
||||||
{
|
|
||||||
let slice = ctx.slice()?;
|
|
||||||
record_namespace_state(
|
|
||||||
ctx.descriptor()?,
|
|
||||||
dataset.metadata.id.as_str(),
|
|
||||||
slice.manifest.slice_id.as_str(),
|
|
||||||
expected_fingerprint.as_str(),
|
|
||||||
&namespace,
|
|
||||||
&database,
|
|
||||||
requested_cases,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,7 +173,5 @@ pub(crate) async fn prepare_namespace(
|
|||||||
"completed evaluation stage"
|
"completed evaluation stage"
|
||||||
);
|
);
|
||||||
|
|
||||||
machine
|
Ok(())
|
||||||
.prepare_namespace()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("prepare_namespace", &guard))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,20 +4,13 @@ use anyhow::Context;
|
|||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
eval::{default_database, default_namespace, ledger_target},
|
db::{default_database, default_namespace},
|
||||||
slice,
|
slice,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::super::{
|
use super::super::context::{EvalStage, EvaluationContext};
|
||||||
context::{EvalStage, EvaluationContext},
|
|
||||||
state::{EvaluationMachine, Ready, SlicePrepared},
|
|
||||||
};
|
|
||||||
use super::{map_guard_error, StageResult};
|
|
||||||
|
|
||||||
pub(crate) async fn prepare_slice(
|
pub(crate) async fn prepare_slice(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||||
machine: EvaluationMachine<(), Ready>,
|
|
||||||
ctx: &mut EvaluationContext<'_>,
|
|
||||||
) -> StageResult<SlicePrepared> {
|
|
||||||
let stage = EvalStage::PrepareSlice;
|
let stage = EvalStage::PrepareSlice;
|
||||||
info!(
|
info!(
|
||||||
evaluation_stage = stage.label(),
|
evaluation_stage = stage.label(),
|
||||||
@@ -25,7 +18,7 @@ pub(crate) async fn prepare_slice(
|
|||||||
);
|
);
|
||||||
let started = Instant::now();
|
let started = Instant::now();
|
||||||
|
|
||||||
let ledger_limit = ledger_target(ctx.config());
|
let ledger_limit = slice::ledger_target(ctx.config());
|
||||||
let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit);
|
let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit);
|
||||||
let resolved_slice =
|
let resolved_slice =
|
||||||
slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?;
|
slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?;
|
||||||
@@ -49,7 +42,11 @@ pub(crate) async fn prepare_slice(
|
|||||||
.db_namespace
|
.db_namespace
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit)
|
default_namespace(
|
||||||
|
ctx.dataset().metadata.id.as_str(),
|
||||||
|
ctx.config().limit,
|
||||||
|
ctx.config().slice.as_deref(),
|
||||||
|
)
|
||||||
});
|
});
|
||||||
ctx.database = ctx
|
ctx.database = ctx
|
||||||
.config()
|
.config()
|
||||||
@@ -66,7 +63,5 @@ pub(crate) async fn prepare_slice(
|
|||||||
"completed evaluation stage"
|
"completed evaluation stage"
|
||||||
);
|
);
|
||||||
|
|
||||||
machine
|
Ok(())
|
||||||
.prepare_slice()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("prepare_slice", &guard))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,9 +5,13 @@ use common::storage::types::StoredObject;
|
|||||||
use futures::stream::{self, StreamExt};
|
use futures::stream::{self, StreamExt};
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use crate::eval::{
|
use crate::{
|
||||||
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
cases::SeededCase,
|
||||||
CaseSummary, RetrievedSummary,
|
context_stats,
|
||||||
|
types::{
|
||||||
|
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||||
|
CaseSummary, RetrievedSummary,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use retrieval_pipeline::{
|
use retrieval_pipeline::{
|
||||||
pipeline::{self, RetrievalConfig, StageTimings},
|
pipeline::{self, RetrievalConfig, StageTimings},
|
||||||
@@ -15,17 +19,10 @@ use retrieval_pipeline::{
|
|||||||
};
|
};
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
use super::super::{
|
use super::super::context::{EvalStage, EvaluationContext};
|
||||||
context::{EvalStage, EvaluationContext},
|
|
||||||
state::{EvaluationMachine, NamespaceReady, QueriesFinished},
|
|
||||||
};
|
|
||||||
use super::{map_guard_error, StageResult};
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
||||||
pub(crate) async fn run_queries(
|
pub(crate) async fn run_queries(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||||
machine: EvaluationMachine<(), NamespaceReady>,
|
|
||||||
ctx: &mut EvaluationContext<'_>,
|
|
||||||
) -> StageResult<QueriesFinished> {
|
|
||||||
let stage = EvalStage::RunQueries;
|
let stage = EvalStage::RunQueries;
|
||||||
info!(
|
info!(
|
||||||
evaluation_stage = stage.label(),
|
evaluation_stage = stage.label(),
|
||||||
@@ -153,7 +150,7 @@ pub(crate) async fn run_queries(
|
|||||||
.await
|
.await
|
||||||
.context("acquiring query semaphore permit")?;
|
.context("acquiring query semaphore permit")?;
|
||||||
|
|
||||||
let crate::eval::SeededCase {
|
let SeededCase {
|
||||||
question_id,
|
question_id,
|
||||||
question,
|
question,
|
||||||
expected_source,
|
expected_source,
|
||||||
@@ -197,6 +194,7 @@ pub(crate) async fn run_queries(
|
|||||||
let query_latency = query_start.elapsed().as_millis();
|
let query_latency = query_start.elapsed().as_millis();
|
||||||
|
|
||||||
let candidates = adapt_retrieval_output(result_output);
|
let candidates = adapt_retrieval_output(result_output);
|
||||||
|
let retrieved_context = context_stats::stats_for_candidates(&candidates);
|
||||||
let mut retrieved = Vec::new();
|
let mut retrieved = Vec::new();
|
||||||
let mut match_rank = None;
|
let mut match_rank = None;
|
||||||
let answers_lower: Vec<String> =
|
let answers_lower: Vec<String> =
|
||||||
@@ -288,6 +286,7 @@ pub(crate) async fn run_queries(
|
|||||||
reciprocal_rank: Some(reciprocal_rank),
|
reciprocal_rank: Some(reciprocal_rank),
|
||||||
ndcg: Some(ndcg),
|
ndcg: Some(ndcg),
|
||||||
latency_ms: query_latency,
|
latency_ms: query_latency,
|
||||||
|
retrieved_context,
|
||||||
retrieved,
|
retrieved,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -353,9 +352,7 @@ pub(crate) async fn run_queries(
|
|||||||
"completed evaluation stage"
|
"completed evaluation stage"
|
||||||
);
|
);
|
||||||
|
|
||||||
machine
|
Ok(())
|
||||||
.run_queries()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("run_queries", &guard))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::arithmetic_side_effects, clippy::cast_precision_loss)]
|
#[allow(clippy::arithmetic_side_effects, clippy::cast_precision_loss)]
|
||||||
|
|||||||
@@ -3,25 +3,19 @@ use std::time::Instant;
|
|||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::eval::{
|
use crate::types::{
|
||||||
build_stage_latency_breakdown, compute_latency_stats, EvaluationSummary, PerformanceTimings,
|
build_stage_latency_breakdown, compute_latency_stats, EvaluationSummary, PerformanceTimings,
|
||||||
|
RetrievedContextStats,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::super::{
|
use super::super::context::{EvalStage, EvaluationContext};
|
||||||
context::{EvalStage, EvaluationContext},
|
|
||||||
state::{EvaluationMachine, QueriesFinished, Summarized},
|
|
||||||
};
|
|
||||||
use super::{map_guard_error, StageResult};
|
|
||||||
|
|
||||||
#[allow(
|
#[allow(
|
||||||
clippy::too_many_lines,
|
clippy::too_many_lines,
|
||||||
clippy::arithmetic_side_effects,
|
clippy::arithmetic_side_effects,
|
||||||
clippy::cast_precision_loss
|
clippy::cast_precision_loss
|
||||||
)]
|
)]
|
||||||
pub(crate) async fn summarize(
|
pub(crate) async fn summarize(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||||
machine: EvaluationMachine<(), QueriesFinished>,
|
|
||||||
ctx: &mut EvaluationContext<'_>,
|
|
||||||
) -> StageResult<Summarized> {
|
|
||||||
let stage = EvalStage::Summarize;
|
let stage = EvalStage::Summarize;
|
||||||
info!(
|
info!(
|
||||||
evaluation_stage = stage.label(),
|
evaluation_stage = stage.label(),
|
||||||
@@ -123,6 +117,12 @@ pub(crate) async fn summarize(
|
|||||||
sum_ndcg / (retrieval_cases as f64)
|
sum_ndcg / (retrieval_cases as f64)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let per_query_context: Vec<RetrievedContextStats> = summaries
|
||||||
|
.iter()
|
||||||
|
.map(|summary| summary.retrieved_context)
|
||||||
|
.collect();
|
||||||
|
let retrieved_context = crate::context_stats::aggregate_context_stats(&per_query_context);
|
||||||
|
|
||||||
let active_tuning = ctx
|
let active_tuning = ctx
|
||||||
.retrieval_config
|
.retrieval_config
|
||||||
.as_ref()
|
.as_ref()
|
||||||
@@ -133,7 +133,7 @@ pub(crate) async fn summarize(
|
|||||||
openai_base_url: ctx
|
openai_base_url: ctx
|
||||||
.openai_base_url
|
.openai_base_url
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| "<unknown>".to_string()),
|
.unwrap_or_else(|| "n/a (chunk-only ingestion)".to_string()),
|
||||||
ingestion_ms: ctx.ingestion_duration_ms,
|
ingestion_ms: ctx.ingestion_duration_ms,
|
||||||
namespace_seed_ms: ctx.namespace_seed_ms,
|
namespace_seed_ms: ctx.namespace_seed_ms,
|
||||||
evaluation_stage_ms: ctx.stage_timings.clone(),
|
evaluation_stage_ms: ctx.stage_timings.clone(),
|
||||||
@@ -217,11 +217,12 @@ pub(crate) async fn summarize(
|
|||||||
chunk_rrf_use_fts: active_tuning.flags.chunk_rrf_use_fts.as_bool(),
|
chunk_rrf_use_fts: active_tuning.flags.chunk_rrf_use_fts.as_bool(),
|
||||||
ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
|
ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
|
||||||
ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
|
ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
|
||||||
ingest_chunks_only: config.ingest.ingest_chunks_only,
|
ingest_chunks_only: !config.ingest.include_entities,
|
||||||
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
||||||
chunk_vector_take: active_tuning.chunk_vector_take,
|
chunk_vector_take: active_tuning.chunk_vector_take,
|
||||||
chunk_fts_take: active_tuning.chunk_fts_take,
|
chunk_fts_take: active_tuning.chunk_fts_take,
|
||||||
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
|
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
|
||||||
|
retrieved_context,
|
||||||
cases: summaries,
|
cases: summaries,
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -233,7 +234,5 @@ pub(crate) async fn summarize(
|
|||||||
"completed evaluation stage"
|
"completed evaluation stage"
|
||||||
);
|
);
|
||||||
|
|
||||||
machine
|
Ok(())
|
||||||
.summarize()
|
|
||||||
.map_err(|(_, guard)| map_guard_error("summarize", &guard))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
use state_machines::state_machine;
|
|
||||||
|
|
||||||
state_machine! {
|
|
||||||
name: EvaluationMachine,
|
|
||||||
state: EvaluationState,
|
|
||||||
initial: Ready,
|
|
||||||
states: [Ready, SlicePrepared, DbReady, CorpusReady, NamespaceReady, QueriesFinished, Summarized, Completed, Failed],
|
|
||||||
events {
|
|
||||||
prepare_slice { transition: { from: Ready, to: SlicePrepared } }
|
|
||||||
prepare_db { transition: { from: SlicePrepared, to: DbReady } }
|
|
||||||
prepare_corpus { transition: { from: DbReady, to: CorpusReady } }
|
|
||||||
prepare_namespace { transition: { from: CorpusReady, to: NamespaceReady } }
|
|
||||||
run_queries { transition: { from: NamespaceReady, to: QueriesFinished } }
|
|
||||||
summarize { transition: { from: QueriesFinished, to: Summarized } }
|
|
||||||
finalize { transition: { from: Summarized, to: Completed } }
|
|
||||||
abort {
|
|
||||||
transition: { from: Ready, to: Failed }
|
|
||||||
transition: { from: SlicePrepared, to: Failed }
|
|
||||||
transition: { from: DbReady, to: Failed }
|
|
||||||
transition: { from: CorpusReady, to: Failed }
|
|
||||||
transition: { from: NamespaceReady, to: Failed }
|
|
||||||
transition: { from: QueriesFinished, to: Failed }
|
|
||||||
transition: { from: Summarized, to: Failed }
|
|
||||||
transition: { from: Completed, to: Failed }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ready() -> EvaluationMachine<(), Ready> {
|
|
||||||
EvaluationMachine::new(())
|
|
||||||
}
|
|
||||||
+81
-212
@@ -7,12 +7,10 @@ use std::{
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::eval::{
|
use crate::types::{
|
||||||
format_timestamp, CaseSummary, EvaluationStageTimings, EvaluationSummary, LatencyStats,
|
format_timestamp, CaseSummary, EvaluationStageTimings, EvaluationSummary, LatencyStats,
|
||||||
StageLatencyBreakdown,
|
RetrievalContextStats, StageLatencyBreakdown,
|
||||||
};
|
};
|
||||||
use chrono::Utc;
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ReportPaths {
|
pub struct ReportPaths {
|
||||||
@@ -108,6 +106,7 @@ pub struct RetrievalSection {
|
|||||||
pub ingest_chunk_max_tokens: usize,
|
pub ingest_chunk_max_tokens: usize,
|
||||||
pub ingest_chunk_overlap_tokens: usize,
|
pub ingest_chunk_overlap_tokens: usize,
|
||||||
pub ingest_chunks_only: bool,
|
pub ingest_chunks_only: bool,
|
||||||
|
pub retrieved_context: RetrievalContextStats,
|
||||||
}
|
}
|
||||||
|
|
||||||
const fn default_chunk_rrf_k() -> f32 {
|
const fn default_chunk_rrf_k() -> f32 {
|
||||||
@@ -242,6 +241,7 @@ impl EvaluationReport {
|
|||||||
ingest_chunk_max_tokens: summary.ingest_chunk_max_tokens,
|
ingest_chunk_max_tokens: summary.ingest_chunk_max_tokens,
|
||||||
ingest_chunk_overlap_tokens: summary.ingest_chunk_overlap_tokens,
|
ingest_chunk_overlap_tokens: summary.ingest_chunk_overlap_tokens,
|
||||||
ingest_chunks_only: summary.ingest_chunks_only,
|
ingest_chunks_only: summary.ingest_chunks_only,
|
||||||
|
retrieved_context: summary.retrieved_context.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let llm = if summary.llm_cases > 0 {
|
let llm = if summary.llm_cases > 0 {
|
||||||
@@ -345,7 +345,7 @@ impl LlmCaseEntry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl RetrievedSnippet {
|
impl RetrievedSnippet {
|
||||||
fn from_summary(entry: &crate::eval::RetrievedSummary) -> Self {
|
fn from_summary(entry: &crate::types::RetrievedSummary) -> Self {
|
||||||
Self {
|
Self {
|
||||||
rank: entry.rank,
|
rank: entry.rank,
|
||||||
source_id: entry.source_id.clone(),
|
source_id: entry.source_id.clone(),
|
||||||
@@ -558,6 +558,65 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
|||||||
} else {
|
} else {
|
||||||
md.push_str("| Rerank | disabled |\\n");
|
md.push_str("| Rerank | disabled |\\n");
|
||||||
}
|
}
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| Chunk result cap | {} |\\n",
|
||||||
|
report.retrieval.chunk_result_cap
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
md.push_str("\\n## Retrieved Context Volume\\n\\n");
|
||||||
|
md.push_str("| Metric | Value |\\n| --- | --- |\\n");
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| Tokenizer | {} |\\n",
|
||||||
|
report.retrieval.retrieved_context.tokenizer
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| Queries measured | {} |\\n",
|
||||||
|
report.retrieval.retrieved_context.queries
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| Total chunks returned | {} |\\n",
|
||||||
|
report.retrieval.retrieved_context.total_chunks
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| Total characters | {} |\\n",
|
||||||
|
report.retrieval.retrieved_context.total_chars
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| Total tokens | {} |\\n",
|
||||||
|
report.retrieval.retrieved_context.total_tokens
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| Avg chunks / query | {:.1} |\\n",
|
||||||
|
report.retrieval.retrieved_context.avg_chunks_per_query
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| Avg tokens / query | {:.1} |\\n",
|
||||||
|
report.retrieval.retrieved_context.avg_tokens_per_query
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
write!(
|
||||||
|
md,
|
||||||
|
"| P50 / P95 / max tokens / query | {} / {} / {} |\\n",
|
||||||
|
report.retrieval.retrieved_context.p50_tokens_per_query,
|
||||||
|
report.retrieval.retrieved_context.p95_tokens_per_query,
|
||||||
|
report.retrieval.retrieved_context.max_tokens_per_query
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
if let Some(llm) = &report.llm {
|
if let Some(llm) = &report.llm {
|
||||||
md.push_str("\\n## LLM Mode Metrics\\n\\n");
|
md.push_str("\\n## LLM Mode Metrics\\n\\n");
|
||||||
@@ -797,182 +856,6 @@ pub fn dataset_report_dir(report_dir: &Path, dataset_id: &str) -> PathBuf {
|
|||||||
report_dir.join(sanitize_component(dataset_id))
|
report_dir.join(sanitize_component(dataset_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct LegacyHistoryEntry {
|
|
||||||
generated_at: String,
|
|
||||||
run_label: Option<String>,
|
|
||||||
dataset_id: String,
|
|
||||||
dataset_label: String,
|
|
||||||
slice_id: String,
|
|
||||||
slice_seed: u64,
|
|
||||||
slice_window_offset: usize,
|
|
||||||
slice_window_length: usize,
|
|
||||||
slice_cases: usize,
|
|
||||||
slice_total_cases: usize,
|
|
||||||
k: usize,
|
|
||||||
limit: Option<usize>,
|
|
||||||
precision: f64,
|
|
||||||
precision_at_1: f64,
|
|
||||||
precision_at_2: f64,
|
|
||||||
precision_at_3: f64,
|
|
||||||
#[serde(default)]
|
|
||||||
mrr: f64,
|
|
||||||
#[serde(default)]
|
|
||||||
average_ndcg: f64,
|
|
||||||
#[serde(default)]
|
|
||||||
retrieval_cases: usize,
|
|
||||||
#[serde(default)]
|
|
||||||
retrieval_precision: f64,
|
|
||||||
#[serde(default)]
|
|
||||||
llm_cases: usize,
|
|
||||||
#[serde(default)]
|
|
||||||
llm_precision: f64,
|
|
||||||
duration_ms: u128,
|
|
||||||
latency_ms: LatencyStats,
|
|
||||||
embedding_backend: String,
|
|
||||||
embedding_model: Option<String>,
|
|
||||||
ingestion_reused: bool,
|
|
||||||
ingestion_embeddings_reused: bool,
|
|
||||||
rerank_enabled: bool,
|
|
||||||
rerank_keep_top: usize,
|
|
||||||
rerank_pool_size: Option<usize>,
|
|
||||||
#[serde(default)]
|
|
||||||
chunk_result_cap: Option<usize>,
|
|
||||||
#[serde(default)]
|
|
||||||
ingest_chunk_min_tokens: Option<usize>,
|
|
||||||
#[serde(default)]
|
|
||||||
ingest_chunk_max_tokens: Option<usize>,
|
|
||||||
#[serde(default)]
|
|
||||||
ingest_chunk_overlap_tokens: Option<usize>,
|
|
||||||
#[serde(default)]
|
|
||||||
ingest_chunks_only: Option<bool>,
|
|
||||||
#[serde(default)]
|
|
||||||
delta: Option<LegacyHistoryDelta>,
|
|
||||||
openai_base_url: String,
|
|
||||||
ingestion_ms: u128,
|
|
||||||
#[serde(default)]
|
|
||||||
namespace_seed_ms: Option<u128>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct LegacyHistoryDelta {
|
|
||||||
precision: f64,
|
|
||||||
precision_at_1: f64,
|
|
||||||
latency_avg_ms: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
|
||||||
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
|
||||||
let overview = OverviewSection {
|
|
||||||
generated_at: entry.generated_at,
|
|
||||||
run_label: entry.run_label,
|
|
||||||
total_cases: entry.slice_cases,
|
|
||||||
filtered_questions: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let dataset = DatasetSection {
|
|
||||||
id: entry.dataset_id,
|
|
||||||
label: entry.dataset_label,
|
|
||||||
source: String::new(),
|
|
||||||
includes_unanswerable: entry.llm_cases > 0,
|
|
||||||
require_verified_chunks: true,
|
|
||||||
embedding_backend: entry.embedding_backend,
|
|
||||||
embedding_model: entry.embedding_model,
|
|
||||||
embedding_dimension: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let slice = SliceSection {
|
|
||||||
id: entry.slice_id,
|
|
||||||
seed: entry.slice_seed,
|
|
||||||
window_offset: entry.slice_window_offset,
|
|
||||||
window_length: entry.slice_window_length,
|
|
||||||
slice_cases: entry.slice_cases,
|
|
||||||
ledger_total_cases: entry.slice_total_cases,
|
|
||||||
positives: 0,
|
|
||||||
negatives: 0,
|
|
||||||
total_paragraphs: 0,
|
|
||||||
negative_multiplier: 0.0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let retrieval_cases = if entry.retrieval_cases > 0 {
|
|
||||||
entry.retrieval_cases
|
|
||||||
} else {
|
|
||||||
entry.slice_cases.saturating_sub(entry.llm_cases)
|
|
||||||
};
|
|
||||||
let retrieval_precision = if entry.retrieval_precision > 0.0 {
|
|
||||||
entry.retrieval_precision
|
|
||||||
} else {
|
|
||||||
entry.precision
|
|
||||||
};
|
|
||||||
|
|
||||||
let retrieval = RetrievalSection {
|
|
||||||
k: entry.k,
|
|
||||||
cases: retrieval_cases,
|
|
||||||
correct: 0,
|
|
||||||
precision: retrieval_precision,
|
|
||||||
precision_at_1: entry.precision_at_1,
|
|
||||||
precision_at_2: entry.precision_at_2,
|
|
||||||
precision_at_3: entry.precision_at_3,
|
|
||||||
mrr: entry.mrr,
|
|
||||||
average_ndcg: entry.average_ndcg,
|
|
||||||
latency: entry.latency_ms,
|
|
||||||
concurrency: 0,
|
|
||||||
resolve_entities: false,
|
|
||||||
rerank_enabled: entry.rerank_enabled,
|
|
||||||
rerank_pool_size: entry.rerank_pool_size,
|
|
||||||
rerank_keep_top: entry.rerank_keep_top,
|
|
||||||
chunk_result_cap: entry.chunk_result_cap.unwrap_or(5),
|
|
||||||
chunk_rrf_k: default_chunk_rrf_k(),
|
|
||||||
chunk_rrf_vector_weight: default_chunk_rrf_weight(),
|
|
||||||
chunk_rrf_fts_weight: default_chunk_rrf_weight(),
|
|
||||||
chunk_rrf_use_vector: default_chunk_rrf_use(),
|
|
||||||
chunk_rrf_use_fts: default_chunk_rrf_use(),
|
|
||||||
chunk_vector_take: 0,
|
|
||||||
chunk_fts_take: 0,
|
|
||||||
ingest_chunk_min_tokens: entry.ingest_chunk_min_tokens.unwrap_or(256),
|
|
||||||
ingest_chunk_max_tokens: entry.ingest_chunk_max_tokens.unwrap_or(512),
|
|
||||||
ingest_chunk_overlap_tokens: entry.ingest_chunk_overlap_tokens.unwrap_or(50),
|
|
||||||
ingest_chunks_only: entry.ingest_chunks_only.unwrap_or(false),
|
|
||||||
};
|
|
||||||
|
|
||||||
let llm = if entry.llm_cases > 0 {
|
|
||||||
Some(LlmSection {
|
|
||||||
cases: entry.llm_cases,
|
|
||||||
answered: 0,
|
|
||||||
precision: entry.llm_precision,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let performance = PerformanceSection {
|
|
||||||
openai_base_url: entry.openai_base_url,
|
|
||||||
ingestion_ms: entry.ingestion_ms,
|
|
||||||
namespace_seed_ms: entry.namespace_seed_ms,
|
|
||||||
evaluation_stages_ms: EvaluationStageTimings::default(),
|
|
||||||
stage_latency: StageLatencyBreakdown::default(),
|
|
||||||
namespace_reused: false,
|
|
||||||
ingestion_reused: entry.ingestion_reused,
|
|
||||||
embeddings_reused: entry.ingestion_embeddings_reused,
|
|
||||||
ingestion_cache_path: String::new(),
|
|
||||||
corpus_paragraphs: 0,
|
|
||||||
positive_paragraphs_reused: 0,
|
|
||||||
negative_paragraphs_reused: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
EvaluationReport {
|
|
||||||
overview,
|
|
||||||
dataset,
|
|
||||||
slice,
|
|
||||||
retrieval,
|
|
||||||
llm,
|
|
||||||
performance,
|
|
||||||
misses: Vec::new(),
|
|
||||||
llm_cases: Vec::new(),
|
|
||||||
detailed_report: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_history(path: &Path) -> Result<Vec<EvaluationReport>> {
|
fn load_history(path: &Path) -> Result<Vec<EvaluationReport>> {
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
@@ -981,34 +864,12 @@ fn load_history(path: &Path) -> Result<Vec<EvaluationReport>> {
|
|||||||
let contents =
|
let contents =
|
||||||
fs::read(path).with_context(|| format!("reading evaluation log {}", path.display()))?;
|
fs::read(path).with_context(|| format!("reading evaluation log {}", path.display()))?;
|
||||||
|
|
||||||
if let Ok(entries) = serde_json::from_slice::<Vec<EvaluationReport>>(&contents) {
|
serde_json::from_slice(&contents).with_context(|| {
|
||||||
return Ok(entries);
|
format!(
|
||||||
}
|
"parsing evaluation history at {}; delete the file and re-run if upgrading from an older format",
|
||||||
|
path.display()
|
||||||
match serde_json::from_slice::<Vec<LegacyHistoryEntry>>(&contents) {
|
)
|
||||||
Ok(entries) => Ok(entries.into_iter().map(convert_legacy_entry).collect()),
|
})
|
||||||
Err(err) => {
|
|
||||||
let timestamp = Utc::now().format("%Y%m%dT%H%M%S");
|
|
||||||
let backup_path = path
|
|
||||||
.parent()
|
|
||||||
.unwrap_or_else(|| Path::new("."))
|
|
||||||
.join(format!("evaluations.json.corrupted.{timestamp}"));
|
|
||||||
warn!(
|
|
||||||
path = %path.display(),
|
|
||||||
backup = %backup_path.display(),
|
|
||||||
error = %err,
|
|
||||||
"Evaluation history file is corrupted; backing up and starting fresh"
|
|
||||||
);
|
|
||||||
if let Err(e) = fs::rename(path, &backup_path) {
|
|
||||||
warn!(
|
|
||||||
path = %path.display(),
|
|
||||||
error = %e,
|
|
||||||
"Failed to backup corrupted evaluation history"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(Vec::new())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBuf> {
|
fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBuf> {
|
||||||
@@ -1024,9 +885,9 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBu
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::eval::{
|
use crate::types::{
|
||||||
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatency,
|
EvaluationStageTimings, PerformanceTimings, RetrievedContextStats, RetrievedSummary,
|
||||||
StageLatencyBreakdown,
|
StageLatency, StageLatencyBreakdown,
|
||||||
};
|
};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
@@ -1101,6 +962,7 @@ mod tests {
|
|||||||
has_verified_chunks: !is_impossible,
|
has_verified_chunks: !is_impossible,
|
||||||
match_rank: if matched { Some(1) } else { None },
|
match_rank: if matched { Some(1) } else { None },
|
||||||
latency_ms: 42,
|
latency_ms: 42,
|
||||||
|
retrieved_context: RetrievedContextStats::default(),
|
||||||
retrieved: vec![RetrievedSummary {
|
retrieved: vec![RetrievedSummary {
|
||||||
rank: 1,
|
rank: 1,
|
||||||
entity_id: "entity1".into(),
|
entity_id: "entity1".into(),
|
||||||
@@ -1199,6 +1061,13 @@ mod tests {
|
|||||||
chunk_vector_take: 50,
|
chunk_vector_take: 50,
|
||||||
chunk_fts_take: 50,
|
chunk_fts_take: 50,
|
||||||
max_chunks_per_entity: 4,
|
max_chunks_per_entity: 4,
|
||||||
|
retrieved_context: crate::context_stats::aggregate_context_stats(&[
|
||||||
|
RetrievedContextStats {
|
||||||
|
chunk_count: 1,
|
||||||
|
char_count: 10,
|
||||||
|
token_count: 3,
|
||||||
|
},
|
||||||
|
]),
|
||||||
cases,
|
cases,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,174 @@
|
|||||||
|
use std::collections::{HashMap, VecDeque};
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use crate::datasets::{ConvertedDataset, BEIR_DATASETS};
|
||||||
|
|
||||||
|
use super::build::{mix_seed, BuildParams};
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
||||||
|
pub(super) fn ordered_question_refs_beir(
|
||||||
|
dataset: &ConvertedDataset,
|
||||||
|
params: &BuildParams,
|
||||||
|
target_cases: usize,
|
||||||
|
) -> Result<Vec<(usize, usize)>> {
|
||||||
|
let prefixes: Vec<&str> = BEIR_DATASETS
|
||||||
|
.iter()
|
||||||
|
.map(|kind| kind.source_prefix())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut grouped: HashMap<&str, Vec<(usize, usize)>> = HashMap::new();
|
||||||
|
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||||
|
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||||
|
let include = if params.include_impossible {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
!question.is_impossible && !question.answers.is_empty()
|
||||||
|
};
|
||||||
|
if !include {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(prefix) = question_prefix(&question.id) else {
|
||||||
|
warn!(
|
||||||
|
question_id = %question.id,
|
||||||
|
"Skipping BEIR question without expected prefix"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if !prefixes.contains(&prefix) {
|
||||||
|
warn!(
|
||||||
|
question_id = %question.id,
|
||||||
|
prefix = %prefix,
|
||||||
|
"Skipping BEIR question with unknown subset prefix"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
grouped.entry(prefix).or_default().push((p_idx, q_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if grouped.values().all(std::vec::Vec::is_empty) {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"no eligible BEIR questions found; cannot build slice"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
for prefix in &prefixes {
|
||||||
|
if let Some(entries) = grouped.get_mut(prefix) {
|
||||||
|
let seed = mix_seed(
|
||||||
|
&format!("{}::{prefix}", dataset.metadata.id),
|
||||||
|
params.base_seed,
|
||||||
|
);
|
||||||
|
let mut rng = StdRng::seed_from_u64(seed);
|
||||||
|
entries.shuffle(&mut rng);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dataset_count = prefixes.len().max(1);
|
||||||
|
let base_quota = target_cases / dataset_count;
|
||||||
|
let mut remainder = target_cases % dataset_count;
|
||||||
|
|
||||||
|
let mut quotas: HashMap<&str, usize> = HashMap::new();
|
||||||
|
for prefix in &prefixes {
|
||||||
|
let mut quota = base_quota;
|
||||||
|
if remainder > 0 {
|
||||||
|
quota += 1;
|
||||||
|
remainder -= 1;
|
||||||
|
}
|
||||||
|
quotas.insert(*prefix, quota);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut take_counts: HashMap<&str, usize> = HashMap::new();
|
||||||
|
let mut spare_slots: HashMap<&str, usize> = HashMap::new();
|
||||||
|
let mut shortfall = 0usize;
|
||||||
|
|
||||||
|
for prefix in &prefixes {
|
||||||
|
let available = grouped.get(prefix).map_or(0, std::vec::Vec::len);
|
||||||
|
let quota = *quotas.get(prefix).unwrap_or(&0);
|
||||||
|
let take = quota.min(available);
|
||||||
|
let missing = quota.saturating_sub(take);
|
||||||
|
shortfall += missing;
|
||||||
|
take_counts.insert(*prefix, take);
|
||||||
|
spare_slots.insert(*prefix, available.saturating_sub(take));
|
||||||
|
}
|
||||||
|
|
||||||
|
while shortfall > 0 {
|
||||||
|
let mut allocated = false;
|
||||||
|
for prefix in &prefixes {
|
||||||
|
if shortfall == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let spare = spare_slots.get(prefix).copied().unwrap_or(0);
|
||||||
|
if spare == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Some(count) = take_counts.get_mut(prefix) {
|
||||||
|
*count += 1;
|
||||||
|
}
|
||||||
|
spare_slots.insert(*prefix, spare - 1);
|
||||||
|
shortfall = shortfall.saturating_sub(1);
|
||||||
|
allocated = true;
|
||||||
|
}
|
||||||
|
if !allocated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut queues: Vec<VecDeque<(usize, usize)>> = Vec::new();
|
||||||
|
let mut total_selected = 0usize;
|
||||||
|
for prefix in &prefixes {
|
||||||
|
let take = *take_counts.get(prefix).unwrap_or(&0);
|
||||||
|
let mut deque = VecDeque::new();
|
||||||
|
if let Some(entries) = grouped.get(prefix) {
|
||||||
|
for item in entries.iter().take(take) {
|
||||||
|
deque.push_back(*item);
|
||||||
|
total_selected += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
queues.push(deque);
|
||||||
|
}
|
||||||
|
|
||||||
|
if total_selected < target_cases {
|
||||||
|
warn!(
|
||||||
|
requested = target_cases,
|
||||||
|
available = total_selected,
|
||||||
|
"BEIR mix requested more questions than available after balancing; continuing with capped set"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = Vec::with_capacity(total_selected);
|
||||||
|
loop {
|
||||||
|
let mut progressed = false;
|
||||||
|
for queue in &mut queues {
|
||||||
|
if let Some(item) = queue.pop_front() {
|
||||||
|
output.push(item);
|
||||||
|
progressed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !progressed {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if output.is_empty() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"no eligible BEIR questions found; cannot build slice"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn question_prefix(question_id: &str) -> Option<&'static str> {
|
||||||
|
for prefix in BEIR_DATASETS.iter().map(|kind| kind.source_prefix()) {
|
||||||
|
if let Some(rest) = question_id.strip_prefix(prefix) {
|
||||||
|
if rest.starts_with('-') {
|
||||||
|
return Some(prefix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(super) struct BuildParams {
|
||||||
|
pub include_impossible: bool,
|
||||||
|
pub base_seed: u64,
|
||||||
|
pub rng_seed: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::indexing_slicing)]
|
||||||
|
pub(super) fn mix_seed(dataset_id: &str, seed: u64) -> u64 {
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
hasher.update(dataset_id.as_bytes());
|
||||||
|
hasher.update(seed.to_le_bytes());
|
||||||
|
let digest = hasher.finalize();
|
||||||
|
let mut bytes = [0u8; 8];
|
||||||
|
bytes.copy_from_slice(&digest[..8]);
|
||||||
|
u64::from_le_bytes(bytes)
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet, VecDeque},
|
collections::{HashMap, HashSet},
|
||||||
fmt::Write,
|
fmt::Write,
|
||||||
fs,
|
fs,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
@@ -12,10 +12,16 @@ use serde::{Deserialize, Serialize};
|
|||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use crate::datasets::{
|
use crate::{
|
||||||
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, BEIR_DATASETS,
|
args::Config,
|
||||||
|
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
mod beir;
|
||||||
|
mod build;
|
||||||
|
|
||||||
|
use build::{mix_seed, BuildParams};
|
||||||
|
|
||||||
const SLICE_VERSION: u32 = 2;
|
const SLICE_VERSION: u32 = 2;
|
||||||
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
|
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
|
||||||
|
|
||||||
@@ -80,8 +86,12 @@ pub enum SliceParagraphKind {
|
|||||||
Negative,
|
Negative,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn paragraph_storage_key(paragraph_id: &str) -> String {
|
||||||
|
sanitize_identifier(paragraph_id)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn default_shard_path(paragraph_id: &str) -> String {
|
pub(crate) fn default_shard_path(paragraph_id: &str) -> String {
|
||||||
let sanitized = sanitize_identifier(paragraph_id);
|
let sanitized = paragraph_storage_key(paragraph_id);
|
||||||
format!("paragraphs/{sanitized}.json")
|
format!("paragraphs/{sanitized}.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,13 +220,6 @@ struct SliceKey<'a> {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct BuildParams {
|
|
||||||
include_impossible: bool,
|
|
||||||
base_seed: u64,
|
|
||||||
rng_seed: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub fn resolve_slice<'a>(
|
pub fn resolve_slice<'a>(
|
||||||
dataset: &'a ConvertedDataset,
|
dataset: &'a ConvertedDataset,
|
||||||
@@ -225,15 +228,28 @@ pub fn resolve_slice<'a>(
|
|||||||
let index = DatasetIndex::build(dataset);
|
let index = DatasetIndex::build(dataset);
|
||||||
|
|
||||||
if let Some(slice_arg) = config.explicit_slice {
|
if let Some(slice_arg) = config.explicit_slice {
|
||||||
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
|
let path = explicit_slice_path(dataset, config, slice_arg);
|
||||||
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
|
if path.exists() {
|
||||||
|
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
|
||||||
|
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
|
||||||
|
info!(
|
||||||
|
slice = %resolved.manifest.slice_id,
|
||||||
|
path = %resolved.path.display(),
|
||||||
|
cases = resolved.manifest.case_count,
|
||||||
|
positives = resolved.manifest.positive_paragraphs,
|
||||||
|
negatives = resolved.manifest.negative_paragraphs,
|
||||||
|
"Using explicitly selected slice"
|
||||||
|
);
|
||||||
|
return Ok(resolved);
|
||||||
|
}
|
||||||
|
let resolved = materialize_slice_ledger(dataset, config, &index, slice_arg, path)?;
|
||||||
info!(
|
info!(
|
||||||
slice = %resolved.manifest.slice_id,
|
slice = %resolved.manifest.slice_id,
|
||||||
path = %resolved.path.display(),
|
path = %resolved.path.display(),
|
||||||
cases = resolved.manifest.case_count,
|
cases = resolved.manifest.case_count,
|
||||||
positives = resolved.manifest.positive_paragraphs,
|
positives = resolved.manifest.positive_paragraphs,
|
||||||
negatives = resolved.manifest.negative_paragraphs,
|
negatives = resolved.manifest.negative_paragraphs,
|
||||||
"Using explicitly selected slice"
|
"Built catalog slice ledger"
|
||||||
);
|
);
|
||||||
return Ok(resolved);
|
return Ok(resolved);
|
||||||
}
|
}
|
||||||
@@ -256,6 +272,82 @@ pub fn resolve_slice<'a>(
|
|||||||
.join("slices")
|
.join("slices")
|
||||||
.join(dataset.metadata.id.as_str());
|
.join(dataset.metadata.id.as_str());
|
||||||
let path = base.join(format!("{slice_id}.json"));
|
let path = base.join(format!("{slice_id}.json"));
|
||||||
|
materialize_slice_ledger(dataset, config, &index, &slice_id, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||||
|
pub fn select_window<'a>(
|
||||||
|
resolved: &'a ResolvedSlice<'a>,
|
||||||
|
offset: usize,
|
||||||
|
limit: Option<usize>,
|
||||||
|
) -> Result<SliceWindow<'a>> {
|
||||||
|
let total = resolved.manifest.case_count;
|
||||||
|
if total == 0 {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"slice '{}' contains no cases",
|
||||||
|
resolved.manifest.slice_id
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if offset >= total {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"slice offset {offset} exceeds available cases ({total})",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let available = total - offset;
|
||||||
|
let requested = limit.unwrap_or(available).max(1);
|
||||||
|
let length = requested.min(available);
|
||||||
|
let cases = resolved.cases[offset..offset + length].to_vec();
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
let mut positive_ids = Vec::new();
|
||||||
|
for case in &cases {
|
||||||
|
if seen.insert(case.paragraph.id.as_str()) {
|
||||||
|
positive_ids.push(case.paragraph.id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(SliceWindow {
|
||||||
|
offset,
|
||||||
|
length,
|
||||||
|
total_cases: total,
|
||||||
|
cases,
|
||||||
|
positive_paragraph_ids: positive_ids,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
|
||||||
|
select_window(resolved, 0, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn explicit_slice_path(
|
||||||
|
dataset: &ConvertedDataset,
|
||||||
|
config: &SliceConfig<'_>,
|
||||||
|
slice_arg: &str,
|
||||||
|
) -> PathBuf {
|
||||||
|
let explicit_path = Path::new(slice_arg);
|
||||||
|
if explicit_path.exists() {
|
||||||
|
explicit_path.to_path_buf()
|
||||||
|
} else {
|
||||||
|
config
|
||||||
|
.cache_dir
|
||||||
|
.join("slices")
|
||||||
|
.join(dataset.metadata.id.as_str())
|
||||||
|
.join(format!("{slice_arg}.json"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
fn materialize_slice_ledger<'a>(
|
||||||
|
dataset: &'a ConvertedDataset,
|
||||||
|
config: &SliceConfig<'_>,
|
||||||
|
index: &DatasetIndex,
|
||||||
|
slice_id: &str,
|
||||||
|
path: PathBuf,
|
||||||
|
) -> Result<ResolvedSlice<'a>> {
|
||||||
|
let requested_corpus = config
|
||||||
|
.corpus_limit
|
||||||
|
.unwrap_or(dataset.paragraphs.len())
|
||||||
|
.min(dataset.paragraphs.len())
|
||||||
|
.max(1);
|
||||||
|
|
||||||
let total_questions = dataset
|
let total_questions = dataset
|
||||||
.paragraphs
|
.paragraphs
|
||||||
@@ -339,7 +431,7 @@ pub fn resolve_slice<'a>(
|
|||||||
let mut manifest = manifest.unwrap_or_else(|| {
|
let mut manifest = manifest.unwrap_or_else(|| {
|
||||||
empty_manifest(
|
empty_manifest(
|
||||||
dataset,
|
dataset,
|
||||||
slice_id.clone(),
|
slice_id.to_string(),
|
||||||
¶ms,
|
¶ms,
|
||||||
requested_corpus,
|
requested_corpus,
|
||||||
config.negative_multiplier,
|
config.negative_multiplier,
|
||||||
@@ -396,52 +488,7 @@ pub fn resolve_slice<'a>(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let resolved = manifest_to_resolved(dataset, &index, manifest.clone(), path)?;
|
manifest_to_resolved(dataset, index, manifest, path)
|
||||||
|
|
||||||
Ok(resolved)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
|
||||||
pub fn select_window<'a>(
|
|
||||||
resolved: &'a ResolvedSlice<'a>,
|
|
||||||
offset: usize,
|
|
||||||
limit: Option<usize>,
|
|
||||||
) -> Result<SliceWindow<'a>> {
|
|
||||||
let total = resolved.manifest.case_count;
|
|
||||||
if total == 0 {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"slice '{}' contains no cases",
|
|
||||||
resolved.manifest.slice_id
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if offset >= total {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"slice offset {offset} exceeds available cases ({total})",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
let available = total - offset;
|
|
||||||
let requested = limit.unwrap_or(available).max(1);
|
|
||||||
let length = requested.min(available);
|
|
||||||
let cases = resolved.cases[offset..offset + length].to_vec();
|
|
||||||
let mut seen = HashSet::new();
|
|
||||||
let mut positive_ids = Vec::new();
|
|
||||||
for case in &cases {
|
|
||||||
if seen.insert(case.paragraph.id.as_str()) {
|
|
||||||
positive_ids.push(case.paragraph.id.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(SliceWindow {
|
|
||||||
offset,
|
|
||||||
length,
|
|
||||||
total_cases: total,
|
|
||||||
cases,
|
|
||||||
positive_paragraph_ids: positive_ids,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
|
|
||||||
select_window(resolved, 0, None)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_explicit_slice(
|
fn load_explicit_slice(
|
||||||
@@ -450,16 +497,7 @@ fn load_explicit_slice(
|
|||||||
config: &SliceConfig<'_>,
|
config: &SliceConfig<'_>,
|
||||||
slice_arg: &str,
|
slice_arg: &str,
|
||||||
) -> Result<(PathBuf, SliceManifest)> {
|
) -> Result<(PathBuf, SliceManifest)> {
|
||||||
let explicit_path = Path::new(slice_arg);
|
let candidate_path = explicit_slice_path(dataset, config, slice_arg);
|
||||||
let candidate_path = if explicit_path.exists() {
|
|
||||||
explicit_path.to_path_buf()
|
|
||||||
} else {
|
|
||||||
config
|
|
||||||
.cache_dir
|
|
||||||
.join("slices")
|
|
||||||
.join(dataset.metadata.id.as_str())
|
|
||||||
.join(format!("{slice_arg}.json"))
|
|
||||||
};
|
|
||||||
|
|
||||||
let manifest = read_manifest(&candidate_path)
|
let manifest = read_manifest(&candidate_path)
|
||||||
.with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?;
|
.with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?;
|
||||||
@@ -613,7 +651,7 @@ fn ordered_question_refs(
|
|||||||
target_cases: usize,
|
target_cases: usize,
|
||||||
) -> Result<Vec<(usize, usize)>> {
|
) -> Result<Vec<(usize, usize)>> {
|
||||||
if dataset.metadata.id == DatasetKind::Beir.id() {
|
if dataset.metadata.id == DatasetKind::Beir.id() {
|
||||||
return ordered_question_refs_beir(dataset, params, target_cases);
|
return beir::ordered_question_refs_beir(dataset, params, target_cases);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut question_refs = Vec::new();
|
let mut question_refs = Vec::new();
|
||||||
@@ -642,171 +680,6 @@ fn ordered_question_refs(
|
|||||||
Ok(question_refs)
|
Ok(question_refs)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
|
||||||
fn ordered_question_refs_beir(
|
|
||||||
dataset: &ConvertedDataset,
|
|
||||||
params: &BuildParams,
|
|
||||||
target_cases: usize,
|
|
||||||
) -> Result<Vec<(usize, usize)>> {
|
|
||||||
let prefixes: Vec<&str> = BEIR_DATASETS
|
|
||||||
.iter()
|
|
||||||
.map(|kind| kind.source_prefix())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut grouped: HashMap<&str, Vec<(usize, usize)>> = HashMap::new();
|
|
||||||
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
|
||||||
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
|
||||||
let include = if params.include_impossible {
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
!question.is_impossible && !question.answers.is_empty()
|
|
||||||
};
|
|
||||||
if !include {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let Some(prefix) = question_prefix(&question.id) else {
|
|
||||||
warn!(
|
|
||||||
question_id = %question.id,
|
|
||||||
"Skipping BEIR question without expected prefix"
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
if !prefixes.contains(&prefix) {
|
|
||||||
warn!(
|
|
||||||
question_id = %question.id,
|
|
||||||
prefix = %prefix,
|
|
||||||
"Skipping BEIR question with unknown subset prefix"
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
grouped.entry(prefix).or_default().push((p_idx, q_idx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if grouped.values().all(std::vec::Vec::is_empty) {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"no eligible BEIR questions found; cannot build slice"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
for prefix in &prefixes {
|
|
||||||
if let Some(entries) = grouped.get_mut(prefix) {
|
|
||||||
let seed = mix_seed(
|
|
||||||
&format!("{}::{prefix}", dataset.metadata.id),
|
|
||||||
params.base_seed,
|
|
||||||
);
|
|
||||||
let mut rng = StdRng::seed_from_u64(seed);
|
|
||||||
entries.shuffle(&mut rng);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let dataset_count = prefixes.len().max(1);
|
|
||||||
let base_quota = target_cases / dataset_count;
|
|
||||||
let mut remainder = target_cases % dataset_count;
|
|
||||||
|
|
||||||
let mut quotas: HashMap<&str, usize> = HashMap::new();
|
|
||||||
for prefix in &prefixes {
|
|
||||||
let mut quota = base_quota;
|
|
||||||
if remainder > 0 {
|
|
||||||
quota += 1;
|
|
||||||
remainder -= 1;
|
|
||||||
}
|
|
||||||
quotas.insert(*prefix, quota);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut take_counts: HashMap<&str, usize> = HashMap::new();
|
|
||||||
let mut spare_slots: HashMap<&str, usize> = HashMap::new();
|
|
||||||
let mut shortfall = 0usize;
|
|
||||||
|
|
||||||
for prefix in &prefixes {
|
|
||||||
let available = grouped.get(prefix).map_or(0, std::vec::Vec::len);
|
|
||||||
let quota = *quotas.get(prefix).unwrap_or(&0);
|
|
||||||
let take = quota.min(available);
|
|
||||||
let missing = quota.saturating_sub(take);
|
|
||||||
shortfall += missing;
|
|
||||||
take_counts.insert(*prefix, take);
|
|
||||||
spare_slots.insert(*prefix, available.saturating_sub(take));
|
|
||||||
}
|
|
||||||
|
|
||||||
while shortfall > 0 {
|
|
||||||
let mut allocated = false;
|
|
||||||
for prefix in &prefixes {
|
|
||||||
if shortfall == 0 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let spare = spare_slots.get(prefix).copied().unwrap_or(0);
|
|
||||||
if spare == 0 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if let Some(count) = take_counts.get_mut(prefix) {
|
|
||||||
*count += 1;
|
|
||||||
}
|
|
||||||
spare_slots.insert(*prefix, spare - 1);
|
|
||||||
shortfall = shortfall.saturating_sub(1);
|
|
||||||
allocated = true;
|
|
||||||
}
|
|
||||||
if !allocated {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut queues: Vec<VecDeque<(usize, usize)>> = Vec::new();
|
|
||||||
let mut total_selected = 0usize;
|
|
||||||
for prefix in &prefixes {
|
|
||||||
let take = *take_counts.get(prefix).unwrap_or(&0);
|
|
||||||
let mut deque = VecDeque::new();
|
|
||||||
if let Some(entries) = grouped.get(prefix) {
|
|
||||||
for item in entries.iter().take(take) {
|
|
||||||
deque.push_back(*item);
|
|
||||||
total_selected += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
queues.push(deque);
|
|
||||||
}
|
|
||||||
|
|
||||||
if total_selected < target_cases {
|
|
||||||
warn!(
|
|
||||||
requested = target_cases,
|
|
||||||
available = total_selected,
|
|
||||||
"BEIR mix requested more questions than available after balancing; continuing with capped set"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut output = Vec::with_capacity(total_selected);
|
|
||||||
loop {
|
|
||||||
let mut progressed = false;
|
|
||||||
for queue in &mut queues {
|
|
||||||
if let Some(item) = queue.pop_front() {
|
|
||||||
output.push(item);
|
|
||||||
progressed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !progressed {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if output.is_empty() {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"no eligible BEIR questions found; cannot build slice"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn question_prefix(question_id: &str) -> Option<&'static str> {
|
|
||||||
for prefix in BEIR_DATASETS.iter().map(|kind| kind.source_prefix()) {
|
|
||||||
if let Some(rest) = question_id.strip_prefix(prefix) {
|
|
||||||
if rest.starts_with('-') {
|
|
||||||
return Some(prefix);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::indexing_slicing)]
|
#[allow(clippy::indexing_slicing)]
|
||||||
fn ensure_negative_pool(
|
fn ensure_negative_pool(
|
||||||
dataset: &ConvertedDataset,
|
dataset: &ConvertedDataset,
|
||||||
@@ -1028,15 +901,47 @@ fn compute_slice_id(key: &SliceKey<'_>) -> Result<String> {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::indexing_slicing)]
|
pub fn read_manifest_if_exists(path: &Path) -> Result<Option<SliceManifest>> {
|
||||||
fn mix_seed(dataset_id: &str, seed: u64) -> u64 {
|
if !path.exists() {
|
||||||
let mut hasher = Sha256::new();
|
return Ok(None);
|
||||||
hasher.update(dataset_id.as_bytes());
|
}
|
||||||
hasher.update(seed.to_le_bytes());
|
read_manifest(path).map(Some)
|
||||||
let digest = hasher.finalize();
|
}
|
||||||
let mut bytes = [0u8; 8];
|
|
||||||
bytes.copy_from_slice(&digest[..8]);
|
pub fn cached_manifest_path(config: &crate::args::Config) -> Option<PathBuf> {
|
||||||
u64::from_le_bytes(bytes)
|
let slice_arg = config.slice.as_deref()?;
|
||||||
|
let explicit_path = Path::new(slice_arg);
|
||||||
|
if explicit_path.exists() {
|
||||||
|
return Some(explicit_path.to_path_buf());
|
||||||
|
}
|
||||||
|
Some(
|
||||||
|
config
|
||||||
|
.cache_dir
|
||||||
|
.join("slices")
|
||||||
|
.join(config.dataset.id())
|
||||||
|
.join(format!("{slice_arg}.json")),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn manifest_is_complete(manifest: &SliceManifest, config: &SliceConfig<'_>) -> bool {
|
||||||
|
let requested_limit = config.limit.unwrap_or(manifest.case_count.max(1)).max(1);
|
||||||
|
if manifest.case_count < requested_limit {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let requested_corpus = config
|
||||||
|
.corpus_limit
|
||||||
|
.unwrap_or(manifest.total_paragraphs.max(1))
|
||||||
|
.max(1);
|
||||||
|
let desired_negatives = desired_negative_target(
|
||||||
|
manifest.positive_paragraphs,
|
||||||
|
requested_corpus,
|
||||||
|
manifest
|
||||||
|
.total_paragraphs
|
||||||
|
.max(manifest.positive_paragraphs.max(1)),
|
||||||
|
config.negative_multiplier,
|
||||||
|
);
|
||||||
|
manifest.negative_paragraphs >= desired_negatives
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_manifest(path: &Path) -> Result<SliceManifest> {
|
fn read_manifest(path: &Path) -> Result<SliceManifest> {
|
||||||
@@ -1057,14 +962,37 @@ fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
use crate::args::Config;
|
pub fn ledger_target(config: &Config) -> Option<usize> {
|
||||||
|
match (config.slice_grow, config.limit) {
|
||||||
impl<'a> From<&'a Config> for SliceConfig<'a> {
|
(Some(grow), Some(limit)) => Some(limit.max(grow)),
|
||||||
fn from(config: &'a Config) -> Self {
|
(Some(grow), None) => Some(grow),
|
||||||
slice_config_with_limit(config, None)
|
(None, limit) => limit,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Grow the slice ledger to contain the target number of cases.
|
||||||
|
pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||||
|
let ledger_limit = ledger_target(config);
|
||||||
|
let slice_settings = slice_config_with_limit(config, ledger_limit);
|
||||||
|
let slice = resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
|
||||||
|
info!(
|
||||||
|
slice = slice.manifest.slice_id.as_str(),
|
||||||
|
cases = slice.manifest.case_count,
|
||||||
|
positives = slice.manifest.positive_paragraphs,
|
||||||
|
negatives = slice.manifest.negative_paragraphs,
|
||||||
|
total_paragraphs = slice.manifest.total_paragraphs,
|
||||||
|
"Slice ledger ready"
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
|
||||||
|
slice.manifest.slice_id,
|
||||||
|
slice.manifest.case_count,
|
||||||
|
slice.manifest.positive_paragraphs,
|
||||||
|
slice.manifest.negative_paragraphs
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn slice_config_with_limit(config: &Config, limit_override: Option<usize>) -> SliceConfig<'_> {
|
pub fn slice_config_with_limit(config: &Config, limit_override: Option<usize>) -> SliceConfig<'_> {
|
||||||
SliceConfig {
|
SliceConfig {
|
||||||
cache_dir: config.cache_dir.as_path(),
|
cache_dir: config.cache_dir.as_path(),
|
||||||
@@ -1088,7 +1016,7 @@ mod tests {
|
|||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
fn sample_dataset() -> ConvertedDataset {
|
fn sample_dataset() -> ConvertedDataset {
|
||||||
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false, None);
|
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false);
|
||||||
ConvertedDataset {
|
ConvertedDataset {
|
||||||
generated_at: Utc::now(),
|
generated_at: Utc::now(),
|
||||||
metadata,
|
metadata,
|
||||||
@@ -1226,7 +1154,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false, None);
|
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false);
|
||||||
let dataset = ConvertedDataset {
|
let dataset = ConvertedDataset {
|
||||||
generated_at: Utc::now(),
|
generated_at: Utc::now(),
|
||||||
metadata,
|
metadata,
|
||||||
@@ -1240,11 +1168,11 @@ mod tests {
|
|||||||
rng_seed: 0xBB,
|
rng_seed: 0xBB,
|
||||||
};
|
};
|
||||||
|
|
||||||
let refs = ordered_question_refs_beir(&dataset, ¶ms, 8)?;
|
let refs = beir::ordered_question_refs_beir(&dataset, ¶ms, 8)?;
|
||||||
let mut per_prefix: HashMap<String, usize> = HashMap::new();
|
let mut per_prefix: HashMap<String, usize> = HashMap::new();
|
||||||
for (p_idx, q_idx) in refs {
|
for (p_idx, q_idx) in refs {
|
||||||
let question = &dataset.paragraphs[p_idx].questions[q_idx];
|
let question = &dataset.paragraphs[p_idx].questions[q_idx];
|
||||||
let prefix = question_prefix(&question.id).unwrap_or("unknown");
|
let prefix = beir::question_prefix(&question.id).unwrap_or("unknown");
|
||||||
*per_prefix.entry(prefix.to_string()).or_default() += 1;
|
*per_prefix.entry(prefix.to_string()).or_default() += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,179 +0,0 @@
|
|||||||
use std::path::PathBuf;
|
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use chrono::{DateTime, Utc};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use sha2::{Digest, Sha256};
|
|
||||||
use tokio::fs;
|
|
||||||
|
|
||||||
use crate::{args::Config, slice};
|
|
||||||
use common::utils::embedding::EmbeddingProvider;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
||||||
pub struct SnapshotMetadata {
|
|
||||||
pub dataset_id: String,
|
|
||||||
pub slice_id: String,
|
|
||||||
pub embedding_backend: String,
|
|
||||||
pub embedding_model: Option<String>,
|
|
||||||
pub embedding_dimension: usize,
|
|
||||||
pub rerank_enabled: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct DbSnapshotState {
|
|
||||||
pub dataset_id: String,
|
|
||||||
pub slice_id: String,
|
|
||||||
pub ingestion_fingerprint: String,
|
|
||||||
pub snapshot_hash: String,
|
|
||||||
pub updated_at: DateTime<Utc>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub namespace: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub database: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub slice_case_count: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Descriptor {
|
|
||||||
#[allow(dead_code)]
|
|
||||||
metadata: SnapshotMetadata,
|
|
||||||
dir: PathBuf,
|
|
||||||
metadata_hash: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Descriptor {
|
|
||||||
pub fn new(
|
|
||||||
config: &Config,
|
|
||||||
slice: &slice::ResolvedSlice<'_>,
|
|
||||||
embedding_provider: &EmbeddingProvider,
|
|
||||||
) -> Self {
|
|
||||||
let metadata = SnapshotMetadata {
|
|
||||||
dataset_id: slice.manifest.dataset_id.clone(),
|
|
||||||
slice_id: slice.manifest.slice_id.clone(),
|
|
||||||
embedding_backend: embedding_provider.backend_label().to_string(),
|
|
||||||
embedding_model: embedding_provider.model_code(),
|
|
||||||
embedding_dimension: embedding_provider.dimension(),
|
|
||||||
rerank_enabled: config.retrieval.rerank,
|
|
||||||
};
|
|
||||||
|
|
||||||
let dir = config
|
|
||||||
.cache_dir
|
|
||||||
.join("snapshots")
|
|
||||||
.join(&metadata.dataset_id)
|
|
||||||
.join(&metadata.slice_id);
|
|
||||||
let metadata_hash = compute_hash(&metadata);
|
|
||||||
|
|
||||||
Self {
|
|
||||||
metadata,
|
|
||||||
dir,
|
|
||||||
metadata_hash,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn metadata_hash(&self) -> &str {
|
|
||||||
&self.metadata_hash
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn load_db_state(&self) -> Result<Option<DbSnapshotState>> {
|
|
||||||
let path = self.db_state_path();
|
|
||||||
if !path.exists() {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
let bytes = fs::read(&path)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("reading namespace state {}", path.display()))?;
|
|
||||||
let state = serde_json::from_slice(&bytes)
|
|
||||||
.with_context(|| format!("deserialising namespace state {}", path.display()))?;
|
|
||||||
Ok(Some(state))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn store_db_state(&self, state: &DbSnapshotState) -> Result<()> {
|
|
||||||
let path = self.db_state_path();
|
|
||||||
if let Some(parent) = path.parent() {
|
|
||||||
fs::create_dir_all(parent).await.with_context(|| {
|
|
||||||
format!("creating namespace state directory {}", parent.display())
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
let blob =
|
|
||||||
serde_json::to_vec_pretty(state).context("serialising namespace state payload")?;
|
|
||||||
fs::write(&path, blob)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("writing namespace state {}", path.display()))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn db_dir(&self) -> PathBuf {
|
|
||||||
self.dir.join("db")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn db_state_path(&self) -> PathBuf {
|
|
||||||
self.db_dir().join("state.json")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
pub fn from_parts(metadata: SnapshotMetadata, dir: PathBuf) -> Self {
|
|
||||||
let metadata_hash = compute_hash(&metadata);
|
|
||||||
Self {
|
|
||||||
metadata,
|
|
||||||
dir,
|
|
||||||
metadata_hash,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::expect_used)]
|
|
||||||
fn compute_hash(metadata: &SnapshotMetadata) -> String {
|
|
||||||
let mut hasher = Sha256::new();
|
|
||||||
hasher.update(
|
|
||||||
serde_json::to_vec(metadata).expect("snapshot metadata serialisation should succeed"),
|
|
||||||
);
|
|
||||||
format!("{:x}", hasher.finalize())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[allow(clippy::unwrap_used, clippy::expect_used)]
|
|
||||||
async fn state_round_trip() {
|
|
||||||
let temp_dir = tempfile::tempdir().unwrap();
|
|
||||||
let metadata = SnapshotMetadata {
|
|
||||||
dataset_id: "dataset".into(),
|
|
||||||
slice_id: "slice".into(),
|
|
||||||
embedding_backend: "hashed".into(),
|
|
||||||
embedding_model: None,
|
|
||||||
embedding_dimension: 128,
|
|
||||||
rerank_enabled: true,
|
|
||||||
};
|
|
||||||
let descriptor = Descriptor::from_parts(
|
|
||||||
metadata,
|
|
||||||
temp_dir
|
|
||||||
.path()
|
|
||||||
.join("snapshots")
|
|
||||||
.join("dataset")
|
|
||||||
.join("slice"),
|
|
||||||
);
|
|
||||||
|
|
||||||
let state = DbSnapshotState {
|
|
||||||
dataset_id: "dataset".into(),
|
|
||||||
slice_id: "slice".into(),
|
|
||||||
ingestion_fingerprint: "fingerprint".into(),
|
|
||||||
snapshot_hash: descriptor.metadata_hash().to_string(),
|
|
||||||
updated_at: Utc::now(),
|
|
||||||
namespace: Some("ns".into()),
|
|
||||||
database: Some("db".into()),
|
|
||||||
slice_case_count: 42,
|
|
||||||
};
|
|
||||||
descriptor.store_db_state(&state).await.unwrap();
|
|
||||||
|
|
||||||
let loaded = descriptor.load_db_state().await.unwrap().unwrap();
|
|
||||||
assert_eq!(loaded.dataset_id, state.dataset_id);
|
|
||||||
assert_eq!(loaded.slice_id, state.slice_id);
|
|
||||||
assert_eq!(loaded.ingestion_fingerprint, state.ingestion_fingerprint);
|
|
||||||
assert_eq!(loaded.snapshot_hash, state.snapshot_hash);
|
|
||||||
assert_eq!(loaded.namespace, state.namespace);
|
|
||||||
assert_eq!(loaded.database, state.database);
|
|
||||||
assert_eq!(loaded.slice_case_count, state.slice_case_count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, SecondsFormat, Utc};
|
||||||
use common::storage::types::StoredObject;
|
use common::storage::types::StoredObject;
|
||||||
use retrieval_pipeline::{
|
use retrieval_pipeline::{
|
||||||
Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings,
|
Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings,
|
||||||
@@ -8,6 +8,8 @@ use retrieval_pipeline::{
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use unicode_normalization::UnicodeNormalization;
|
use unicode_normalization::UnicodeNormalization;
|
||||||
|
|
||||||
|
pub use crate::context_stats::{RetrievalContextStats, RetrievedContextStats};
|
||||||
|
|
||||||
#[allow(clippy::struct_excessive_bools)]
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct EvaluationSummary {
|
pub struct EvaluationSummary {
|
||||||
@@ -83,6 +85,7 @@ pub struct EvaluationSummary {
|
|||||||
pub chunk_vector_take: usize,
|
pub chunk_vector_take: usize,
|
||||||
pub chunk_fts_take: usize,
|
pub chunk_fts_take: usize,
|
||||||
pub max_chunks_per_entity: usize,
|
pub max_chunks_per_entity: usize,
|
||||||
|
pub retrieved_context: RetrievalContextStats,
|
||||||
pub cases: Vec<CaseSummary>,
|
pub cases: Vec<CaseSummary>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,6 +111,7 @@ pub struct CaseSummary {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub ndcg: Option<f64>,
|
pub ndcg: Option<f64>,
|
||||||
pub latency_ms: u128,
|
pub latency_ms: u128,
|
||||||
|
pub retrieved_context: RetrievedContextStats,
|
||||||
pub retrieved: Vec<RetrievedSummary>,
|
pub retrieved: Vec<RetrievedSummary>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -483,3 +487,7 @@ pub fn build_case_diagnostics(
|
|||||||
pipeline: pipeline_stats,
|
pipeline: pipeline_stats,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
|
||||||
|
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
crane,
|
crane,
|
||||||
}: let
|
}: let
|
||||||
inherit (nixpkgs.legacyPackages.x86_64-linux) lib;
|
inherit (nixpkgs.legacyPackages.x86_64-linux) lib;
|
||||||
ortVersion = lib.removeSuffix "\n" (builtins.readFile "${self}/ort-version");
|
ortVersion = "1.23.2";
|
||||||
in
|
in
|
||||||
flake-utils.lib.eachDefaultSystem (system: let
|
flake-utils.lib.eachDefaultSystem (system: let
|
||||||
pkgs = nixpkgs.legacyPackages.${system};
|
pkgs = nixpkgs.legacyPackages.${system};
|
||||||
@@ -24,83 +24,182 @@
|
|||||||
if pkgs.stdenv.isDarwin
|
if pkgs.stdenv.isDarwin
|
||||||
then "dylib"
|
then "dylib"
|
||||||
else "so";
|
else "so";
|
||||||
minne-pkg =
|
minneVersion = "1.0.4";
|
||||||
if pkgs.onnxruntime.version == ortVersion then
|
|
||||||
craneLib.buildPackage {
|
# Pre-download mozjs binary archive for mozjs_sys (servo dep).
|
||||||
|
# When updating mozjs_sys version in Cargo.lock, update this URL too.
|
||||||
|
mozjsArchive = pkgs.fetchurl {
|
||||||
|
url = "https://github.com/servo/mozjs/releases/download/mozjs-sys-v140.10.1-0/libmozjs-x86_64-unknown-linux-gnu.tar.gz";
|
||||||
|
hash = "sha256-e5kW8HTg6Hrd3sGgU9bqFNTTf7wJCChFOwKE3xyYT4Q=";
|
||||||
|
};
|
||||||
|
|
||||||
|
# Extra paths (common/db, html-router/templates, html-router/assets) are
|
||||||
|
# embedded at compile time via include_dir! / minijinja_embed.
|
||||||
|
commonArgs = {
|
||||||
|
version = minneVersion;
|
||||||
src = lib.cleanSourceWith {
|
src = lib.cleanSourceWith {
|
||||||
src = ./.;
|
src = ./.;
|
||||||
filter = let
|
filter = path: type:
|
||||||
extraPaths = [
|
craneLib.filterCargoSources path type
|
||||||
|
|| lib.any (x: lib.hasPrefix (toString x) (toString path)) [
|
||||||
(toString ./Cargo.lock)
|
(toString ./Cargo.lock)
|
||||||
(toString ./common/db)
|
(toString ./common/db)
|
||||||
(toString ./html-router/templates)
|
(toString ./html-router/templates)
|
||||||
(toString ./html-router/assets)
|
(toString ./html-router/assets)
|
||||||
];
|
];
|
||||||
in
|
|
||||||
path: type: let
|
|
||||||
p = toString path;
|
|
||||||
in
|
|
||||||
craneLib.filterCargoSources path type
|
|
||||||
|| lib.any (x: lib.hasPrefix x p) extraPaths;
|
|
||||||
};
|
};
|
||||||
|
strictDeps = true;
|
||||||
|
|
||||||
pname = "minne";
|
buildInputs = [
|
||||||
version = "1.0.3";
|
pkgs.openssl
|
||||||
# Uses nixpkgs rustc (stable). Release/Docker pin: rust-toolchain.toml (1.91.1).
|
pkgs.libglvnd
|
||||||
doCheck = false;
|
pkgs.onnxruntime
|
||||||
|
pkgs.fontconfig # .pc for yeslogic-fontconfig-sys (servo dep)
|
||||||
|
pkgs.libclang.lib # libclang.so for bindgen (servo dep)
|
||||||
|
];
|
||||||
|
|
||||||
nativeBuildInputs = [pkgs.pkg-config pkgs.rustfmt pkgs.makeWrapper];
|
nativeBuildInputs = [
|
||||||
buildInputs = [pkgs.openssl pkgs.chromium pkgs.onnxruntime];
|
pkgs.pkg-config
|
||||||
|
pkgs.rustfmt
|
||||||
|
pkgs.makeWrapper
|
||||||
|
pkgs.python3 # needed by servo's stylo crate build.rs
|
||||||
|
pkgs.llvmPackages.llvm # llvm-objdump for mozjs_sys (servo dep)
|
||||||
|
pkgs.rustPlatform.bindgenHook # configures bindgen (servo deps)
|
||||||
|
];
|
||||||
|
|
||||||
postInstall = ''
|
# Provide pre-downloaded mozjs archive so it doesn't need network
|
||||||
wrapProgram $out/bin/main \
|
MOZJS_ARCHIVE = "${mozjsArchive}";
|
||||||
--set CHROME ${pkgs.chromium}/bin/chromium \
|
};
|
||||||
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
|
||||||
for b in worker server; do
|
# cargoBuild (not buildDepsOnly) avoids mkDummySrc breaking native build scripts.
|
||||||
if [ -x "$out/bin/$b" ]; then
|
cargoArtifacts = craneLib.cargoBuild (commonArgs
|
||||||
wrapProgram $out/bin/$b \
|
// {
|
||||||
--set CHROME ${pkgs.chromium}/bin/chromium \
|
cargoArtifacts = null;
|
||||||
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
pname = "minne-deps";
|
||||||
fi
|
cargoExtraArgs = "--workspace";
|
||||||
done
|
doCheck = false;
|
||||||
'';
|
doInstallCargoArtifacts = true;
|
||||||
}
|
installPhaseCommand = "";
|
||||||
else
|
});
|
||||||
throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ort-version (${ortVersion})";
|
|
||||||
|
minne-pkg =
|
||||||
|
if pkgs.onnxruntime.version == ortVersion
|
||||||
|
then
|
||||||
|
craneLib.buildPackage (commonArgs
|
||||||
|
// {
|
||||||
|
pname = "minne";
|
||||||
|
version = minneVersion;
|
||||||
|
inherit cargoArtifacts;
|
||||||
|
doCheck = false; # checks are in separate derivations
|
||||||
|
doInstallCargoArtifacts = true; # for reuse by check derivations
|
||||||
|
|
||||||
|
postInstall = ''
|
||||||
|
wrapProgram $out/bin/main \
|
||||||
|
--prefix LD_LIBRARY_PATH : ${pkgs.libglvnd}/lib \
|
||||||
|
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
||||||
|
for b in worker server; do
|
||||||
|
if [ -x "$out/bin/$b" ]; then
|
||||||
|
wrapProgram $out/bin/$b \
|
||||||
|
--prefix LD_LIBRARY_PATH : ${pkgs.libglvnd}/lib \
|
||||||
|
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
'';
|
||||||
|
})
|
||||||
|
else throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ortVersion in flake.nix (${ortVersion})";
|
||||||
|
|
||||||
|
dockerImage = pkgs.dockerTools.buildLayeredImage {
|
||||||
|
name = "minne";
|
||||||
|
tag = minneVersion;
|
||||||
|
created = "now";
|
||||||
|
|
||||||
|
contents = [
|
||||||
|
minne-pkg
|
||||||
|
pkgs.cacert
|
||||||
|
pkgs.bashInteractive
|
||||||
|
pkgs.libglvnd
|
||||||
|
pkgs.fontconfig.lib
|
||||||
|
pkgs.freetype
|
||||||
|
pkgs.stdenv.cc.cc.lib # libgomp (OpenMP) for ONNX Runtime
|
||||||
|
];
|
||||||
|
|
||||||
|
maxLayers = 25;
|
||||||
|
|
||||||
|
config = {
|
||||||
|
Cmd = ["${minne-pkg}/bin/main"];
|
||||||
|
Env = [
|
||||||
|
"SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-certificates.crt"
|
||||||
|
"ORT_DYLIB_PATH=${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}"
|
||||||
|
];
|
||||||
|
ExposedPorts = {"3000/tcp" = {};};
|
||||||
|
User = "appuser";
|
||||||
|
};
|
||||||
|
};
|
||||||
in {
|
in {
|
||||||
packages = {
|
packages = {
|
||||||
minne-pkg = minne-pkg;
|
inherit minne-pkg dockerImage;
|
||||||
default = minne-pkg;
|
default = minne-pkg;
|
||||||
};
|
};
|
||||||
|
|
||||||
apps = {
|
apps = {
|
||||||
main = flake-utils.lib.mkApp {
|
main = {
|
||||||
drv = minne-pkg;
|
type = "app";
|
||||||
name = "main";
|
program = "${minne-pkg}/bin/main";
|
||||||
|
meta.description = "Minne main server — API, web UI, and background worker";
|
||||||
};
|
};
|
||||||
worker = flake-utils.lib.mkApp {
|
worker = {
|
||||||
drv = minne-pkg;
|
type = "app";
|
||||||
name = "worker";
|
program = "${minne-pkg}/bin/worker";
|
||||||
|
meta.description = "Minne standalone background worker (ingestion, indexing, maintenance)";
|
||||||
};
|
};
|
||||||
server = flake-utils.lib.mkApp {
|
server = {
|
||||||
drv = minne-pkg;
|
type = "app";
|
||||||
name = "server";
|
program = "${minne-pkg}/bin/server";
|
||||||
|
meta.description = "Minne API-only server (no background worker)";
|
||||||
};
|
};
|
||||||
default = flake-utils.lib.mkApp {
|
default = {
|
||||||
drv = minne-pkg;
|
type = "app";
|
||||||
name = "main";
|
program = "${minne-pkg}/bin/main";
|
||||||
|
meta.description = "Minne main server — API, web UI, and background worker";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
checks = {
|
checks = {
|
||||||
ortVersion = pkgs.runCommand "ort-version-check" {} ''
|
ortVersion = pkgs.runCommand "ort-version-check" {} ''
|
||||||
if [ "${pkgs.onnxruntime.version}" != "${ortVersion}" ]; then
|
if [ "${pkgs.onnxruntime.version}" != "${ortVersion}" ]; then
|
||||||
echo "pkgs.onnxruntime.version is ${pkgs.onnxruntime.version}, but ort-version pins ${ortVersion}" >&2
|
echo "pkgs.onnxruntime.version is ${pkgs.onnxruntime.version}, but flake pins ${ortVersion}" >&2
|
||||||
echo "Update ort-version or wait for nixpkgs to catch up." >&2
|
echo "Update ortVersion in flake.nix or wait for nixpkgs to catch up." >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
touch $out
|
touch $out
|
||||||
'';
|
'';
|
||||||
|
|
||||||
|
minne-clippy = craneLib.cargoClippy (commonArgs
|
||||||
|
// {
|
||||||
|
cargoArtifacts = minne-pkg;
|
||||||
|
pname = "minne";
|
||||||
|
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
||||||
|
});
|
||||||
|
|
||||||
|
minne-test = craneLib.cargoTest (commonArgs
|
||||||
|
// {
|
||||||
|
cargoArtifacts = minne-pkg;
|
||||||
|
pname = "minne";
|
||||||
|
buildInputs = commonArgs.buildInputs ++ [pkgs.cacert];
|
||||||
|
SSL_CERT_FILE = "${pkgs.cacert}/etc/ssl/certs/ca-certificates.crt";
|
||||||
|
cargoTestExtraArgs = "--lib --bins";
|
||||||
|
});
|
||||||
|
|
||||||
|
minne-fmt = craneLib.cargoFmt {
|
||||||
|
pname = "minne-fmt";
|
||||||
|
version = minneVersion;
|
||||||
|
src = craneLib.cleanCargoSource ./.;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
})
|
||||||
|
// {
|
||||||
|
lib = {
|
||||||
|
inherit ortVersion;
|
||||||
};
|
};
|
||||||
}) // {
|
|
||||||
ortVersion = ortVersion;
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
+182
-159
@@ -1,6 +1,6 @@
|
|||||||
/**
|
/**
|
||||||
* Design Polishing Pass - Interactive Effects
|
* Design Polishing Pass - Interactive Effects
|
||||||
*
|
*
|
||||||
* Includes:
|
* Includes:
|
||||||
* - Scroll-Linked Navbar Shadow
|
* - Scroll-Linked Navbar Shadow
|
||||||
* - HTMX Swap Animation
|
* - HTMX Swap Animation
|
||||||
@@ -8,183 +8,207 @@
|
|||||||
* - Rubberbanding Scroll
|
* - Rubberbanding Scroll
|
||||||
*/
|
*/
|
||||||
|
|
||||||
(function() {
|
(() => {
|
||||||
'use strict';
|
// === SCROLL-LINKED NAVBAR SHADOW ===
|
||||||
|
function initScrollShadow() {
|
||||||
|
const mainContent = document.querySelector("main");
|
||||||
|
const navbar = document.querySelector("nav");
|
||||||
|
if (!mainContent || !navbar) return;
|
||||||
|
|
||||||
// === SCROLL-LINKED NAVBAR SHADOW ===
|
mainContent.addEventListener(
|
||||||
function initScrollShadow() {
|
"scroll",
|
||||||
const mainContent = document.querySelector('main');
|
() => {
|
||||||
const navbar = document.querySelector('nav');
|
const scrollTop = mainContent.scrollTop;
|
||||||
if (!mainContent || !navbar) return;
|
const scrollHeight =
|
||||||
|
mainContent.scrollHeight - mainContent.clientHeight;
|
||||||
|
const scrollDepth = scrollHeight > 0 ? Math.min(scrollTop / 200, 1) : 0;
|
||||||
|
navbar.style.setProperty("--scroll-depth", scrollDepth.toFixed(2));
|
||||||
|
},
|
||||||
|
{ passive: true },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
mainContent.addEventListener('scroll', () => {
|
// === HTMX SWAP ANIMATION ===
|
||||||
const scrollTop = mainContent.scrollTop;
|
function initHtmxSwapAnimation() {
|
||||||
const scrollHeight = mainContent.scrollHeight - mainContent.clientHeight;
|
document.body.addEventListener("htmx:afterSwap", (event) => {
|
||||||
const scrollDepth = scrollHeight > 0 ? Math.min(scrollTop / 200, 1) : 0;
|
let target = event.detail.target;
|
||||||
navbar.style.setProperty('--scroll-depth', scrollDepth.toFixed(2));
|
if (!target) return;
|
||||||
}, { passive: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
// === HTMX SWAP ANIMATION ===
|
// If full body swap (hx-boost), animate only the main content
|
||||||
function initHtmxSwapAnimation() {
|
if (target.tagName === "BODY") {
|
||||||
document.body.addEventListener('htmx:afterSwap', (event) => {
|
const main = document.querySelector("main");
|
||||||
let target = event.detail.target;
|
if (main) target = main;
|
||||||
if (!target) return;
|
}
|
||||||
|
|
||||||
// If full body swap (hx-boost), animate only the main content
|
// Only animate if target is valid and inside/is main content or a card/panel
|
||||||
if (target.tagName === 'BODY') {
|
// Avoid animating sidebar or navbar updates
|
||||||
const main = document.querySelector('main');
|
if (target && (target.tagName === "MAIN" || target.closest("main"))) {
|
||||||
if (main) target = main;
|
if (!target.classList.contains("animate-fade-up")) {
|
||||||
}
|
target.classList.add("animate-fade-up");
|
||||||
|
// Remove class after animation completes to allow re-animation
|
||||||
|
setTimeout(() => {
|
||||||
|
target.classList.remove("animate-fade-up");
|
||||||
|
}, 250);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Only animate if target is valid and inside/is main content or a card/panel
|
// === TYPEWRITER AI RESPONSE ===
|
||||||
// Avoid animating sidebar or navbar updates
|
// Works with SSE streaming - buffers text and reveals character by character
|
||||||
if (target && (target.tagName === 'MAIN' || target.closest('main'))) {
|
window.initTypewriter = (element, options = {}) => {
|
||||||
if (!target.classList.contains('animate-fade-up')) {
|
const { minDelay = 5, maxDelay = 15, showCursor = true } = options;
|
||||||
target.classList.add('animate-fade-up');
|
|
||||||
// Remove class after animation completes to allow re-animation
|
|
||||||
setTimeout(() => {
|
|
||||||
target.classList.remove('animate-fade-up');
|
|
||||||
}, 250);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// === TYPEWRITER AI RESPONSE ===
|
let buffer = "";
|
||||||
// Works with SSE streaming - buffers text and reveals character by character
|
let isTyping = false;
|
||||||
window.initTypewriter = function(element, options = {}) {
|
let cursorElement = null;
|
||||||
const {
|
|
||||||
minDelay = 5,
|
|
||||||
maxDelay = 15,
|
|
||||||
showCursor = true
|
|
||||||
} = options;
|
|
||||||
|
|
||||||
let buffer = '';
|
if (showCursor) {
|
||||||
let isTyping = false;
|
cursorElement = document.createElement("span");
|
||||||
let cursorElement = null;
|
cursorElement.className = "typewriter-cursor";
|
||||||
|
cursorElement.textContent = "▌";
|
||||||
|
cursorElement.style.animation = "blink 1s step-end infinite";
|
||||||
|
element.appendChild(cursorElement);
|
||||||
|
}
|
||||||
|
|
||||||
if (showCursor) {
|
function typeNextChar() {
|
||||||
cursorElement = document.createElement('span');
|
if (buffer.length === 0) {
|
||||||
cursorElement.className = 'typewriter-cursor';
|
isTyping = false;
|
||||||
cursorElement.textContent = '▌';
|
return;
|
||||||
cursorElement.style.animation = 'blink 1s step-end infinite';
|
}
|
||||||
element.appendChild(cursorElement);
|
|
||||||
}
|
|
||||||
|
|
||||||
function typeNextChar() {
|
isTyping = true;
|
||||||
if (buffer.length === 0) {
|
const char = buffer.charAt(0);
|
||||||
isTyping = false;
|
buffer = buffer.slice(1);
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
isTyping = true;
|
// Insert before cursor
|
||||||
const char = buffer.charAt(0);
|
if (cursorElement && cursorElement.parentNode) {
|
||||||
buffer = buffer.slice(1);
|
const textNode = document.createTextNode(char);
|
||||||
|
element.insertBefore(textNode, cursorElement);
|
||||||
|
} else {
|
||||||
|
element.textContent += char;
|
||||||
|
}
|
||||||
|
|
||||||
// Insert before cursor
|
const delay = minDelay + Math.random() * (maxDelay - minDelay);
|
||||||
if (cursorElement && cursorElement.parentNode) {
|
setTimeout(typeNextChar, delay);
|
||||||
const textNode = document.createTextNode(char);
|
}
|
||||||
element.insertBefore(textNode, cursorElement);
|
|
||||||
} else {
|
|
||||||
element.textContent += char;
|
|
||||||
}
|
|
||||||
|
|
||||||
const delay = minDelay + Math.random() * (maxDelay - minDelay);
|
return {
|
||||||
setTimeout(typeNextChar, delay);
|
append: (text) => {
|
||||||
}
|
buffer += text;
|
||||||
|
if (!isTyping) {
|
||||||
|
typeNextChar();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
complete: () => {
|
||||||
|
// Flush remaining buffer immediately
|
||||||
|
if (cursorElement && cursorElement.parentNode) {
|
||||||
|
const textNode = document.createTextNode(buffer);
|
||||||
|
element.insertBefore(textNode, cursorElement);
|
||||||
|
cursorElement.remove();
|
||||||
|
} else {
|
||||||
|
element.textContent += buffer;
|
||||||
|
}
|
||||||
|
buffer = "";
|
||||||
|
isTyping = false;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
return {
|
// === RUBBERBANDING SCROLL ===
|
||||||
append: function(text) {
|
function attachRubberbanding(
|
||||||
buffer += text;
|
container,
|
||||||
if (!isTyping) {
|
{ maxPull = 60, resistance = 0.4 } = {},
|
||||||
typeNextChar();
|
) {
|
||||||
}
|
let startY = 0;
|
||||||
},
|
let pulling = false;
|
||||||
complete: function() {
|
|
||||||
// Flush remaining buffer immediately
|
|
||||||
if (cursorElement && cursorElement.parentNode) {
|
|
||||||
const textNode = document.createTextNode(buffer);
|
|
||||||
element.insertBefore(textNode, cursorElement);
|
|
||||||
cursorElement.remove();
|
|
||||||
} else {
|
|
||||||
element.textContent += buffer;
|
|
||||||
}
|
|
||||||
buffer = '';
|
|
||||||
isTyping = false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
// === RUBBERBANDING SCROLL ===
|
function applyPull(distance) {
|
||||||
function initRubberbanding() {
|
container.style.transform = `translateY(${distance}px)`;
|
||||||
const containers = document.querySelectorAll('#chat-scroll-container, .content-scroll-container');
|
}
|
||||||
|
|
||||||
containers.forEach(container => {
|
|
||||||
let startY = 0;
|
|
||||||
let pulling = false;
|
|
||||||
let pullDistance = 0;
|
|
||||||
const maxPull = 60;
|
|
||||||
const resistance = 0.4;
|
|
||||||
|
|
||||||
container.addEventListener('touchstart', (e) => {
|
function release() {
|
||||||
startY = e.touches[0].clientY;
|
container.style.transition =
|
||||||
}, { passive: true });
|
"transform 300ms cubic-bezier(0.25, 1, 0.5, 1)";
|
||||||
|
container.style.transform = "translateY(0)";
|
||||||
|
setTimeout(() => {
|
||||||
|
container.style.transition = "";
|
||||||
|
}, 300);
|
||||||
|
pulling = false;
|
||||||
|
}
|
||||||
|
|
||||||
container.addEventListener('touchmove', (e) => {
|
function isAtTop() {
|
||||||
const currentY = e.touches[0].clientY;
|
return container.scrollTop <= 0;
|
||||||
const diff = currentY - startY;
|
}
|
||||||
|
function isAtBottom() {
|
||||||
// At top boundary, pulling down
|
return (
|
||||||
if (container.scrollTop <= 0 && diff > 0) {
|
container.scrollTop + container.clientHeight >= container.scrollHeight
|
||||||
pulling = true;
|
);
|
||||||
pullDistance = Math.min(diff * resistance, maxPull);
|
}
|
||||||
container.style.transform = `translateY(${pullDistance}px)`;
|
|
||||||
}
|
|
||||||
// At bottom boundary, pulling up
|
|
||||||
else if (container.scrollTop + container.clientHeight >= container.scrollHeight && diff < 0) {
|
|
||||||
pulling = true;
|
|
||||||
pullDistance = Math.max(diff * resistance, -maxPull);
|
|
||||||
container.style.transform = `translateY(${pullDistance}px)`;
|
|
||||||
}
|
|
||||||
}, { passive: true });
|
|
||||||
|
|
||||||
container.addEventListener('touchend', () => {
|
container.addEventListener(
|
||||||
if (pulling) {
|
"touchstart",
|
||||||
container.style.transition = 'transform 300ms cubic-bezier(0.25, 1, 0.5, 1)';
|
(e) => {
|
||||||
container.style.transform = 'translateY(0)';
|
startY = e.touches[0].clientY;
|
||||||
setTimeout(() => {
|
},
|
||||||
container.style.transition = '';
|
{ passive: true },
|
||||||
}, 300);
|
);
|
||||||
pulling = false;
|
|
||||||
pullDistance = 0;
|
|
||||||
}
|
|
||||||
}, { passive: true });
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// === INITIALIZATION ===
|
container.addEventListener(
|
||||||
function init() {
|
"touchmove",
|
||||||
initScrollShadow();
|
(e) => {
|
||||||
initHtmxSwapAnimation();
|
const diff = e.touches[0].clientY - startY;
|
||||||
initRubberbanding();
|
const isPullingDown = diff > 0 && isAtTop();
|
||||||
}
|
const isPullingUp = diff < 0 && isAtBottom();
|
||||||
|
|
||||||
// Run on DOMContentLoaded
|
if (isPullingDown) {
|
||||||
if (document.readyState === 'loading') {
|
pulling = true;
|
||||||
document.addEventListener('DOMContentLoaded', init);
|
applyPull(Math.min(diff * resistance, maxPull));
|
||||||
} else {
|
} else if (isPullingUp) {
|
||||||
init();
|
pulling = true;
|
||||||
}
|
applyPull(Math.max(diff * resistance, -maxPull));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ passive: true },
|
||||||
|
);
|
||||||
|
|
||||||
// Re-init rubberbanding after HTMX navigations
|
container.addEventListener(
|
||||||
document.body.addEventListener('htmx:afterSettle', () => {
|
"touchend",
|
||||||
initRubberbanding();
|
() => {
|
||||||
});
|
if (pulling) release();
|
||||||
|
},
|
||||||
|
{ passive: true },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Add typewriter cursor blink animation
|
function initRubberbanding() {
|
||||||
const style = document.createElement('style');
|
document
|
||||||
style.textContent = `
|
.querySelectorAll("#chat-scroll-container, .content-scroll-container")
|
||||||
|
.forEach((container) => attachRubberbanding(container));
|
||||||
|
}
|
||||||
|
|
||||||
|
// === INITIALIZATION ===
|
||||||
|
function init() {
|
||||||
|
initScrollShadow();
|
||||||
|
initHtmxSwapAnimation();
|
||||||
|
initRubberbanding();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run on DOMContentLoaded
|
||||||
|
if (document.readyState === "loading") {
|
||||||
|
document.addEventListener("DOMContentLoaded", init);
|
||||||
|
} else {
|
||||||
|
init();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-init rubberbanding after HTMX navigations
|
||||||
|
document.body.addEventListener("htmx:afterSettle", () => {
|
||||||
|
initRubberbanding();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add typewriter cursor blink animation
|
||||||
|
const style = document.createElement("style");
|
||||||
|
style.textContent = `
|
||||||
@keyframes blink {
|
@keyframes blink {
|
||||||
0%, 100% { opacity: 1; }
|
0%, 100% { opacity: 1; }
|
||||||
50% { opacity: 0; }
|
50% { opacity: 0; }
|
||||||
@@ -194,6 +218,5 @@
|
|||||||
font-weight: bold;
|
font-weight: bold;
|
||||||
}
|
}
|
||||||
`;
|
`;
|
||||||
document.head.appendChild(style);
|
document.head.appendChild(style);
|
||||||
|
|
||||||
})();
|
})();
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -44,7 +44,6 @@
|
|||||||
--leading-snug: 1.375;
|
--leading-snug: 1.375;
|
||||||
--leading-relaxed: 1.625;
|
--leading-relaxed: 1.625;
|
||||||
--ease-out: cubic-bezier(0, 0, 0.2, 1);
|
--ease-out: cubic-bezier(0, 0, 0.2, 1);
|
||||||
--ease-in-out: cubic-bezier(0.4, 0, 0.2, 1);
|
|
||||||
--animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
--animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
--default-transition-duration: 150ms;
|
--default-transition-duration: 150ms;
|
||||||
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||||
@@ -285,37 +284,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.drawer-open {
|
|
||||||
> .drawer-side {
|
|
||||||
overflow-y: auto;
|
|
||||||
}
|
|
||||||
> .drawer-toggle {
|
|
||||||
display: none;
|
|
||||||
& ~ .drawer-side {
|
|
||||||
pointer-events: auto;
|
|
||||||
visibility: visible;
|
|
||||||
position: sticky;
|
|
||||||
display: block;
|
|
||||||
width: auto;
|
|
||||||
overscroll-behavior: auto;
|
|
||||||
opacity: 100%;
|
|
||||||
& > .drawer-overlay {
|
|
||||||
cursor: default;
|
|
||||||
background-color: transparent;
|
|
||||||
}
|
|
||||||
& > *:not(.drawer-overlay) {
|
|
||||||
translate: 0%;
|
|
||||||
[dir="rtl"] & {
|
|
||||||
translate: 0%;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
&:checked ~ .drawer-side {
|
|
||||||
pointer-events: auto;
|
|
||||||
visibility: visible;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.drawer-toggle {
|
.drawer-toggle {
|
||||||
position: fixed;
|
position: fixed;
|
||||||
height: calc(0.25rem * 0);
|
height: calc(0.25rem * 0);
|
||||||
@@ -1074,22 +1042,6 @@
|
|||||||
grid-row-start: 1;
|
grid-row-start: 1;
|
||||||
min-width: calc(0.25rem * 0);
|
min-width: calc(0.25rem * 0);
|
||||||
}
|
}
|
||||||
.chat-image {
|
|
||||||
grid-row: span 2 / span 2;
|
|
||||||
align-self: flex-end;
|
|
||||||
}
|
|
||||||
.chat-footer {
|
|
||||||
grid-row-start: 3;
|
|
||||||
display: flex;
|
|
||||||
gap: calc(0.25rem * 1);
|
|
||||||
font-size: 0.6875rem;
|
|
||||||
}
|
|
||||||
.chat-header {
|
|
||||||
grid-row-start: 1;
|
|
||||||
display: flex;
|
|
||||||
gap: calc(0.25rem * 1);
|
|
||||||
font-size: 0.6875rem;
|
|
||||||
}
|
|
||||||
.container {
|
.container {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
@media (width >= 40rem) {
|
@media (width >= 40rem) {
|
||||||
@@ -1796,9 +1748,6 @@
|
|||||||
.w-10 {
|
.w-10 {
|
||||||
width: calc(var(--spacing) * 10);
|
width: calc(var(--spacing) * 10);
|
||||||
}
|
}
|
||||||
.w-11 {
|
|
||||||
width: calc(var(--spacing) * 11);
|
|
||||||
}
|
|
||||||
.w-11\/12 {
|
.w-11\/12 {
|
||||||
width: calc(11/12 * 100%);
|
width: calc(11/12 * 100%);
|
||||||
}
|
}
|
||||||
@@ -1862,9 +1811,6 @@
|
|||||||
.flex-none {
|
.flex-none {
|
||||||
flex: none;
|
flex: none;
|
||||||
}
|
}
|
||||||
.flex-shrink {
|
|
||||||
flex-shrink: 1;
|
|
||||||
}
|
|
||||||
.flex-shrink-0 {
|
.flex-shrink-0 {
|
||||||
flex-shrink: 0;
|
flex-shrink: 0;
|
||||||
}
|
}
|
||||||
@@ -1877,13 +1823,6 @@
|
|||||||
.grow {
|
.grow {
|
||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
}
|
}
|
||||||
.border-collapse {
|
|
||||||
border-collapse: collapse;
|
|
||||||
}
|
|
||||||
.-translate-y-1 {
|
|
||||||
--tw-translate-y: calc(var(--spacing) * -1);
|
|
||||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
|
||||||
}
|
|
||||||
.-translate-y-1\/2 {
|
.-translate-y-1\/2 {
|
||||||
--tw-translate-y: calc(calc(1/2 * 100%) * -1);
|
--tw-translate-y: calc(calc(1/2 * 100%) * -1);
|
||||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
translate: var(--tw-translate-x) var(--tw-translate-y);
|
||||||
@@ -1956,9 +1895,6 @@
|
|||||||
.justify-start {
|
.justify-start {
|
||||||
justify-content: flex-start;
|
justify-content: flex-start;
|
||||||
}
|
}
|
||||||
.gap-0 {
|
|
||||||
gap: calc(var(--spacing) * 0);
|
|
||||||
}
|
|
||||||
.gap-0\.5 {
|
.gap-0\.5 {
|
||||||
gap: calc(var(--spacing) * 0.5);
|
gap: calc(var(--spacing) * 0.5);
|
||||||
}
|
}
|
||||||
@@ -2091,9 +2027,6 @@
|
|||||||
.border-base-200 {
|
.border-base-200 {
|
||||||
border-color: var(--color-base-200);
|
border-color: var(--color-base-200);
|
||||||
}
|
}
|
||||||
.border-base-content {
|
|
||||||
border-color: var(--color-base-content);
|
|
||||||
}
|
|
||||||
.border-base-content\/10 {
|
.border-base-content\/10 {
|
||||||
border-color: var(--color-base-content);
|
border-color: var(--color-base-content);
|
||||||
@supports (color: color-mix(in lab, red, red)) {
|
@supports (color: color-mix(in lab, red, red)) {
|
||||||
@@ -2130,9 +2063,6 @@
|
|||||||
.bg-transparent {
|
.bg-transparent {
|
||||||
background-color: transparent;
|
background-color: transparent;
|
||||||
}
|
}
|
||||||
.bg-warning {
|
|
||||||
background-color: var(--color-warning);
|
|
||||||
}
|
|
||||||
.bg-warning\/10 {
|
.bg-warning\/10 {
|
||||||
background-color: var(--color-warning);
|
background-color: var(--color-warning);
|
||||||
@supports (color: color-mix(in lab, red, red)) {
|
@supports (color: color-mix(in lab, red, red)) {
|
||||||
@@ -2151,9 +2081,6 @@
|
|||||||
.loading-spinner {
|
.loading-spinner {
|
||||||
mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E");
|
mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E");
|
||||||
}
|
}
|
||||||
.mask-repeat {
|
|
||||||
mask-repeat: repeat;
|
|
||||||
}
|
|
||||||
.fill-current {
|
.fill-current {
|
||||||
fill: currentcolor;
|
fill: currentcolor;
|
||||||
}
|
}
|
||||||
@@ -2184,9 +2111,6 @@
|
|||||||
.p-8 {
|
.p-8 {
|
||||||
padding: calc(var(--spacing) * 8);
|
padding: calc(var(--spacing) * 8);
|
||||||
}
|
}
|
||||||
.px-1 {
|
|
||||||
padding-inline: calc(var(--spacing) * 1);
|
|
||||||
}
|
|
||||||
.px-1\.5 {
|
.px-1\.5 {
|
||||||
padding-inline: calc(var(--spacing) * 1.5);
|
padding-inline: calc(var(--spacing) * 1.5);
|
||||||
}
|
}
|
||||||
@@ -2341,9 +2265,6 @@
|
|||||||
--tw-tracking: var(--tracking-widest);
|
--tw-tracking: var(--tracking-widest);
|
||||||
letter-spacing: var(--tracking-widest);
|
letter-spacing: var(--tracking-widest);
|
||||||
}
|
}
|
||||||
.text-wrap {
|
|
||||||
text-wrap: wrap;
|
|
||||||
}
|
|
||||||
.break-words {
|
.break-words {
|
||||||
overflow-wrap: break-word;
|
overflow-wrap: break-word;
|
||||||
}
|
}
|
||||||
@@ -2410,17 +2331,6 @@
|
|||||||
.italic {
|
.italic {
|
||||||
font-style: italic;
|
font-style: italic;
|
||||||
}
|
}
|
||||||
.underline {
|
|
||||||
text-decoration-line: underline;
|
|
||||||
}
|
|
||||||
.swap-active {
|
|
||||||
.swap-off {
|
|
||||||
opacity: 0%;
|
|
||||||
}
|
|
||||||
.swap-on {
|
|
||||||
opacity: 100%;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.opacity-0 {
|
.opacity-0 {
|
||||||
opacity: 0%;
|
opacity: 0%;
|
||||||
}
|
}
|
||||||
@@ -2514,10 +2424,6 @@
|
|||||||
--tw-duration: 300ms;
|
--tw-duration: 300ms;
|
||||||
transition-duration: 300ms;
|
transition-duration: 300ms;
|
||||||
}
|
}
|
||||||
.ease-in-out {
|
|
||||||
--tw-ease: var(--ease-in-out);
|
|
||||||
transition-timing-function: var(--ease-in-out);
|
|
||||||
}
|
|
||||||
.ease-out {
|
.ease-out {
|
||||||
--tw-ease: var(--ease-out);
|
--tw-ease: var(--ease-out);
|
||||||
transition-timing-function: var(--ease-out);
|
transition-timing-function: var(--ease-out);
|
||||||
|
|||||||
Generated
+3
-3
@@ -958,9 +958,9 @@
|
|||||||
"license": "ISC"
|
"license": "ISC"
|
||||||
},
|
},
|
||||||
"node_modules/picomatch": {
|
"node_modules/picomatch": {
|
||||||
"version": "2.3.1",
|
"version": "2.3.2",
|
||||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||||
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
|
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=8.6"
|
"node": ">=8.6"
|
||||||
|
|||||||
@@ -4,6 +4,9 @@
|
|||||||
//! the template middleware renders them with shared layout context. Route composition
|
//! the template middleware renders them with shared layout context. Route composition
|
||||||
//! and middleware layering are handled by [`router_factory::RouterFactory`].
|
//! and middleware layering are handled by [`router_factory::RouterFactory`].
|
||||||
|
|
||||||
|
// minijinja_embed output (release builds) triggers these lints.
|
||||||
|
#![allow(unused_variables, clippy::expect_used, clippy::missing_panics_doc)]
|
||||||
|
|
||||||
pub mod html_state;
|
pub mod html_state;
|
||||||
pub mod middlewares;
|
pub mod middlewares;
|
||||||
pub mod router_factory;
|
pub mod router_factory;
|
||||||
|
|||||||
@@ -22,13 +22,13 @@ macro_rules! create_asset_service {
|
|||||||
let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||||
let assets_path = crate_dir.join($relative_path);
|
let assets_path = crate_dir.join($relative_path);
|
||||||
tracing::debug!("Assets: Serving from filesystem: {:?}", assets_path);
|
tracing::debug!("Assets: Serving from filesystem: {:?}", assets_path);
|
||||||
tower_http::services::ServeDir::new(assets_path)
|
tower_http::services::ServeDir::new(&assets_path)
|
||||||
}
|
}
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
{
|
{
|
||||||
tracing::debug!("Assets: Serving embedded directory");
|
|
||||||
static ASSETS_DIR: include_dir::Dir<'static> =
|
static ASSETS_DIR: include_dir::Dir<'static> =
|
||||||
include_dir::include_dir!("$CARGO_MANIFEST_DIR/assets");
|
include_dir::include_dir!("$CARGO_MANIFEST_DIR/assets");
|
||||||
|
tracing::debug!(directory = %$relative_path, "Assets: Serving embedded directory");
|
||||||
tower_serve_static::ServeDir::new(&ASSETS_DIR)
|
tower_serve_static::ServeDir::new(&ASSETS_DIR)
|
||||||
}
|
}
|
||||||
}};
|
}};
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use async_openai::types::ListModelResponse;
|
use async_openai::types::models::ListModelResponse;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Query, State},
|
extract::{Query, State},
|
||||||
Form,
|
Form,
|
||||||
@@ -350,6 +350,9 @@ mod tests {
|
|||||||
image_processing_model: "gpt-4o-mini".into(),
|
image_processing_model: "gpt-4o-mini".into(),
|
||||||
image_processing_prompt: "p".into(),
|
image_processing_prompt: "p".into(),
|
||||||
voice_processing_model: "whisper-1".into(),
|
voice_processing_model: "whisper-1".into(),
|
||||||
|
last_index_rebuild_at: None,
|
||||||
|
index_rebuild_lease_owner: None,
|
||||||
|
index_rebuild_lease_expires_at: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ fn build_chat_event_stream(
|
|||||||
state: HtmlState,
|
state: HtmlState,
|
||||||
openai_stream: impl Stream<
|
openai_stream: impl Stream<
|
||||||
Item = Result<
|
Item = Result<
|
||||||
async_openai::types::CreateChatCompletionStreamResponse,
|
async_openai::types::chat::CreateChatCompletionStreamResponse,
|
||||||
async_openai::error::OpenAIError,
|
async_openai::error::OpenAIError,
|
||||||
>,
|
>,
|
||||||
> + Send
|
> + Send
|
||||||
@@ -342,7 +342,7 @@ async fn prepare_chat_request(
|
|||||||
history: &[Message],
|
history: &[Message],
|
||||||
) -> Result<
|
) -> Result<
|
||||||
(
|
(
|
||||||
async_openai::types::CreateChatCompletionRequest,
|
async_openai::types::chat::CreateChatCompletionRequest,
|
||||||
Vec<String>,
|
Vec<String>,
|
||||||
),
|
),
|
||||||
SseResponse,
|
SseResponse,
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ use axum::{
|
|||||||
use axum_htmx::{HxBoosted, HxRequest, HxTarget};
|
use axum_htmx::{HxBoosted, HxRequest, HxTarget};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use common::storage::types::{
|
use common::storage::types::{file_info::FileInfo, text_content::TextContent, user::User};
|
||||||
file_info::FileInfo, knowledge_entity::KnowledgeEntity, text_chunk::TextChunk,
|
|
||||||
text_content::TextContent, user::User,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
html_state::HtmlState,
|
html_state::HtmlState,
|
||||||
@@ -180,9 +177,7 @@ pub async fn delete_text_content(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete related knowledge entities and text chunks
|
TextContent::clear_ingested_children(&id, &user.id, &state.db).await?;
|
||||||
KnowledgeEntity::delete_by_source_id(&id, &state.db).await?;
|
|
||||||
TextChunk::delete_by_source_id(&id, &state.db).await?;
|
|
||||||
|
|
||||||
// Delete the text content
|
// Delete the text content
|
||||||
state.db.delete_item::<TextContent>(&id).await?;
|
state.db.delete_item::<TextContent>(&id).await?;
|
||||||
|
|||||||
@@ -23,9 +23,7 @@ use common::storage::types::user::DashboardStats;
|
|||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::types::{
|
storage::types::{
|
||||||
file_info::FileInfo, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
|
file_info::FileInfo, ingestion_task::IngestionTask, text_content::TextContent, user::User,
|
||||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
|
||||||
text_content::TextContent, user::User,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -81,11 +79,7 @@ pub async fn delete_text_content(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the text content and any related data
|
TextContent::clear_ingested_children(&text_content.id, &user.id, &state.db).await?;
|
||||||
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
|
|
||||||
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
|
|
||||||
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db)
|
|
||||||
.await?;
|
|
||||||
state
|
state
|
||||||
.db
|
.db
|
||||||
.delete_item::<TextContent>(&text_content.id)
|
.delete_item::<TextContent>(&text_content.id)
|
||||||
|
|||||||
@@ -203,7 +203,13 @@ pub async fn create_knowledge_entity(
|
|||||||
);
|
);
|
||||||
let new_entity_id = new_entity.id.clone();
|
let new_entity_id = new_entity.id.clone();
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(new_entity, embedding, &state.db).await?;
|
KnowledgeEntity::store_with_embedding(
|
||||||
|
new_entity,
|
||||||
|
embedding,
|
||||||
|
state.embedding_provider.dimension(),
|
||||||
|
&state.db,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let relationship_type = relationship_type_or_default(form.relationship_type.as_deref());
|
let relationship_type = relationship_type_or_default(form.relationship_type.as_deref());
|
||||||
let user_id = user.id.clone();
|
let user_id = user.id.clone();
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
|
|
||||||
{# Default: one outer #modal_form. Modals with multiple forms (scratchpad editor)
|
{# Default: one outer #modal_form. Modals with multiple forms (scratchpad editor)
|
||||||
override modal_form_open / modal_form_close — nested <form> is invalid HTML. #}
|
override modal_form_open / modal_form_close — nested <form> is invalid HTML. #}
|
||||||
{% block modal_form_open %}<form id="modal_form" hx-on::after-request="if(event.detail.successful) document.getElementById('body_modal').close()" {% block form_attributes %}{% endblock %}>{% endblock %}
|
{% block modal_form_open %}<form id="modal_form" hx-on::after-request="if(event.detail.successful && event.detail.elt === event.currentTarget) document.getElementById('body_modal').close()" {% block form_attributes %}{% endblock %}>{% endblock %}
|
||||||
<div class="flex flex-col flex-1 gap-5">
|
<div class="flex flex-col flex-1 gap-5">
|
||||||
{% block modal_content %}{% endblock %}
|
{% block modal_content %}{% endblock %}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -333,6 +333,22 @@ async fn snapshot_new_entity_modal() {
|
|||||||
snapshot_settings().bind(|| insta::assert_snapshot!("new_entity_modal", body));
|
snapshot_settings().bind(|| insta::assert_snapshot!("new_entity_modal", body));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn modal_form_after_request_ignores_inner_htmx_requests() {
|
||||||
|
let (app, db) = build_test_app().await;
|
||||||
|
let cookie = seeded_cookie(&app, &db).await;
|
||||||
|
let modal = get_html(&app, "/knowledge-entity/new", Some(&cookie)).await;
|
||||||
|
|
||||||
|
// Inner buttons (e.g. Suggest Relationships) bubble htmx:afterRequest to
|
||||||
|
// #modal_form; closing must only run when the form itself submitted.
|
||||||
|
assert!(
|
||||||
|
modal.contains(
|
||||||
|
r#"hx-on::after-request="if(event.detail.successful && event.detail.elt === event.currentTarget) document.getElementById('body_modal').close()"#
|
||||||
|
),
|
||||||
|
"#modal_form should ignore bubbled after-request events from child elements"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
async fn sign_in(app: &Router, email: &str, password: &str) -> String {
|
async fn sign_in(app: &Router, email: &str, password: &str) -> String {
|
||||||
let response = app
|
let response = app
|
||||||
.clone()
|
.clone()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
---
|
---
|
||||||
source: html-router/tests/router_integration.rs
|
source: html-router/tests/router_integration.rs
|
||||||
|
assertion_line: 333
|
||||||
expression: body
|
expression: body
|
||||||
---
|
---
|
||||||
<dialog id="body_modal" class="modal">
|
<dialog id="body_modal" class="modal">
|
||||||
@@ -18,7 +19,7 @@ expression: body
|
|||||||
</button>
|
</button>
|
||||||
|
|
||||||
|
|
||||||
<form id="modal_form" hx-on::after-request="if(event.detail.successful) document.getElementById('body_modal').close()"
|
<form id="modal_form" hx-on::after-request="if(event.detail.successful && event.detail.elt === event.currentTarget) document.getElementById('body_modal').close()"
|
||||||
hx-post="/knowledge-entity"
|
hx-post="/knowledge-entity"
|
||||||
hx-target="#knowledge_pane"
|
hx-target="#knowledge_pane"
|
||||||
hx-swap="outerHTML"
|
hx-swap="outerHTML"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user