mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-25 03:16:26 +02:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bddb603bd9 | |||
| aee9136f1e | |||
| 253ac9819d | |||
| 511e42a078 | |||
| d8e839bf46 | |||
| 588e616baf | |||
| 87e6fa14b2 | |||
| 09e545816e | |||
| 01ef1bcb7a | |||
| 530cd0a8f1 | |||
| 3b20adc50f | |||
| b3d42d2586 | |||
| fb51a8b55f | |||
| adc04d8c6d | |||
| 1013035731 | |||
| cf69cb7b05 | |||
| ead17530bd | |||
| 4e8a58fff1 |
+1
-1
@@ -1,2 +1,2 @@
|
||||
[alias]
|
||||
eval = "run -p evaluations --"
|
||||
eval = "run -p evaluations --release --"
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
# Git stuff
|
||||
.git/
|
||||
.gitignore
|
||||
.github
|
||||
|
||||
# Node build artifacts
|
||||
**/node_modules/
|
||||
|
||||
# Nix/Devenv environment files
|
||||
.direnv/
|
||||
.devenv/
|
||||
devenv.lock
|
||||
devenv.nix
|
||||
devenv.yaml
|
||||
docker-compose.yml
|
||||
.envrc
|
||||
.devenv.flake.nix
|
||||
flake.lock
|
||||
flake.nix
|
||||
|
||||
# Rust build artifacts (crucial for multi-stage builds)
|
||||
**/target/
|
||||
|
||||
# Runtime data directories
|
||||
data/
|
||||
database/
|
||||
|
||||
# Local environment config (sensitive)
|
||||
.env
|
||||
|
||||
# IDE specific
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# OS specific
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs / Temporary files
|
||||
*.log
|
||||
@@ -0,0 +1,30 @@
|
||||
name: CI
|
||||
permissions:
|
||||
contents: read
|
||||
actions: write
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Format, lint, build & test
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- uses: DeterminateSystems/determinate-nix-action@v3
|
||||
|
||||
- uses: nix-community/cache-nix-action@v7
|
||||
with:
|
||||
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||
gc-max-store-size-linux: 10G
|
||||
|
||||
- name: Check formatting, clippy lint, unit tests & ort version
|
||||
run: nix flake check --show-trace
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
pull_request:
|
||||
push:
|
||||
tags:
|
||||
- '**[0-9]+.[0-9]+.[0-9]+*'
|
||||
- "**[0-9]+.[0-9]+.[0-9]+*"
|
||||
|
||||
jobs:
|
||||
plan:
|
||||
@@ -17,6 +17,7 @@ jobs:
|
||||
tag: ${{ !github.event.pull_request && github.ref_name || '' }}
|
||||
tag-flag: ${{ !github.event.pull_request && format('--tag={0}', github.ref_name) || '' }}
|
||||
publishing: ${{ !github.event.pull_request }}
|
||||
ort-version: ${{ steps.ort_version.outputs.value }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
@@ -25,13 +26,20 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install Nix
|
||||
uses: cachix/install-nix-action@v27
|
||||
uses: DeterminateSystems/determinate-nix-action@v3
|
||||
|
||||
- uses: nix-community/cache-nix-action@v7
|
||||
with:
|
||||
extra_nix_config: |
|
||||
experimental-features = nix-command flakes
|
||||
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||
gc-max-store-size-linux: 10G
|
||||
|
||||
- name: Read ORT version from flake
|
||||
id: ort_version
|
||||
run: echo "value=$(nix eval .#lib.ortVersion --raw)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Verify ort-version matches nixpkgs onnxruntime
|
||||
run: nix flake check --system x86_64-linux -L
|
||||
run: nix flake check --system x86_64-linux
|
||||
|
||||
- name: Install dist
|
||||
shell: bash
|
||||
@@ -78,7 +86,7 @@ jobs:
|
||||
|
||||
- name: Load ONNX Runtime version
|
||||
shell: bash
|
||||
run: echo "ORT_VER=$(tr -d '[:space:]' < ort-version)" >> "$GITHUB_ENV"
|
||||
run: echo "ORT_VER=${{ needs.plan.outputs.ort-version }}" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Install Rust non-interactively if not already installed
|
||||
if: ${{ matrix.container }}
|
||||
@@ -108,7 +116,7 @@ jobs:
|
||||
run: |
|
||||
mkdir -p lib
|
||||
rm -f lib/*
|
||||
|
||||
|
||||
# Windows PowerShell
|
||||
- name: Prepare lib dir (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
@@ -158,7 +166,6 @@ jobs:
|
||||
echo "lib/ contents:"
|
||||
ls -l lib || dir lib
|
||||
# ===== END: Injected ORT staging =====
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
${{ matrix.packages_install }}
|
||||
@@ -186,21 +193,31 @@ jobs:
|
||||
${{ env.BUILD_MANIFEST_NAME }}
|
||||
|
||||
build_and_push_docker_image:
|
||||
name: Build and Push Docker Image
|
||||
name: Build and Push Docker Image (Nix)
|
||||
runs-on: ubuntu-latest
|
||||
needs: [plan]
|
||||
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
packages: write
|
||||
actions: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Install Nix
|
||||
uses: DeterminateSystems/determinate-nix-action@v3
|
||||
|
||||
- uses: nix-community/cache-nix-action@v7
|
||||
with:
|
||||
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||
gc-max-store-size-linux: 10G
|
||||
|
||||
- name: Build Docker image with Nix
|
||||
run: nix build .#dockerImage -L --show-trace
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
@@ -215,15 +232,16 @@ jobs:
|
||||
with:
|
||||
images: ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
- name: Load and push Docker image
|
||||
env:
|
||||
IMAGE_NAME: ghcr.io/${{ github.repository }}
|
||||
IMAGE_TAG: ${{ needs.plan.outputs.tag }}
|
||||
run: |
|
||||
docker load < result
|
||||
docker tag "minne:1.0.3" "$IMAGE_NAME:$IMAGE_TAG"
|
||||
docker tag "minne:1.0.3" "$IMAGE_NAME:latest"
|
||||
docker push "$IMAGE_NAME:$IMAGE_TAG"
|
||||
docker push "$IMAGE_NAME:latest"
|
||||
|
||||
build-global-artifacts:
|
||||
needs: [plan, build-local-artifacts]
|
||||
|
||||
+41
-1
@@ -1,7 +1,31 @@
|
||||
# Changelog
|
||||
|
||||
## Unreleased
|
||||
|
||||
- Infra: CI workflow fixes. CI is now a nix flake check which includes compilation, caching and running tests, clippy, fmt, validation for ort version.
|
||||
- Docker-compose: The example now references the ghcr image, this is so we can remove the Dockerfile and reducing maintenance scope.
|
||||
- Refactor: web scraping now uses `servo-fetch` (pure-Rust Servo engine) and PDF rendering uses `pdfium-render` (direct PDFium bindings) — reduces Docker image size by ~300MB, improves startup latency by ~100× for PDF rendering, and provides more stable output
|
||||
- Fix: added `pkgs.libglvnd` to `LD_LIBRARY_PATH` in devenv so Servo engine can find `libEGL.so` at runtime
|
||||
- Fix: updated Dockerfile to add `libegl1 libegl-mesa0 libgles2 libfontconfig1 libfreetype6` runtime dependencies for servo-fetch
|
||||
- Docs: updated architecture, features, and installation docs to reflect the new web processing stack
|
||||
- Fix: added pre-commit hooks to further maintain code consistency.
|
||||
- Security: updated some deps because dependabot told me, good bot.
|
||||
- Security: bump `async-openai` to 0.41.1 (feature-gated types, transcription API rename; removes `backoff` transitive dep)
|
||||
- Refactor: deduplicated test database setup across common/src/storage/.
|
||||
- Refactor: split knowledge-graph.js monolith into focused functions.
|
||||
- Evaluations: simplified crate layout — linear pipeline, sharded-only converted store, in-memory ingestion, `db/` and `cli/` modules; namespace reuse state in corpus manifest (removed `cache/snapshots/`); no legacy JSON/history compatibility (re-run `--warm` after upgrade)
|
||||
- Performance: ingestion skips per-task index rebuild; worker runs scheduled `REBUILD INDEX` (default every 24h via `index_rebuild_interval_secs`, `0` disables)
|
||||
- Performance: ingestion persists all artifacts in a single SurrealDB transaction per task (atomic replace by task id)
|
||||
- Performance: entity embeddings during ingestion use batched `embed_batch`, matching chunk embedding
|
||||
- Fix: ingestion reclaims tasks after a successful persist without re-running the pipeline when `mark_succeeded` failed
|
||||
- Fix: content deletion clears graph relationships via shared `TextContent::clear_ingested_children`
|
||||
- Fix: regression re suggestion of relationships
|
||||
- Internal: extracted duplicate entity+embedding patterns into `HasEmbedding` and `EmbeddingRecord` traits with generic `store_with_embedding`, `delete_by_source_id`, and `vector_search` on `SurrealDbClient`.
|
||||
- Infra: `ort-version` file removed — version inlined in `flake.nix` and `devenv.nix`; `release.yml` reads it via `nix eval .#lib.ortVersion` from the plan job
|
||||
- Infra: `screenshot-graph.webp` and `.dockerignore` deleted — stale artifacts from Dockerfile era
|
||||
|
||||
## 1.0.3 (2026-06-12)
|
||||
|
||||
- Search: filter results by type — knowledge entities, ingested content, or both
|
||||
- Admin: choose the local FastEmbed model from the admin UI; changes save immediately and apply after restart (re-embeds when the vector dimension changes)
|
||||
- Performance: pooled FastEmbed workers and batched embedding generation for faster ingestion and search
|
||||
@@ -11,6 +35,7 @@
|
||||
- Fix: API key revocation now correctly clears the stored key
|
||||
|
||||
## 1.0.2 (2026-02-15)
|
||||
|
||||
- Fix: edge case where navigation back to a chat page could trigger a new response generation
|
||||
- Fix: chat references now validate and render more reliably
|
||||
- Fix: improved admin access checks for restricted routes
|
||||
@@ -19,73 +44,88 @@
|
||||
- Security: hardened query handling and ingestion logging to reduce injection and data exposure risk
|
||||
|
||||
## 1.0.1 (2026-02-11)
|
||||
|
||||
- Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments.
|
||||
- Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling.
|
||||
- Fixed edge cases, including content deletion behavior and compatibility for older user records.
|
||||
|
||||
## 1.0.0 (2026-01-02)
|
||||
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
|
||||
|
||||
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
|
||||
- Added a benchmarks create for evaluating the retrieval process
|
||||
- Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms.
|
||||
- Embeddings stored on own table.
|
||||
- Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details.
|
||||
|
||||
## Version 0.2.7 (2025-12-04)
|
||||
|
||||
- Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features.
|
||||
- Fix: timezone aware info in scratchpad
|
||||
|
||||
## Version 0.2.6 (2025-10-29)
|
||||
|
||||
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||
- Fix: default name for relationships harmonized across application
|
||||
|
||||
## Version 0.2.5 (2025-10-24)
|
||||
|
||||
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
||||
- Scratchpad feature, with the feature to convert scratchpads to content.
|
||||
- Added knowledge entity search results to the global search
|
||||
- Backend fixes for improved performance when ingesting and retrieval
|
||||
|
||||
## Version 0.2.4 (2025-10-15)
|
||||
|
||||
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
|
||||
- Ingestion task archive
|
||||
|
||||
## Version 0.2.3 (2025-10-12)
|
||||
|
||||
- Fix changing vector dimensions on a fresh database (#3)
|
||||
|
||||
## Version 0.2.2 (2025-10-07)
|
||||
|
||||
- Support for ingestion of PDF files
|
||||
- Improved ingestion speed
|
||||
- Fix deletion of items work as expected
|
||||
- Fix enabling GPT-5 use via OpenAI API
|
||||
|
||||
## Version 0.2.1 (2025-09-24)
|
||||
|
||||
- Fixed API JSON responses so iOS Shortcuts integrations keep working.
|
||||
|
||||
## Version 0.2.0 (2025-09-23)
|
||||
|
||||
- Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph.
|
||||
- Added pagination for entities and content plus new observability metrics on the dashboard.
|
||||
- Enabled audio ingestion and merged the new storage backend.
|
||||
- Improved performance, request filtering, and journalctl/systemd compatibility.
|
||||
|
||||
## Version 0.1.4 (2025-07-01)
|
||||
|
||||
- Added image ingestion with configurable system settings and updated Docker Compose docs.
|
||||
- Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses.
|
||||
|
||||
## Version 0.1.3 (2025-06-08)
|
||||
|
||||
- Added support for AI providers beyond OpenAI.
|
||||
- Made the HTTP port configurable for deployments.
|
||||
- Smoothed graph mapper failures, long content tiles, and refreshed project documentation.
|
||||
|
||||
## Version 0.1.2 (2025-05-26)
|
||||
|
||||
- Introduced full-text search across indexed knowledge.
|
||||
- Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling.
|
||||
- Fixed search result links and SurrealDB vector formatting glitches.
|
||||
|
||||
## Version 0.1.1 (2025-05-13)
|
||||
|
||||
- Added streaming feedback to ingestion tasks for clearer progress updates.
|
||||
- Made the data storage path configurable.
|
||||
- Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes.
|
||||
|
||||
## Version 0.1.0 (2025-05-06)
|
||||
|
||||
- Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage.
|
||||
- Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts.
|
||||
- Introduced an admin console with analytics, registration and timezone controls, and job monitoring.
|
||||
|
||||
Generated
+6203
-364
File diff suppressed because it is too large
Load Diff
+26
-13
@@ -7,19 +7,24 @@ members = [
|
||||
"ingestion-pipeline",
|
||||
"retrieval-pipeline",
|
||||
"json-stream-parser",
|
||||
"evaluations"
|
||||
"evaluations",
|
||||
]
|
||||
resolver = "2"
|
||||
resolver = "3"
|
||||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1.0.94"
|
||||
async-openai = "0.29.3"
|
||||
async-openai = { version = "0.41.1", features = [
|
||||
"chat-completion",
|
||||
"embedding",
|
||||
"audio",
|
||||
"model",
|
||||
] }
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1.88"
|
||||
axum-htmx = "0.7.0"
|
||||
axum_session = "0.16"
|
||||
axum_session_auth = "0.16"
|
||||
axum_session_surreal = "0.4"
|
||||
axum_session = "0.18"
|
||||
axum_session_auth = "0.18"
|
||||
axum_session_surreal = "0.6"
|
||||
axum_typed_multipart = "0.16"
|
||||
axum = { version = "0.8", features = ["multipart", "macros"] }
|
||||
chrono-tz = "0.10.1"
|
||||
@@ -27,7 +32,6 @@ chrono = { version = "0.4.39", features = ["serde"] }
|
||||
config = "0.15.4"
|
||||
dom_smoothie = "0.10.0"
|
||||
futures = "0.3.31"
|
||||
headless_chrome = "1.0.17"
|
||||
include_dir = "0.7.4"
|
||||
mime = "0.3.17"
|
||||
mime_guess = "2.0.5"
|
||||
@@ -35,12 +39,12 @@ minijinja-autoreload = "2.5.0"
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
minijinja-embed = { version = "2.8.0" }
|
||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
reqwest = {version = "0.12.12", features = ["charset", "json"]}
|
||||
reqwest = { version = "0.12.12", features = ["charset", "json"] }
|
||||
serde_json = "1.0.128"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
sha2 = "0.10.8"
|
||||
surrealdb-migrations = "2.2.2"
|
||||
surrealdb = { version = "2" }
|
||||
surrealdb-migrations = "2.4.0"
|
||||
surrealdb = { version = "2.6" }
|
||||
tempfile = "3.12.0"
|
||||
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
|
||||
tokenizers = { version = "0.20.4", features = ["http"] }
|
||||
@@ -61,14 +65,24 @@ bytes = "1.7.1"
|
||||
state-machines = "0.9"
|
||||
pdf-extract = "0.9"
|
||||
lopdf = "0.32"
|
||||
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
|
||||
pdfium-auto = "0.3"
|
||||
pdfium-render = "0.8"
|
||||
servo-fetch = "0.13"
|
||||
tendril = "0.4"
|
||||
image = { version = "0.25", default-features = false, features = ["png"] }
|
||||
fastembed = { version = "5.2.0", default-features = false, features = [
|
||||
"hf-hub-native-tls",
|
||||
"ort-load-dynamic",
|
||||
] }
|
||||
|
||||
[profile.dist]
|
||||
inherits = "release"
|
||||
lto = "thin"
|
||||
|
||||
[workspace.lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ["cfg(feature, values(\"inspect\"))"] }
|
||||
unexpected_cfgs = { level = "warn", check-cfg = [
|
||||
"cfg(feature, values(\"inspect\"))",
|
||||
] }
|
||||
|
||||
[workspace.lints.clippy]
|
||||
# Performance-focused lints
|
||||
@@ -118,4 +132,3 @@ needless_raw_string_hashes = "allow"
|
||||
multiple_bound_locations = "allow"
|
||||
cargo_common_metadata = "allow"
|
||||
multiple-crate-versions = "allow"
|
||||
|
||||
|
||||
-53
@@ -1,53 +0,0 @@
|
||||
# === Builder ===
|
||||
FROM rust:1.91.1-bookworm AS builder
|
||||
WORKDIR /usr/src/minne
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cache deps
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
RUN mkdir -p api-router common retrieval-pipeline html-router ingestion-pipeline json-stream-parser main worker
|
||||
COPY api-router/Cargo.toml ./api-router/
|
||||
COPY common/Cargo.toml ./common/
|
||||
COPY retrieval-pipeline/Cargo.toml ./retrieval-pipeline/
|
||||
COPY html-router/Cargo.toml ./html-router/
|
||||
COPY ingestion-pipeline/Cargo.toml ./ingestion-pipeline/
|
||||
COPY json-stream-parser/Cargo.toml ./json-stream-parser/
|
||||
COPY main/Cargo.toml ./main/
|
||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker || true
|
||||
|
||||
# Build
|
||||
COPY . .
|
||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker
|
||||
|
||||
# === Runtime ===
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Chromium + runtime deps + OpenMP for ORT
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
chromium libnss3 libasound2 libgbm1 libxshmfence1 \
|
||||
ca-certificates fonts-dejavu fonts-noto-color-emoji \
|
||||
libgomp1 libstdc++6 curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ONNX Runtime (CPU). Version is read from ort-version (override with --build-arg ORT_VERSION=...).
|
||||
COPY ort-version /tmp/ort-version
|
||||
ARG ORT_VERSION
|
||||
RUN ORT_VERSION="${ORT_VERSION:-$(tr -d '[:space:]' < /tmp/ort-version)}" && \
|
||||
mkdir -p /opt/onnxruntime && \
|
||||
curl -fsSL -o /tmp/ort.tgz \
|
||||
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
||||
tar -xzf /tmp/ort.tgz -C /opt/onnxruntime --strip-components=1 && rm /tmp/ort.tgz
|
||||
|
||||
ENV CHROME_BIN=/usr/bin/chromium \
|
||||
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \
|
||||
ORT_DYLIB_PATH=/opt/onnxruntime/lib/libonnxruntime.so
|
||||
|
||||
# Non-root
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
WORKDIR /home/appuser
|
||||
|
||||
COPY --from=builder /usr/src/minne/target/release/main /usr/local/bin/main
|
||||
EXPOSE 3000
|
||||
CMD ["main"]
|
||||
@@ -121,7 +121,7 @@ fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults t
|
||||
- **Frontend:** HTML with HTMX and minimal JavaScript for interactivity
|
||||
- **Database:** SurrealDB (graph, document, and vector search)
|
||||
- **AI Integration:** OpenAI-compatible API with structured outputs
|
||||
- **Web Processing:** Headless Chrome for robust webpage content extraction
|
||||
- **Web Processing:** Embedded Servo engine (servo-fetch) for webpage content extraction + PDFium for PDF rendering
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -172,7 +172,7 @@ cd minne
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
The included `docker-compose.yml` handles SurrealDB and Chromium dependencies automatically.
|
||||
The included `docker-compose.yml` handles SurrealDB automatically.
|
||||
|
||||
### 2. Nix
|
||||
|
||||
@@ -180,13 +180,13 @@ The included `docker-compose.yml` handles SurrealDB and Chromium dependencies au
|
||||
nix run 'github:perstarkse/minne#main'
|
||||
```
|
||||
|
||||
This fetches Minne and all dependencies, including Chromium.
|
||||
This fetches Minne and all dependencies.
|
||||
|
||||
### 3. Pre-built Binaries
|
||||
|
||||
Download binaries for Windows, macOS, and Linux from the [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||
|
||||
**Requirements:** You'll need to provide SurrealDB and Chromium separately.
|
||||
**Requirements:** You'll need to provide SurrealDB separately.
|
||||
|
||||
### 4. Build from Source
|
||||
|
||||
@@ -196,7 +196,7 @@ cd minne
|
||||
cargo run --release --bin main
|
||||
```
|
||||
|
||||
**Requirements:** SurrealDB and Chromium must be installed and accessible in your PATH.
|
||||
**Requirements:** SurrealDB must be installed and accessible in your PATH.
|
||||
|
||||
## Application Architecture
|
||||
|
||||
|
||||
+1
-1
@@ -24,7 +24,7 @@ dom_smoothie = { workspace = true }
|
||||
axum_session = { workspace = true }
|
||||
axum_session_auth = { workspace = true }
|
||||
axum_session_surreal = { workspace = true}
|
||||
axum_typed_multipart = { workspace = true}
|
||||
axum_typed_multipart = { workspace = true}
|
||||
include_dir = { workspace = true }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-autoreload = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Track scheduled runtime index rebuild state on the system_settings singleton.
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -201,6 +201,10 @@\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;\n+DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;\n+DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;\n+DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}
|
||||
@@ -14,3 +14,7 @@ DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;
|
||||
|
||||
+157
-119
@@ -1,9 +1,11 @@
|
||||
use super::types::StoredObject;
|
||||
use super::types::{EmbeddingRecord, HasEmbedding, StoredObject};
|
||||
use crate::error::AppError;
|
||||
use axum_session::{SessionConfig, SessionError, SessionStore};
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::Stream;
|
||||
use include_dir::{include_dir, Dir};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
use surrealdb::{
|
||||
engine::any::{connect, Any},
|
||||
@@ -26,20 +28,6 @@ pub trait ProvidesDb {
|
||||
}
|
||||
|
||||
impl SurrealDbClient {
|
||||
/// Initialize a new database client.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
|
||||
/// * `username` — Root username for authentication.
|
||||
/// * `password` — Root password for authentication.
|
||||
/// * `namespace` — SurrealDB namespace to use.
|
||||
/// * `database` — SurrealDB database to use.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
|
||||
/// In-memory (`mem://`) connections skip authentication.
|
||||
pub async fn new(
|
||||
address: &str,
|
||||
username: &str,
|
||||
@@ -49,30 +37,15 @@ impl SurrealDbClient {
|
||||
) -> Result<Self, Error> {
|
||||
let db = connect(address).await?;
|
||||
|
||||
// Skip sign-in for in-memory engine (no auth support)
|
||||
if !address.starts_with("mem://") {
|
||||
db.signin(Root { username, password }).await?;
|
||||
}
|
||||
|
||||
// Set namespace
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
/// Initialize a new database client using namespace-level authentication.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` — Database connection string.
|
||||
/// * `namespace` — SurrealDB namespace to use (also used for auth).
|
||||
/// * `username` — Namespace username for authentication.
|
||||
/// * `password` — Namespace password for authentication.
|
||||
/// * `database` — SurrealDB database to use.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
|
||||
pub async fn new_with_namespace_user(
|
||||
address: &str,
|
||||
namespace: &str,
|
||||
@@ -91,11 +64,6 @@ impl SurrealDbClient {
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
/// Create an Axum session store backed by SurrealDB.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `SessionError` if the session store configuration or table creation fails.
|
||||
pub async fn create_session_store(
|
||||
&self,
|
||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||
@@ -109,15 +77,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Applies all pending database migrations found in the embedded MIGRATIONS_DIR.
|
||||
///
|
||||
/// This function should be called during application startup, after connecting to
|
||||
/// the database and selecting the appropriate namespace and database, but before
|
||||
/// the application starts performing operations that rely on the schema.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
|
||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||
debug!("Applying migrations");
|
||||
MigrationRunner::new(&self.client)
|
||||
@@ -129,15 +88,6 @@ impl SurrealDbClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store an object in SurrealDB.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` — The item to store. Must implement `StoredObject`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database create operation fails.
|
||||
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
@@ -148,13 +98,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
|
||||
///
|
||||
/// Useful for idempotent ingestion flows.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database upsert operation fails.
|
||||
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
@@ -166,11 +109,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retrieve all objects from a table.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database select operation fails.
|
||||
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -178,16 +116,6 @@ impl SurrealDbClient {
|
||||
self.client.select(T::table_name()).await
|
||||
}
|
||||
|
||||
/// Retrieve a single object by its ID.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` — The ID of the item to retrieve.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database select operation fails.
|
||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -195,16 +123,6 @@ impl SurrealDbClient {
|
||||
self.client.select((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Delete a single object by its ID.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` — The ID of the item to delete.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database delete operation fails.
|
||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -212,11 +130,6 @@ impl SurrealDbClient {
|
||||
self.client.delete((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Listen to a table for real-time updates via a live query stream.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database live query subscription fails.
|
||||
pub async fn listen<T>(
|
||||
&self,
|
||||
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
||||
@@ -225,6 +138,156 @@ impl SurrealDbClient {
|
||||
{
|
||||
self.client.select(T::table_name()).live().await
|
||||
}
|
||||
|
||||
/// Atomically store an entity and its embedding vector in a single
|
||||
/// SurrealDB transaction.
|
||||
///
|
||||
/// Creates (or overwrites) the entity row and upserts the linked
|
||||
/// embedding record. The embedding dimension is validated against
|
||||
/// `embedding_dimensions` before the query is issued.
|
||||
pub async fn store_with_embedding<E>(
|
||||
&self,
|
||||
entity: E,
|
||||
embedding: Vec<f32>,
|
||||
embedding_dimensions: usize,
|
||||
) -> Result<(), AppError>
|
||||
where
|
||||
E: HasEmbedding + Serialize + Send + Sync + 'static,
|
||||
<E as HasEmbedding>::Embedding: Serialize + Send + Sync,
|
||||
{
|
||||
E::Embedding::validate_dimension(&embedding, embedding_dimensions)?;
|
||||
|
||||
let entity_id = entity.id().to_string();
|
||||
let emb = <E as HasEmbedding>::Embedding::new(
|
||||
&entity_id,
|
||||
entity.source_id().to_string(),
|
||||
embedding,
|
||||
entity.user_id().to_string(),
|
||||
E::table_name(),
|
||||
);
|
||||
|
||||
let sql = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{et}', $id) CONTENT $entity;
|
||||
UPSERT type::thing('{emt}', $id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
et = E::table_name(),
|
||||
emt = <E as HasEmbedding>::Embedding::table_name(),
|
||||
);
|
||||
|
||||
self.client
|
||||
.query(sql)
|
||||
.bind(("id", entity_id))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb", emb))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all entity and embedding rows matching a given `source_id`.
|
||||
///
|
||||
/// Runs inside a SurrealDB transaction so that entity and embedding
|
||||
/// deletes are atomic.
|
||||
pub async fn delete_by_source_id<E>(&self, source_id: &str) -> Result<(), AppError>
|
||||
where
|
||||
E: HasEmbedding,
|
||||
E::Embedding: Send + Sync,
|
||||
{
|
||||
self.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
E::Embedding::table_name()
|
||||
))
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
E::table_name()
|
||||
))
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector similarity search over entities using HNSW index.
|
||||
///
|
||||
/// Performs a cosine-similarity search against the embedding table,
|
||||
/// fetches the corresponding entity rows server-side via `FETCH`,
|
||||
/// and returns `(entity, score)` pairs ordered by descending
|
||||
/// similarity. Orphaned embeddings (entity deleted but its
|
||||
/// embedding row remains) are logged as a warning and dropped.
|
||||
///
|
||||
/// This is a single round-trip — SurrealDB resolves the link field
|
||||
/// (`entity_id` or `chunk_id`) inside the query engine.
|
||||
pub async fn vector_search<E, Emb>(
|
||||
&self,
|
||||
take: usize,
|
||||
query_embedding: &[f32],
|
||||
user_id: &str,
|
||||
) -> Result<Vec<(E, f32)>, AppError>
|
||||
where
|
||||
E: StoredObject + DeserializeOwned + Clone + Send + Sync,
|
||||
Emb: EmbeddingRecord + Send + Sync,
|
||||
{
|
||||
// Generic row that works with both `entity_id` and `chunk_id` link
|
||||
// fields via `#[serde(alias)]`. SurrealDB's `FETCH` resolves the link
|
||||
// server-side so we get the full entity in a single round-trip.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct FetchRow<Ent> {
|
||||
score: f32,
|
||||
#[serde(alias = "entity_id", alias = "chunk_id")]
|
||||
entity: Option<Ent>,
|
||||
}
|
||||
|
||||
let link_field = Emb::link_field();
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
{link_field},
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH {link_field}
|
||||
"#,
|
||||
link_field = link_field,
|
||||
emb_table = Emb::table_name(),
|
||||
take = take,
|
||||
);
|
||||
|
||||
let mut response = self
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?;
|
||||
|
||||
response = response.check()?;
|
||||
|
||||
let rows: Vec<FetchRow<E>> = response.take(0)?;
|
||||
|
||||
let mut results = Vec::with_capacity(rows.len());
|
||||
for r in rows {
|
||||
if let Some(entity) = r.entity {
|
||||
results.push((entity, r.score));
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Vector search hit orphaned {} row with missing {link_field}",
|
||||
Emb::table_name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SurrealDbClient {
|
||||
@@ -237,12 +300,9 @@ impl Deref for SurrealDbClient {
|
||||
|
||||
#[cfg(any(test, feature = "test-utils"))]
|
||||
impl SurrealDbClient {
|
||||
/// Create an in-memory SurrealDB client for testing.
|
||||
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
|
||||
let db = connect("mem://").await?;
|
||||
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
}
|
||||
@@ -253,8 +313,7 @@ mod tests {
|
||||
use crate::stored_object;
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
stored_object!(Dummy, "dummy", {
|
||||
name: String
|
||||
@@ -262,15 +321,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialization_and_crud() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
@@ -314,15 +365,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
@@ -371,12 +414,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_applying_migrations() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to build indexes".to_string())?;
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::future::try_join_all;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Map, Value};
|
||||
use tracing::{debug, info, warn};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
|
||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||
const INDEX_BUILD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
||||
@@ -204,6 +208,9 @@ pub async fn ensure_runtime(
|
||||
|
||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||
///
|
||||
/// Uses `DEFINE INDEX OVERWRITE` and is reserved for dimension migrations, re-embed
|
||||
/// flows, and tests. Routine optimization should use [`rebuild_runtime`].
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if any index rebuild operation fails.
|
||||
@@ -211,6 +218,115 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_inner(db).await.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Rebuilds existing runtime FTS and HNSW indexes in place via SurrealQL `REBUILD INDEX`.
|
||||
///
|
||||
/// SurrealDB maintains ready indexes incrementally on writes; this is for periodic
|
||||
/// optimization (for example a nightly maintainer job), not ingest correctness.
|
||||
/// On SurrealDB 2.6 this runs synchronously (`CONCURRENTLY` is not supported on `REBUILD`).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if any rebuild operation fails.
|
||||
pub async fn rebuild_runtime(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_runtime_inner(db).await.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Returns whether a scheduled index rebuild is due based on the persisted last-run time.
|
||||
#[must_use]
|
||||
pub fn scheduled_index_rebuild_due(
|
||||
last_run: Option<DateTime<Utc>>,
|
||||
interval_secs: u64,
|
||||
now: DateTime<Utc>,
|
||||
) -> bool {
|
||||
if interval_secs == 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let Some(last_run) = last_run else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let elapsed = now.signed_duration_since(last_run);
|
||||
elapsed.num_seconds() >= i64::try_from(interval_secs).unwrap_or(i64::MAX)
|
||||
}
|
||||
|
||||
/// Runs a scheduled native `REBUILD INDEX` pass when due, using a DB lock so only one
|
||||
/// maintainer rebuilds at a time. Seeds a checkpoint on first run so the initial rebuild
|
||||
/// waits one full interval after worker startup.
|
||||
pub async fn maybe_run_scheduled_index_rebuild(
|
||||
db: &SurrealDbClient,
|
||||
worker_id: &str,
|
||||
interval_secs: u64,
|
||||
) {
|
||||
if interval_secs == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
let settings = match SystemSettings::get_current(db).await {
|
||||
Ok(settings) => settings,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to load system settings for index rebuild schedule");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let last_run = settings.last_index_rebuild_at;
|
||||
|
||||
if last_run.is_none() {
|
||||
match SystemSettings::seed_index_rebuild_checkpoint(db).await {
|
||||
Ok(true) => debug!("seeded index rebuild checkpoint; first rebuild deferred"),
|
||||
Ok(false) => {}
|
||||
Err(err) => warn!(error = %err, "failed to seed index rebuild checkpoint"),
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !scheduled_index_rebuild_due(last_run, interval_secs, now) {
|
||||
return;
|
||||
}
|
||||
|
||||
let lock_owner = format!("{worker_id}-index-rebuild");
|
||||
let acquired = match SystemSettings::try_acquire_index_rebuild_lease(db, &lock_owner).await {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to acquire index rebuild lease");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !acquired {
|
||||
debug!("another maintainer is rebuilding indexes");
|
||||
return;
|
||||
}
|
||||
|
||||
let started = Instant::now();
|
||||
info!(interval_secs, "starting scheduled runtime index rebuild");
|
||||
let rebuild_result = rebuild_runtime(db).await;
|
||||
|
||||
match rebuild_result {
|
||||
Ok(()) => {
|
||||
if let Err(err) = SystemSettings::record_index_rebuild_completed(db, &lock_owner).await
|
||||
{
|
||||
warn!(error = %err, "failed to persist index rebuild checkpoint");
|
||||
SystemSettings::release_index_rebuild_lease(db, &lock_owner).await;
|
||||
}
|
||||
info!(
|
||||
elapsed_ms = started.elapsed().as_millis(),
|
||||
"scheduled runtime index rebuild completed"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
SystemSettings::release_index_rebuild_lease(db, &lock_owner).await;
|
||||
error!(
|
||||
error = %err,
|
||||
elapsed_ms = started.elapsed().as_millis(),
|
||||
"scheduled runtime index rebuild failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any.
|
||||
///
|
||||
/// Stored embeddings always share this index's dimension because re-embedding rewrites the
|
||||
@@ -382,6 +498,45 @@ async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||
}
|
||||
|
||||
async fn rebuild_runtime_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
debug!("Rebuilding runtime indexes with REBUILD INDEX");
|
||||
|
||||
for spec in fts_index_specs() {
|
||||
rebuild_existing_index_in_place(db, spec.index_name, spec.table).await?;
|
||||
}
|
||||
|
||||
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||
rebuild_existing_index_in_place(db, spec.index_name, spec.table).await
|
||||
});
|
||||
|
||||
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||
}
|
||||
|
||||
async fn rebuild_existing_index_in_place(
|
||||
db: &SurrealDbClient,
|
||||
index_name: &str,
|
||||
table: &str,
|
||||
) -> Result<()> {
|
||||
if !index_exists(db, table, index_name).await? {
|
||||
debug!(
|
||||
index = index_name,
|
||||
table, "Skipping in-place rebuild because index is missing"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let query = format!("REBUILD INDEX IF EXISTS {index_name} ON {table};");
|
||||
let res = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.with_context(|| format!("rebuilding index {index_name} on table {table}"))?;
|
||||
res.check()
|
||||
.with_context(|| format!("rebuild index {index_name} on table {table} failed"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn existing_hnsw_dimension(
|
||||
db: &SurrealDbClient,
|
||||
spec: &HnswIndexSpec,
|
||||
@@ -906,6 +1061,43 @@ mod tests {
|
||||
assert_eq!(extract_dimension(definition), Some(1536));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scheduled_index_rebuild_due_respects_interval_and_disabled() {
|
||||
let now = Utc::now();
|
||||
let last = now - chrono::Duration::hours(25);
|
||||
|
||||
assert!(!scheduled_index_rebuild_due(None, 86_400, now));
|
||||
assert!(!scheduled_index_rebuild_due(Some(last), 0, now));
|
||||
assert!(!scheduled_index_rebuild_due(
|
||||
Some(now - chrono::Duration::hours(1)),
|
||||
86_400,
|
||||
now
|
||||
));
|
||||
assert!(scheduled_index_rebuild_due(Some(last), 86_400, now));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rebuild_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_in_place_rebuild";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.context("in-memory db")?;
|
||||
|
||||
db.apply_migrations().await.context("migrations")?;
|
||||
ensure_runtime(&db, 8)
|
||||
.await
|
||||
.context("ensure runtime indexes")?;
|
||||
|
||||
rebuild_runtime(&db)
|
||||
.await
|
||||
.context("first in-place rebuild")?;
|
||||
rebuild_runtime(&db)
|
||||
.await
|
||||
.context("second in-place rebuild")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_ns";
|
||||
|
||||
@@ -108,6 +108,7 @@ mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use super::*;
|
||||
use crate::stored_object;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use anyhow::{self};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -120,10 +121,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_analytics_initialization() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Test initialization of analytics
|
||||
let analytics = Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
@@ -145,10 +143,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_get_current_analytics() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
@@ -165,10 +160,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
@@ -190,10 +182,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
@@ -214,11 +203,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_users_amount() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||
// Test with no users
|
||||
let count = Analytics::get_users_amount(&db).await?;
|
||||
assert_eq!(count, 0);
|
||||
@@ -246,10 +231,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors_without_prior_init() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let analytics = Analytics::increment_visitors(&db).await?;
|
||||
assert_eq!(analytics.visitors, 1);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
@@ -259,10 +241,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads_without_prior_init() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||
assert_eq!(analytics.page_loads, 1);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
@@ -272,10 +251,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_visitor_and_page_load_increments_are_independent() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let after_visitors = Analytics::increment_visitors(&db).await?;
|
||||
assert_eq!(after_visitors.visitors, 1);
|
||||
assert_eq!(after_visitors.page_loads, 0);
|
||||
@@ -293,10 +269,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_record_page_view() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let first_view = Analytics::record_page_view(&db, true).await?;
|
||||
assert_eq!(first_view.visitors, 1);
|
||||
assert_eq!(first_view.page_loads, 1);
|
||||
@@ -310,11 +283,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||
// Don't initialize analytics and try to get it
|
||||
let result = Analytics::get_current(&db).await;
|
||||
|
||||
|
||||
@@ -157,6 +157,7 @@ impl Conversation {
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use crate::storage::types::message::MessageRole;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
@@ -181,11 +182,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let title = "Test Conversation";
|
||||
@@ -214,11 +211,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let result =
|
||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||
@@ -234,11 +227,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id_1 = "user_1";
|
||||
let conversation =
|
||||
@@ -264,11 +253,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_success() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "user_1";
|
||||
let original_title = "Original Title";
|
||||
@@ -297,11 +282,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
||||
|
||||
@@ -316,11 +297,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user_id = "intruder";
|
||||
@@ -345,11 +322,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_user_sidebar_conversations_filters_and_orders_by_updated_at_desc() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await.expect("setup_test_db");
|
||||
|
||||
let user_id = "sidebar_user";
|
||||
let other_user_id = "other_user";
|
||||
@@ -398,11 +371,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sidebar_projection_reflects_patch_title_and_updated_at_reorder() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await.expect("setup_test_db");
|
||||
|
||||
let user_id = "sidebar_patch_user";
|
||||
let base = Utc::now();
|
||||
@@ -440,11 +409,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id_1 = "user_1";
|
||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||
@@ -527,11 +492,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sidebar_conversation_deserializes_id_from_db_record() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await.expect("setup_test_db");
|
||||
|
||||
let owner = "sidebar_owner";
|
||||
let conversation = Conversation::new(owner.to_string(), "Sidebar title".to_string());
|
||||
@@ -551,9 +512,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_query_filters_by_owner_user_id_in_sql() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner = "owner_user";
|
||||
let intruder = "intruder_user";
|
||||
@@ -590,9 +549,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_orders_messages_by_updated_at() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "order_user";
|
||||
let conversation = Conversation::new(user_id.to_string(), "Ordered".to_string());
|
||||
@@ -637,9 +594,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_not_found_when_conversation_deleted() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner = "owner";
|
||||
let conversation = Conversation::new(owner.to_string(), "To delete".to_string());
|
||||
|
||||
@@ -327,6 +327,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::store::testing::TestStorageManager;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::FieldMetadata;
|
||||
use std::{io::Write, path::Path};
|
||||
@@ -378,15 +379,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"This is a test file for StorageManager operations";
|
||||
let file_name = "storage_manager_test.txt";
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
@@ -435,15 +428,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"filename sanitization";
|
||||
let original_name = "Complex name (1).txt";
|
||||
let expected_sanitized = "Complex_name__1_.txt";
|
||||
@@ -470,15 +455,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"This is a test file for StorageManager duplicate detection";
|
||||
let file_name = "storage_manager_duplicate.txt";
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
@@ -538,15 +515,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_creation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"This is a test file content";
|
||||
let file_name = "test_file.txt";
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
@@ -585,15 +554,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_duplicate_detection() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// First, store a file with known content
|
||||
let content = b"This is a test file for duplicate detection";
|
||||
let file_name = "original.txt";
|
||||
@@ -692,12 +653,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", "user123", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
@@ -710,12 +666,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let db = setup_test_db().await.expect("setup test db");
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
@@ -740,15 +691,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_duplicate_detection_is_per_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"shared content across users";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
@@ -783,10 +726,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found_for_other_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let now = Utc::now();
|
||||
let sha = "abc123sha";
|
||||
let owner = "owner_user";
|
||||
@@ -816,9 +756,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_with_storage_missing_file_name() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
|
||||
let field_data = create_test_file_without_name(b"data")?;
|
||||
@@ -832,9 +770,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_with_storage_empty_file() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
|
||||
let file_info = FileInfo::new_with_storage(
|
||||
@@ -856,10 +792,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_duplicate_upload_persists_single_row_per_user_sha() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
db.apply_migrations().await?;
|
||||
let db = setup_test_db().await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
let storage = test_storage.storage();
|
||||
let user_id = "dedup_user";
|
||||
@@ -901,12 +834,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manual_file_info_creation() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let db = setup_test_db().await.expect("setup test db");
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
@@ -939,15 +867,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Create and persist a test file via FileInfo::new_with_storage
|
||||
let user_id = "user123";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
@@ -985,12 +905,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Try to delete a file that doesn't exist
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
@@ -1006,12 +921,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
let file_id = Uuid::new_v4().to_string();
|
||||
@@ -1045,12 +955,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_id_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Try to retrieve a non-existent ID
|
||||
let non_existent_id = "non-existent-file-id";
|
||||
let result = FileInfo::get_by_id(non_existent_id, &db).await;
|
||||
|
||||
@@ -630,6 +630,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
fn create_payload(user_id: &str) -> IngestionPayload {
|
||||
IngestionPayload::Text {
|
||||
@@ -641,11 +642,7 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn memory_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.with_context(|| "in-memory surrealdb".to_string())
|
||||
setup_test_db().await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -4,10 +4,13 @@ use std::fmt::Write;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::db::SurrealDbClient,
|
||||
storage::indexes::hnsw_index_overwrite_sql,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
storage::types::system_settings::SystemSettings,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::hnsw_index_overwrite_sql,
|
||||
types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
types::system_settings::SystemSettings,
|
||||
types::{EmbeddingRecord, HasEmbedding},
|
||||
},
|
||||
stored_object,
|
||||
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
|
||||
};
|
||||
@@ -70,6 +73,18 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl HasEmbedding for KnowledgeEntity {
|
||||
type Embedding = KnowledgeEntityEmbedding;
|
||||
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
impl KnowledgeEntity {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
@@ -227,67 +242,22 @@ impl KnowledgeEntity {
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
// Delete embeddings first, while we can still look them up via the entity's source_id
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
db.delete_by_source_id::<Self>(source_id).await
|
||||
}
|
||||
|
||||
/// Atomically store a knowledge entity and its embedding.
|
||||
/// Writes the entity to `knowledge_entity` and the embedding to `knowledge_entity_embedding`.
|
||||
/// Atomically store one knowledge entity and its embedding (single-record path).
|
||||
///
|
||||
/// Bulk ingestion uses `ingestion_pipeline::persist_artifacts` instead.
|
||||
pub async fn store_with_embedding(
|
||||
entity: KnowledgeEntity,
|
||||
embedding: Vec<f32>,
|
||||
embedding_dimensions: usize,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let settings = SystemSettings::get_current(db).await?;
|
||||
KnowledgeEntityEmbedding::validate_dimension(
|
||||
&embedding,
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
|
||||
let entity_id = entity.id.clone();
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
&entity_id,
|
||||
entity.source_id.clone(),
|
||||
embedding,
|
||||
entity.user_id.clone(),
|
||||
);
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
|
||||
UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
entity_table = Self::table_name(),
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb", emb))
|
||||
db.store_with_embedding(entity, embedding, embedding_dimensions)
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
||||
@@ -297,48 +267,14 @@ impl KnowledgeEntity {
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: Option<KnowledgeEntity>,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
entity_id,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH entity_id;
|
||||
"#,
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
db.vector_search::<Self, KnowledgeEntityEmbedding>(take, query_embedding, user_id)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
response = response.check().map_err(AppError::from)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
r.entity_id.map(|entity| KnowledgeEntitySearchResult {
|
||||
entity,
|
||||
score: r.score,
|
||||
})
|
||||
.map(|results| {
|
||||
results
|
||||
.into_iter()
|
||||
.map(|(entity, score)| KnowledgeEntitySearchResult { entity, score })
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
@@ -364,7 +300,13 @@ impl KnowledgeEntity {
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
|
||||
let emb = KnowledgeEntityEmbedding::new(id, entity.source_id, embedding, entity.user_id);
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
id,
|
||||
entity.source_id,
|
||||
embedding,
|
||||
entity.user_id,
|
||||
Self::table_name(),
|
||||
);
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
@@ -554,9 +496,8 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::storage::indexes::rebuild;
|
||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||
use crate::test_utils::configure_embedding_dimension;
|
||||
use crate::test_utils::{ensure_fts_index, prepare_knowledge_entity_test_db, setup_test_db};
|
||||
use anyhow::{self, Context};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn embedding_input_text_uses_canonical_type_label() {
|
||||
@@ -568,27 +509,6 @@ mod tests {
|
||||
assert_eq!(text, "name: Alpha, description: Beta, type: TextSnippet");
|
||||
}
|
||||
|
||||
async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||
let snowball_sql = r#"
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
if let Err(err) = db.client.query(snowball_sql).await {
|
||||
let fallback_sql = r#"
|
||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
db.client
|
||||
.query(fallback_sql)
|
||||
.await
|
||||
.with_context(|| format!("define entity fts index fallback: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -675,19 +595,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 5).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
let db = prepare_knowledge_entity_test_db(5).await?;
|
||||
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
@@ -722,13 +630,13 @@ mod tests {
|
||||
);
|
||||
|
||||
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), 5, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity 1".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), 5, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity 2".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), 5, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store different entity".to_string())?;
|
||||
|
||||
@@ -783,21 +691,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
let db = prepare_knowledge_entity_test_db(3)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
configure_embedding_dimension(&db, 3)
|
||||
.await
|
||||
.expect("configure dim");
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.expect("prepare test db");
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
@@ -819,10 +715,10 @@ mod tests {
|
||||
user_id,
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.expect("store entity1");
|
||||
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], 3, &db)
|
||||
.await
|
||||
.expect("store entity2");
|
||||
|
||||
@@ -849,18 +745,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
let db = prepare_knowledge_entity_test_db(3)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.expect("prepare test db");
|
||||
|
||||
let results = KnowledgeEntity::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
@@ -870,19 +757,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
@@ -895,7 +770,7 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
@@ -918,7 +793,7 @@ mod tests {
|
||||
assert_eq!(stored_embeddings.len(), 1);
|
||||
|
||||
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "fetch embedding".to_string())?;
|
||||
assert!(fetched_emb.is_some());
|
||||
@@ -938,19 +813,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let e1 = KnowledgeEntity::new(
|
||||
@@ -970,10 +833,10 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store e1".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store e2".to_string())?;
|
||||
|
||||
@@ -1001,11 +864,11 @@ mod tests {
|
||||
|
||||
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
||||
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
|
||||
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e1)
|
||||
.await
|
||||
.with_context(|| "get embedding e1".to_string())?
|
||||
.is_some());
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
|
||||
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e2)
|
||||
.await
|
||||
.with_context(|| "get embedding e2".to_string())?
|
||||
.is_some());
|
||||
@@ -1037,19 +900,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_orphan";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
@@ -1062,7 +913,7 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
@@ -1089,15 +940,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_returns_empty_when_no_entities() -> anyhow::Result<()> {
|
||||
let namespace = "fts_entity_ns_empty";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_entity_fts_indexes(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(
|
||||
&db,
|
||||
"knowledge_entity",
|
||||
&[("name", "name"), ("description", "description")],
|
||||
)
|
||||
.await?;
|
||||
rebuild(&db)
|
||||
.await
|
||||
.with_context(|| "rebuild indexes".to_string())?;
|
||||
@@ -1112,15 +961,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "fts_entity_ns_single";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_entity_fts_indexes(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(
|
||||
&db,
|
||||
"knowledge_entity",
|
||||
&[("name", "name"), ("description", "description")],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let user_id = "fts_user";
|
||||
let entity = KnowledgeEntity::new(
|
||||
@@ -1151,15 +998,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
||||
let namespace = "fts_entity_ns_order";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_entity_fts_indexes(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(
|
||||
&db,
|
||||
"knowledge_entity",
|
||||
&[("name", "name"), ("description", "description")],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let user_id = "fts_user_order";
|
||||
let high_score_entity = KnowledgeEntity::new(
|
||||
|
||||
@@ -4,7 +4,7 @@ use surrealdb::RecordId;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||
storage::{db::SurrealDbClient, types::EmbeddingRecord},
|
||||
stored_object,
|
||||
};
|
||||
|
||||
@@ -17,72 +17,48 @@ stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl KnowledgeEntityEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = hnsw_index_redefine_transaction_sql(
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||
res.check().map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
impl EmbeddingRecord for KnowledgeEntityEmbedding {
|
||||
fn link_field() -> &'static str {
|
||||
"entity_id"
|
||||
}
|
||||
|
||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||
if embedding.len() != expected {
|
||||
return Err(AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
fn index_name() -> &'static str {
|
||||
"idx_embedding_knowledge_entity_embedding"
|
||||
}
|
||||
|
||||
/// Create a new knowledge entity embedding.
|
||||
///
|
||||
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
|
||||
#[must_use]
|
||||
pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
|
||||
fn embedding(&self) -> &[f32] {
|
||||
&self.embedding
|
||||
}
|
||||
|
||||
fn new(
|
||||
entity_id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: entity_id.to_owned(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
||||
entity_id: RecordId::from_table_key(entity_table, entity_id),
|
||||
embedding,
|
||||
source_id,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get embedding by entity ID
|
||||
pub async fn get_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
impl KnowledgeEntityEmbedding {
|
||||
/// Get embeddings for multiple entities in batch
|
||||
pub async fn get_by_entity_ids(
|
||||
entity_ids: &[RecordId],
|
||||
@@ -109,44 +85,6 @@ impl KnowledgeEntityEmbedding {
|
||||
.map(|e| (e.entity_id.key().to_string(), e.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Delete embedding by entity ID
|
||||
pub async fn delete_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE entity_id = $entity_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings with the given denormalized `source_id`.
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -184,6 +122,7 @@ mod tests {
|
||||
"source-1".to_owned(),
|
||||
vec![0.1, 0.2],
|
||||
"user-1".to_owned(),
|
||||
KnowledgeEntity::table_name(),
|
||||
);
|
||||
assert_eq!(emb.id, "entity-abc");
|
||||
}
|
||||
@@ -205,13 +144,13 @@ mod tests {
|
||||
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
@@ -234,22 +173,22 @@ mod tests {
|
||||
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let existing = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some());
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||
KnowledgeEntityEmbedding::delete_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
||||
|
||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let after = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none());
|
||||
@@ -266,7 +205,7 @@ mod tests {
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
@@ -277,7 +216,7 @@ mod tests {
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||
let stored_embedding =
|
||||
@@ -295,7 +234,7 @@ mod tests {
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
|
||||
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], 3, &db).await;
|
||||
|
||||
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||
|
||||
@@ -313,15 +252,20 @@ mod tests {
|
||||
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
|
||||
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(
|
||||
entity_other.clone(),
|
||||
vec![3.0_f32, 3.1, 3.2],
|
||||
3,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||
@@ -332,18 +276,18 @@ mod tests {
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity1_rid)
|
||||
.await
|
||||
.with_context(|| "get entity1 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity2_rid)
|
||||
.await
|
||||
.with_context(|| "get entity2 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||
assert!(KnowledgeEntityEmbedding::get_by_record_id(&db, &other_rid)
|
||||
.await
|
||||
.with_context(|| "get other embedding after delete".to_string())?
|
||||
.is_some());
|
||||
@@ -403,7 +347,7 @@ mod tests {
|
||||
let source_id = "source-fetch";
|
||||
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
@@ -441,7 +385,7 @@ mod tests {
|
||||
let source_id = "source-upsert";
|
||||
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "initial store".to_string())?;
|
||||
|
||||
@@ -450,6 +394,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![0.0, 1.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
KnowledgeEntity::table_name(),
|
||||
);
|
||||
db.upsert_item(replacement)
|
||||
.await
|
||||
|
||||
@@ -78,7 +78,7 @@ pub fn format_history(history: &[Message]) -> String {
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use anyhow::{self, Context};
|
||||
|
||||
#[tokio::test]
|
||||
@@ -106,11 +106,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_persistence() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &uuid::Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let conversation_id = "test_conversation";
|
||||
let message = Message::new(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#![allow(clippy::unsafe_derive_deserialize)]
|
||||
#![allow(async_fn_in_trait)]
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub mod analytics;
|
||||
pub mod conversation;
|
||||
@@ -22,6 +23,135 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
||||
fn id(&self) -> &str;
|
||||
}
|
||||
|
||||
/// An entity that has an associated embedding record for vector search.
|
||||
pub trait HasEmbedding: StoredObject {
|
||||
/// The embedding record type paired with this entity.
|
||||
type Embedding: EmbeddingRecord;
|
||||
|
||||
fn source_id(&self) -> &str;
|
||||
fn user_id(&self) -> &str;
|
||||
}
|
||||
|
||||
/// An embedding record linked to a `HasEmbedding` entity.
|
||||
pub trait EmbeddingRecord: StoredObject {
|
||||
/// The field name in the embedding table that links back to the entity
|
||||
/// (e.g. `"entity_id"` or `"chunk_id"`). Used in FETCH and WHERE clauses.
|
||||
fn link_field() -> &'static str;
|
||||
|
||||
/// The HNSW index name (e.g. `"idx_embedding_knowledge_entity_embedding"`).
|
||||
fn index_name() -> &'static str;
|
||||
|
||||
fn source_id(&self) -> &str;
|
||||
fn user_id(&self) -> &str;
|
||||
fn embedding(&self) -> &[f32];
|
||||
|
||||
/// Construct a new embedding record.
|
||||
///
|
||||
/// * `id` – shared record id (same as the entity id).
|
||||
/// * `source_id` – denormalised source id for bulk deletes.
|
||||
/// * `embedding` – the embedding vector.
|
||||
/// * `user_id` – denormalised user id for query scoping.
|
||||
/// * `entity_table` – the entity's table name (used to build the link `RecordId`).
|
||||
fn new(
|
||||
id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self;
|
||||
|
||||
/// Validate that an embedding vector matches the expected dimension.
|
||||
fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
if embedding.len() != expected {
|
||||
return Err(crate::error::AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recreate the HNSW vector index with a new dimension.
|
||||
///
|
||||
/// This drops and recreates the index inside a transaction.
|
||||
async fn redefine_hnsw_index(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = crate::storage::indexes::hnsw_index_redefine_transaction_sql(
|
||||
Self::index_name(),
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
db.client.query(query).await?.check()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fetch a single embedding record by its link `RecordId`.
|
||||
async fn get_by_record_id(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
rid: &surrealdb::RecordId,
|
||||
) -> Result<Option<Self>, crate::error::AppError>
|
||||
where
|
||||
Self: Sized + serde::de::DeserializeOwned,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE {} = $rid LIMIT 1",
|
||||
Self::table_name(),
|
||||
Self::link_field(),
|
||||
);
|
||||
let mut result = db.client.query(query).bind(("rid", rid.clone())).await?;
|
||||
Ok(result.take(0)?)
|
||||
}
|
||||
|
||||
/// Delete an embedding record by its link `RecordId`.
|
||||
async fn delete_by_record_id(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
rid: &surrealdb::RecordId,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE {} = $rid",
|
||||
Self::table_name(),
|
||||
Self::link_field(),
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("rid", rid.clone()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embedding records with a given `source_id`.
|
||||
async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name(),
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! stored_object {
|
||||
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||
|
||||
@@ -221,19 +221,11 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_scratchpad() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a new scratchpad
|
||||
let user_id = "test_user";
|
||||
@@ -271,15 +263,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
|
||||
@@ -333,15 +317,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_archive_and_restore() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
@@ -368,15 +344,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
@@ -398,15 +366,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
@@ -428,15 +388,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_scratchpad() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
@@ -461,15 +413,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
@@ -498,13 +442,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to create test database".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user_123";
|
||||
let scratchpad =
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::utils::config::EmbeddingBackend;
|
||||
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -22,6 +25,15 @@ pub struct SystemSettings {
|
||||
pub image_processing_model: String,
|
||||
pub image_processing_prompt: String,
|
||||
pub voice_processing_model: String,
|
||||
/// When the maintainer last completed a scheduled `REBUILD INDEX` pass.
|
||||
#[serde(default)]
|
||||
pub last_index_rebuild_at: Option<DateTime<Utc>>,
|
||||
/// Worker id holding the index-rebuild lease, if any.
|
||||
#[serde(default)]
|
||||
pub index_rebuild_lease_owner: Option<String>,
|
||||
/// Lease expiry for in-flight scheduled index rebuilds.
|
||||
#[serde(default)]
|
||||
pub index_rebuild_lease_expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Partial update for singleton system settings without cloning unchanged fields.
|
||||
@@ -100,6 +112,8 @@ impl SystemSettingsPatch {
|
||||
}
|
||||
}
|
||||
|
||||
const INDEX_REBUILD_LEASE_TTL: &str = "6h";
|
||||
|
||||
impl SystemSettings {
|
||||
pub const RECORD_ID: &'static str = "current";
|
||||
|
||||
@@ -227,6 +241,89 @@ impl SystemSettings {
|
||||
|
||||
Ok((settings, needs_update))
|
||||
}
|
||||
|
||||
/// Seeds the first rebuild checkpoint so the initial scheduled rebuild waits one interval.
|
||||
pub async fn seed_index_rebuild_checkpoint(db: &SurrealDbClient) -> Result<bool, AppError> {
|
||||
let mut response = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('system_settings', $id) SET last_index_rebuild_at = time::now()
|
||||
WHERE last_index_rebuild_at IS NONE
|
||||
RETURN AFTER;",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
let updated: Option<Self> = response.take(0).map_err(AppError::from)?;
|
||||
Ok(updated.is_some())
|
||||
}
|
||||
|
||||
/// Claims the singleton index-rebuild lease when it is free or expired.
|
||||
pub async fn try_acquire_index_rebuild_lease(
|
||||
db: &SurrealDbClient,
|
||||
owner: &str,
|
||||
) -> Result<bool, AppError> {
|
||||
let mut response = db
|
||||
.client
|
||||
.query(format!(
|
||||
"UPDATE type::thing('system_settings', $id) SET
|
||||
index_rebuild_lease_owner = $owner,
|
||||
index_rebuild_lease_expires_at = time::now() + {INDEX_REBUILD_LEASE_TTL}
|
||||
WHERE index_rebuild_lease_expires_at IS NONE
|
||||
OR index_rebuild_lease_expires_at < time::now()
|
||||
RETURN AFTER;"
|
||||
))
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.bind(("owner", owner.to_string()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
let updated: Option<Self> = response.take(0).map_err(AppError::from)?;
|
||||
Ok(updated.is_some())
|
||||
}
|
||||
|
||||
/// Releases the index-rebuild lease when held by `owner`.
|
||||
pub async fn release_index_rebuild_lease(db: &SurrealDbClient, owner: &str) {
|
||||
let released = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('system_settings', $id) SET
|
||||
index_rebuild_lease_owner = NONE,
|
||||
index_rebuild_lease_expires_at = NONE
|
||||
WHERE index_rebuild_lease_owner = $owner;",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.bind(("owner", owner.to_string()))
|
||||
.await
|
||||
.and_then(surrealdb::Response::check);
|
||||
|
||||
if let Err(err) = released {
|
||||
warn!(error = %err, "failed to release index rebuild lease");
|
||||
}
|
||||
}
|
||||
|
||||
/// Records a completed scheduled index rebuild and clears the lease.
|
||||
pub async fn record_index_rebuild_completed(
|
||||
db: &SurrealDbClient,
|
||||
owner: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let response = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('system_settings', $id) SET
|
||||
last_index_rebuild_at = time::now(),
|
||||
index_rebuild_lease_owner = NONE,
|
||||
index_rebuild_lease_expires_at = NONE
|
||||
WHERE index_rebuild_lease_owner = $owner;",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.bind(("owner", owner.to_string()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
response.check().map_err(AppError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -237,6 +334,7 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn get_hnsw_index_dimension(
|
||||
@@ -320,17 +418,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Test initialization of system settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to get system settings".to_string())?;
|
||||
@@ -367,19 +455,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_settings() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Initialize settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Test get_current method
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to get current settings".to_string())?;
|
||||
@@ -392,17 +469,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_settings() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Initialize settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create updated settings
|
||||
let mut updated_settings = SystemSettings::get_current(&db)
|
||||
@@ -435,13 +502,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||
// Don't initialize settings and try to get them
|
||||
let result = SystemSettings::get_current(&db).await;
|
||||
|
||||
@@ -458,12 +519,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_rejects_zero_embedding_dimensions() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -477,12 +533,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_updates_without_cloning_full_settings() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let updated = SystemSettingsPatch {
|
||||
registrations_enabled: Some(false),
|
||||
@@ -498,12 +549,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_leaves_unmentioned_fields_unchanged() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let original = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -533,12 +579,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_rejects_empty_model_name() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -552,12 +593,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_normalizes_record_id() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -575,12 +611,7 @@ mod tests {
|
||||
async fn test_update_preserves_embedding_backend() -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let provider = EmbeddingProvider::new_hashed(384)
|
||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||
@@ -607,12 +638,7 @@ mod tests {
|
||||
async fn test_sync_from_embedding_provider_updates_mismatched_settings() -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let provider = EmbeddingProvider::new_hashed(384)
|
||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||
@@ -636,12 +662,7 @@ mod tests {
|
||||
async fn test_sync_from_embedding_provider_is_noop_when_already_synced() -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let provider = EmbeddingProvider::new_hashed(384)
|
||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||
@@ -660,12 +681,7 @@ mod tests {
|
||||
async fn test_sync_rejects_provider_dimension_above_u32_max() -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let provider = EmbeddingProvider::new_hashed((u32::MAX as usize) + 1)
|
||||
.with_context(|| "Failed to create oversized hashed provider".to_string())?;
|
||||
@@ -676,14 +692,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Initial migration failed".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let initial_chunk = TextChunk::new(
|
||||
"source1".into(),
|
||||
@@ -691,7 +700,7 @@ mod tests {
|
||||
"user1".into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], 1536, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
|
||||
|
||||
@@ -714,14 +723,7 @@ mod tests {
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Initial migration failed".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut current_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -802,4 +804,28 @@ mod tests {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn index_rebuild_lease_is_exclusive_on_system_settings() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
assert!(
|
||||
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a").await?,
|
||||
"first lease claim should succeed"
|
||||
);
|
||||
assert!(
|
||||
!SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?,
|
||||
"second lease claim should fail while lease is held"
|
||||
);
|
||||
|
||||
SystemSettings::release_index_rebuild_lease(&db, "worker-a").await;
|
||||
|
||||
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?;
|
||||
SystemSettings::record_index_rebuild_completed(&db, "worker-b").await?;
|
||||
|
||||
let settings = SystemSettings::get_current(&db).await?;
|
||||
assert!(settings.last_index_rebuild_at.is_some());
|
||||
assert!(settings.index_rebuild_lease_owner.is_none());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::storage::indexes::hnsw_index_overwrite_sql;
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::storage::types::{
|
||||
text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord, HasEmbedding,
|
||||
};
|
||||
use crate::utils::embedding::RE_EMBED_BATCH_SIZE;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TextChunk, "text_chunk", {
|
||||
@@ -25,6 +26,18 @@ pub struct TextChunkSearchResult {
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl HasEmbedding for TextChunk {
|
||||
type Embedding = TextChunkEmbedding;
|
||||
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
impl TextChunk {
|
||||
#[must_use]
|
||||
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
|
||||
@@ -41,123 +54,39 @@ impl TextChunk {
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
db_client
|
||||
.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
TextChunkEmbedding::table_name()
|
||||
))
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id;")
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.bind(("table", Self::table_name()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
db.delete_by_source_id::<Self>(source_id).await
|
||||
}
|
||||
|
||||
/// Atomically store a text chunk and its embedding.
|
||||
/// Writes the chunk to `text_chunk` and the embedding to `text_chunk_embedding`.
|
||||
/// Atomically store one text chunk and its embedding (single-record path).
|
||||
///
|
||||
/// Bulk ingestion uses `ingestion_pipeline::persist_artifacts` instead.
|
||||
pub async fn store_with_embedding(
|
||||
chunk: TextChunk,
|
||||
embedding: Vec<f32>,
|
||||
embedding_dimensions: usize,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let settings = SystemSettings::get_current(db).await?;
|
||||
TextChunkEmbedding::validate_dimension(&embedding, settings.embedding_dimensions as usize)?;
|
||||
|
||||
let chunk_id = chunk.id.clone();
|
||||
let emb = TextChunkEmbedding::new(
|
||||
&chunk_id,
|
||||
chunk.source_id.clone(),
|
||||
embedding,
|
||||
chunk.user_id.clone(),
|
||||
);
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
|
||||
UPSERT type::thing('{emb_table}', $chunk_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
chunk_table = Self::table_name(),
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id))
|
||||
.bind(("chunk", chunk))
|
||||
.bind(("emb", emb))
|
||||
db.store_with_embedding(chunk, embedding, embedding_dimensions)
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
|
||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and scores.
|
||||
pub async fn vector_search(
|
||||
take: usize,
|
||||
query_embedding: &[f32],
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||
#[allow(clippy::missing_docs_in_private_items)]
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
chunk_id: Option<TextChunk>,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
chunk_id,
|
||||
embedding,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH chunk_id;
|
||||
"#,
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
db.vector_search::<Self, TextChunkEmbedding>(take, query_embedding, user_id)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
response = response.check().map_err(AppError::from)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
r.chunk_id.map(|chunk| TextChunkSearchResult {
|
||||
chunk,
|
||||
score: r.score,
|
||||
}).or_else(|| {
|
||||
warn!("vector search hit orphaned text_chunk_embedding row with missing chunk");
|
||||
None
|
||||
})
|
||||
.map(|results| {
|
||||
results
|
||||
.into_iter()
|
||||
.map(|(chunk, score)| TextChunkSearchResult { chunk, score })
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Full-text search over text chunks using the BM25 FTS index.
|
||||
@@ -393,29 +322,10 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::test_utils::configure_embedding_dimension;
|
||||
use crate::test_utils::{
|
||||
configure_embedding_dimension, ensure_fts_index, prepare_text_chunk_test_db, setup_test_db,
|
||||
};
|
||||
use surrealdb::RecordId;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||
let snowball_sql = r#"
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
if let Err(err) = db.client.query(snowball_sql).await {
|
||||
let fallback_sql = r#"
|
||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
db.client
|
||||
.query(fallback_sql)
|
||||
.await
|
||||
.with_context(|| format!("define chunk fts index fallback: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_chunk_creation() -> anyhow::Result<()> {
|
||||
@@ -434,21 +344,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = prepare_text_chunk_test_db(5).await?;
|
||||
let source_id = "source123".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
configure_embedding_dimension(&db, 5).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let chunk1 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
@@ -466,15 +364,16 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk1".to_string())?;
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
TextChunk::store_with_embedding(
|
||||
different_chunk.clone(),
|
||||
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
5,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
@@ -516,18 +415,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
configure_embedding_dimension(&db, 5).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(5).await?;
|
||||
|
||||
let real_source_id = "real_source".to_string();
|
||||
let chunk = TextChunk::new(
|
||||
@@ -536,7 +424,7 @@ mod tests {
|
||||
"user123".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk".to_string())?;
|
||||
|
||||
@@ -560,18 +448,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
let db = prepare_text_chunk_test_db(5)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
configure_embedding_dimension(&db, 5)
|
||||
.await
|
||||
.expect("configure dim");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.expect("prepare test db");
|
||||
|
||||
let chunk1 = TextChunk::new(
|
||||
"safe_source".to_string(),
|
||||
@@ -584,10 +463,10 @@ mod tests {
|
||||
"user123".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db)
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], 5, &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
|
||||
@@ -614,25 +493,12 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
let source_id = "store-src".to_string();
|
||||
let user_id = "user_store".to_string();
|
||||
let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone());
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store with embedding".to_string())?;
|
||||
|
||||
@@ -645,7 +511,7 @@ mod tests {
|
||||
assert_eq!(stored_chunk.user_id, user_id);
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "expected embedding".to_string())?;
|
||||
@@ -658,14 +524,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_runtime";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let embedding_dimension = 3usize;
|
||||
configure_embedding_dimension(
|
||||
@@ -683,7 +542,7 @@ mod tests {
|
||||
"runtime_user".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store with embedding".to_string())?;
|
||||
|
||||
@@ -695,7 +554,7 @@ mod tests {
|
||||
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "embedding should exist".to_string())?;
|
||||
@@ -709,19 +568,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||
@@ -733,19 +580,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let source_id = "src".to_string();
|
||||
let user_id = "user".to_string();
|
||||
@@ -755,7 +590,7 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store".to_string())?;
|
||||
|
||||
@@ -774,28 +609,16 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
|
||||
let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone());
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk1".to_string())?;
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
|
||||
@@ -815,15 +638,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_empty";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||
rebuild(&db)
|
||||
.await
|
||||
.with_context(|| "rebuild indexes".to_string())?;
|
||||
@@ -838,15 +654,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_single";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||
|
||||
let user_id = "fts_user";
|
||||
let chunk = TextChunk::new(
|
||||
@@ -874,15 +683,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_order";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||
|
||||
let user_id = "fts_user_order";
|
||||
let high_score_chunk = TextChunk::new(
|
||||
@@ -936,19 +738,12 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_dim";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
|
||||
let chunk = TextChunk::new("src".to_string(), "body".to_string(), "user".to_string());
|
||||
|
||||
let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], &db)
|
||||
let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], 3, &db)
|
||||
.await
|
||||
.expect_err("expected dimension validation failure");
|
||||
assert!(matches!(err, AppError::Validation(_)));
|
||||
@@ -958,18 +753,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_orphan_chunk";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let chunk = TextChunk::new(
|
||||
@@ -978,7 +762,7 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk with embedding".to_string())?;
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||
stored_object,
|
||||
};
|
||||
use crate::{storage::types::EmbeddingRecord, stored_object};
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::error::AppError;
|
||||
|
||||
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||
/// Record link to the owning text_chunk
|
||||
@@ -18,123 +16,46 @@ stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextChunkEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
///
|
||||
/// This is useful when the embedding length changes; Surreal requires the
|
||||
/// index definition to be recreated with the updated dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = hnsw_index_redefine_transaction_sql(
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||
res.check().map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
impl EmbeddingRecord for TextChunkEmbedding {
|
||||
fn link_field() -> &'static str {
|
||||
"chunk_id"
|
||||
}
|
||||
|
||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||
if embedding.len() != expected {
|
||||
return Err(AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
fn index_name() -> &'static str {
|
||||
"idx_embedding_text_chunk_embedding"
|
||||
}
|
||||
|
||||
/// Create a new text chunk embedding.
|
||||
///
|
||||
/// The embedding record id equals `chunk_id` so each chunk has at most one embedding row.
|
||||
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid".
|
||||
#[must_use]
|
||||
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
|
||||
fn embedding(&self) -> &[f32] {
|
||||
&self.embedding
|
||||
}
|
||||
|
||||
fn new(
|
||||
chunk_id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
|
||||
Self {
|
||||
id: chunk_id.to_owned(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
||||
chunk_id: RecordId::from_table_key(entity_table, chunk_id),
|
||||
source_id,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a single embedding by its chunk RecordId
|
||||
pub async fn get_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
/// Delete embeddings for a given chunk RecordId
|
||||
pub async fn delete_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE chunk_id = $chunk_id",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings that belong to chunks with a given `source_id`
|
||||
///
|
||||
/// This uses the denormalized `source_id` on the embedding table.
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -144,8 +65,31 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
||||
use surrealdb::Value as SurrealValue;
|
||||
|
||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: surrealdb::Value = info_res
|
||||
.take(0)
|
||||
.with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||
.with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
Ok(idx_sql)
|
||||
}
|
||||
|
||||
async fn create_text_chunk_with_id(
|
||||
db: &SurrealDbClient,
|
||||
@@ -169,29 +113,6 @@ mod tests {
|
||||
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
||||
}
|
||||
|
||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: SurrealValue = info_res
|
||||
.take(0)
|
||||
.with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||
.with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
Ok(idx_sql)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_chunk_id_as_record_id() {
|
||||
let emb = TextChunkEmbedding::new(
|
||||
@@ -199,6 +120,7 @@ mod tests {
|
||||
"source-1".to_owned(),
|
||||
vec![0.1, 0.2],
|
||||
"user-1".to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
assert_eq!(emb.id, "chunk-abc");
|
||||
}
|
||||
@@ -226,13 +148,14 @@ mod tests {
|
||||
source_id.to_string(),
|
||||
embedding_vec.clone(),
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| "Failed to store embedding".to_string())?;
|
||||
|
||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let fetched = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
||||
.with_context(|| "Expected an embedding to be found".to_string())?;
|
||||
@@ -259,22 +182,23 @@ mod tests {
|
||||
source_id.to_string(),
|
||||
vec![0.4_f32, 0.5, 0.6],
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| "Failed to store embedding".to_string())?;
|
||||
|
||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let existing = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||
|
||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||
TextChunkEmbedding::delete_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
||||
|
||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let after = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none(), "Embedding should have been deleted");
|
||||
@@ -299,21 +223,27 @@ mod tests {
|
||||
("chunk-s2", source_id, vec![0.2]),
|
||||
("chunk-other", other_source, vec![0.3]),
|
||||
] {
|
||||
let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
|
||||
let emb = TextChunkEmbedding::new(
|
||||
key,
|
||||
src.to_string(),
|
||||
vec,
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| format!("store embedding for {key}"))?;
|
||||
}
|
||||
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk1".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk2".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk_other".to_string())?
|
||||
.is_some());
|
||||
@@ -322,15 +252,15 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk1".to_string())?
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk2".to_string())?
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
assert!(TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk_other".to_string())?
|
||||
.is_some());
|
||||
@@ -352,6 +282,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![1.0_f32, 0.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(initial)
|
||||
.await
|
||||
@@ -362,6 +293,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![0.0, 1.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(replacement)
|
||||
.await
|
||||
|
||||
@@ -96,6 +96,41 @@ impl TextContent {
|
||||
}
|
||||
}
|
||||
|
||||
/// SurrealQL deletes for ingested child rows keyed by `source_id` (no transaction wrapper).
|
||||
///
|
||||
/// Used inside larger transactions (e.g. ingestion `persist_artifacts`) and mirrored by
|
||||
/// [`Self::clear_ingested_children`].
|
||||
pub const CLEAR_INGESTED_CHILD_ROWS_SURQL: &'static str = r"
|
||||
DELETE relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id;
|
||||
DELETE text_chunk_embedding WHERE source_id = $source_id;
|
||||
DELETE text_chunk WHERE source_id = $source_id;
|
||||
DELETE knowledge_entity_embedding WHERE source_id = $source_id;
|
||||
DELETE knowledge_entity WHERE source_id = $source_id;
|
||||
";
|
||||
|
||||
/// Removes chunks, embeddings, entities, and relationships for one ingested document snapshot.
|
||||
pub async fn clear_ingested_children(
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;\n{} COMMIT TRANSACTION;",
|
||||
Self::CLEAR_INGESTED_CHILD_ROWS_SURQL
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_string()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
id: &str,
|
||||
context: &str,
|
||||
@@ -364,7 +399,14 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::setup_test_db_with_runtime_indexes;
|
||||
use crate::{
|
||||
storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
},
|
||||
test_utils::{setup_test_db, setup_test_db_with_runtime_indexes},
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||
@@ -638,4 +680,81 @@ mod tests {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_ingested_children_removes_chunks_entities_and_relationships(
|
||||
) -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "clear-user";
|
||||
let source_id = Uuid::new_v4().to_string();
|
||||
|
||||
let entity_a = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"entity-a".to_string(),
|
||||
"desc-a".to_string(),
|
||||
KnowledgeEntityType::Idea,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
let entity_b = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"entity-b".to_string(),
|
||||
"desc-b".to_string(),
|
||||
KnowledgeEntityType::Idea,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
KnowledgeEntity::store_with_embedding(entity_a.clone(), vec![0.1; 3], 3, &db)
|
||||
.await
|
||||
.context("store entity a")?;
|
||||
KnowledgeEntity::store_with_embedding(entity_b.clone(), vec![0.2; 3], 3, &db)
|
||||
.await
|
||||
.context("store entity b")?;
|
||||
|
||||
let chunk = TextChunk::new(source_id.clone(), "chunk".to_string(), user_id.to_string());
|
||||
TextChunk::store_with_embedding(chunk, vec![0.3; 3], 3, &db)
|
||||
.await
|
||||
.context("store chunk")?;
|
||||
|
||||
KnowledgeRelationship::new(
|
||||
entity_a.id.clone(),
|
||||
entity_b.id,
|
||||
user_id.to_string(),
|
||||
source_id.clone(),
|
||||
"relates_to".to_string(),
|
||||
)
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.context("store relationship")?;
|
||||
|
||||
TextContent::clear_ingested_children(&source_id, user_id, &db)
|
||||
.await
|
||||
.context("clear ingested children")?;
|
||||
|
||||
let chunks: Vec<TextChunk> = db
|
||||
.client
|
||||
.query("SELECT * FROM text_chunk WHERE source_id = $source_id;")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
assert!(chunks.is_empty());
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query("SELECT * FROM knowledge_entity WHERE source_id = $source_id;")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
assert!(entities.is_empty());
|
||||
|
||||
let relationships: Vec<KnowledgeRelationship> = db
|
||||
.client
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id;")
|
||||
.bind(("source_id", source_id))
|
||||
.await?
|
||||
.take(0)?;
|
||||
assert!(relationships.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::storage::{
|
||||
indexes::{ensure_runtime, rebuild},
|
||||
types::{
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
|
||||
text_chunk_embedding::TextChunkEmbedding,
|
||||
text_chunk_embedding::TextChunkEmbedding, EmbeddingRecord,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -91,3 +91,47 @@ pub async fn setup_test_db_with_runtime_indexes() -> Result<SurrealDbClient> {
|
||||
rebuild(&db).await?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Ensures an FTS analyzer and BM25 indexes exist for a table.
|
||||
///
|
||||
/// Attempts snowball(english) tokenizer first; falls back to basic
|
||||
/// lowercase+ascii when the platform lacks the snowball extension.
|
||||
///
|
||||
/// `indexes` is a slice of `(field_name, index_id_suffix)` pairs —
|
||||
/// e.g. `&[("chunk", "chunk")]` produces index
|
||||
/// `text_chunk_fts_chunk_idx` on column `chunk` of `text_chunk`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the fallback definition fails. The initial
|
||||
/// snowball attempt is allowed to fail silently.
|
||||
pub async fn ensure_fts_index(
|
||||
db: &SurrealDbClient,
|
||||
table: &str,
|
||||
indexes: &[(&str, &str)],
|
||||
) -> Result<()> {
|
||||
use std::fmt::Write;
|
||||
|
||||
let mut define_indexes = String::new();
|
||||
for (field, suffix) in indexes {
|
||||
let _ = writeln!(
|
||||
define_indexes,
|
||||
"DEFINE INDEX IF NOT EXISTS {table}_fts_{suffix}_idx ON TABLE {table} FIELDS {field} SEARCH ANALYZER app_en_fts_analyzer BM25;"
|
||||
);
|
||||
}
|
||||
|
||||
let snowball_sql = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);\n{define_indexes}"
|
||||
);
|
||||
|
||||
if let Err(err) = db.client.query(&snowball_sql).await {
|
||||
let fallback_sql = format!(
|
||||
"DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;\n{define_indexes}"
|
||||
);
|
||||
db.client
|
||||
.query(&fallback_sql)
|
||||
.await
|
||||
.with_context(|| format!("define fts index fallback for {table}: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -135,6 +135,9 @@ pub struct AppConfig {
|
||||
pub ingest_max_context_bytes: usize,
|
||||
#[serde(default = "default_ingest_max_category_bytes")]
|
||||
pub ingest_max_category_bytes: usize,
|
||||
/// Seconds between scheduled `REBUILD INDEX` maintainer runs (`0` disables).
|
||||
#[serde(default = "default_index_rebuild_interval_secs")]
|
||||
pub index_rebuild_interval_secs: u64,
|
||||
}
|
||||
|
||||
/// Default data directory for persisted assets.
|
||||
@@ -172,6 +175,10 @@ fn default_ingest_max_category_bytes() -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
fn default_index_rebuild_interval_secs() -> u64 {
|
||||
86_400
|
||||
}
|
||||
|
||||
static ORT_PATH_INIT: Once = Once::new();
|
||||
|
||||
/// Sets `ORT_DYLIB_PATH` once per process when a bundled ONNX runtime library is found.
|
||||
@@ -238,6 +245,7 @@ impl Default for AppConfig {
|
||||
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
||||
ingest_max_context_bytes: default_ingest_max_context_bytes(),
|
||||
ingest_max_category_bytes: default_ingest_max_category_bytes(),
|
||||
index_rebuild_interval_secs: default_index_rebuild_interval_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use std::{
|
||||
use serde::Serialize;
|
||||
use tracing::warn;
|
||||
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use async_openai::{types::embeddings::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
|
||||
|
||||
+14
-14
@@ -3,10 +3,10 @@
|
||||
"devenv": {
|
||||
"locked": {
|
||||
"dir": "src/modules",
|
||||
"lastModified": 1771066302,
|
||||
"lastModified": 1781800860,
|
||||
"owner": "cachix",
|
||||
"repo": "devenv",
|
||||
"rev": "1b355dec9bddbaddbe4966d6fc30d7aa3af8575b",
|
||||
"rev": "d59d872d80876d9eeb3e214d3b088bc4a14a9c4f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -22,10 +22,10 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1771052630,
|
||||
"lastModified": 1781779700,
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "d0555da98576b8611c25df0c208e51e9a182d95f",
|
||||
"rev": "ad30e585c7a2917325943c2b19511f5a249eff53",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -58,10 +58,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1770726378,
|
||||
"lastModified": 1781733627,
|
||||
"owner": "cachix",
|
||||
"repo": "git-hooks.nix",
|
||||
"rev": "5eaaedde414f6eb1aea8b8525c466dc37bba95ae",
|
||||
"rev": "3bbec39bc90eadfa031e6f3b77272f3f60803e39",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -92,10 +92,10 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1771008912,
|
||||
"lastModified": 1781577229,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "a82ccc39b39b621151d6732718e3e250109076fa",
|
||||
"rev": "567a49d1913ce81ac6e9582e3553dd90a955875f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -107,10 +107,10 @@
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1770843696,
|
||||
"lastModified": 1781607440,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "2343bbb58f99267223bc2aac4fc9ea301a155a16",
|
||||
"rev": "3e41b24abd260e8f71dbe2f5737d24122f972158",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -135,10 +135,10 @@
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1771007332,
|
||||
"lastModified": 1781714865,
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "bbc84d335fbbd9b3099d3e40c7469ee57dbd1873",
|
||||
"rev": "abb1301c3c14a40645bb2588b1cc858fe374b527",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -155,10 +155,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1771038269,
|
||||
"lastModified": 1781850613,
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "d7a86c8a4df49002446737603a3e0d7ef91a9637",
|
||||
"rev": "4baecb43a008cd004e5220a777e1724bd8d43e43",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
+32
-5
@@ -4,17 +4,32 @@
|
||||
config,
|
||||
inputs,
|
||||
...
|
||||
}:
|
||||
let
|
||||
ortVersion = lib.removeSuffix "\n" (builtins.readFile "${toString ./.}/ort-version");
|
||||
}: let
|
||||
ortVersion = "1.23.2";
|
||||
_ortVersionCheck =
|
||||
if pkgs.onnxruntime.version == ortVersion
|
||||
then null
|
||||
else
|
||||
throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ort-version (${ortVersion})";
|
||||
else throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ortVersion in flake.nix (${ortVersion})";
|
||||
in {
|
||||
devenv.warnOnNewVersion = false;
|
||||
|
||||
cachix.enable = false;
|
||||
|
||||
git-hooks.install.enable = true;
|
||||
git-hooks.hooks = {
|
||||
rustfmt.enable = true;
|
||||
clippy = {
|
||||
enable = true;
|
||||
settings.allFeatures = true;
|
||||
};
|
||||
};
|
||||
|
||||
# Use pinned Rust toolchain from languages.rust for git-hooks wrappers
|
||||
# (git-hooks.nix defaults to nixpkgs's cargo/clippy/rustfmt, ignoring the pin)
|
||||
git-hooks.tools.cargo = lib.mkDefault config.languages.rust.toolchain.cargo;
|
||||
git-hooks.tools.clippy = lib.mkDefault config.languages.rust.toolchain.clippy;
|
||||
git-hooks.tools.rustfmt = lib.mkDefault config.languages.rust.toolchain.rustfmt;
|
||||
|
||||
packages = [
|
||||
pkgs.openssl
|
||||
pkgs.nodejs
|
||||
@@ -26,6 +41,14 @@ in {
|
||||
pkgs.onnxruntime
|
||||
pkgs.cargo-watch
|
||||
pkgs.tailwindcss_4
|
||||
pkgs.python3
|
||||
pkgs.fontconfig
|
||||
pkgs.fontconfig.dev
|
||||
pkgs.libGL
|
||||
pkgs.libGLU
|
||||
pkgs.libclang
|
||||
pkgs.wayland
|
||||
pkgs.libxkbcommon
|
||||
];
|
||||
|
||||
languages.rust = {
|
||||
@@ -38,6 +61,10 @@ in {
|
||||
};
|
||||
|
||||
env = {
|
||||
# tikv-jemalloc-sys configure flags: -O0 + -Werror triggers glibc _FORTIFY_SOURCE warning
|
||||
NIX_CFLAGS_COMPILE = "-Wno-error=cpp";
|
||||
LIBCLANG_PATH = "${pkgs.libclang.lib}/lib";
|
||||
LD_LIBRARY_PATH = "${pkgs.wayland}/lib:${pkgs.libxkbcommon}/lib:${pkgs.pipewire}/lib:${pkgs.libglvnd}/lib";
|
||||
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
|
||||
S3_ENDPOINT = "http://127.0.0.1:19000";
|
||||
S3_BUCKET = "minne-tests";
|
||||
|
||||
@@ -9,3 +9,6 @@ inputs:
|
||||
nixpkgs:
|
||||
follows: nixpkgs
|
||||
allowUnfree: true
|
||||
nixpkgs:
|
||||
permittedInsecurePackages:
|
||||
- "minio-2025-10-15T17-29-55Z"
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
services:
|
||||
minne:
|
||||
build: .
|
||||
image: ghcr.io/perstarkse/minne:latest
|
||||
container_name: minne_app
|
||||
ports:
|
||||
- "3000:3000"
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
| Frontend | HTML + HTMX + minimal JS |
|
||||
| Database | SurrealDB (graph, document, vector) |
|
||||
| AI | OpenAI-compatible API |
|
||||
| Web Processing | Headless Chromium |
|
||||
| Web Processing | Servo engine (servo-fetch) + PDFium |
|
||||
|
||||
## Crate Structure
|
||||
|
||||
|
||||
+3
-1
@@ -10,7 +10,7 @@
|
||||
|
||||
Minne automatically processes saved content:
|
||||
|
||||
1. **Web scraping** extracts readable text from URLs (via headless Chrome)
|
||||
1. **Web scraping** extracts readable text from URLs (via embedded Servo engine)
|
||||
2. **Text analysis** identifies key concepts and relationships
|
||||
3. **Graph creation** builds connections between related content
|
||||
4. **Embedding generation** enables semantic search
|
||||
@@ -43,6 +43,7 @@ Optional **reranking** can rescore fused chunk lists with a cross-encoder model;
|
||||
When enabled, retrieval results are rescored with a cross-encoder model for improved relevance. Powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs).
|
||||
|
||||
**Trade-offs:**
|
||||
|
||||
- Downloads ~1.1 GB of model data
|
||||
- Adds latency per query
|
||||
- Potentially improves answer quality, see [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/)
|
||||
@@ -52,6 +53,7 @@ Enable via `RERANKING_ENABLED=true`. See [Configuration](./configuration.md).
|
||||
## Multi-Format Ingestion
|
||||
|
||||
Supported content types:
|
||||
|
||||
- Plain text and notes
|
||||
- URLs (web pages)
|
||||
- PDF documents
|
||||
|
||||
@@ -12,13 +12,13 @@ cd minne
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
The included `docker-compose.yml` handles SurrealDB and Chromium automatically.
|
||||
The included `docker-compose.yml` handles SurrealDB automatically.
|
||||
|
||||
**Required:** Set your `OPENAI_API_KEY` in `docker-compose.yml` before starting.
|
||||
|
||||
## Nix
|
||||
|
||||
Run Minne directly with Nix (includes Chromium):
|
||||
Run Minne directly with Nix:
|
||||
|
||||
```bash
|
||||
nix run 'github:perstarkse/minne#main'
|
||||
@@ -31,8 +31,9 @@ Configure via environment variables or a `config.yaml` file. See [Configuration]
|
||||
Download binaries for Windows, macOS, and Linux from [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- SurrealDB instance (local or remote)
|
||||
- Chromium (for web scraping)
|
||||
- `libEGL` + `libfontconfig` (for servo-fetch web scraping)
|
||||
|
||||
## Build from Source
|
||||
|
||||
@@ -45,9 +46,10 @@ cargo build --release --bin main
|
||||
The binary will be at `target/release/main`.
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- Rust toolchain
|
||||
- SurrealDB accessible at configured address
|
||||
- Chromium in PATH
|
||||
- `libEGL` + `libfontconfig` for servo-fetch (web scraping) — bundled in Nix and Docker images
|
||||
|
||||
## Process Modes
|
||||
|
||||
|
||||
@@ -30,8 +30,6 @@ serde_json = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
once_cell = "1.19"
|
||||
serde_yaml = "0.9"
|
||||
criterion = "0.5"
|
||||
state-machines = { workspace = true }
|
||||
clap = { version = "4.4", features = ["derive", "env"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
+71
-181
@@ -1,212 +1,102 @@
|
||||
# Evaluations
|
||||
|
||||
The `evaluations` crate provides a retrieval evaluation framework for benchmarking Minne's information retrieval pipeline against standard datasets.
|
||||
The `evaluations` crate benchmarks Minne's retrieval pipeline against standard datasets.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Run SQuAD v2.0 evaluation (vector-only, recommended)
|
||||
cargo run --package evaluations -- --ingest-chunks-only
|
||||
# One-time prep (convert, slice ledger, corpus cache, DB seed)
|
||||
cargo eval --warm --dataset beir --slice beir-mix-600
|
||||
|
||||
# Run a specific dataset
|
||||
cargo run --package evaluations -- --dataset fiqa --ingest-chunks-only
|
||||
# Check readiness
|
||||
cargo eval --status --dataset beir --slice beir-mix-600
|
||||
|
||||
# Convert dataset only (no evaluation)
|
||||
cargo run --package evaluations -- --convert-only
|
||||
# Run benchmark (steady state after warm)
|
||||
cargo eval --dataset beir --slice beir-mix-600 --require-ready
|
||||
```
|
||||
|
||||
Default dataset is `beir`. When `--slice` is omitted, the first catalog slice for the dataset is applied automatically (e.g. `beir-mix-600`).
|
||||
|
||||
Chunk-only ingestion is the default. Pass `--include-entities` to opt into entity extraction during ingestion (requires `OPENAI_API_KEY`).
|
||||
|
||||
### Custom slice sizes
|
||||
|
||||
`--slice` is a ledger id, not only a catalog name. You can use any id; `--limit` controls how many questions the ledger contains:
|
||||
|
||||
```bash
|
||||
# 200-case BEIR mix (default --limit is 200)
|
||||
cargo eval --warm --dataset beir --slice beir-mix-200
|
||||
cargo eval --dataset beir --slice beir-mix-200 --require-ready
|
||||
```
|
||||
|
||||
The catalog slice `beir-mix-600` in `manifest.yaml` is a preset with `limit: 600` and `negative_multiplier: 9.0`.
|
||||
|
||||
### BEIR mix layout
|
||||
|
||||
`beir` is a **virtual mix** across eight subset datasets (FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR). There is no monolithic `beir-minne/` store.
|
||||
|
||||
1. Build an in-memory qrels-world mix from raw subset data
|
||||
2. Resolve the slice ledger (`cache/slices/beir/<slice-id>.json`)
|
||||
3. Materialize only ledger paragraph ids into per-subset stores (`fever-minne/`, `fiqa-minne/`, …)
|
||||
4. Ingest the slice corpus and seed SurrealDB
|
||||
|
||||
Conversion is **qrels-closed**: only documents that appear in qrels are exported, not the full BEIR corpus.
|
||||
|
||||
Chunk-only mode may evaluate fewer cases than the slice ledger size when some questions are impossible or lack verifiable answer chunks.
|
||||
|
||||
Reports include a **Retrieved Context Volume** section: total characters and estimated tokens across all chunks returned per query (`~chars/4`, comparable across `--chunk-result-cap` sweeps). Use this to compare the cost of raising `--chunk-result-cap`.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. SurrealDB
|
||||
|
||||
Start a SurrealDB instance before running evaluations:
|
||||
### SurrealDB
|
||||
|
||||
```bash
|
||||
docker-compose up -d surrealdb
|
||||
```
|
||||
|
||||
Or using the default endpoint configuration:
|
||||
### Raw datasets
|
||||
|
||||
```bash
|
||||
surreal start --user root_user --pass root_password
|
||||
```
|
||||
Place raw datasets under `evaluations/data/raw/`. See [manifest.yaml](./manifest.yaml) for paths.
|
||||
|
||||
### 2. Download Raw Datasets
|
||||
BEIR subsets live in sibling directories (`data/raw/fever`, `data/raw/fiqa`, …). The `data/raw/beir` entry is a virtual catalog placeholder; warm uses the subset paths.
|
||||
|
||||
Raw datasets must be downloaded manually and placed in `evaluations/data/raw/`. See [Dataset Sources](#dataset-sources) below for links and formats.
|
||||
|
||||
## Directory Structure
|
||||
## Directory structure
|
||||
|
||||
```
|
||||
evaluations/
|
||||
├── data/
|
||||
│ ├── raw/ # Downloaded raw datasets (manual)
|
||||
│ │ ├── squad/ # SQuAD v2.0
|
||||
│ │ ├── nq-dev/ # Natural Questions
|
||||
│ │ ├── fiqa/ # BEIR: FiQA-2018
|
||||
│ │ ├── fever/ # BEIR: FEVER
|
||||
│ │ ├── hotpotqa/ # BEIR: HotpotQA
|
||||
│ │ └── ... # Other BEIR subsets
|
||||
│ └── converted/ # Auto-generated (Minne JSON format)
|
||||
├── cache/ # Ingestion and embedding caches
|
||||
├── reports/ # Evaluation output (JSON + Markdown)
|
||||
├── manifest.yaml # Dataset and slice definitions
|
||||
└── src/ # Evaluation source code
|
||||
│ ├── raw/ # Downloaded datasets (manual)
|
||||
│ │ ├── fever/ # BEIR subset raw dirs (corpus.jsonl, queries.jsonl, qrels/)
|
||||
│ │ ├── fiqa/
|
||||
│ │ └── …
|
||||
│ └── converted/ # Sharded stores (auto-generated)
|
||||
│ ├── fever-minne/ # per-BEIR-subset stores
|
||||
│ ├── fiqa-minne/
|
||||
│ └── … # BEIR mix loads from subset stores (no monolithic beir-minne/)
|
||||
├── cache/
|
||||
│ ├── slices/ # Slice ledgers
|
||||
│ └── ingested/ # Corpus ingestion caches (manifest includes namespace seed)
|
||||
├── reports/ # JSON + Markdown output from benchmark runs
|
||||
├── manifest.yaml
|
||||
└── src/
|
||||
```
|
||||
|
||||
## Dataset Sources
|
||||
**After upgrading:** delete old monolithic `*-minne.json` files, any legacy `beir-minne/` merged store, `cache/snapshots/` directories, and stale `reports/history/` artifacts, then re-run `--warm`.
|
||||
|
||||
### SQuAD v2.0
|
||||
|
||||
Download and place at `data/raw/squad/dev-v2.0.json`:
|
||||
|
||||
```bash
|
||||
mkdir -p evaluations/data/raw/squad
|
||||
curl -L https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json \
|
||||
-o evaluations/data/raw/squad/dev-v2.0.json
|
||||
```
|
||||
|
||||
### Natural Questions (NQ)
|
||||
|
||||
Download and place at `data/raw/nq-dev/dev-all.jsonl`:
|
||||
|
||||
```bash
|
||||
mkdir -p evaluations/data/raw/nq-dev
|
||||
# Download from Google's Natural Questions page or HuggingFace
|
||||
# File: dev-all.jsonl (simplified JSONL format)
|
||||
```
|
||||
|
||||
Source: [Google Natural Questions](https://ai.google.com/research/NaturalQuestions)
|
||||
|
||||
### BEIR Datasets
|
||||
|
||||
All BEIR datasets follow the same format structure:
|
||||
|
||||
```
|
||||
data/raw/<dataset>/
|
||||
├── corpus.jsonl # Document corpus
|
||||
├── queries.jsonl # Query set
|
||||
└── qrels/
|
||||
└── test.tsv # Relevance judgments (or dev.tsv)
|
||||
```
|
||||
|
||||
Download datasets from the [BEIR Benchmark repository](https://github.com/beir-cellar/beir). Each dataset zip extracts to the required directory structure.
|
||||
|
||||
| Dataset | Directory |
|
||||
|------------|---------------|
|
||||
| FEVER | `fever/` |
|
||||
| FiQA-2018 | `fiqa/` |
|
||||
| HotpotQA | `hotpotqa/` |
|
||||
| NFCorpus | `nfcorpus/` |
|
||||
| Quora | `quora/` |
|
||||
| TREC-COVID | `trec-covid/` |
|
||||
| SciFact | `scifact/` |
|
||||
| NQ (BEIR) | `nq/` |
|
||||
|
||||
Example download:
|
||||
|
||||
```bash
|
||||
cd evaluations/data/raw
|
||||
curl -L https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip -o fiqa.zip
|
||||
unzip fiqa.zip && rm fiqa.zip
|
||||
```
|
||||
|
||||
## Dataset Conversion
|
||||
|
||||
Raw datasets are automatically converted to Minne's internal JSON format on first run. To force reconversion:
|
||||
|
||||
```bash
|
||||
cargo run --package evaluations -- --force-convert
|
||||
```
|
||||
|
||||
Converted files are saved to `data/converted/` and cached for subsequent runs.
|
||||
|
||||
## CLI Reference
|
||||
|
||||
### Common Options
|
||||
## Common flags
|
||||
|
||||
| Flag | Description | Default |
|
||||
|------|-------------|---------|
|
||||
| `--dataset <NAME>` | Dataset to evaluate | `squad-v2` |
|
||||
| `--limit <N>` | Max questions to evaluate (0 = all) | `200` |
|
||||
| `--k <N>` | Precision@k cutoff | `5` |
|
||||
| `--slice <ID>` | Use a predefined slice from manifest | — |
|
||||
| `--rerank` | Enable FastEmbed reranking stage | disabled |
|
||||
| `--embedding-backend <BE>` | `fastembed` or `hashed` | `fastembed` |
|
||||
| `--ingest-chunks-only` | Skip entity extraction, ingest only text chunks | disabled |
|
||||
| `--dataset` | Dataset to evaluate | `beir` |
|
||||
| `--slice` | Slice ledger id (catalog or custom) | first catalog slice |
|
||||
| `--limit` | Max questions in the slice ledger | `200` |
|
||||
| `--warm` | Prepare without running queries | — |
|
||||
| `--status` | Print readiness | — |
|
||||
| `--require-ready` | Fail if not warmed | — |
|
||||
| `--include-entities` | Entity extraction during ingestion | off |
|
||||
| `--force-convert` | Rebuild converted store | — |
|
||||
| `--chunk-result-cap` | Max chunks returned per query (raise with `--k`) | `5` |
|
||||
| `--perf-log-console` | Print per-stage timings after a run | off |
|
||||
| `--label` | Label stored in JSON/Markdown reports | — |
|
||||
|
||||
> [!TIP]
|
||||
> Use `--ingest-chunks-only` when evaluating vector-only retrieval strategies. This skips the LLM-based entity extraction and graph generation, significantly speeding up ingestion while focusing on pure chunk-based vector search.
|
||||
|
||||
### Available Datasets
|
||||
|
||||
```
|
||||
squad-v2, natural-questions, beir, fever, fiqa, hotpotqa,
|
||||
nfcorpus, quora, trec-covid, scifact, nq-beir
|
||||
```
|
||||
|
||||
### Database Configuration
|
||||
|
||||
| Flag | Environment | Default |
|
||||
|------|-------------|---------|
|
||||
| `--db-endpoint` | `EVAL_DB_ENDPOINT` | `ws://127.0.0.1:8000` |
|
||||
| `--db-username` | `EVAL_DB_USERNAME` | `root_user` |
|
||||
| `--db-password` | `EVAL_DB_PASSWORD` | `root_password` |
|
||||
| `--db-namespace` | `EVAL_DB_NAMESPACE` | auto-generated |
|
||||
| `--db-database` | `EVAL_DB_DATABASE` | auto-generated |
|
||||
|
||||
### Example Runs
|
||||
|
||||
```bash
|
||||
# Vector-only evaluation (recommended for benchmarking)
|
||||
cargo run --package evaluations -- \
|
||||
--dataset fiqa \
|
||||
--ingest-chunks-only \
|
||||
--limit 200
|
||||
|
||||
# Full FiQA evaluation with reranking
|
||||
cargo run --package evaluations -- \
|
||||
--dataset fiqa \
|
||||
--ingest-chunks-only \
|
||||
--limit 500 \
|
||||
--rerank \
|
||||
--k 10
|
||||
|
||||
# Use a predefined slice for reproducibility
|
||||
cargo run --package evaluations -- --slice fiqa-test-200 --ingest-chunks-only
|
||||
|
||||
# Run the mixed BEIR benchmark
|
||||
cargo run --package evaluations -- --dataset beir --slice beir-mix-600 --ingest-chunks-only
|
||||
```
|
||||
|
||||
## Slices
|
||||
|
||||
Slices are predefined, reproducible subsets defined in `manifest.yaml`. Each slice specifies:
|
||||
|
||||
- **limit**: Number of questions
|
||||
- **corpus_limit**: Maximum corpus size
|
||||
- **seed**: Fixed RNG seed for reproducibility
|
||||
|
||||
View available slices in [manifest.yaml](./manifest.yaml).
|
||||
|
||||
## Reports
|
||||
|
||||
Evaluations generate reports in `reports/`:
|
||||
|
||||
- **JSON**: Full structured results (`*-report.json`)
|
||||
- **Markdown**: Human-readable summary with sample mismatches (`*-report.md`)
|
||||
- **History**: Timestamped run history (`history/`)
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
```bash
|
||||
# Log per-stage performance timings
|
||||
cargo run --package evaluations -- --perf-log-console
|
||||
|
||||
# Save telemetry to file
|
||||
cargo run --package evaluations -- --perf-log-json ./perf.json
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
See [../LICENSE](../LICENSE).
|
||||
See [REFACTOR.md](./REFACTOR.md) for architecture notes.
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
# Evaluations crate refactor plan
|
||||
|
||||
This document records the architecture review and the simplification work applied to the
|
||||
`evaluations` crate. **No backwards compatibility** is maintained for converted JSON layouts,
|
||||
legacy report history, or old cache artifact formats.
|
||||
|
||||
## Goals
|
||||
|
||||
- Smaller, linear pipeline (no state machine ceremony)
|
||||
- Sharded converted store for **all** datasets (memory-efficient partial loading)
|
||||
- Slice-first loading when a catalog slice is selected
|
||||
- In-memory SurrealDB for ingestion (no ephemeral server namespaces)
|
||||
- Single DB lifecycle module (`db/`)
|
||||
- CLI helpers under `cli/`
|
||||
|
||||
## Primary workflow
|
||||
|
||||
```bash
|
||||
# One-time prep (converts raw data if needed, builds slice ledger, corpus cache, DB seed)
|
||||
cargo eval --warm --dataset beir --slice beir-mix-600
|
||||
|
||||
# Check readiness
|
||||
cargo eval --status --dataset beir --slice beir-mix-600
|
||||
|
||||
# Steady-state benchmark
|
||||
cargo eval --dataset beir --slice beir-mix-600 --require-ready
|
||||
```
|
||||
|
||||
Default dataset is `beir`. Chunk-only ingestion is the default; pass `--include-entities` to
|
||||
opt into entity extraction (requires `OPENAI_API_KEY`). Slice tuning such as
|
||||
`negative_multiplier` lives in `manifest.yaml` (e.g. `beir-mix-600` uses `9.0`).
|
||||
|
||||
## Cache layers (after refactor)
|
||||
|
||||
| Layer | Location | Purpose |
|
||||
|-------|----------|---------|
|
||||
| Converted store | `data/converted/<name>/` | Sharded paragraphs + question catalog |
|
||||
| Slice ledger | `cache/slices/<dataset>/<slice-id>.json` | Deterministic questions + paragraph set |
|
||||
| Corpus cache | `cache/ingested/<dataset>/<slice-id>/` | Ingestion paragraph shards, manifest, and namespace reuse seed |
|
||||
|
||||
Namespace reuse state lives in the corpus manifest (`metadata.namespace_seed`), not a separate
|
||||
`snapshots/` tree. After upgrading, delete old `*-minne.json` monolithic files, any
|
||||
`cache/snapshots/` directories, and re-run `--warm`.
|
||||
|
||||
## Phases applied
|
||||
|
||||
### Phase 0 — dead code
|
||||
|
||||
- Removed unused `criterion` dependency
|
||||
- Removed unused `EmbeddingCache`
|
||||
- Updated README for current CLI
|
||||
|
||||
### Phase 1 — structure
|
||||
|
||||
- Flattened pipeline to linear `async fn` stages
|
||||
- Removed `eval.rs` hub; imports go to owning modules
|
||||
- Merged `namespace.rs`, `db_helpers.rs` → `db/`; dropped standalone `snapshot.rs`
|
||||
- Moved `status.rs` → `cli/status.rs`
|
||||
- Fixed catalog slice bootstrap (build ledger when explicit slice manifest is missing)
|
||||
|
||||
### Phase 2 — no legacy paths
|
||||
|
||||
- All datasets use sharded converted store only
|
||||
- Removed legacy JSON layout and migration
|
||||
- Removed legacy report history format
|
||||
- Auto-apply first catalog slice when `--slice` omitted
|
||||
- Namespace seed folded into corpus manifest (removed `cache/snapshots/`)
|
||||
|
||||
### Phase 3 — performance
|
||||
|
||||
- Ingestion always uses in-memory SurrealDB
|
||||
- Slice-first partial load when ledger is complete
|
||||
- Default catalog slice for dataset when `--slice` not passed
|
||||
- Split `slice/` into `mod.rs`, `build.rs`, and `beir.rs`
|
||||
|
||||
### Phase 4 — BEIR mix slice-first
|
||||
|
||||
- `beir` is a virtual mix: slice ledger references prefixed ids (`fever-…`, `fiqa-…`, …)
|
||||
- Conversion is **qrels-closed** per subset (only documents appearing in qrels, not full corpus)
|
||||
- Slice ledger is resolved for the requested `--slice` (catalog preset or custom id + `--limit`)
|
||||
- Only ledger paragraph ids are materialized into per-subset stores (`fever-minne/`, `fiqa-minne/`, …)
|
||||
- No monolithic `beir-minne/` merged store
|
||||
- Raw BEIR data lives in per-subset dirs under `data/raw/`; `data/raw/beir` is a catalog placeholder
|
||||
|
||||
## Do not re-introduce
|
||||
|
||||
- Monolithic `*-minne.json` converted files
|
||||
- Monolithic `beir-minne/` merged converted store (use per-subset stores + virtual mix loader)
|
||||
- `state-machines` pipeline for this linear flow
|
||||
- `eval.rs` re-export hub
|
||||
- Legacy history migration in reports
|
||||
- Ephemeral `ingest_eval_*` namespaces on the shared SurrealDB server
|
||||
- Separate `cache/snapshots/` namespace state files
|
||||
|
||||
## Open follow-ups
|
||||
|
||||
- Generate `DatasetKind` from `manifest.yaml` at build time
|
||||
- Split `report.rs` when touching reporting again
|
||||
@@ -1,4 +1,4 @@
|
||||
default_dataset: squad-v2
|
||||
default_dataset: beir
|
||||
datasets:
|
||||
- id: squad-v2
|
||||
label: "SQuAD v2.0"
|
||||
@@ -45,6 +45,7 @@ datasets:
|
||||
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR"
|
||||
limit: 600
|
||||
corpus_limit: 6000
|
||||
negative_multiplier: 9.0
|
||||
seed: 0x5eed2025
|
||||
- id: fever
|
||||
label: "FEVER (BEIR)"
|
||||
|
||||
+66
-18
@@ -137,9 +137,9 @@ pub struct IngestConfig {
|
||||
#[arg(long, default_value_t = 50)]
|
||||
pub ingest_chunk_overlap_tokens: usize,
|
||||
|
||||
/// Run ingestion in chunk-only mode (skip analyzer/graph generation)
|
||||
/// Include entity extraction and graph generation during ingestion (uses LLM tokens)
|
||||
#[arg(long)]
|
||||
pub ingest_chunks_only: bool,
|
||||
pub include_entities: bool,
|
||||
|
||||
/// Number of paragraphs to ingest concurrently
|
||||
#[arg(long, default_value_t = 10)]
|
||||
@@ -159,6 +159,7 @@ pub struct IngestConfig {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Args)]
|
||||
#[allow(clippy::struct_field_names)]
|
||||
pub struct DatabaseArgs {
|
||||
/// `SurrealDB` server endpoint
|
||||
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
|
||||
@@ -179,10 +180,6 @@ pub struct DatabaseArgs {
|
||||
/// Override the database used on the `SurrealDB` server
|
||||
#[arg(long, env = "EVAL_DB_DATABASE")]
|
||||
pub db_database: Option<String>,
|
||||
|
||||
/// Path to inspect DB state
|
||||
#[arg(long)]
|
||||
pub inspect_db_state: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@@ -233,10 +230,6 @@ pub struct Config {
|
||||
#[arg(long, default_value_t = 5)]
|
||||
pub sample: usize,
|
||||
|
||||
/// Disable context cropping when converting datasets (ingest entire documents)
|
||||
#[arg(long)]
|
||||
pub full_context: bool,
|
||||
|
||||
#[command(flatten)]
|
||||
pub retrieval: RetrievalSettings,
|
||||
|
||||
@@ -322,6 +315,18 @@ pub struct Config {
|
||||
#[command(flatten)]
|
||||
pub database: DatabaseArgs,
|
||||
|
||||
/// Require warmed corpus/namespace before running queries
|
||||
#[arg(long)]
|
||||
pub require_ready: bool,
|
||||
|
||||
/// Prepare converted data, slice, corpus, and namespace without running queries
|
||||
#[arg(long, conflicts_with = "status")]
|
||||
pub warm: bool,
|
||||
|
||||
/// Print readiness of converted data, slice, corpus, and namespace
|
||||
#[arg(long, conflicts_with = "warm")]
|
||||
pub status: bool,
|
||||
|
||||
// Computed fields (not arguments)
|
||||
#[arg(skip)]
|
||||
pub raw_dataset_path: PathBuf,
|
||||
@@ -334,11 +339,6 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
#[allow(clippy::unused_self)]
|
||||
pub fn context_token_limit(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn finalize(&mut self) -> Result<()> {
|
||||
// Handle dataset paths
|
||||
@@ -367,9 +367,7 @@ impl Config {
|
||||
// Handle retrieval settings
|
||||
self.retrieval.require_verified_chunks = !self.llm_mode;
|
||||
|
||||
if self.dataset == DatasetKind::Beir {
|
||||
self.negative_multiplier = 9.0;
|
||||
}
|
||||
self.apply_catalog_slice_defaults()?;
|
||||
|
||||
// Validations
|
||||
if self.ingest.ingest_chunk_min_tokens == 0
|
||||
@@ -477,6 +475,56 @@ impl Config {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_catalog_slice_defaults(&mut self) -> Result<()> {
|
||||
let catalog = crate::datasets::catalog()?;
|
||||
let entry = catalog.dataset(self.dataset.id())?;
|
||||
|
||||
if self.slice.is_none() {
|
||||
if let Some(default_slice) = entry.slices.first() {
|
||||
self.slice = Some(default_slice.id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let Some(slice_id) = self.slice.as_deref() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let Ok((_, slice)) = catalog.slice(slice_id) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if slice.dataset_id != self.dataset.id() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(limit) = slice.limit {
|
||||
if self.limit_arg == 200 {
|
||||
self.limit_arg = limit;
|
||||
self.limit = Some(limit);
|
||||
}
|
||||
}
|
||||
if self.corpus_limit.is_none() {
|
||||
self.corpus_limit = slice.corpus_limit;
|
||||
}
|
||||
if let Some(seed) = slice.seed {
|
||||
self.slice_seed = seed;
|
||||
}
|
||||
if let Some(include_unanswerable) = slice.include_unanswerable {
|
||||
self.llm_mode = include_unanswerable;
|
||||
self.retrieval.require_verified_chunks = !include_unanswerable;
|
||||
}
|
||||
if let Some(multiplier) = slice.negative_multiplier {
|
||||
if negative_multiplier_is_default(self.negative_multiplier) {
|
||||
self.negative_multiplier = multiplier;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn negative_multiplier_is_default(value: f32) -> bool {
|
||||
(value - crate::slice::DEFAULT_NEGATIVE_MULTIPLIER).abs() < f32::EPSILON
|
||||
}
|
||||
|
||||
pub struct ParsedArgs {
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct EmbeddingCacheData {
|
||||
entities: HashMap<String, Vec<f32>>,
|
||||
chunks: HashMap<String, Vec<f32>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingCache {
|
||||
path: Arc<Path>,
|
||||
data: Arc<Mutex<EmbeddingCacheData>>,
|
||||
dirty: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl EmbeddingCache {
|
||||
pub async fn load(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref().to_path_buf();
|
||||
let data = if path.exists() {
|
||||
let raw = tokio::fs::read(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading embedding cache {}", path.display()))?;
|
||||
serde_json::from_slice(&raw)
|
||||
.with_context(|| format!("parsing embedding cache {}", path.display()))?
|
||||
} else {
|
||||
EmbeddingCacheData::default()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
path: Arc::from(path.as_path()),
|
||||
data: Arc::new(Mutex::new(data)),
|
||||
dirty: Arc::new(AtomicBool::new(false)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_entity(&self, id: &str) -> Option<Vec<f32>> {
|
||||
let guard = self.data.lock().await;
|
||||
guard.entities.get(id).cloned()
|
||||
}
|
||||
|
||||
pub async fn insert_entity(&self, id: String, embedding: Vec<f32>) {
|
||||
let mut guard = self.data.lock().await;
|
||||
guard.entities.insert(id, embedding);
|
||||
self.dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn get_chunk(&self, id: &str) -> Option<Vec<f32>> {
|
||||
let guard = self.data.lock().await;
|
||||
guard.chunks.get(id).cloned()
|
||||
}
|
||||
|
||||
pub async fn insert_chunk(&self, id: String, embedding: Vec<f32>) {
|
||||
let mut guard = self.data.lock().await;
|
||||
guard.chunks.insert(id, embedding);
|
||||
self.dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> Result<()> {
|
||||
if !self.dirty.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let guard = self.data.lock().await;
|
||||
let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?;
|
||||
if let Some(parent) = self.path.parent() {
|
||||
tokio::fs::create_dir_all(parent)
|
||||
.await
|
||||
.with_context(|| format!("creating cache directory {}", parent.display()))?;
|
||||
}
|
||||
tokio::fs::write(&*self.path, body)
|
||||
.await
|
||||
.with_context(|| format!("writing embedding cache {}", self.path.display()))?;
|
||||
self.dirty.store(false, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -156,6 +156,7 @@ mod tests {
|
||||
chunk_min_tokens: 1,
|
||||
chunk_max_tokens: 10,
|
||||
chunk_only: false,
|
||||
namespace_seed: None,
|
||||
},
|
||||
paragraphs,
|
||||
questions,
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
pub mod status;
|
||||
|
||||
pub use status::{collect_status, ensure_query_ready, print_status, warm};
|
||||
@@ -0,0 +1,311 @@
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::{
|
||||
args::Config,
|
||||
corpus::{self, CorpusCacheConfig},
|
||||
datasets::{
|
||||
beir_subset_store_summary, beir_subset_stores_ready, content_checksum_for_layout,
|
||||
detect_layout, mix_content_checksum, store_dir_for, ConvertedLayout, DatasetKind,
|
||||
},
|
||||
db::{connect_eval_db, default_database, default_namespace, namespace_has_corpus},
|
||||
slice::{self, ledger_target},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct EvalStatus {
|
||||
pub dataset: String,
|
||||
pub slice: Option<String>,
|
||||
pub converted: ConvertedStatus,
|
||||
pub slice_ledger: SliceLedgerStatus,
|
||||
pub corpus_cache: CorpusCacheStatus,
|
||||
pub namespace: NamespaceStatus,
|
||||
pub query_ready: bool,
|
||||
pub notes: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ConvertedStatus {
|
||||
pub layout: String,
|
||||
pub path: String,
|
||||
pub ready: bool,
|
||||
pub partial_load_eligible: bool,
|
||||
pub checksum: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct SliceLedgerStatus {
|
||||
pub ready: bool,
|
||||
pub path: Option<String>,
|
||||
pub cases: Option<usize>,
|
||||
pub positives: Option<usize>,
|
||||
pub negatives: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct CorpusCacheStatus {
|
||||
pub ready: bool,
|
||||
pub path: Option<String>,
|
||||
pub manifest_present: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct NamespaceStatus {
|
||||
pub namespace: String,
|
||||
pub database: String,
|
||||
pub seeded: bool,
|
||||
pub namespace_seed_recorded: bool,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn collect_status(config: &Config) -> Result<EvalStatus> {
|
||||
let mut notes = Vec::new();
|
||||
let is_beir_mix = config.dataset == DatasetKind::Beir;
|
||||
let converted_path = &config.converted_dataset_path;
|
||||
let layout = if is_beir_mix {
|
||||
ConvertedLayout::Missing
|
||||
} else {
|
||||
detect_layout(converted_path)
|
||||
};
|
||||
let layout_label = if is_beir_mix {
|
||||
"beir-mix-subset-stores"
|
||||
} else {
|
||||
match layout {
|
||||
ConvertedLayout::ShardedStore => "sharded-store",
|
||||
ConvertedLayout::Missing => "missing",
|
||||
}
|
||||
};
|
||||
|
||||
let store_dir = store_dir_for(converted_path);
|
||||
let display_path = if is_beir_mix {
|
||||
beir_subset_store_summary()?
|
||||
.into_iter()
|
||||
.map(|(subset, paragraphs, questions)| {
|
||||
format!("{subset}-minne ({paragraphs} paragraphs, {questions} questions)")
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("; ")
|
||||
} else {
|
||||
store_dir.display().to_string()
|
||||
};
|
||||
|
||||
let manifest_path = slice::cached_manifest_path(config);
|
||||
let slice_config = slice::slice_config_with_limit(config, ledger_target(config));
|
||||
let slice_manifest = manifest_path
|
||||
.as_ref()
|
||||
.and_then(|path| slice::read_manifest_if_exists(path).ok().flatten());
|
||||
|
||||
let slice_ledger = SliceLedgerStatus {
|
||||
ready: slice_manifest
|
||||
.as_ref()
|
||||
.is_some_and(|manifest| slice::manifest_is_complete(manifest, &slice_config)),
|
||||
path: manifest_path
|
||||
.as_ref()
|
||||
.map(|path| path.display().to_string()),
|
||||
cases: slice_manifest.as_ref().map(|manifest| manifest.case_count),
|
||||
positives: slice_manifest
|
||||
.as_ref()
|
||||
.map(|manifest| manifest.positive_paragraphs),
|
||||
negatives: slice_manifest
|
||||
.as_ref()
|
||||
.map(|manifest| manifest.negative_paragraphs),
|
||||
};
|
||||
|
||||
let beir_paragraph_ids = slice_manifest.as_ref().map(|manifest| {
|
||||
manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|entry| entry.id.clone())
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
});
|
||||
|
||||
let converted_ready = if is_beir_mix {
|
||||
slice_ledger.ready
|
||||
&& beir_paragraph_ids
|
||||
.as_ref()
|
||||
.is_some_and(|ids| beir_subset_stores_ready(ids).unwrap_or(false))
|
||||
} else {
|
||||
layout == ConvertedLayout::ShardedStore
|
||||
};
|
||||
|
||||
let checksum = if is_beir_mix {
|
||||
beir_paragraph_ids
|
||||
.as_ref()
|
||||
.and_then(|ids| mix_content_checksum(ids).ok())
|
||||
} else if layout == ConvertedLayout::ShardedStore {
|
||||
content_checksum_for_layout(converted_path).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let partial_load_eligible = slice_ledger.ready && config.slice.is_some();
|
||||
|
||||
let corpus_cache = if let Some(manifest) = slice_manifest.as_ref() {
|
||||
let cache_settings = CorpusCacheConfig::from(config);
|
||||
let base_dir = corpus::cached_corpus_dir(
|
||||
&cache_settings,
|
||||
config.dataset.id(),
|
||||
manifest.slice_id.as_str(),
|
||||
);
|
||||
let manifest_present = corpus::load_cached_manifest(&base_dir)?.is_some();
|
||||
CorpusCacheStatus {
|
||||
ready: manifest_present,
|
||||
path: Some(base_dir.display().to_string()),
|
||||
manifest_present,
|
||||
}
|
||||
} else {
|
||||
CorpusCacheStatus {
|
||||
ready: false,
|
||||
path: None,
|
||||
manifest_present: false,
|
||||
}
|
||||
};
|
||||
|
||||
let namespace = config.database.db_namespace.clone().unwrap_or_else(|| {
|
||||
default_namespace(config.dataset.id(), config.limit, config.slice.as_deref())
|
||||
});
|
||||
let database = config
|
||||
.database
|
||||
.db_database
|
||||
.clone()
|
||||
.unwrap_or_else(default_database);
|
||||
|
||||
let namespace_seed = corpus_cache.path.as_ref().and_then(|path| {
|
||||
corpus::load_cached_manifest(Path::new(path))
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|manifest| manifest.metadata.namespace_seed)
|
||||
});
|
||||
|
||||
let (seeded, namespace_seed_recorded) =
|
||||
match connect_eval_db(config, &namespace, &database).await {
|
||||
Ok(db) => {
|
||||
let has_corpus = namespace_has_corpus(&db).await.unwrap_or(false);
|
||||
(has_corpus, namespace_seed.is_some())
|
||||
}
|
||||
Err(err) => {
|
||||
notes.push(format!("SurrealDB unavailable: {err}"));
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
let query_ready = converted_ready
|
||||
&& slice_ledger.ready
|
||||
&& corpus_cache.ready
|
||||
&& seeded
|
||||
&& namespace_seed_recorded;
|
||||
|
||||
if !query_ready {
|
||||
notes.push("Run `cargo eval --warm --slice <id>` to prepare corpus and namespace.".into());
|
||||
}
|
||||
|
||||
Ok(EvalStatus {
|
||||
dataset: config.dataset.id().to_string(),
|
||||
slice: config.slice.clone(),
|
||||
converted: ConvertedStatus {
|
||||
layout: layout_label.to_string(),
|
||||
path: display_path,
|
||||
ready: converted_ready,
|
||||
partial_load_eligible,
|
||||
checksum,
|
||||
},
|
||||
slice_ledger,
|
||||
corpus_cache,
|
||||
namespace: NamespaceStatus {
|
||||
namespace,
|
||||
database,
|
||||
seeded,
|
||||
namespace_seed_recorded,
|
||||
},
|
||||
query_ready,
|
||||
notes,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn print_status(status: &EvalStatus) {
|
||||
println!("Evaluation status for dataset `{}`", status.dataset);
|
||||
if let Some(slice) = &status.slice {
|
||||
println!("Slice: {slice}");
|
||||
}
|
||||
println!(
|
||||
"Converted: {} ({})",
|
||||
if status.converted.ready {
|
||||
"ready"
|
||||
} else {
|
||||
"missing"
|
||||
},
|
||||
status.converted.layout
|
||||
);
|
||||
println!("Converted path: {}", status.converted.path);
|
||||
if status.converted.partial_load_eligible {
|
||||
println!("Slice-first loading: eligible");
|
||||
}
|
||||
println!(
|
||||
"Slice ledger: {}",
|
||||
if status.slice_ledger.ready {
|
||||
format!(
|
||||
"ready ({} cases, {} positives, {} negatives)",
|
||||
status.slice_ledger.cases.unwrap_or(0),
|
||||
status.slice_ledger.positives.unwrap_or(0),
|
||||
status.slice_ledger.negatives.unwrap_or(0)
|
||||
)
|
||||
} else {
|
||||
"missing or incomplete".to_string()
|
||||
}
|
||||
);
|
||||
if let Some(path) = &status.slice_ledger.path {
|
||||
println!("Slice ledger path: {path}");
|
||||
}
|
||||
println!(
|
||||
"Corpus cache: {}",
|
||||
if status.corpus_cache.ready {
|
||||
"ready"
|
||||
} else {
|
||||
"missing"
|
||||
}
|
||||
);
|
||||
if let Some(path) = &status.corpus_cache.path {
|
||||
println!("Corpus cache path: {path}");
|
||||
}
|
||||
println!(
|
||||
"Namespace `{}` / `{}`: seeded={}, namespace_seed_recorded={}",
|
||||
status.namespace.namespace,
|
||||
status.namespace.database,
|
||||
status.namespace.seeded,
|
||||
status.namespace.namespace_seed_recorded
|
||||
);
|
||||
println!(
|
||||
"Query-ready: {}",
|
||||
if status.query_ready { "yes" } else { "no" }
|
||||
);
|
||||
for note in &status.notes {
|
||||
println!("Note: {note}");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn warm(config: &Config) -> Result<()> {
|
||||
let loaded =
|
||||
crate::datasets::prepare_dataset(config.dataset, config).context("preparing dataset")?;
|
||||
crate::pipeline::warm_evaluation(&loaded.dataset, config, &loaded.content_checksum)
|
||||
.await
|
||||
.context("warming evaluation corpus and namespace")?;
|
||||
let status = collect_status(config).await?;
|
||||
print_status(&status);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn ensure_query_ready(config: &Config) -> Result<()> {
|
||||
let status = collect_status(config).await?;
|
||||
if status.query_ready {
|
||||
return Ok(());
|
||||
}
|
||||
print_status(&status);
|
||||
anyhow::bail!(
|
||||
"evaluation is not query-ready; run `cargo eval --warm --slice {}` first",
|
||||
config.slice.as_deref().unwrap_or("<slice-id>")
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
#![allow(clippy::arithmetic_side_effects)]
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
|
||||
use crate::types::EvaluationCandidate;
|
||||
|
||||
const TOKENIZER_LABEL: &str = "estimated (~chars/4; ingestion uses bert-base-cased)";
|
||||
|
||||
#[allow(clippy::struct_field_names)]
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct RetrievedContextStats {
|
||||
pub chunk_count: usize,
|
||||
pub char_count: usize,
|
||||
pub token_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct RetrievalContextStats {
|
||||
pub tokenizer: String,
|
||||
pub queries: usize,
|
||||
pub total_chunks: usize,
|
||||
pub total_chars: usize,
|
||||
pub total_tokens: usize,
|
||||
pub avg_chunks_per_query: f64,
|
||||
pub avg_chars_per_query: f64,
|
||||
pub avg_tokens_per_query: f64,
|
||||
pub p50_tokens_per_query: usize,
|
||||
pub p95_tokens_per_query: usize,
|
||||
pub max_tokens_per_query: usize,
|
||||
}
|
||||
|
||||
pub fn stats_for_candidates(candidates: &[EvaluationCandidate]) -> RetrievedContextStats {
|
||||
let mut seen_chunk_ids = std::collections::HashSet::new();
|
||||
let mut stats = RetrievedContextStats::default();
|
||||
|
||||
for candidate in candidates {
|
||||
for chunk in &candidate.chunks {
|
||||
let chunk_id = chunk.chunk.id().to_string();
|
||||
if !seen_chunk_ids.insert(chunk_id) {
|
||||
continue;
|
||||
}
|
||||
let text = chunk.chunk.chunk.as_str();
|
||||
stats.chunk_count += 1;
|
||||
stats.char_count += text.chars().count();
|
||||
stats.token_count += estimate_ingestion_tokens(text);
|
||||
}
|
||||
}
|
||||
|
||||
stats
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
pub fn aggregate_context_stats(per_query: &[RetrievedContextStats]) -> RetrievalContextStats {
|
||||
let queries = per_query.len();
|
||||
if queries == 0 {
|
||||
return RetrievalContextStats {
|
||||
tokenizer: TOKENIZER_LABEL.to_string(),
|
||||
queries: 0,
|
||||
total_chunks: 0,
|
||||
total_chars: 0,
|
||||
total_tokens: 0,
|
||||
avg_chunks_per_query: 0.0,
|
||||
avg_chars_per_query: 0.0,
|
||||
avg_tokens_per_query: 0.0,
|
||||
p50_tokens_per_query: 0,
|
||||
p95_tokens_per_query: 0,
|
||||
max_tokens_per_query: 0,
|
||||
};
|
||||
}
|
||||
|
||||
let total_chunks: usize = per_query.iter().map(|stats| stats.chunk_count).sum();
|
||||
let total_chars: usize = per_query.iter().map(|stats| stats.char_count).sum();
|
||||
let total_tokens: usize = per_query.iter().map(|stats| stats.token_count).sum();
|
||||
let mut tokens_per_query: Vec<usize> =
|
||||
per_query.iter().map(|stats| stats.token_count).collect();
|
||||
tokens_per_query.sort_unstable();
|
||||
let max_tokens_per_query = *tokens_per_query.last().unwrap_or(&0);
|
||||
|
||||
let total_chunks_f = total_chunks as f64;
|
||||
let total_chars_f = total_chars as f64;
|
||||
let total_tokens_f = total_tokens as f64;
|
||||
let queries_f = queries as f64;
|
||||
let avg_chunks_per_query = total_chunks_f / queries_f;
|
||||
let avg_chars_per_query = total_chars_f / queries_f;
|
||||
let avg_tokens_per_query = total_tokens_f / queries_f;
|
||||
|
||||
RetrievalContextStats {
|
||||
tokenizer: TOKENIZER_LABEL.to_string(),
|
||||
queries,
|
||||
total_chunks,
|
||||
total_chars,
|
||||
total_tokens,
|
||||
avg_chunks_per_query,
|
||||
avg_chars_per_query,
|
||||
avg_tokens_per_query,
|
||||
p50_tokens_per_query: percentile_usize(&tokens_per_query, 0.50),
|
||||
p95_tokens_per_query: percentile_usize(&tokens_per_query, 0.95),
|
||||
max_tokens_per_query,
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_ingestion_tokens(text: &str) -> usize {
|
||||
let chars = text.chars().count();
|
||||
if chars == 0 {
|
||||
return 0;
|
||||
}
|
||||
chars.div_ceil(4)
|
||||
}
|
||||
|
||||
#[allow(
|
||||
clippy::cast_precision_loss,
|
||||
clippy::cast_sign_loss,
|
||||
clippy::cast_possible_truncation,
|
||||
clippy::indexing_slicing,
|
||||
clippy::arithmetic_side_effects
|
||||
)]
|
||||
fn percentile_usize(sorted: &[usize], fraction: f64) -> usize {
|
||||
if sorted.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let clamped = fraction.clamp(0.0, 1.0);
|
||||
let index = ((sorted.len() - 1) as f64 * clamped).round() as usize;
|
||||
sorted[index.min(sorted.len() - 1)]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use common::storage::types::text_chunk::TextChunk;
|
||||
use retrieval_pipeline::RetrievedChunk;
|
||||
|
||||
#[test]
|
||||
fn deduplicates_chunks_when_counting_context() {
|
||||
let shared = Arc::new(TextChunk::new(
|
||||
"src".into(),
|
||||
"hello world".into(),
|
||||
"user".into(),
|
||||
));
|
||||
let candidates = vec![
|
||||
EvaluationCandidate {
|
||||
entity_id: "a".into(),
|
||||
source_id: "src".into(),
|
||||
entity_name: "A".into(),
|
||||
entity_description: None,
|
||||
entity_category: None,
|
||||
score: 1.0,
|
||||
chunks: vec![RetrievedChunk {
|
||||
chunk: Arc::clone(&shared),
|
||||
score: 1.0,
|
||||
}],
|
||||
},
|
||||
EvaluationCandidate {
|
||||
entity_id: "b".into(),
|
||||
source_id: "src".into(),
|
||||
entity_name: "B".into(),
|
||||
entity_description: None,
|
||||
entity_category: None,
|
||||
score: 0.9,
|
||||
chunks: vec![RetrievedChunk {
|
||||
chunk: shared,
|
||||
score: 0.9,
|
||||
}],
|
||||
},
|
||||
];
|
||||
let stats = stats_for_candidates(&candidates);
|
||||
assert_eq!(stats.chunk_count, 1);
|
||||
assert_eq!(stats.char_count, "hello world".chars().count());
|
||||
assert_eq!(stats.token_count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aggregates_per_query_token_totals() {
|
||||
let per_query = vec![
|
||||
RetrievedContextStats {
|
||||
chunk_count: 2,
|
||||
char_count: 100,
|
||||
token_count: 40,
|
||||
},
|
||||
RetrievedContextStats {
|
||||
chunk_count: 5,
|
||||
char_count: 250,
|
||||
token_count: 100,
|
||||
},
|
||||
];
|
||||
let aggregate = aggregate_context_stats(&per_query);
|
||||
assert_eq!(aggregate.queries, 2);
|
||||
assert_eq!(aggregate.total_chunks, 7);
|
||||
assert_eq!(aggregate.total_tokens, 140);
|
||||
assert_eq!(aggregate.max_tokens_per_query, 100);
|
||||
assert!((aggregate.avg_tokens_per_query - 70.0).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -11,32 +11,14 @@ pub struct CorpusCacheConfig {
|
||||
pub ingestion_max_retries: usize,
|
||||
}
|
||||
|
||||
impl CorpusCacheConfig {
|
||||
pub fn new(
|
||||
ingestion_cache_dir: impl Into<PathBuf>,
|
||||
force_refresh: bool,
|
||||
refresh_embeddings_only: bool,
|
||||
ingestion_batch_size: usize,
|
||||
ingestion_max_retries: usize,
|
||||
) -> Self {
|
||||
impl From<&Config> for CorpusCacheConfig {
|
||||
fn from(config: &Config) -> Self {
|
||||
Self {
|
||||
ingestion_cache_dir: ingestion_cache_dir.into(),
|
||||
force_refresh,
|
||||
refresh_embeddings_only,
|
||||
ingestion_batch_size,
|
||||
ingestion_max_retries,
|
||||
ingestion_cache_dir: config.ingest.ingestion_cache_dir.clone(),
|
||||
force_refresh: config.force_convert || config.ingest.slice_reset_ingestion,
|
||||
refresh_embeddings_only: config.ingest.refresh_embeddings_only,
|
||||
ingestion_batch_size: config.ingest.ingestion_batch_size,
|
||||
ingestion_max_retries: config.ingest.ingestion_max_retries,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Config> for CorpusCacheConfig {
|
||||
fn from(config: &Config) -> Self {
|
||||
CorpusCacheConfig::new(
|
||||
config.ingest.ingestion_cache_dir.clone(),
|
||||
config.force_convert || config.ingest.slice_reset_ingestion,
|
||||
config.ingest.refresh_embeddings_only,
|
||||
config.ingest.ingestion_batch_size,
|
||||
config.ingest.ingestion_max_retries,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,11 @@ pub(crate) mod store;
|
||||
pub use config::CorpusCacheConfig;
|
||||
pub use orchestrator::{
|
||||
cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus,
|
||||
load_cached_manifest,
|
||||
load_cached_manifest, persist_corpus_manifest,
|
||||
};
|
||||
pub use store::{
|
||||
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
|
||||
CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||
CorpusQuestion, NamespaceSeedRecord, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||
};
|
||||
|
||||
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
|
||||
@@ -20,6 +20,6 @@ pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline
|
||||
chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
||||
..Default::default()
|
||||
},
|
||||
chunk_only: config.ingest.ingest_chunks_only,
|
||||
chunk_only: !config.ingest.include_entities,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,6 @@ use std::{
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::Client;
|
||||
use chrono::Utc;
|
||||
#[cfg(not(test))]
|
||||
use common::utils::config::get_config;
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
@@ -125,10 +123,14 @@ pub async fn ensure_corpus(
|
||||
openai: Arc<OpenAIClient>,
|
||||
user_id: &str,
|
||||
converted_path: &Path,
|
||||
precomputed_checksum: Option<&str>,
|
||||
ingestion_config: IngestionConfig,
|
||||
) -> Result<CorpusHandle> {
|
||||
let checksum = compute_file_checksum(converted_path)
|
||||
.with_context(|| format!("computing checksum for {}", converted_path.display()))?;
|
||||
let checksum = match precomputed_checksum {
|
||||
Some(value) => value.to_string(),
|
||||
None => crate::datasets::content_checksum_for_layout(converted_path)
|
||||
.with_context(|| format!("computing checksum for {}", converted_path.display()))?,
|
||||
};
|
||||
let ingestion_fingerprint =
|
||||
build_ingestion_fingerprint(dataset, slice, &checksum, &ingestion_config);
|
||||
|
||||
@@ -381,6 +383,7 @@ pub async fn ensure_corpus(
|
||||
chunk_min_tokens: ingestion_config.tuning.chunk_min_tokens,
|
||||
chunk_max_tokens: ingestion_config.tuning.chunk_max_tokens,
|
||||
chunk_only: ingestion_config.chunk_only,
|
||||
namespace_seed: None,
|
||||
},
|
||||
paragraphs: corpus_paragraphs,
|
||||
questions: corpus_questions,
|
||||
@@ -415,7 +418,7 @@ pub async fn ensure_corpus(
|
||||
negative_ingested: stats.negative_ingested,
|
||||
};
|
||||
|
||||
persist_manifest(&handle).context("persisting corpus manifest")?;
|
||||
persist_corpus_manifest(&handle).context("persisting corpus manifest")?;
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
@@ -501,7 +504,6 @@ async fn ingest_paragraph_batch(
|
||||
Ok(shards)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||
let db = SurrealDbClient::memory(namespace, "corpus")
|
||||
.await
|
||||
@@ -509,21 +511,6 @@ async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||
Ok(Arc::new(db))
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||
let config = get_config().context("loading app config for ingestion database")?;
|
||||
let db = SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
namespace,
|
||||
"corpus",
|
||||
)
|
||||
.await
|
||||
.context("creating surrealdb database for ingestion")?;
|
||||
Ok(Arc::new(db))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn ingest_single_paragraph(
|
||||
pipeline: Arc<IngestionPipeline>,
|
||||
@@ -631,8 +618,12 @@ pub fn compute_ingestion_fingerprint(
|
||||
slice: &ResolvedSlice<'_>,
|
||||
converted_path: &Path,
|
||||
ingestion_config: &IngestionConfig,
|
||||
precomputed_checksum: Option<&str>,
|
||||
) -> Result<String> {
|
||||
let checksum = compute_file_checksum(converted_path)?;
|
||||
let checksum = match precomputed_checksum {
|
||||
Some(value) => value.to_string(),
|
||||
None => crate::datasets::content_checksum_for_layout(converted_path)?,
|
||||
};
|
||||
Ok(build_ingestion_fingerprint(
|
||||
dataset,
|
||||
slice,
|
||||
@@ -641,7 +632,7 @@ pub fn compute_ingestion_fingerprint(
|
||||
))
|
||||
}
|
||||
|
||||
pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
|
||||
pub fn load_cached_manifest(base_dir: &std::path::Path) -> Result<Option<CorpusManifest>> {
|
||||
let path = base_dir.join("manifest.json");
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
@@ -656,7 +647,7 @@ pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
|
||||
Ok(Some(manifest))
|
||||
}
|
||||
|
||||
fn persist_manifest(handle: &CorpusHandle) -> Result<()> {
|
||||
pub fn persist_corpus_manifest(handle: &CorpusHandle) -> Result<()> {
|
||||
let path = handle.path.join("manifest.json");
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
@@ -685,24 +676,6 @@ pub fn corpus_handle_from_manifest(manifest: CorpusManifest, base_dir: PathBuf)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
fn compute_file_checksum(path: &Path) -> Result<String> {
|
||||
let mut file = fs::File::open(path)
|
||||
.with_context(|| format!("opening file {} for checksum", path.display()))?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buffer = [0u8; 8192];
|
||||
loop {
|
||||
let read = file
|
||||
.read(&mut buffer)
|
||||
.with_context(|| format!("reading {} for checksum", path.display()))?;
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buffer[..read]);
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -728,11 +701,7 @@ mod tests {
|
||||
|
||||
ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata: crate::datasets::DatasetMetadata::for_kind(
|
||||
DatasetKind::default(),
|
||||
false,
|
||||
None,
|
||||
),
|
||||
metadata: crate::datasets::DatasetMetadata::for_kind(DatasetKind::default(), false),
|
||||
source: "src".to_string(),
|
||||
paragraphs: vec![paragraph],
|
||||
}
|
||||
|
||||
+36
-287
@@ -7,33 +7,21 @@ use std::{
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use common::storage::types::StoredObject;
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
knowledge_relationship::{KnowledgeRelationship, RelationshipMetadata},
|
||||
text_chunk::TextChunk,
|
||||
text_chunk_embedding::TextChunkEmbedding,
|
||||
text_content::TextContent,
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk, text_content::TextContent, StoredObject,
|
||||
},
|
||||
};
|
||||
use ingestion_pipeline::{persist_artifacts, IngestionTuning, PipelineArtifacts};
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use surrealdb::sql::Thing;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
||||
|
||||
pub const MANIFEST_VERSION: u32 = 3;
|
||||
pub const PARAGRAPH_SHARD_VERSION: u32 = 3;
|
||||
const MANIFEST_BATCH_SIZE: usize = 100;
|
||||
const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches
|
||||
const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively
|
||||
const MAX_BATCHES_PER_REQUEST: usize = 24;
|
||||
const REQUEST_MAX_BYTES: usize = 800_000; // total payload cap per Surreal query request
|
||||
|
||||
fn current_manifest_version() -> u32 {
|
||||
MANIFEST_VERSION
|
||||
}
|
||||
@@ -51,7 +39,7 @@ fn default_chunk_max_tokens() -> usize {
|
||||
}
|
||||
|
||||
fn default_chunk_only() -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
|
||||
// Reuse the pipeline's canonical embedded-artifact types so the on-disk corpus
|
||||
@@ -131,6 +119,14 @@ pub struct CorpusManifest {
|
||||
pub questions: Vec<CorpusQuestion>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct NamespaceSeedRecord {
|
||||
pub namespace: String,
|
||||
pub database: String,
|
||||
pub slice_case_count: usize,
|
||||
pub seeded_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusMetadata {
|
||||
pub dataset_id: String,
|
||||
@@ -153,6 +149,8 @@ pub struct CorpusMetadata {
|
||||
pub chunk_max_tokens: usize,
|
||||
#[serde(default = "default_chunk_only")]
|
||||
pub chunk_only: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub namespace_seed: Option<NamespaceSeedRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
@@ -251,130 +249,6 @@ pub fn window_manifest(
|
||||
Ok(narrowed)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct RelationInsert {
|
||||
#[serde(rename = "in")]
|
||||
pub in_: Thing,
|
||||
#[serde(rename = "out")]
|
||||
pub out: Thing,
|
||||
pub id: String,
|
||||
pub metadata: RelationshipMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SizedBatch<T> {
|
||||
approx_bytes: usize,
|
||||
items: Vec<T>,
|
||||
}
|
||||
|
||||
struct ManifestBatches {
|
||||
text_contents: Vec<SizedBatch<TextContent>>,
|
||||
entities: Vec<SizedBatch<KnowledgeEntity>>,
|
||||
entity_embeddings: Vec<SizedBatch<KnowledgeEntityEmbedding>>,
|
||||
relationships: Vec<SizedBatch<RelationInsert>>,
|
||||
chunks: Vec<SizedBatch<TextChunk>>,
|
||||
chunk_embeddings: Vec<SizedBatch<TextChunkEmbedding>>,
|
||||
}
|
||||
|
||||
fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches> {
|
||||
let mut text_contents = Vec::new();
|
||||
let mut entities = Vec::new();
|
||||
let mut entity_embeddings = Vec::new();
|
||||
let mut relationships = Vec::new();
|
||||
let mut chunks = Vec::new();
|
||||
let mut chunk_embeddings = Vec::new();
|
||||
|
||||
let mut seen_text_content = HashSet::new();
|
||||
let mut seen_entities = HashSet::new();
|
||||
let mut seen_relationships = HashSet::new();
|
||||
let mut seen_chunks = HashSet::new();
|
||||
|
||||
for paragraph in &manifest.paragraphs {
|
||||
if seen_text_content.insert(paragraph.text_content.id.clone()) {
|
||||
text_contents.push(paragraph.text_content.clone());
|
||||
}
|
||||
|
||||
for embedded_entity in ¶graph.entities {
|
||||
if seen_entities.insert(embedded_entity.entity.id.clone()) {
|
||||
let entity = embedded_entity.entity.clone();
|
||||
entities.push(entity.clone());
|
||||
entity_embeddings.push(KnowledgeEntityEmbedding::new(
|
||||
&entity.id,
|
||||
entity.source_id.clone(),
|
||||
embedded_entity.embedding.clone(),
|
||||
entity.user_id.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
for relationship in ¶graph.relationships {
|
||||
if seen_relationships.insert(relationship.id.clone()) {
|
||||
let table = KnowledgeEntity::table_name();
|
||||
let in_id = relationship
|
||||
.in_
|
||||
.strip_prefix(&format!("{table}:"))
|
||||
.unwrap_or(&relationship.in_);
|
||||
let out_id = relationship
|
||||
.out
|
||||
.strip_prefix(&format!("{table}:"))
|
||||
.unwrap_or(&relationship.out);
|
||||
let in_thing = Thing::from((table, in_id));
|
||||
let out_thing = Thing::from((table, out_id));
|
||||
relationships.push(RelationInsert {
|
||||
in_: in_thing,
|
||||
out: out_thing,
|
||||
id: relationship.id.clone(),
|
||||
metadata: relationship.metadata.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for embedded_chunk in ¶graph.chunks {
|
||||
if seen_chunks.insert(embedded_chunk.chunk.id.clone()) {
|
||||
let chunk = embedded_chunk.chunk.clone();
|
||||
chunks.push(chunk.clone());
|
||||
chunk_embeddings.push(TextChunkEmbedding::new(
|
||||
&chunk.id,
|
||||
chunk.source_id.clone(),
|
||||
embedded_chunk.embedding.clone(),
|
||||
chunk.user_id.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ManifestBatches {
|
||||
text_contents: chunk_items(
|
||||
&text_contents,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
TEXT_CONTENT_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking text_content payloads")?,
|
||||
entities: chunk_items(&entities, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
||||
.context("chunking knowledge_entity payloads")?,
|
||||
entity_embeddings: chunk_items(
|
||||
&entity_embeddings,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking knowledge_entity_embedding payloads")?,
|
||||
relationships: chunk_items(
|
||||
&relationships,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking relationship payloads")?,
|
||||
chunks: chunk_items(&chunks, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
||||
.context("chunking text_chunk payloads")?,
|
||||
chunk_embeddings: chunk_items(
|
||||
&chunk_embeddings,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking text_chunk_embedding payloads")?,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ParagraphShard {
|
||||
#[serde(default = "current_paragraph_shard_version")]
|
||||
@@ -599,157 +473,28 @@ fn normalize_answer_text(text: &str) -> String {
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
||||
fn chunk_items<T: Clone + Serialize>(
|
||||
items: &[T],
|
||||
max_items: usize,
|
||||
max_bytes: usize,
|
||||
) -> Result<Vec<SizedBatch<T>>> {
|
||||
if items.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut batches = Vec::new();
|
||||
let mut current = Vec::new();
|
||||
let mut current_bytes = 0usize;
|
||||
|
||||
for item in items {
|
||||
let size = serde_json::to_vec(item)
|
||||
.map(|buf| buf.len())
|
||||
.context("serialising batch item for sizing")?;
|
||||
|
||||
let would_overflow_items = !current.is_empty() && current.len() >= max_items;
|
||||
let would_overflow_bytes = !current.is_empty() && current_bytes + size > max_bytes;
|
||||
|
||||
if would_overflow_items || would_overflow_bytes {
|
||||
batches.push(SizedBatch {
|
||||
approx_bytes: current_bytes.max(1),
|
||||
items: std::mem::take(&mut current),
|
||||
});
|
||||
current_bytes = 0;
|
||||
}
|
||||
|
||||
current_bytes += size;
|
||||
current.push(item.clone());
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
batches.push(SizedBatch {
|
||||
approx_bytes: current_bytes.max(1),
|
||||
items: current,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(batches)
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
||||
async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
|
||||
db: &SurrealDbClient,
|
||||
statement: impl AsRef<str>,
|
||||
prefix: &str,
|
||||
batches: &[SizedBatch<T>],
|
||||
) -> Result<()> {
|
||||
if batches.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut start = 0;
|
||||
while start < batches.len() {
|
||||
let mut group_bytes = 0usize;
|
||||
let mut group_end = start;
|
||||
let mut group_count = 0usize;
|
||||
|
||||
while group_end < batches.len() {
|
||||
let batch_bytes = batches[group_end].approx_bytes.max(1);
|
||||
if group_count > 0
|
||||
&& (group_bytes + batch_bytes > REQUEST_MAX_BYTES
|
||||
|| group_count >= MAX_BATCHES_PER_REQUEST)
|
||||
{
|
||||
break;
|
||||
}
|
||||
group_bytes += batch_bytes;
|
||||
group_end += 1;
|
||||
group_count += 1;
|
||||
}
|
||||
|
||||
let slice = &batches[start..group_end];
|
||||
let mut query = db.client.query("BEGIN TRANSACTION;");
|
||||
for (bind_index, batch) in slice.iter().enumerate() {
|
||||
let name = format!("{prefix}{bind_index}");
|
||||
query = query
|
||||
.query(format!("{} ${};", statement.as_ref(), name))
|
||||
.bind((name, batch.items.clone()));
|
||||
}
|
||||
let response = query
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.await
|
||||
.context("executing batched insert transaction")?;
|
||||
if let Err(err) = response.check() {
|
||||
return Err(anyhow!(
|
||||
"batched insert failed for statement '{}': {err:?}",
|
||||
statement.as_ref()
|
||||
));
|
||||
}
|
||||
|
||||
start = group_end;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
||||
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
|
||||
let tuning = IngestionTuning::default();
|
||||
let embedding_dimensions = manifest.metadata.embedding_dimension;
|
||||
let mut seen_text_content = HashSet::new();
|
||||
|
||||
let result = async {
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextContent::table_name()),
|
||||
"tc",
|
||||
&batches.text_contents,
|
||||
)
|
||||
.await?;
|
||||
for paragraph in &manifest.paragraphs {
|
||||
if !seen_text_content.insert(paragraph.text_content.id.clone()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", KnowledgeEntity::table_name()),
|
||||
"ke",
|
||||
&batches.entities,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextChunk::table_name()),
|
||||
"ch",
|
||||
&batches.chunks,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
"INSERT RELATION INTO relates_to",
|
||||
"rel",
|
||||
&batches.relationships,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", KnowledgeEntityEmbedding::table_name()),
|
||||
"kee",
|
||||
&batches.entity_embeddings,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextChunkEmbedding::table_name()),
|
||||
"tce",
|
||||
&batches.chunk_embeddings,
|
||||
)
|
||||
.await?;
|
||||
let artifacts = PipelineArtifacts {
|
||||
text_content: paragraph.text_content.clone(),
|
||||
entities: paragraph.entities.clone(),
|
||||
relationships: paragraph.relationships.clone(),
|
||||
chunks: paragraph.chunks.clone(),
|
||||
};
|
||||
|
||||
persist_artifacts(db, &tuning, embedding_dimensions, artifacts)
|
||||
.await
|
||||
.map_err(|err| anyhow!("persist manifest paragraph: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
@@ -778,7 +523,10 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
text_chunk::TextChunk,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
@@ -888,6 +636,7 @@ mod tests {
|
||||
chunk_min_tokens: 1,
|
||||
chunk_max_tokens: 10,
|
||||
chunk_only: false,
|
||||
namespace_seed: None,
|
||||
},
|
||||
paragraphs: vec![paragraph_one, paragraph_two],
|
||||
questions: vec![question],
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
fs::File,
|
||||
io::{BufRead, BufReader},
|
||||
path::{Path, PathBuf},
|
||||
@@ -47,20 +47,71 @@ struct QrelEntry {
|
||||
score: i32,
|
||||
}
|
||||
|
||||
/// Convert only documents that appear in qrels (the BEIR evaluation closed world).
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
||||
pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<ConvertedParagraph>> {
|
||||
convert_beir_documents(raw_dir, dataset, None)
|
||||
}
|
||||
|
||||
/// Convert a subset of qrels-world documents. `doc_ids` use corpus ids (unprefixed).
|
||||
#[allow(
|
||||
clippy::too_many_lines,
|
||||
clippy::arithmetic_side_effects,
|
||||
clippy::indexing_slicing
|
||||
)]
|
||||
pub fn convert_beir_documents(
|
||||
raw_dir: &Path,
|
||||
dataset: DatasetKind,
|
||||
doc_ids: Option<&HashSet<String>>,
|
||||
) -> Result<Vec<ConvertedParagraph>> {
|
||||
let corpus_path = raw_dir.join("corpus.jsonl");
|
||||
let queries_path = raw_dir.join("queries.jsonl");
|
||||
let qrels_path = resolve_qrels_path(raw_dir)?;
|
||||
|
||||
let corpus = load_corpus(&corpus_path)?;
|
||||
let queries = load_queries(&queries_path)?;
|
||||
let qrels = load_qrels(&qrels_path)?;
|
||||
|
||||
let mut paragraphs = Vec::with_capacity(corpus.len());
|
||||
let mut qrels_doc_ids = HashSet::new();
|
||||
for entries in qrels.values() {
|
||||
for entry in entries {
|
||||
qrels_doc_ids.insert(entry.doc_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let target_doc_ids: HashSet<String> = match doc_ids {
|
||||
Some(ids) => ids
|
||||
.iter()
|
||||
.filter(|id| qrels_doc_ids.contains(*id))
|
||||
.cloned()
|
||||
.collect(),
|
||||
None => qrels_doc_ids.clone(),
|
||||
};
|
||||
|
||||
if target_doc_ids.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no qrels documents to convert for {} at {}",
|
||||
dataset.id(),
|
||||
raw_dir.display()
|
||||
));
|
||||
}
|
||||
|
||||
let corpus = load_corpus_filtered(&corpus_path, &target_doc_ids)?;
|
||||
|
||||
let mut doc_ids_sorted: Vec<String> = target_doc_ids.into_iter().collect();
|
||||
doc_ids_sorted.sort();
|
||||
|
||||
let mut paragraphs = Vec::with_capacity(doc_ids_sorted.len());
|
||||
let mut paragraph_index = HashMap::new();
|
||||
|
||||
for (doc_id, entry) in &corpus {
|
||||
for doc_id in &doc_ids_sorted {
|
||||
let Some(entry) = corpus.get(doc_id) else {
|
||||
warn!(
|
||||
doc_id = %doc_id,
|
||||
dataset = %dataset.id(),
|
||||
"Skipping qrels document missing from corpus"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let paragraph_id = format!("{}-{doc_id}", dataset.source_prefix());
|
||||
let paragraph = ConvertedParagraph {
|
||||
id: paragraph_id.clone(),
|
||||
@@ -87,6 +138,12 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some(filter) = doc_ids {
|
||||
if !filter.contains(&best.doc_id) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let Some(¶graph_slot) = paragraph_index.get(&best.doc_id) else {
|
||||
missing_docs += 1;
|
||||
warn!(
|
||||
@@ -106,7 +163,6 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let answers = vec![snippet];
|
||||
|
||||
let question_id = format!("{}-{query_id}", dataset.source_prefix());
|
||||
paragraphs[paragraph_slot]
|
||||
@@ -114,7 +170,7 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
||||
.push(ConvertedQuestion {
|
||||
id: question_id,
|
||||
question: query.text.clone(),
|
||||
answers,
|
||||
answers: vec![snippet],
|
||||
is_impossible: false,
|
||||
});
|
||||
}
|
||||
@@ -122,13 +178,21 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
||||
if missing_queries + missing_docs + skipped_answers > 0 {
|
||||
warn!(
|
||||
missing_queries,
|
||||
missing_docs, skipped_answers, "Skipped some BEIR qrels entries during conversion"
|
||||
missing_docs,
|
||||
skipped_answers,
|
||||
dataset = %dataset.id(),
|
||||
"Skipped some BEIR qrels entries during conversion"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(paragraphs)
|
||||
}
|
||||
|
||||
pub fn corpus_doc_id(paragraph_id: &str, dataset: DatasetKind) -> Option<String> {
|
||||
let prefix = format!("{}-", dataset.source_prefix());
|
||||
paragraph_id.strip_prefix(&prefix).map(str::to_string)
|
||||
}
|
||||
|
||||
fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
|
||||
let qrels_dir = raw_dir.join("qrels");
|
||||
let candidates = ["test.tsv", "dev.tsv", "train.tsv"];
|
||||
@@ -148,7 +212,10 @@ fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects)]
|
||||
fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
|
||||
fn load_corpus_filtered(
|
||||
path: &Path,
|
||||
doc_ids: &HashSet<String>,
|
||||
) -> Result<BTreeMap<String, BeirParagraph>> {
|
||||
let file =
|
||||
File::open(path).with_context(|| format!("opening BEIR corpus at {}", path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
@@ -167,6 +234,9 @@ fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
if !doc_ids.contains(&corpus_row.id) {
|
||||
continue;
|
||||
}
|
||||
let title = corpus_row.title.unwrap_or_else(|| corpus_row.id.clone());
|
||||
let text = corpus_row.text.unwrap_or_default();
|
||||
let context = build_context(&title, &text);
|
||||
@@ -296,10 +366,8 @@ mod tests {
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||
fn converts_basic_beir_layout() {
|
||||
let dir = tempdir().unwrap();
|
||||
#[allow(clippy::unwrap_used)]
|
||||
fn write_fixture(dir: &tempfile::TempDir) {
|
||||
let corpus = r#"
|
||||
{"_id":"d1","title":"Doc 1","text":"Doc one has some text for testing."}
|
||||
{"_id":"d2","title":"Doc 2","text":"Second document content."}
|
||||
@@ -313,24 +381,34 @@ mod tests {
|
||||
fs::write(dir.path().join("queries.jsonl"), queries.trim()).unwrap();
|
||||
fs::create_dir_all(dir.path().join("qrels")).unwrap();
|
||||
fs::write(dir.path().join("qrels/test.tsv"), qrels).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||
fn converts_qrels_world_only() {
|
||||
let dir = tempdir().unwrap();
|
||||
write_fixture(&dir);
|
||||
|
||||
let paragraphs = convert_beir(dir.path(), DatasetKind::Fever).unwrap();
|
||||
|
||||
assert_eq!(paragraphs.len(), 2);
|
||||
let doc_one = paragraphs
|
||||
.iter()
|
||||
.find(|p| p.id == "fever-d1")
|
||||
.expect("missing paragraph for d1");
|
||||
assert_eq!(paragraphs.len(), 1);
|
||||
let doc_one = ¶graphs[0];
|
||||
assert_eq!(doc_one.id, "fever-d1");
|
||||
assert_eq!(doc_one.questions.len(), 1);
|
||||
let question = &doc_one.questions[0];
|
||||
assert_eq!(question.id, "fever-q1");
|
||||
assert!(!question.answers.is_empty());
|
||||
assert!(doc_one.context.contains(&question.answers[0]));
|
||||
assert_eq!(doc_one.questions[0].id, "fever-q1");
|
||||
}
|
||||
|
||||
let doc_two = paragraphs
|
||||
.iter()
|
||||
.find(|p| p.id == "fever-d2")
|
||||
.expect("missing paragraph for d2");
|
||||
assert!(doc_two.questions.is_empty());
|
||||
#[test]
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||
fn converts_filtered_doc_ids() {
|
||||
let dir = tempdir().unwrap();
|
||||
write_fixture(&dir);
|
||||
|
||||
let mut ids = HashSet::new();
|
||||
ids.insert("d1".to_string());
|
||||
let paragraphs =
|
||||
convert_beir_documents(dir.path(), DatasetKind::Fever, Some(&ids)).unwrap();
|
||||
assert_eq!(paragraphs.len(), 1);
|
||||
assert_eq!(paragraphs[0].id, "fever-d1");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::info;
|
||||
|
||||
use super::{
|
||||
beir,
|
||||
checksum::hash_file,
|
||||
store::{
|
||||
self, build_dataset_from_catalog, paragraph_path, read_meta, store_dir_for,
|
||||
upsert_sharded_paragraphs, write_sharded,
|
||||
},
|
||||
ConvertedDataset, DatasetKind, DatasetMetadata, BEIR_DATASETS,
|
||||
};
|
||||
use crate::{args::Config, slice};
|
||||
|
||||
pub fn subset_for_paragraph_id(paragraph_id: &str) -> Option<DatasetKind> {
|
||||
let mut kinds: Vec<DatasetKind> = BEIR_DATASETS.to_vec();
|
||||
kinds.sort_by_key(|kind| std::cmp::Reverse(kind.source_prefix().len()));
|
||||
for kind in kinds {
|
||||
let prefix = format!("{}-", kind.source_prefix());
|
||||
if paragraph_id.starts_with(&prefix) {
|
||||
return Some(kind);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn build_beir_mix_qrels_dataset(include_unanswerable: bool) -> Result<ConvertedDataset> {
|
||||
if include_unanswerable {
|
||||
tracing::warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
|
||||
}
|
||||
|
||||
let mut paragraphs = Vec::new();
|
||||
for subset in BEIR_DATASETS {
|
||||
let entry = super::dataset_entry_for_kind(subset)?;
|
||||
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
|
||||
paragraphs.extend(subset_paragraphs);
|
||||
}
|
||||
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: super::base_timestamp(),
|
||||
metadata: DatasetMetadata::for_kind(DatasetKind::Beir, include_unanswerable),
|
||||
source: "beir-mix".to_string(),
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn prepare_beir_mix(config: &Config) -> Result<super::loader::LoadedDataset> {
|
||||
let virtual_ds = build_beir_mix_qrels_dataset(config.llm_mode)?;
|
||||
let slice_config = slice::slice_config_with_limit(config, slice::ledger_target(config));
|
||||
let resolved = slice::resolve_slice(&virtual_ds, &slice_config)
|
||||
.context("resolving BEIR mix slice ledger (check --slice and --limit match your intent)")?;
|
||||
|
||||
let unique: HashSet<String> = resolved
|
||||
.manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|entry| entry.id.clone())
|
||||
.collect();
|
||||
|
||||
materialize_subset_stores(&unique, config.force_convert)?;
|
||||
|
||||
let dataset = load_beir_mix_from_subsets(&unique)?;
|
||||
let checksum = mix_content_checksum(&unique)?;
|
||||
|
||||
info!(
|
||||
slice = resolved.manifest.slice_id.as_str(),
|
||||
paragraphs = unique.len(),
|
||||
checksum = %checksum,
|
||||
"Prepared BEIR mix from per-subset converted stores"
|
||||
);
|
||||
|
||||
Ok(super::loader::LoadedDataset {
|
||||
dataset,
|
||||
content_checksum: checksum,
|
||||
partial: true,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn materialize_subset_stores(paragraph_ids: &HashSet<String>, force: bool) -> Result<()> {
|
||||
let mut by_subset: HashMap<DatasetKind, Vec<String>> = HashMap::new();
|
||||
for paragraph_id in paragraph_ids {
|
||||
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||
})?;
|
||||
by_subset
|
||||
.entry(kind)
|
||||
.or_default()
|
||||
.push(paragraph_id.clone());
|
||||
}
|
||||
|
||||
for (kind, ids) in by_subset {
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
let existing = if store_dir.join("meta.json").is_file() {
|
||||
store::load_paragraph_ids_set(&store_dir)?
|
||||
} else {
|
||||
HashSet::new()
|
||||
};
|
||||
|
||||
let missing: Vec<String> = if force {
|
||||
ids
|
||||
} else {
|
||||
ids.into_iter()
|
||||
.filter(|paragraph_id| !existing.contains(paragraph_id))
|
||||
.collect()
|
||||
};
|
||||
|
||||
if missing.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let corpus_ids: HashSet<String> = missing
|
||||
.iter()
|
||||
.filter_map(|paragraph_id| beir::corpus_doc_id(paragraph_id, kind))
|
||||
.collect();
|
||||
let paragraphs = beir::convert_beir_documents(&entry.raw_path, kind, Some(&corpus_ids))?;
|
||||
|
||||
if store_dir.join("meta.json").is_file() {
|
||||
upsert_sharded_paragraphs(&store_dir, ¶graphs)?;
|
||||
} else {
|
||||
let question_count = paragraphs
|
||||
.iter()
|
||||
.map(|paragraph| paragraph.questions.len())
|
||||
.sum::<usize>();
|
||||
let dataset = ConvertedDataset {
|
||||
generated_at: super::base_timestamp(),
|
||||
metadata: DatasetMetadata::for_kind(kind, false),
|
||||
source: entry.raw_path.display().to_string(),
|
||||
paragraphs,
|
||||
};
|
||||
write_sharded(&dataset, &store_dir)?;
|
||||
info!(
|
||||
subset = kind.id(),
|
||||
store = %store_dir.display(),
|
||||
paragraphs = dataset.paragraphs.len(),
|
||||
questions = question_count,
|
||||
"Created subset converted store for BEIR mix"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load_beir_mix_from_subsets(paragraph_ids: &HashSet<String>) -> Result<ConvertedDataset> {
|
||||
let mut by_subset: HashMap<DatasetKind, HashSet<String>> = HashMap::new();
|
||||
for paragraph_id in paragraph_ids {
|
||||
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||
})?;
|
||||
by_subset
|
||||
.entry(kind)
|
||||
.or_default()
|
||||
.insert(paragraph_id.clone());
|
||||
}
|
||||
|
||||
let mut paragraphs = Vec::with_capacity(paragraph_ids.len());
|
||||
for (kind, subset_ids) in by_subset {
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
let partial = build_dataset_from_catalog(&store_dir, &subset_ids)?;
|
||||
paragraphs.extend(partial.paragraphs);
|
||||
}
|
||||
|
||||
paragraphs.sort_by(|left, right| left.id.cmp(&right.id));
|
||||
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: super::base_timestamp(),
|
||||
metadata: DatasetMetadata::for_kind(DatasetKind::Beir, false),
|
||||
source: "beir-mix".to_string(),
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn mix_content_checksum(paragraph_ids: &HashSet<String>) -> Result<String> {
|
||||
let mut ids: Vec<String> = paragraph_ids.iter().cloned().collect();
|
||||
ids.sort();
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
for paragraph_id in ids {
|
||||
let kind = subset_for_paragraph_id(¶graph_id)
|
||||
.ok_or_else(|| anyhow!("unknown BEIR subset for paragraph '{paragraph_id}'"))?;
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
let path = paragraph_path(&store_dir, ¶graph_id);
|
||||
if !path.is_file() {
|
||||
return Err(anyhow!(
|
||||
"missing converted paragraph {} at {}",
|
||||
paragraph_id,
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
hasher.update(paragraph_id.as_bytes());
|
||||
hasher.update([0]);
|
||||
hasher.update(hash_file(&path)?.as_bytes());
|
||||
}
|
||||
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
pub fn beir_subset_stores_ready(paragraph_ids: &HashSet<String>) -> Result<bool> {
|
||||
for paragraph_id in paragraph_ids {
|
||||
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||
})?;
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
if !store_dir.join("meta.json").is_file() {
|
||||
return Ok(false);
|
||||
}
|
||||
if !paragraph_path(&store_dir, paragraph_id).is_file() {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub fn beir_subset_store_summary() -> Result<Vec<(String, usize, usize)>> {
|
||||
let mut summary = Vec::new();
|
||||
for kind in BEIR_DATASETS {
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
if store_dir.join("meta.json").is_file() {
|
||||
let meta = read_meta(&store_dir)?;
|
||||
summary.push((
|
||||
kind.id().to_string(),
|
||||
meta.paragraph_count,
|
||||
meta.question_count,
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(summary)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn routes_prefixed_paragraph_ids() {
|
||||
assert_eq!(
|
||||
subset_for_paragraph_id("fever-doc-1"),
|
||||
Some(DatasetKind::Fever)
|
||||
);
|
||||
assert_eq!(
|
||||
subset_for_paragraph_id("nq-beir-doc-1"),
|
||||
Some(DatasetKind::NqBeir)
|
||||
);
|
||||
assert_eq!(
|
||||
subset_for_paragraph_id("trec-covid-doc-1"),
|
||||
Some(DatasetKind::TrecCovid)
|
||||
);
|
||||
assert!(subset_for_paragraph_id("unknown-doc").is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::Read,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
const SIDECAR_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChecksumSidecar {
|
||||
pub version: u32,
|
||||
pub sha256: String,
|
||||
pub size_bytes: u64,
|
||||
#[serde(default)]
|
||||
pub modified_unix_secs: u64,
|
||||
}
|
||||
|
||||
impl ChecksumSidecar {
|
||||
#[cfg(test)]
|
||||
pub fn sidecar_path(content_path: &Path) -> PathBuf {
|
||||
content_path.with_extension("sha256")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn is_valid_for(&self, content_path: &Path) -> bool {
|
||||
if self.version != SIDECAR_VERSION {
|
||||
return false;
|
||||
}
|
||||
let Ok(metadata) = fs::metadata(content_path) else {
|
||||
return false;
|
||||
};
|
||||
if metadata.len() != self.size_bytes {
|
||||
return false;
|
||||
}
|
||||
if self.modified_unix_secs != 0 {
|
||||
let Ok(modified) = metadata.modified() else {
|
||||
return true;
|
||||
};
|
||||
let Ok(secs) = modified.duration_since(std::time::UNIX_EPOCH) else {
|
||||
return true;
|
||||
};
|
||||
if secs.as_secs() != self.modified_unix_secs {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
pub fn hash_file(path: &Path) -> Result<String> {
|
||||
let mut file = File::open(path)
|
||||
.with_context(|| format!("opening file {} for checksum", path.display()))?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buffer = vec![0u8; 65_536];
|
||||
loop {
|
||||
let read = file
|
||||
.read(&mut buffer)
|
||||
.with_context(|| format!("reading {} for checksum", path.display()))?;
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buffer[..read]);
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
pub fn read_sidecar(path: &Path) -> Result<Option<ChecksumSidecar>> {
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let raw = fs::read_to_string(path)
|
||||
.with_context(|| format!("reading checksum sidecar {}", path.display()))?;
|
||||
let sidecar: ChecksumSidecar = serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing checksum sidecar {}", path.display()))?;
|
||||
Ok(Some(sidecar))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn write_sidecar(content_path: &Path, sha256: &str) -> Result<()> {
|
||||
let metadata = fs::metadata(content_path)
|
||||
.with_context(|| format!("reading metadata for {}", content_path.display()))?;
|
||||
let modified_unix_secs = metadata
|
||||
.modified()
|
||||
.ok()
|
||||
.and_then(|time| time.duration_since(std::time::UNIX_EPOCH).ok())
|
||||
.map_or(0, |duration| duration.as_secs());
|
||||
let sidecar = ChecksumSidecar {
|
||||
version: SIDECAR_VERSION,
|
||||
sha256: sha256.to_string(),
|
||||
size_bytes: metadata.len(),
|
||||
modified_unix_secs,
|
||||
};
|
||||
let path = ChecksumSidecar::sidecar_path(content_path);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating checksum sidecar directory {}", parent.display()))?;
|
||||
}
|
||||
let blob = serde_json::to_vec_pretty(&sidecar).context("serialising checksum sidecar")?;
|
||||
fs::write(&path, blob)
|
||||
.with_context(|| format!("writing checksum sidecar {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn content_checksum(content_path: &Path) -> Result<String> {
|
||||
let sidecar_path = ChecksumSidecar::sidecar_path(content_path);
|
||||
if let Some(sidecar) = read_sidecar(&sidecar_path)? {
|
||||
if sidecar.is_valid_for(content_path) {
|
||||
return Ok(sidecar.sha256);
|
||||
}
|
||||
}
|
||||
let sha256 = hash_file(content_path)?;
|
||||
write_sidecar(content_path, &sha256)?;
|
||||
Ok(sha256)
|
||||
}
|
||||
|
||||
pub fn store_aggregate_checksum(store_dir: &Path) -> Result<String> {
|
||||
let marker = store_dir.join("checksum.sha256");
|
||||
let meta = store_dir.join("meta.json");
|
||||
if marker.is_file() && meta.is_file() {
|
||||
if let (Ok(marker_meta), Ok(meta_meta)) = (marker.metadata(), meta.metadata()) {
|
||||
if marker_meta
|
||||
.modified()
|
||||
.ok()
|
||||
.zip(meta_meta.modified().ok())
|
||||
.is_some_and(|(marker_modified, meta_modified)| marker_modified >= meta_modified)
|
||||
{
|
||||
if let Some(sidecar) = read_sidecar(&marker)? {
|
||||
return Ok(sidecar.sha256);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut entries = Vec::new();
|
||||
collect_store_files(store_dir, store_dir, &mut entries)?;
|
||||
entries.sort();
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
for relative in &entries {
|
||||
let path = store_dir.join(relative);
|
||||
if path == marker {
|
||||
continue;
|
||||
}
|
||||
hasher.update(relative.as_bytes());
|
||||
hasher.update([0]);
|
||||
let file_hash = hash_file(&path)?;
|
||||
hasher.update(file_hash.as_bytes());
|
||||
}
|
||||
let digest = format!("{:x}", hasher.finalize());
|
||||
|
||||
let sidecar = ChecksumSidecar {
|
||||
version: SIDECAR_VERSION,
|
||||
sha256: digest.clone(),
|
||||
size_bytes: entries.len() as u64,
|
||||
modified_unix_secs: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map_or(0, |duration| duration.as_secs()),
|
||||
};
|
||||
if let Some(parent) = marker.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(&marker, serde_json::to_vec_pretty(&sidecar)?)?;
|
||||
Ok(digest)
|
||||
}
|
||||
|
||||
fn collect_store_files(base: &Path, current: &Path, entries: &mut Vec<String>) -> Result<()> {
|
||||
for entry in fs::read_dir(current)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path
|
||||
.file_name()
|
||||
.is_some_and(|name| name == "checksum.sha256")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if path.is_dir() {
|
||||
collect_store_files(base, &path, entries)?;
|
||||
} else if path.is_file() {
|
||||
let relative = path
|
||||
.strip_prefix(base)
|
||||
.unwrap_or(&path)
|
||||
.to_string_lossy()
|
||||
.replace('\\', "/");
|
||||
entries.push(relative);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn sidecar_round_trip() -> Result<()> {
|
||||
let dir = tempdir()?;
|
||||
let file = dir.path().join("sample.json");
|
||||
fs::write(&file, br#"{"hello":"world"}"#)?;
|
||||
|
||||
let first = content_checksum(&file)?;
|
||||
let second = content_checksum(&file)?;
|
||||
assert_eq!(first, second);
|
||||
|
||||
fs::write(&file, br#"{"hello":"world!"}"#)?;
|
||||
let third = content_checksum(&file)?;
|
||||
assert_ne!(first, third);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tracing::info;
|
||||
|
||||
use super::{
|
||||
catalog,
|
||||
store::{
|
||||
self, build_dataset_from_catalog, detect_layout, read_meta, store_dir_for, write_sharded,
|
||||
ConvertedLayout,
|
||||
},
|
||||
ConvertedDataset, DatasetKind,
|
||||
};
|
||||
use crate::{
|
||||
args::Config,
|
||||
slice::{self, SliceConfig},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoadedDataset {
|
||||
pub dataset: ConvertedDataset,
|
||||
pub content_checksum: String,
|
||||
pub partial: bool,
|
||||
}
|
||||
|
||||
pub fn prepare_dataset(dataset_kind: DatasetKind, config: &Config) -> Result<LoadedDataset> {
|
||||
if dataset_kind == DatasetKind::Beir {
|
||||
return super::beir_mix::prepare_beir_mix(config);
|
||||
}
|
||||
|
||||
let converted_path = &config.converted_dataset_path;
|
||||
let layout = detect_layout(converted_path);
|
||||
let store_dir = store_dir_for(converted_path);
|
||||
|
||||
if layout == ConvertedLayout::Missing || config.force_convert {
|
||||
return convert_and_load(dataset_kind, config);
|
||||
}
|
||||
|
||||
load_from_store(dataset_kind, config, &store_dir, true)
|
||||
}
|
||||
|
||||
fn convert_and_load(dataset_kind: DatasetKind, config: &Config) -> Result<LoadedDataset> {
|
||||
let dataset = super::convert(
|
||||
config.raw_dataset_path.as_path(),
|
||||
dataset_kind,
|
||||
config.llm_mode,
|
||||
)
|
||||
.with_context(|| format!("converting {} dataset", dataset_kind.label()))?;
|
||||
|
||||
let store_dir = store_dir_for(&config.converted_dataset_path);
|
||||
write_sharded(&dataset, &store_dir)?;
|
||||
prebuild_catalog_slices(&dataset, config)?;
|
||||
let checksum = crate::datasets::store_aggregate_checksum(&store_dir)?;
|
||||
|
||||
Ok(LoadedDataset {
|
||||
dataset,
|
||||
content_checksum: checksum,
|
||||
partial: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_from_store(
|
||||
dataset_kind: DatasetKind,
|
||||
config: &Config,
|
||||
store_dir: &std::path::Path,
|
||||
allow_partial: bool,
|
||||
) -> Result<LoadedDataset> {
|
||||
let checksum = crate::datasets::store_aggregate_checksum(store_dir)?;
|
||||
let meta = read_meta(store_dir)?;
|
||||
validate_metadata_fields(&meta.metadata, dataset_kind, config)?;
|
||||
|
||||
if allow_partial {
|
||||
if let Some(paragraph_ids) = slice_paragraph_ids_for_fast_path(config)? {
|
||||
let unique: HashSet<String> = paragraph_ids.into_iter().collect();
|
||||
info!(
|
||||
paragraphs = unique.len(),
|
||||
store = %store_dir.display(),
|
||||
"Loading slice-addressed paragraphs from sharded converted store"
|
||||
);
|
||||
let dataset = build_dataset_from_catalog(store_dir, &unique)?;
|
||||
return Ok(LoadedDataset {
|
||||
dataset,
|
||||
content_checksum: checksum,
|
||||
partial: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
store = %store_dir.display(),
|
||||
paragraphs = meta.paragraph_count,
|
||||
"Loading full sharded converted store"
|
||||
);
|
||||
let dataset = store::load_sharded_full(store_dir)?;
|
||||
Ok(LoadedDataset {
|
||||
dataset,
|
||||
content_checksum: checksum,
|
||||
partial: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn slice_paragraph_ids_for_fast_path(config: &Config) -> Result<Option<Vec<String>>> {
|
||||
let Some(manifest_path) = slice::cached_manifest_path(config) else {
|
||||
return Ok(None);
|
||||
};
|
||||
let Some(manifest) = slice::read_manifest_if_exists(&manifest_path)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let slice_config = slice::slice_config_with_limit(config, slice::ledger_target(config));
|
||||
if !slice::manifest_is_complete(&manifest, &slice_config) {
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(
|
||||
manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|entry| entry.id.clone())
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
fn validate_metadata_fields(
|
||||
metadata: &super::DatasetMetadata,
|
||||
dataset_kind: DatasetKind,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
if metadata.id != dataset_kind.id() {
|
||||
anyhow::bail!(
|
||||
"converted dataset targets '{}', expected '{}'",
|
||||
metadata.id,
|
||||
dataset_kind.id()
|
||||
);
|
||||
}
|
||||
if metadata.include_unanswerable != config.llm_mode {
|
||||
anyhow::bail!(
|
||||
"converted dataset include_unanswerable mismatch (expected {}, found {})",
|
||||
config.llm_mode,
|
||||
metadata.include_unanswerable
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn prebuild_catalog_slices(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||
let catalog = catalog()?;
|
||||
let entry = catalog.dataset(dataset.metadata.id.as_str())?;
|
||||
if entry.slices.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!(
|
||||
dataset = dataset.metadata.id.as_str(),
|
||||
slices = entry.slices.len(),
|
||||
"Prebuilding catalog slice ledgers"
|
||||
);
|
||||
|
||||
for slice_entry in &entry.slices {
|
||||
let slice_config = slice_config_for_catalog_entry(config, slice_entry);
|
||||
match slice::resolve_slice(dataset, &slice_config) {
|
||||
Ok(resolved) => info!(
|
||||
slice = resolved.manifest.slice_id.as_str(),
|
||||
cases = resolved.manifest.case_count,
|
||||
positives = resolved.manifest.positive_paragraphs,
|
||||
negatives = resolved.manifest.negative_paragraphs,
|
||||
"Prebuilt catalog slice ledger"
|
||||
),
|
||||
Err(err) => tracing::warn!(
|
||||
slice = slice_entry.id.as_str(),
|
||||
error = %err,
|
||||
"Failed to prebuild catalog slice ledger"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_config_for_catalog_entry<'a>(
|
||||
config: &'a Config,
|
||||
slice_entry: &'a super::SliceEntry,
|
||||
) -> SliceConfig<'a> {
|
||||
SliceConfig {
|
||||
cache_dir: config.cache_dir.as_path(),
|
||||
force_convert: config.force_convert,
|
||||
explicit_slice: Some(slice_entry.id.as_str()),
|
||||
limit: slice_entry.limit,
|
||||
corpus_limit: slice_entry.corpus_limit,
|
||||
slice_seed: slice_entry.seed.unwrap_or(config.slice_seed),
|
||||
llm_mode: slice_entry.include_unanswerable.unwrap_or(config.llm_mode),
|
||||
negative_multiplier: slice_entry
|
||||
.negative_multiplier
|
||||
.unwrap_or(config.negative_multiplier),
|
||||
require_verified_chunks: config.retrieval.require_verified_chunks,
|
||||
}
|
||||
}
|
||||
+36
-143
@@ -1,6 +1,10 @@
|
||||
mod beir;
|
||||
mod beir_mix;
|
||||
mod checksum;
|
||||
mod loader;
|
||||
mod nq;
|
||||
mod squad;
|
||||
mod store;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
@@ -20,38 +24,31 @@ const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"
|
||||
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct DatasetCatalog {
|
||||
datasets: BTreeMap<String, DatasetEntry>,
|
||||
slices: HashMap<String, SliceLocation>,
|
||||
default_dataset: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct DatasetEntry {
|
||||
pub metadata: DatasetMetadata,
|
||||
pub raw_path: PathBuf,
|
||||
pub converted_path: PathBuf,
|
||||
pub include_unanswerable: bool,
|
||||
pub slices: Vec<SliceEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SliceEntry {
|
||||
pub id: String,
|
||||
pub dataset_id: String,
|
||||
pub label: String,
|
||||
pub description: Option<String>,
|
||||
pub limit: Option<usize>,
|
||||
pub corpus_limit: Option<usize>,
|
||||
pub include_unanswerable: Option<bool>,
|
||||
pub seed: Option<u64>,
|
||||
pub negative_multiplier: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
struct SliceLocation {
|
||||
dataset_id: String,
|
||||
slice_index: usize,
|
||||
@@ -59,7 +56,6 @@ struct SliceLocation {
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ManifestFile {
|
||||
default_dataset: Option<String>,
|
||||
datasets: Vec<ManifestDataset>,
|
||||
}
|
||||
|
||||
@@ -81,6 +77,7 @@ struct ManifestDataset {
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ManifestSlice {
|
||||
id: String,
|
||||
label: String,
|
||||
@@ -94,6 +91,8 @@ struct ManifestSlice {
|
||||
include_unanswerable: Option<bool>,
|
||||
#[serde(default)]
|
||||
seed: Option<u64>,
|
||||
#[serde(default)]
|
||||
negative_multiplier: Option<f32>,
|
||||
}
|
||||
|
||||
impl DatasetCatalog {
|
||||
@@ -111,18 +110,19 @@ impl DatasetCatalog {
|
||||
let raw_path = resolve_path(root, &dataset.raw);
|
||||
let converted_path = resolve_path(root, &dataset.converted);
|
||||
|
||||
if !raw_path.exists() {
|
||||
if !raw_path.exists() && dataset.id != "beir" {
|
||||
bail!(
|
||||
"dataset '{}' raw file missing at {}",
|
||||
dataset.id,
|
||||
raw_path.display()
|
||||
);
|
||||
}
|
||||
if !converted_path.exists() {
|
||||
let store_dir = store::store_dir_for(&converted_path);
|
||||
if !converted_path.exists() && !store_dir.join("meta.json").is_file() {
|
||||
warn!(
|
||||
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
|
||||
"dataset '{}' converted store missing at {}; the next conversion run will regenerate it",
|
||||
dataset.id,
|
||||
converted_path.display()
|
||||
store_dir.display()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -139,7 +139,6 @@ impl DatasetCatalog {
|
||||
.clone()
|
||||
.unwrap_or_else(|| dataset.id.clone()),
|
||||
include_unanswerable: dataset.include_unanswerable,
|
||||
context_token_limit: None,
|
||||
};
|
||||
|
||||
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
|
||||
@@ -154,12 +153,11 @@ impl DatasetCatalog {
|
||||
entry_slices.push(SliceEntry {
|
||||
id: manifest_slice.id.clone(),
|
||||
dataset_id: dataset.id.clone(),
|
||||
label: manifest_slice.label,
|
||||
description: manifest_slice.description,
|
||||
limit: manifest_slice.limit,
|
||||
corpus_limit: manifest_slice.corpus_limit,
|
||||
include_unanswerable: manifest_slice.include_unanswerable,
|
||||
seed: manifest_slice.seed,
|
||||
negative_multiplier: manifest_slice.negative_multiplier,
|
||||
});
|
||||
slices.insert(
|
||||
manifest_slice.id,
|
||||
@@ -176,22 +174,16 @@ impl DatasetCatalog {
|
||||
metadata,
|
||||
raw_path,
|
||||
converted_path,
|
||||
include_unanswerable: dataset.include_unanswerable,
|
||||
slices: entry_slices,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let default_dataset = manifest
|
||||
.default_dataset
|
||||
.or_else(|| datasets.keys().next().cloned())
|
||||
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
|
||||
if datasets.is_empty() {
|
||||
bail!("dataset manifest does not include any datasets");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
datasets,
|
||||
slices,
|
||||
default_dataset,
|
||||
})
|
||||
Ok(Self { datasets, slices })
|
||||
}
|
||||
|
||||
pub fn global() -> Result<&'static Self> {
|
||||
@@ -204,12 +196,6 @@ impl DatasetCatalog {
|
||||
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
|
||||
self.dataset(&self.default_dataset)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
|
||||
let location = self
|
||||
.slices
|
||||
@@ -236,20 +222,27 @@ fn resolve_path(root: &Path, value: &str) -> PathBuf {
|
||||
}
|
||||
}
|
||||
|
||||
pub use beir_mix::{beir_subset_store_summary, beir_subset_stores_ready, mix_content_checksum};
|
||||
pub use checksum::store_aggregate_checksum;
|
||||
pub use loader::{prebuild_catalog_slices, prepare_dataset};
|
||||
pub use store::{
|
||||
content_checksum_for_layout, detect_layout, store_dir_for, write_sharded, ConvertedLayout,
|
||||
};
|
||||
|
||||
pub fn catalog() -> Result<&'static DatasetCatalog> {
|
||||
DatasetCatalog::global()
|
||||
}
|
||||
|
||||
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||
pub(crate) fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||
let catalog = catalog()?;
|
||||
catalog.dataset(kind.id())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, ValueEnum, Default)]
|
||||
pub enum DatasetKind {
|
||||
#[default]
|
||||
SquadV2,
|
||||
NaturalQuestions,
|
||||
#[default]
|
||||
Beir,
|
||||
#[value(name = "fever")]
|
||||
Fever,
|
||||
@@ -416,16 +409,10 @@ pub struct DatasetMetadata {
|
||||
pub source_prefix: String,
|
||||
#[serde(default)]
|
||||
pub include_unanswerable: bool,
|
||||
#[serde(default)]
|
||||
pub context_token_limit: Option<usize>,
|
||||
}
|
||||
|
||||
impl DatasetMetadata {
|
||||
pub fn for_kind(
|
||||
kind: DatasetKind,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Self {
|
||||
pub fn for_kind(kind: DatasetKind, include_unanswerable: bool) -> Self {
|
||||
if let Ok(entry) = dataset_entry_for_kind(kind) {
|
||||
return Self {
|
||||
id: entry.metadata.id.clone(),
|
||||
@@ -434,7 +421,6 @@ impl DatasetMetadata {
|
||||
entity_suffix: entry.metadata.entity_suffix.clone(),
|
||||
source_prefix: entry.metadata.source_prefix.clone(),
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -445,13 +431,12 @@ impl DatasetMetadata {
|
||||
entity_suffix: kind.entity_suffix().to_string(),
|
||||
source_prefix: kind.source_prefix().to_string(),
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_metadata() -> DatasetMetadata {
|
||||
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
|
||||
DatasetMetadata::for_kind(DatasetKind::default(), false)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -483,14 +468,15 @@ pub fn convert(
|
||||
raw_path: &Path,
|
||||
dataset: DatasetKind,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
let paragraphs = match dataset {
|
||||
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
|
||||
DatasetKind::NaturalQuestions => {
|
||||
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
|
||||
DatasetKind::NaturalQuestions => nq::convert_nq(raw_path, include_unanswerable)?,
|
||||
DatasetKind::Beir => {
|
||||
bail!(
|
||||
"BEIR mix is prepared via slice-first subset stores; use prepare_beir_mix instead of convert"
|
||||
);
|
||||
}
|
||||
DatasetKind::Beir => convert_beir_mix(include_unanswerable, context_token_limit)?,
|
||||
DatasetKind::Fever
|
||||
| DatasetKind::Fiqa
|
||||
| DatasetKind::HotpotQa
|
||||
@@ -501,11 +487,6 @@ pub fn convert(
|
||||
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
|
||||
};
|
||||
|
||||
let metadata_limit = match dataset {
|
||||
DatasetKind::NaturalQuestions => None,
|
||||
_ => context_token_limit,
|
||||
};
|
||||
|
||||
let generated_at = match dataset {
|
||||
DatasetKind::Beir
|
||||
| DatasetKind::Fever
|
||||
@@ -526,100 +507,12 @@ pub fn convert(
|
||||
|
||||
Ok(ConvertedDataset {
|
||||
generated_at,
|
||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable),
|
||||
source: source_label,
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_beir_mix(
|
||||
include_unanswerable: bool,
|
||||
_context_token_limit: Option<usize>,
|
||||
) -> Result<Vec<ConvertedParagraph>> {
|
||||
if include_unanswerable {
|
||||
warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
|
||||
}
|
||||
|
||||
let mut paragraphs = Vec::new();
|
||||
for subset in BEIR_DATASETS {
|
||||
let entry = dataset_entry_for_kind(subset)?;
|
||||
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
|
||||
paragraphs.extend(subset_paragraphs);
|
||||
}
|
||||
|
||||
Ok(paragraphs)
|
||||
}
|
||||
|
||||
fn ensure_parent(path: &Path) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating parent directory for {}", path.display()))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
|
||||
ensure_parent(converted_path)?;
|
||||
let json =
|
||||
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
|
||||
fs::write(converted_path, json)
|
||||
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
|
||||
}
|
||||
|
||||
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
|
||||
let raw = fs::read_to_string(converted_path)
|
||||
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
|
||||
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
|
||||
if dataset.metadata.id.trim().is_empty() {
|
||||
dataset.metadata = default_metadata();
|
||||
}
|
||||
if dataset.source.is_empty() {
|
||||
dataset.source = converted_path.display().to_string();
|
||||
}
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
pub fn ensure_converted(
|
||||
dataset_kind: DatasetKind,
|
||||
raw_path: &Path,
|
||||
converted_path: &Path,
|
||||
force: bool,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
if force || !converted_path.exists() {
|
||||
let dataset = convert(
|
||||
raw_path,
|
||||
dataset_kind,
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
)?;
|
||||
write_converted(&dataset, converted_path)?;
|
||||
return Ok(dataset);
|
||||
}
|
||||
|
||||
match read_converted(converted_path) {
|
||||
Ok(dataset)
|
||||
if dataset.metadata.id == dataset_kind.id()
|
||||
&& dataset.metadata.include_unanswerable == include_unanswerable
|
||||
&& dataset.metadata.context_token_limit == context_token_limit =>
|
||||
{
|
||||
Ok(dataset)
|
||||
}
|
||||
_ => {
|
||||
let dataset = convert(
|
||||
raw_path,
|
||||
dataset_kind,
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
)?;
|
||||
write_converted(&dataset, converted_path)?;
|
||||
Ok(dataset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base_timestamp() -> DateTime<Utc> {
|
||||
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
|
||||
}
|
||||
|
||||
@@ -16,11 +16,7 @@ use super::{ConvertedParagraph, ConvertedQuestion};
|
||||
clippy::arithmetic_side_effects,
|
||||
clippy::cast_sign_loss
|
||||
)]
|
||||
pub fn convert_nq(
|
||||
raw_path: &Path,
|
||||
include_unanswerable: bool,
|
||||
_context_token_limit: Option<usize>,
|
||||
) -> Result<Vec<ConvertedParagraph>> {
|
||||
pub fn convert_nq(raw_path: &Path, include_unanswerable: bool) -> Result<Vec<ConvertedParagraph>> {
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NqExample {
|
||||
|
||||
@@ -0,0 +1,412 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs::{self, File, OpenOptions},
|
||||
io::{BufRead, BufReader, Write},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
use super::{
|
||||
checksum::store_aggregate_checksum, ConvertedDataset, ConvertedParagraph, ConvertedQuestion,
|
||||
DatasetMetadata,
|
||||
};
|
||||
use crate::slice;
|
||||
|
||||
pub const SHARDED_STORE_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ShardedMeta {
|
||||
pub version: u32,
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub metadata: DatasetMetadata,
|
||||
pub source: String,
|
||||
pub paragraph_count: usize,
|
||||
pub question_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct QuestionRecord {
|
||||
paragraph_id: String,
|
||||
#[serde(flatten)]
|
||||
question: ConvertedQuestion,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuestionCatalog {
|
||||
pub entries: Vec<QuestionRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ConvertedLayout {
|
||||
ShardedStore,
|
||||
Missing,
|
||||
}
|
||||
|
||||
pub fn store_dir_for(converted_path: &Path) -> PathBuf {
|
||||
converted_path
|
||||
.parent()
|
||||
.unwrap_or_else(|| Path::new("."))
|
||||
.join(converted_path.file_stem().map_or_else(
|
||||
|| "dataset".to_string(),
|
||||
|stem| stem.to_string_lossy().into(),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn detect_layout(converted_path: &Path) -> ConvertedLayout {
|
||||
let store_dir = store_dir_for(converted_path);
|
||||
if store_dir.join("meta.json").is_file() {
|
||||
ConvertedLayout::ShardedStore
|
||||
} else {
|
||||
ConvertedLayout::Missing
|
||||
}
|
||||
}
|
||||
|
||||
fn paragraph_file_name(paragraph_id: &str) -> String {
|
||||
format!("{}.json", slice::paragraph_storage_key(paragraph_id))
|
||||
}
|
||||
|
||||
pub fn paragraph_path(store_dir: &Path, paragraph_id: &str) -> PathBuf {
|
||||
store_dir
|
||||
.join("paragraphs")
|
||||
.join(paragraph_file_name(paragraph_id))
|
||||
}
|
||||
|
||||
pub fn write_sharded(dataset: &ConvertedDataset, store_dir: &Path) -> Result<String> {
|
||||
if store_dir.exists() {
|
||||
fs::remove_dir_all(store_dir)
|
||||
.with_context(|| format!("clearing sharded store {}", store_dir.display()))?;
|
||||
}
|
||||
fs::create_dir_all(store_dir.join("paragraphs"))
|
||||
.with_context(|| format!("creating sharded store {}", store_dir.display()))?;
|
||||
|
||||
let question_count = dataset
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|paragraph| paragraph.questions.len())
|
||||
.sum::<usize>();
|
||||
|
||||
let meta = ShardedMeta {
|
||||
version: SHARDED_STORE_VERSION,
|
||||
generated_at: dataset.generated_at,
|
||||
metadata: dataset.metadata.clone(),
|
||||
source: dataset.source.clone(),
|
||||
paragraph_count: dataset.paragraphs.len(),
|
||||
question_count,
|
||||
};
|
||||
let meta_path = store_dir.join("meta.json");
|
||||
fs::write(
|
||||
&meta_path,
|
||||
serde_json::to_vec_pretty(&meta).context("serialising sharded store metadata")?,
|
||||
)
|
||||
.with_context(|| format!("writing sharded metadata {}", meta_path.display()))?;
|
||||
|
||||
let mut questions_file = File::create(store_dir.join("questions.jsonl"))
|
||||
.context("creating questions.jsonl for sharded store")?;
|
||||
let mut paragraph_ids_file = File::create(store_dir.join("paragraph_ids.jsonl"))
|
||||
.context("creating paragraph_ids.jsonl for sharded store")?;
|
||||
|
||||
for paragraph in &dataset.paragraphs {
|
||||
writeln!(paragraph_ids_file, "{}", paragraph.id)
|
||||
.context("writing paragraph id to paragraph_ids.jsonl")?;
|
||||
for question in ¶graph.questions {
|
||||
let record = QuestionRecord {
|
||||
paragraph_id: paragraph.id.clone(),
|
||||
question: question.clone(),
|
||||
};
|
||||
serde_json::to_writer(&mut questions_file, &record)
|
||||
.context("writing question record to questions.jsonl")?;
|
||||
questions_file.write_all(b"\n")?;
|
||||
}
|
||||
|
||||
let path = paragraph_path(store_dir, ¶graph.id);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(
|
||||
&path,
|
||||
serde_json::to_vec(paragraph).context("serialising sharded paragraph")?,
|
||||
)
|
||||
.with_context(|| format!("writing sharded paragraph {}", path.display()))?;
|
||||
}
|
||||
|
||||
let digest = store_aggregate_checksum(store_dir)?;
|
||||
info!(
|
||||
store = %store_dir.display(),
|
||||
paragraphs = dataset.paragraphs.len(),
|
||||
questions = question_count,
|
||||
checksum = %digest,
|
||||
"Wrote sharded converted dataset"
|
||||
);
|
||||
Ok(digest)
|
||||
}
|
||||
|
||||
pub fn read_meta(store_dir: &Path) -> Result<ShardedMeta> {
|
||||
let path = store_dir.join("meta.json");
|
||||
let raw = fs::read_to_string(&path)
|
||||
.with_context(|| format!("reading sharded metadata {}", path.display()))?;
|
||||
serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing sharded metadata {}", path.display()))
|
||||
}
|
||||
|
||||
pub fn content_checksum_for_layout(converted_path: &Path) -> Result<String> {
|
||||
match detect_layout(converted_path) {
|
||||
ConvertedLayout::ShardedStore => {
|
||||
crate::datasets::store_aggregate_checksum(&store_dir_for(converted_path))
|
||||
}
|
||||
ConvertedLayout::Missing => Err(anyhow!(
|
||||
"converted dataset missing at {}",
|
||||
converted_path.display()
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn load_paragraph(store_dir: &Path, paragraph_id: &str) -> Result<ConvertedParagraph> {
|
||||
let path = paragraph_path(store_dir, paragraph_id);
|
||||
let raw =
|
||||
fs::read(&path).with_context(|| format!("reading sharded paragraph {}", path.display()))?;
|
||||
serde_json::from_slice(&raw)
|
||||
.with_context(|| format!("parsing sharded paragraph {}", path.display()))
|
||||
}
|
||||
|
||||
fn load_paragraphs(store_dir: &Path, paragraph_ids: &[String]) -> Result<Vec<ConvertedParagraph>> {
|
||||
paragraph_ids
|
||||
.iter()
|
||||
.map(|paragraph_id| load_paragraph(store_dir, paragraph_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn load_sharded_partial(
|
||||
store_dir: &Path,
|
||||
paragraph_ids: &[String],
|
||||
) -> Result<ConvertedDataset> {
|
||||
let meta = read_meta(store_dir)?;
|
||||
let mut paragraphs = load_paragraphs(store_dir, paragraph_ids)?;
|
||||
paragraphs.sort_by(|left, right| left.id.cmp(&right.id));
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: meta.generated_at,
|
||||
metadata: meta.metadata,
|
||||
source: meta.source,
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_sharded_full(store_dir: &Path) -> Result<ConvertedDataset> {
|
||||
let meta = read_meta(store_dir)?;
|
||||
let ids = load_paragraph_ids(store_dir)?;
|
||||
let paragraphs = load_paragraphs(store_dir, &ids)?;
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: meta.generated_at,
|
||||
metadata: meta.metadata,
|
||||
source: meta.source,
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_paragraph_ids_set(store_dir: &Path) -> Result<HashSet<String>> {
|
||||
Ok(load_paragraph_ids(store_dir)?.into_iter().collect())
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects)]
|
||||
pub fn upsert_sharded_paragraphs(
|
||||
store_dir: &Path,
|
||||
paragraphs: &[ConvertedParagraph],
|
||||
) -> Result<()> {
|
||||
if paragraphs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if !store_dir.join("meta.json").is_file() {
|
||||
return Err(anyhow!(
|
||||
"cannot upsert into missing sharded store at {}",
|
||||
store_dir.display()
|
||||
));
|
||||
}
|
||||
|
||||
fs::create_dir_all(store_dir.join("paragraphs"))
|
||||
.with_context(|| format!("creating paragraphs directory in {}", store_dir.display()))?;
|
||||
|
||||
let existing = load_paragraph_ids_set(store_dir)?;
|
||||
let questions_path = store_dir.join("questions.jsonl");
|
||||
let mut questions_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&questions_path)
|
||||
.with_context(|| format!("opening question catalog {}", questions_path.display()))?;
|
||||
|
||||
let mut ids_file = None;
|
||||
let mut new_paragraphs = 0usize;
|
||||
let mut new_questions = 0usize;
|
||||
|
||||
for paragraph in paragraphs {
|
||||
let is_new = !existing.contains(¶graph.id);
|
||||
let path = paragraph_path(store_dir, ¶graph.id);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(
|
||||
&path,
|
||||
serde_json::to_vec(paragraph).context("serialising sharded paragraph")?,
|
||||
)
|
||||
.with_context(|| format!("writing sharded paragraph {}", path.display()))?;
|
||||
|
||||
if is_new {
|
||||
if ids_file.is_none() {
|
||||
ids_file = Some(
|
||||
OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(store_dir.join("paragraph_ids.jsonl"))
|
||||
.context("opening paragraph_ids.jsonl for append")?,
|
||||
);
|
||||
}
|
||||
if let Some(file) = ids_file.as_mut() {
|
||||
writeln!(file, "{}", paragraph.id).context("appending paragraph id")?;
|
||||
}
|
||||
new_paragraphs += 1;
|
||||
|
||||
for question in ¶graph.questions {
|
||||
let record = QuestionRecord {
|
||||
paragraph_id: paragraph.id.clone(),
|
||||
question: question.clone(),
|
||||
};
|
||||
serde_json::to_writer(&mut questions_file, &record)
|
||||
.context("writing question record to questions.jsonl")?;
|
||||
questions_file.write_all(b"\n")?;
|
||||
new_questions += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if new_paragraphs > 0 || new_questions > 0 {
|
||||
let meta = read_meta(store_dir)?;
|
||||
let updated = ShardedMeta {
|
||||
paragraph_count: meta.paragraph_count + new_paragraphs,
|
||||
question_count: meta.question_count + new_questions,
|
||||
..meta
|
||||
};
|
||||
fs::write(
|
||||
store_dir.join("meta.json"),
|
||||
serde_json::to_vec_pretty(&updated).context("serialising updated sharded metadata")?,
|
||||
)?;
|
||||
store_aggregate_checksum(store_dir)?;
|
||||
info!(
|
||||
store = %store_dir.display(),
|
||||
new_paragraphs,
|
||||
new_questions,
|
||||
"Upserted paragraphs into sharded converted store"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load_paragraph_ids(store_dir: &Path) -> Result<Vec<String>> {
|
||||
let path = store_dir.join("paragraph_ids.jsonl");
|
||||
let file = File::open(&path)
|
||||
.with_context(|| format!("opening paragraph id index {}", path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
reader
|
||||
.lines()
|
||||
.map(|line| {
|
||||
line.context("reading paragraph id index line")
|
||||
.and_then(|value| {
|
||||
let trimmed = value.trim();
|
||||
if trimmed.is_empty() {
|
||||
Err(anyhow!("empty paragraph id in index"))
|
||||
} else {
|
||||
Ok(trimmed.to_string())
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn load_question_catalog(store_dir: &Path) -> Result<QuestionCatalog> {
|
||||
let path = store_dir.join("questions.jsonl");
|
||||
let file = File::open(&path)
|
||||
.with_context(|| format!("opening question catalog {}", path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut entries = Vec::new();
|
||||
for line in reader.lines() {
|
||||
let line = line.context("reading question catalog line")?;
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let record: QuestionRecord =
|
||||
serde_json::from_str(&line).context("parsing question catalog record")?;
|
||||
entries.push(record);
|
||||
}
|
||||
Ok(QuestionCatalog { entries })
|
||||
}
|
||||
|
||||
pub fn build_dataset_from_catalog(
|
||||
store_dir: &Path,
|
||||
paragraph_ids: &HashSet<String>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
let catalog = load_question_catalog(store_dir)?;
|
||||
let mut questions_by_paragraph: HashMap<String, Vec<ConvertedQuestion>> = HashMap::new();
|
||||
for entry in catalog.entries {
|
||||
if paragraph_ids.contains(&entry.paragraph_id) {
|
||||
questions_by_paragraph
|
||||
.entry(entry.paragraph_id.clone())
|
||||
.or_default()
|
||||
.push(entry.question);
|
||||
}
|
||||
}
|
||||
|
||||
let mut dataset = load_sharded_partial(
|
||||
store_dir,
|
||||
¶graph_ids.iter().cloned().collect::<Vec<_>>(),
|
||||
)?;
|
||||
for paragraph in &mut dataset.paragraphs {
|
||||
if let Some(questions) = questions_by_paragraph.remove(¶graph.id) {
|
||||
paragraph.questions = questions;
|
||||
} else {
|
||||
paragraph.questions.clear();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::datasets::{DatasetKind, DatasetMetadata};
|
||||
|
||||
fn sample_dataset() -> ConvertedDataset {
|
||||
ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata: DatasetMetadata::for_kind(DatasetKind::SquadV2, false),
|
||||
source: "test".to_string(),
|
||||
paragraphs: vec![ConvertedParagraph {
|
||||
id: "p1".to_string(),
|
||||
title: "Title".to_string(),
|
||||
context: "Body".to_string(),
|
||||
questions: vec![ConvertedQuestion {
|
||||
id: "q1".to_string(),
|
||||
question: "Question?".to_string(),
|
||||
answers: vec!["Answer".to_string()],
|
||||
is_impossible: false,
|
||||
}],
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
fn sharded_round_trip() -> Result<()> {
|
||||
let dir = tempfile::tempdir()?;
|
||||
let store_dir = dir.path().join("sample");
|
||||
let dataset = sample_dataset();
|
||||
write_sharded(&dataset, &store_dir)?;
|
||||
|
||||
let loaded = load_sharded_full(&store_dir)?;
|
||||
assert_eq!(loaded.paragraphs.len(), 1);
|
||||
assert_eq!(loaded.paragraphs[0].questions[0].id, "q1");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,22 @@
|
||||
//! Database namespace management utilities.
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::Utc;
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::user::{Theme, User},
|
||||
types::StoredObject,
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::user::{Theme, User},
|
||||
types::StoredObject,
|
||||
},
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
args::Config,
|
||||
corpus::{self, CorpusHandle, CorpusManifest, NamespaceSeedRecord},
|
||||
datasets,
|
||||
snapshot::{self, DbSnapshotState},
|
||||
};
|
||||
|
||||
/// Connect to the evaluation database with fallback auth strategies.
|
||||
pub(crate) async fn connect_eval_db(
|
||||
config: &Config,
|
||||
namespace: &str,
|
||||
@@ -73,7 +73,6 @@ pub(crate) async fn connect_eval_db(
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the namespace contains any corpus data.
|
||||
pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
|
||||
#[derive(Deserialize)]
|
||||
struct CountRow {
|
||||
@@ -89,41 +88,51 @@ pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
|
||||
Ok(rows.first().map_or(0, |row| row.count) > 0)
|
||||
}
|
||||
|
||||
/// Determine if we can reuse an existing namespace based on cached state.
|
||||
fn manifest_matches_runtime(
|
||||
manifest: &CorpusManifest,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
ingestion_fingerprint: &str,
|
||||
) -> bool {
|
||||
let metadata = &manifest.metadata;
|
||||
metadata.ingestion_fingerprint == ingestion_fingerprint
|
||||
&& metadata.embedding_backend == embedding_provider.backend_label()
|
||||
&& metadata.embedding_model == embedding_provider.model_code()
|
||||
&& metadata.embedding_dimension == embedding_provider.dimension()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn can_reuse_namespace(
|
||||
db: &SurrealDbClient,
|
||||
descriptor: &snapshot::Descriptor,
|
||||
manifest: &CorpusManifest,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
dataset_id: &str,
|
||||
slice_id: &str,
|
||||
ingestion_fingerprint: &str,
|
||||
slice_case_count: usize,
|
||||
) -> Result<bool> {
|
||||
let Some(state) = descriptor.load_db_state().await? else {
|
||||
info!("No namespace state recorded; reseeding corpus from cached shards");
|
||||
if !manifest_matches_runtime(manifest, embedding_provider, ingestion_fingerprint) {
|
||||
info!("Corpus manifest metadata mismatch; rebuilding namespace from cached shards");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let Some(seed) = manifest.metadata.namespace_seed.as_ref() else {
|
||||
info!("No namespace seed recorded in corpus manifest; reseeding");
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
if state.slice_case_count != slice_case_count {
|
||||
if seed.slice_case_count != slice_case_count {
|
||||
info!(
|
||||
requested_cases = slice_case_count,
|
||||
stored_cases = state.slice_case_count,
|
||||
"Skipping live namespace reuse; cached state does not match requested window"
|
||||
stored_cases = seed.slice_case_count,
|
||||
"Skipping namespace reuse; case window mismatch"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if state.dataset_id != dataset_id
|
||||
|| state.slice_id != slice_id
|
||||
|| state.ingestion_fingerprint != ingestion_fingerprint
|
||||
|| state.namespace.as_deref() != Some(namespace)
|
||||
|| state.database.as_deref() != Some(database)
|
||||
{
|
||||
if seed.namespace != namespace || seed.database != database {
|
||||
info!(
|
||||
namespace,
|
||||
database, "Cached namespace metadata mismatch; rebuilding corpus from ingestion cache"
|
||||
database, "Corpus manifest namespace metadata mismatch; reseeding"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
@@ -140,28 +149,20 @@ pub(crate) async fn can_reuse_namespace(
|
||||
}
|
||||
}
|
||||
|
||||
/// Record the current namespace state to allow future reuse checks.
|
||||
pub(crate) async fn record_namespace_state(
|
||||
descriptor: &snapshot::Descriptor,
|
||||
dataset_id: &str,
|
||||
slice_id: &str,
|
||||
ingestion_fingerprint: &str,
|
||||
pub(crate) async fn record_namespace_seed(
|
||||
handle: &mut CorpusHandle,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
slice_case_count: usize,
|
||||
) {
|
||||
let state = DbSnapshotState {
|
||||
dataset_id: dataset_id.to_string(),
|
||||
slice_id: slice_id.to_string(),
|
||||
ingestion_fingerprint: ingestion_fingerprint.to_string(),
|
||||
snapshot_hash: descriptor.metadata_hash().to_string(),
|
||||
updated_at: Utc::now(),
|
||||
namespace: Some(namespace.to_string()),
|
||||
database: Some(database.to_string()),
|
||||
handle.manifest.metadata.namespace_seed = Some(NamespaceSeedRecord {
|
||||
namespace: namespace.to_string(),
|
||||
database: database.to_string(),
|
||||
slice_case_count,
|
||||
};
|
||||
if let Err(err) = descriptor.store_db_state(&state).await {
|
||||
warn!(error = %err, "Failed to record namespace state");
|
||||
seeded_at: Utc::now(),
|
||||
});
|
||||
if let Err(err) = corpus::persist_corpus_manifest(handle) {
|
||||
warn!(error = %err, "Failed to record namespace seed in corpus manifest");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,8 +186,17 @@ fn sanitize_identifier(input: &str) -> String {
|
||||
cleaned
|
||||
}
|
||||
|
||||
/// Generate a default namespace name based on dataset and limit.
|
||||
pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> String {
|
||||
pub(crate) fn default_namespace(
|
||||
dataset_id: &str,
|
||||
limit: Option<usize>,
|
||||
slice_id: Option<&str>,
|
||||
) -> String {
|
||||
if let Some(slice_id) = slice_id {
|
||||
let sanitized = sanitize_identifier(slice_id);
|
||||
if !sanitized.is_empty() {
|
||||
return format!("eval_{sanitized}");
|
||||
}
|
||||
}
|
||||
let dataset_component = sanitize_identifier(dataset_id);
|
||||
let limit_component = match limit {
|
||||
Some(value) if value > 0 => format!("limit{value}"),
|
||||
@@ -195,12 +205,10 @@ pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> Strin
|
||||
format!("eval_{dataset_component}_{limit_component}")
|
||||
}
|
||||
|
||||
/// Generate the default database name for evaluations.
|
||||
pub(crate) fn default_database() -> String {
|
||||
"retrieval_eval".to_string()
|
||||
}
|
||||
|
||||
/// Ensure the evaluation user exists in the database.
|
||||
pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
||||
let timestamp = datasets::base_timestamp();
|
||||
let user = User {
|
||||
@@ -225,3 +233,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
||||
.context("storing evaluation user")?;
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub(crate) fn sanitize_model_code(code: &str) -> String {
|
||||
sanitize_identifier(code)
|
||||
}
|
||||
@@ -2,13 +2,6 @@ use anyhow::{Context, Result};
|
||||
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime};
|
||||
use tracing::info;
|
||||
|
||||
// Helper functions for index management during namespace reseed
|
||||
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
|
||||
let _ = db;
|
||||
info!("Removing ALL indexes before namespace reseed (no-op placeholder)");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
|
||||
ensure_runtime(db, dimension)
|
||||
@@ -34,14 +27,39 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// // Test helper to force index dimension change
|
||||
// #[allow(dead_code)]
|
||||
// pub async fn change_embedding_length_in_hnsw_indexes(
|
||||
// db: &SurrealDbClient,
|
||||
// dimension: usize,
|
||||
// ) -> Result<()> {
|
||||
// recreate_indexes(db, dimension).await
|
||||
// }
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
|
||||
|
||||
info!("Warming HNSW caches with sample queries");
|
||||
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
r#"SELECT chunk_id
|
||||
FROM text_chunk_embedding
|
||||
WHERE embedding <|1,1|> $embedding
|
||||
LIMIT 5"#,
|
||||
)
|
||||
.bind(("embedding", dummy_embedding.clone()))
|
||||
.await
|
||||
.context("warming text chunk HNSW cache")?;
|
||||
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
r#"SELECT entity_id
|
||||
FROM knowledge_entity_embedding
|
||||
WHERE embedding <|1,1|> $embedding
|
||||
LIMIT 5"#,
|
||||
)
|
||||
.bind(("embedding", dummy_embedding))
|
||||
.await
|
||||
.context("warming knowledge entity HNSW cache")?;
|
||||
|
||||
info!("HNSW cache warming completed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -0,0 +1,9 @@
|
||||
mod connect;
|
||||
mod lifecycle;
|
||||
|
||||
pub(crate) use connect::{
|
||||
can_reuse_namespace, connect_eval_db, default_database, default_namespace, ensure_eval_user,
|
||||
namespace_has_corpus, record_namespace_seed, sanitize_model_code,
|
||||
};
|
||||
pub(crate) use lifecycle::warm_hnsw_cache;
|
||||
pub use lifecycle::{recreate_indexes, reset_namespace};
|
||||
@@ -1,128 +0,0 @@
|
||||
//! Evaluation utilities module - re-exports from focused submodules.
|
||||
|
||||
// Re-export types from the root types module
|
||||
pub use crate::types::*;
|
||||
|
||||
// Re-export from focused modules at crate root (crate-internal only)
|
||||
pub(crate) use crate::cases::{cases_from_manifest, SeededCase};
|
||||
pub(crate) use crate::namespace::{
|
||||
can_reuse_namespace, connect_eval_db, default_database, default_namespace, ensure_eval_user,
|
||||
record_namespace_state,
|
||||
};
|
||||
pub(crate) use crate::settings::{enforce_system_settings, load_or_init_system_settings};
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use common::storage::db::SurrealDbClient;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
args::{self, Config},
|
||||
datasets::ConvertedDataset,
|
||||
slice::{self},
|
||||
};
|
||||
|
||||
/// Grow the slice ledger to contain the target number of cases.
|
||||
pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||
let ledger_limit = ledger_target(config);
|
||||
let slice_settings = slice::slice_config_with_limit(config, ledger_limit);
|
||||
let slice =
|
||||
slice::resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
|
||||
info!(
|
||||
slice = slice.manifest.slice_id.as_str(),
|
||||
cases = slice.manifest.case_count,
|
||||
positives = slice.manifest.positive_paragraphs,
|
||||
negatives = slice.manifest.negative_paragraphs,
|
||||
total_paragraphs = slice.manifest.total_paragraphs,
|
||||
"Slice ledger ready"
|
||||
);
|
||||
println!(
|
||||
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
|
||||
slice.manifest.slice_id,
|
||||
slice.manifest.case_count,
|
||||
slice.manifest.positive_paragraphs,
|
||||
slice.manifest.negative_paragraphs
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn ledger_target(config: &Config) -> Option<usize> {
|
||||
match (config.slice_grow, config.limit) {
|
||||
(Some(grow), Some(limit)) => Some(limit.max(grow)),
|
||||
(Some(grow), None) => Some(grow),
|
||||
(None, limit) => limit,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
|
||||
args::ensure_parent(path)?;
|
||||
let mut file = tokio::fs::File::create(path)
|
||||
.await
|
||||
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
|
||||
for case in cases {
|
||||
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
|
||||
file.write_all(&line).await?;
|
||||
file.write_all(b"\n").await?;
|
||||
}
|
||||
file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
|
||||
|
||||
info!("Warming HNSW caches with sample queries");
|
||||
|
||||
// Warm up chunk embedding index - just query the embedding table to load HNSW index
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
r#"SELECT chunk_id
|
||||
FROM text_chunk_embedding
|
||||
WHERE embedding <|1,1|> $embedding
|
||||
LIMIT 5"#,
|
||||
)
|
||||
.bind(("embedding", dummy_embedding.clone()))
|
||||
.await
|
||||
.context("warming text chunk HNSW cache")?;
|
||||
|
||||
// Warm up entity embedding index
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
r#"SELECT entity_id
|
||||
FROM knowledge_entity_embedding
|
||||
WHERE embedding <|1,1|> $embedding
|
||||
LIMIT 5"#,
|
||||
)
|
||||
.bind(("embedding", dummy_embedding))
|
||||
.await
|
||||
.context("warming knowledge entity HNSW cache")?;
|
||||
|
||||
info!("HNSW cache warming completed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
|
||||
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
|
||||
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
|
||||
}
|
||||
|
||||
pub(crate) fn sanitize_model_code(code: &str) -> String {
|
||||
code.chars()
|
||||
.map(|ch| {
|
||||
if ch.is_ascii_alphanumeric() {
|
||||
ch.to_ascii_lowercase()
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Re-export run_evaluation from the pipeline module at crate root
|
||||
pub use crate::pipeline::run_evaluation;
|
||||
@@ -1,13 +1,9 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use std::{collections::HashMap, fs, path::Path};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
|
||||
|
||||
use crate::{args::Config, corpus, eval::connect_eval_db, snapshot::DbSnapshotState};
|
||||
use crate::{args::Config, corpus, db::connect_eval_db};
|
||||
|
||||
pub async fn inspect_question(config: &Config) -> Result<()> {
|
||||
let question_id = config
|
||||
@@ -64,39 +60,26 @@ pub async fn inspect_question(config: &Config) -> Result<()> {
|
||||
);
|
||||
}
|
||||
|
||||
let db_state_path = config
|
||||
.database
|
||||
.inspect_db_state
|
||||
.clone()
|
||||
.unwrap_or_else(|| default_state_path(config, &manifest));
|
||||
if let Some(state) = load_db_state(&db_state_path)? {
|
||||
if let (Some(ns), Some(db_name)) = (state.namespace.as_deref(), state.database.as_deref()) {
|
||||
match connect_eval_db(config, ns, db_name).await {
|
||||
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
|
||||
MissingChunks::None => println!(
|
||||
"All matching_chunk_ids exist in namespace '{ns}', database '{db_name}'"
|
||||
),
|
||||
MissingChunks::Missing(list) => println!(
|
||||
"Missing chunks in namespace '{ns}', database '{db_name}': {list:?}"
|
||||
),
|
||||
},
|
||||
Err(err) => {
|
||||
println!(
|
||||
"Failed to connect to SurrealDB namespace '{ns}' / database '{db_name}': {err}"
|
||||
);
|
||||
if let Some(seed) = manifest.metadata.namespace_seed.as_ref() {
|
||||
let ns = seed.namespace.as_str();
|
||||
let db_name = seed.database.as_str();
|
||||
match connect_eval_db(config, ns, db_name).await {
|
||||
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
|
||||
MissingChunks::None => println!(
|
||||
"All matching_chunk_ids exist in namespace '{ns}', database '{db_name}'"
|
||||
),
|
||||
MissingChunks::Missing(list) => {
|
||||
println!("Missing chunks in namespace '{ns}', database '{db_name}': {list:?}");
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
println!(
|
||||
"Failed to connect to SurrealDB namespace '{ns}' / database '{db_name}': {err}"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"State file {} is missing namespace/database fields; skipping live DB validation",
|
||||
db_state_path.display()
|
||||
);
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"State file {} not found; skipping live DB validation",
|
||||
db_state_path.display()
|
||||
);
|
||||
println!("Corpus manifest has no namespace seed; skipping live DB validation");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -137,25 +120,6 @@ fn build_chunk_lookup(manifest: &corpus::CorpusManifest) -> HashMap<String, Chun
|
||||
lookup
|
||||
}
|
||||
|
||||
fn default_state_path(config: &Config, manifest: &corpus::CorpusManifest) -> PathBuf {
|
||||
config
|
||||
.cache_dir
|
||||
.join("snapshots")
|
||||
.join(&manifest.metadata.dataset_id)
|
||||
.join(&manifest.metadata.slice_id)
|
||||
.join("db/state.json")
|
||||
}
|
||||
|
||||
fn load_db_state(path: &Path) -> Result<Option<DbSnapshotState>> {
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let bytes = fs::read(path).with_context(|| format!("reading db state {}", path.display()))?;
|
||||
let state = serde_json::from_slice(&bytes)
|
||||
.with_context(|| format!("parsing db state {}", path.display()))?;
|
||||
Ok(Some(state))
|
||||
}
|
||||
|
||||
enum MissingChunks {
|
||||
None,
|
||||
Missing(Vec<String>),
|
||||
|
||||
+49
-46
@@ -1,19 +1,17 @@
|
||||
mod args;
|
||||
mod cache;
|
||||
mod cases;
|
||||
mod cli;
|
||||
mod context_stats;
|
||||
mod corpus;
|
||||
mod datasets;
|
||||
mod db_helpers;
|
||||
mod eval;
|
||||
mod db;
|
||||
mod inspection;
|
||||
mod namespace;
|
||||
mod openai;
|
||||
mod perf;
|
||||
mod pipeline;
|
||||
mod report;
|
||||
mod settings;
|
||||
mod slice;
|
||||
mod snapshot;
|
||||
mod types;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -24,7 +22,6 @@ use tracing_subscriber::{fmt, EnvFilter};
|
||||
/// Configure `SurrealDB` environment variables for optimal performance
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::unwrap_used)]
|
||||
fn configure_surrealdb_performance(cpu_count: usize) {
|
||||
// Set environment variables only if they're not already set
|
||||
let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 2).to_string());
|
||||
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
|
||||
@@ -62,12 +59,11 @@ fn configure_surrealdb_performance(cpu_count: usize) {
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Create an explicit multi-threaded runtime with optimized configuration
|
||||
let runtime = Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.worker_threads(std::thread::available_parallelism()?.get())
|
||||
.max_blocking_threads(std::thread::available_parallelism()?.get())
|
||||
.thread_stack_size(10 * 1024 * 1024) // 10MiB stack size
|
||||
.thread_stack_size(10 * 1024 * 1024)
|
||||
.thread_name("eval-retrieval-worker")
|
||||
.build()
|
||||
.context("failed to create tokio runtime")?;
|
||||
@@ -77,7 +73,6 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn async_main() -> anyhow::Result<()> {
|
||||
// Log runtime configuration
|
||||
let cpu_count = std::thread::available_parallelism()?.get();
|
||||
info!(
|
||||
cpu_cores = cpu_count,
|
||||
@@ -87,7 +82,6 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
"Started multi-threaded tokio runtime"
|
||||
);
|
||||
|
||||
// Configure SurrealDB environment variables for better performance
|
||||
configure_surrealdb_performance(cpu_count);
|
||||
|
||||
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
|
||||
@@ -97,13 +91,22 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
|
||||
let parsed = args::parse()?;
|
||||
|
||||
// Clap handles help automatically, so we don't need to check for it manually
|
||||
|
||||
if parsed.config.inspect_question.is_some() {
|
||||
inspection::inspect_question(&parsed.config).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if parsed.config.status {
|
||||
let status = cli::collect_status(&parsed.config).await?;
|
||||
cli::print_status(&status);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if parsed.config.warm {
|
||||
cli::warm(&parsed.config).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let dataset_kind = parsed.config.dataset;
|
||||
|
||||
if parsed.config.convert_only {
|
||||
@@ -115,7 +118,6 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
parsed.config.raw_dataset_path.as_path(),
|
||||
dataset_kind,
|
||||
parsed.config.llm_mode,
|
||||
parsed.config.context_token_limit(),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
@@ -124,56 +126,52 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
parsed.config.raw_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
crate::datasets::write_converted(&dataset, parsed.config.converted_dataset_path.as_path())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"writing converted dataset to {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
println!(
|
||||
"Converted dataset written to {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
);
|
||||
let store_dir = datasets::store_dir_for(&parsed.config.converted_dataset_path);
|
||||
datasets::write_sharded(&dataset, &store_dir)?;
|
||||
datasets::prebuild_catalog_slices(&dataset, &parsed.config)?;
|
||||
println!("Converted dataset written under {}", store_dir.display());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if parsed.config.require_ready {
|
||||
cli::ensure_query_ready(&parsed.config).await?;
|
||||
}
|
||||
|
||||
info!(dataset = dataset_kind.id(), "Preparing converted dataset");
|
||||
let dataset = crate::datasets::ensure_converted(
|
||||
dataset_kind,
|
||||
parsed.config.raw_dataset_path.as_path(),
|
||||
parsed.config.converted_dataset_path.as_path(),
|
||||
parsed.config.force_convert,
|
||||
parsed.config.llm_mode,
|
||||
parsed.config.context_token_limit(),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"preparing converted dataset at {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
let loaded =
|
||||
crate::datasets::prepare_dataset(dataset_kind, &parsed.config).with_context(|| {
|
||||
format!(
|
||||
"preparing converted dataset at {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
info!(
|
||||
questions = dataset
|
||||
questions = loaded
|
||||
.dataset
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|p| p.questions.len())
|
||||
.sum::<usize>(),
|
||||
paragraphs = dataset.paragraphs.len(),
|
||||
dataset = dataset.metadata.id.as_str(),
|
||||
paragraphs = loaded.dataset.paragraphs.len(),
|
||||
partial = loaded.partial,
|
||||
dataset = loaded.dataset.metadata.id.as_str(),
|
||||
"Dataset ready"
|
||||
);
|
||||
|
||||
if parsed.config.slice_grow.is_some() {
|
||||
eval::grow_slice(&dataset, &parsed.config).context("growing slice ledger")?;
|
||||
slice::grow_slice(&loaded.dataset, &parsed.config).context("growing slice ledger")?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Running retrieval evaluation");
|
||||
let summary = eval::run_evaluation(&dataset, &parsed.config)
|
||||
.await
|
||||
.context("running retrieval evaluation")?;
|
||||
let summary = pipeline::run_evaluation(
|
||||
&loaded.dataset,
|
||||
&parsed.config,
|
||||
Some(loaded.content_checksum.as_str()),
|
||||
)
|
||||
.await
|
||||
.context("running retrieval evaluation")?;
|
||||
|
||||
let report = report::write_reports(
|
||||
&summary,
|
||||
@@ -226,12 +224,17 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
|
||||
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) | Retrieved context: {chunks} chunks, {tokens} tokens ({tokenizer}, avg {avg_tokens:.0}/query, p95 {p95}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
|
||||
summary.dataset_label,
|
||||
k = summary.k,
|
||||
precision = summary.precision,
|
||||
correct = summary.correct,
|
||||
retrieval_total = summary.retrieval_cases,
|
||||
chunks = summary.retrieved_context.total_chunks,
|
||||
tokens = summary.retrieved_context.total_tokens,
|
||||
tokenizer = summary.retrieved_context.tokenizer,
|
||||
avg_tokens = summary.retrieved_context.avg_tokens_per_query,
|
||||
p95 = summary.retrieved_context.p95_tokens_per_query,
|
||||
json = report.paths.json.display(),
|
||||
md = report.paths.markdown.display(),
|
||||
history = report.history_path.display(),
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
|
||||
pub fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
|
||||
pub fn ingestion_openai_client(
|
||||
include_entities: bool,
|
||||
) -> Result<(Arc<Client<OpenAIConfig>>, Option<String>)> {
|
||||
if include_entities {
|
||||
let (client, base_url) = build_client_from_env().context(
|
||||
"OPENAI_API_KEY must be set when --include-entities is enabled (entity extraction uses OpenAI)",
|
||||
)?;
|
||||
Ok((Arc::new(client), Some(base_url)))
|
||||
} else {
|
||||
Ok((Arc::new(Client::with_config(OpenAIConfig::default())), None))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.context("OPENAI_API_KEY must be set to run retrieval evaluations")?;
|
||||
let base_url =
|
||||
|
||||
+11
-7
@@ -7,8 +7,8 @@ use anyhow::{Context, Result};
|
||||
|
||||
use crate::{
|
||||
args,
|
||||
eval::EvaluationSummary,
|
||||
report::{self, EvaluationReport},
|
||||
types::EvaluationSummary,
|
||||
};
|
||||
|
||||
pub fn mirror_perf_outputs(
|
||||
@@ -91,23 +91,26 @@ fn format_duration(value: Option<u128>) -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::eval::{EvaluationStageTimings, PerformanceTimings};
|
||||
use crate::types::{
|
||||
EvaluationStageTimings, LatencyStats, PerformanceTimings, StageLatency,
|
||||
StageLatencyBreakdown,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn sample_latency() -> crate::eval::LatencyStats {
|
||||
crate::eval::LatencyStats {
|
||||
fn sample_latency() -> LatencyStats {
|
||||
LatencyStats {
|
||||
avg: 10.0,
|
||||
p50: 8,
|
||||
p95: 15,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
|
||||
crate::eval::StageLatencyBreakdown {
|
||||
fn sample_stage_latency() -> StageLatencyBreakdown {
|
||||
StageLatencyBreakdown {
|
||||
stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
|
||||
.into_iter()
|
||||
.map(|stage| crate::eval::StageLatency {
|
||||
.map(|stage| StageLatency {
|
||||
stage: stage.to_string(),
|
||||
stats: sample_latency(),
|
||||
})
|
||||
@@ -206,6 +209,7 @@ mod tests {
|
||||
chunk_vector_take: 20,
|
||||
chunk_fts_take: 20,
|
||||
max_chunks_per_entity: 4,
|
||||
retrieved_context: crate::context_stats::aggregate_context_stats(&[]),
|
||||
cases: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,11 +20,11 @@ use retrieval_pipeline::{
|
||||
|
||||
use crate::{
|
||||
args::Config,
|
||||
cache::EmbeddingCache,
|
||||
cases::SeededCase,
|
||||
corpus,
|
||||
datasets::ConvertedDataset,
|
||||
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
|
||||
slice, snapshot,
|
||||
slice,
|
||||
types::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary},
|
||||
};
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
@@ -41,12 +41,10 @@ pub(super) struct EvaluationContext<'a> {
|
||||
pub namespace: String,
|
||||
pub database: String,
|
||||
pub db: Option<SurrealDbClient>,
|
||||
pub descriptor: Option<snapshot::Descriptor>,
|
||||
pub settings: Option<SystemSettings>,
|
||||
pub settings_missing: bool,
|
||||
pub must_reapply_settings: bool,
|
||||
pub embedding_provider: Option<EmbeddingProvider>,
|
||||
pub embedding_cache: Option<EmbeddingCache>,
|
||||
pub openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||
pub openai_base_url: Option<String>,
|
||||
pub expected_fingerprint: Option<String>,
|
||||
@@ -67,13 +65,19 @@ pub(super) struct EvaluationContext<'a> {
|
||||
pub summary: Option<EvaluationSummary>,
|
||||
pub diagnostics_path: Option<PathBuf>,
|
||||
pub diagnostics_enabled: bool,
|
||||
pub content_checksum: Option<String>,
|
||||
}
|
||||
|
||||
impl<'a> EvaluationContext<'a> {
|
||||
pub fn new(dataset: &'a ConvertedDataset, config: &'a Config) -> Self {
|
||||
pub fn new(
|
||||
dataset: &'a ConvertedDataset,
|
||||
config: &'a Config,
|
||||
content_checksum: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
dataset,
|
||||
config,
|
||||
content_checksum,
|
||||
stage_timings: EvaluationStageTimings::default(),
|
||||
ledger_limit: None,
|
||||
slice_settings: None,
|
||||
@@ -84,12 +88,10 @@ impl<'a> EvaluationContext<'a> {
|
||||
namespace: String::new(),
|
||||
database: String::new(),
|
||||
db: None,
|
||||
descriptor: None,
|
||||
settings: None,
|
||||
settings_missing: false,
|
||||
must_reapply_settings: false,
|
||||
embedding_provider: None,
|
||||
embedding_cache: None,
|
||||
openai_client: None,
|
||||
openai_base_url: None,
|
||||
expected_fingerprint: None,
|
||||
@@ -133,12 +135,6 @@ impl<'a> EvaluationContext<'a> {
|
||||
.ok_or_else(|| anyhow!("database connection missing"))
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> Result<&snapshot::Descriptor> {
|
||||
self.descriptor
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("snapshot descriptor unavailable"))
|
||||
}
|
||||
|
||||
pub fn embedding_provider(&self) -> Result<&EmbeddingProvider> {
|
||||
self.embedding_provider
|
||||
.as_ref()
|
||||
@@ -159,6 +155,10 @@ impl<'a> EvaluationContext<'a> {
|
||||
.ok_or_else(|| anyhow!("corpus handle missing"))
|
||||
}
|
||||
|
||||
pub fn content_checksum(&self) -> Option<&str> {
|
||||
self.content_checksum.as_deref()
|
||||
}
|
||||
|
||||
pub fn evaluation_user(&self) -> Result<&User> {
|
||||
self.eval_user
|
||||
.as_ref()
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use crate::{args, types::CaseDiagnostics};
|
||||
|
||||
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
|
||||
args::ensure_parent(path)?;
|
||||
let mut file = tokio::fs::File::create(path)
|
||||
.await
|
||||
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
|
||||
for case in cases {
|
||||
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
|
||||
file.write_all(&line).await?;
|
||||
file.write_all(b"\n").await?;
|
||||
}
|
||||
file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
mod context;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
@@ -8,20 +8,40 @@ use crate::{args::Config, datasets::ConvertedDataset, types::EvaluationSummary};
|
||||
|
||||
use context::EvaluationContext;
|
||||
|
||||
async fn run_through_namespace<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
config: &'a Config,
|
||||
content_checksum: Option<String>,
|
||||
) -> Result<EvaluationContext<'a>> {
|
||||
let mut ctx = EvaluationContext::new(dataset, config, content_checksum);
|
||||
stages::prepare_slice(&mut ctx).await?;
|
||||
stages::prepare_db(&mut ctx).await?;
|
||||
stages::prepare_corpus(&mut ctx).await?;
|
||||
stages::prepare_namespace(&mut ctx).await?;
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
pub async fn warm_evaluation(
|
||||
dataset: &ConvertedDataset,
|
||||
config: &Config,
|
||||
content_checksum: &str,
|
||||
) -> Result<()> {
|
||||
let _ctx = run_through_namespace(dataset, config, Some(content_checksum.to_string())).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_evaluation(
|
||||
dataset: &ConvertedDataset,
|
||||
config: &Config,
|
||||
content_checksum: Option<&str>,
|
||||
) -> Result<EvaluationSummary> {
|
||||
let mut ctx = EvaluationContext::new(dataset, config);
|
||||
let machine = state::ready();
|
||||
|
||||
let machine = stages::prepare_slice(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_db(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_corpus(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_namespace(machine, &mut ctx).await?;
|
||||
let machine = stages::run_queries(machine, &mut ctx).await?;
|
||||
let machine = stages::summarize(machine, &mut ctx).await?;
|
||||
let _ = stages::finalize(machine, &mut ctx).await?;
|
||||
|
||||
let mut ctx = EvaluationContext::new(dataset, config, content_checksum.map(str::to_string));
|
||||
stages::prepare_slice(&mut ctx).await?;
|
||||
stages::prepare_db(&mut ctx).await?;
|
||||
stages::prepare_corpus(&mut ctx).await?;
|
||||
stages::prepare_namespace(&mut ctx).await?;
|
||||
stages::run_queries(&mut ctx).await?;
|
||||
stages::summarize(&mut ctx).await?;
|
||||
stages::finalize(&mut ctx).await?;
|
||||
ctx.into_summary()
|
||||
}
|
||||
|
||||
@@ -3,18 +3,12 @@ use std::time::Instant;
|
||||
use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::eval::write_chunk_diagnostics;
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{Completed, EvaluationMachine, Summarized},
|
||||
diagnostics::write_chunk_diagnostics,
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn finalize(
|
||||
machine: EvaluationMachine<(), Summarized>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<Completed> {
|
||||
pub(crate) async fn finalize(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::Finalize;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -22,13 +16,6 @@ pub(crate) async fn finalize(
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
if let Some(cache) = ctx.embedding_cache.as_ref() {
|
||||
cache
|
||||
.persist()
|
||||
.await
|
||||
.context("persisting embedding cache")?;
|
||||
}
|
||||
|
||||
if let Some(path) = ctx.diagnostics_path.as_ref() {
|
||||
if ctx.diagnostics_enabled {
|
||||
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
|
||||
@@ -53,7 +40,5 @@ pub(crate) async fn finalize(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.finalize()
|
||||
.map_err(|(_, guard)| map_guard_error("finalize", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -13,14 +13,3 @@ pub(crate) use prepare_namespace::prepare_namespace;
|
||||
pub(crate) use prepare_slice::prepare_slice;
|
||||
pub(crate) use run_queries::run_queries;
|
||||
pub(crate) use summarize::summarize;
|
||||
|
||||
use anyhow::Result;
|
||||
use state_machines::core::GuardError;
|
||||
|
||||
use super::state::EvaluationMachine;
|
||||
|
||||
fn map_guard_error(event: &str, guard: &GuardError) -> anyhow::Error {
|
||||
anyhow::anyhow!("invalid evaluation pipeline transition during {event}: {guard:?}")
|
||||
}
|
||||
|
||||
type StageResult<S> = Result<EvaluationMachine<(), S>>;
|
||||
|
||||
@@ -3,19 +3,12 @@ use std::time::Instant;
|
||||
use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{corpus, eval::can_reuse_namespace, slice, snapshot};
|
||||
use crate::{corpus, db::can_reuse_namespace, slice};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{CorpusReady, DbReady, EvaluationMachine},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn prepare_corpus(
|
||||
machine: EvaluationMachine<(), DbReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<CorpusReady> {
|
||||
pub(crate) async fn prepare_corpus(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::PrepareCorpus;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -31,13 +24,13 @@ pub(crate) async fn prepare_corpus(
|
||||
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
|
||||
.context("selecting slice window for corpus preparation")?;
|
||||
|
||||
let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider()?);
|
||||
let ingestion_config = corpus::make_ingestion_config(config);
|
||||
let expected_fingerprint = corpus::compute_ingestion_fingerprint(
|
||||
ctx.dataset(),
|
||||
slice,
|
||||
config.converted_dataset_path.as_path(),
|
||||
&ingestion_config,
|
||||
ctx.content_checksum(),
|
||||
)?;
|
||||
let base_dir = corpus::cached_corpus_dir(
|
||||
&cache_settings,
|
||||
@@ -47,19 +40,18 @@ pub(crate) async fn prepare_corpus(
|
||||
|
||||
if !config.reseed_slice {
|
||||
let requested_cases = window.cases.len();
|
||||
if can_reuse_namespace(
|
||||
ctx.db()?,
|
||||
&descriptor,
|
||||
&ctx.namespace,
|
||||
&ctx.database,
|
||||
ctx.dataset().metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
if let Some(manifest) = corpus::load_cached_manifest(&base_dir)? {
|
||||
if let Some(manifest) = corpus::load_cached_manifest(&base_dir)? {
|
||||
if can_reuse_namespace(
|
||||
ctx.db()?,
|
||||
&manifest,
|
||||
&embedding_provider,
|
||||
&ctx.namespace,
|
||||
&ctx.database,
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
info!(
|
||||
cache = %base_dir.display(),
|
||||
namespace = ctx.namespace.as_str(),
|
||||
@@ -70,7 +62,6 @@ pub(crate) async fn prepare_corpus(
|
||||
ctx.corpus_handle = Some(corpus_handle);
|
||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||
ctx.ingestion_duration_ms = 0;
|
||||
ctx.descriptor = Some(descriptor);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
@@ -80,14 +71,8 @@ pub(crate) async fn prepare_corpus(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
return machine
|
||||
.prepare_corpus()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_corpus", &guard));
|
||||
return Ok(());
|
||||
}
|
||||
info!(
|
||||
cache = %base_dir.display(),
|
||||
"Namespace reusable but cached manifest missing; regenerating corpus"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +88,7 @@ pub(crate) async fn prepare_corpus(
|
||||
openai_client,
|
||||
&eval_user_id,
|
||||
config.converted_dataset_path.as_path(),
|
||||
ctx.content_checksum(),
|
||||
ingestion_config.clone(),
|
||||
)
|
||||
.await
|
||||
@@ -126,7 +112,6 @@ pub(crate) async fn prepare_corpus(
|
||||
ctx.corpus_handle = Some(corpus_handle);
|
||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||
ctx.ingestion_duration_ms = ingestion_duration_ms;
|
||||
ctx.descriptor = Some(descriptor);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
@@ -136,7 +121,5 @@ pub(crate) async fn prepare_corpus(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_corpus()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_corpus", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
use std::{sync::Arc, time::Instant};
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
args::EmbeddingBackend,
|
||||
cache::EmbeddingCache,
|
||||
eval::{
|
||||
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
|
||||
},
|
||||
db::{connect_eval_db, sanitize_model_code},
|
||||
openai,
|
||||
settings::{enforce_system_settings, load_or_init_system_settings},
|
||||
};
|
||||
use common::utils::embedding::{default_embedding_pool_size, EmbeddingProvider};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{DbReady, EvaluationMachine, SlicePrepared},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
pub(crate) async fn prepare_db(
|
||||
machine: EvaluationMachine<(), SlicePrepared>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<DbReady> {
|
||||
pub(crate) async fn prepare_db(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::PrepareDb;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -36,19 +27,18 @@ pub(crate) async fn prepare_db(
|
||||
|
||||
let db = connect_eval_db(config, &namespace, &database).await?;
|
||||
|
||||
let (raw_openai_client, openai_base_url) =
|
||||
openai::build_client_from_env().context("building OpenAI client")?;
|
||||
let openai_client = Arc::new(raw_openai_client);
|
||||
let (openai_client, openai_base_url) =
|
||||
openai::ingestion_openai_client(config.ingest.include_entities)
|
||||
.context("building OpenAI client for ingestion")?;
|
||||
|
||||
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
|
||||
let embedding_provider = match config.embedding_backend {
|
||||
crate::args::EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed(
|
||||
EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed(
|
||||
config.embedding_model.clone(),
|
||||
default_embedding_pool_size(),
|
||||
)
|
||||
.await
|
||||
.context("creating FastEmbed provider")?,
|
||||
crate::args::EmbeddingBackend::Hashed => {
|
||||
EmbeddingBackend::Hashed => {
|
||||
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
|
||||
}
|
||||
};
|
||||
@@ -68,12 +58,14 @@ pub(crate) async fn prepare_db(
|
||||
dimension = provider_dimension,
|
||||
"Embedding provider initialised"
|
||||
);
|
||||
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
|
||||
if let Some(base_url) = &openai_base_url {
|
||||
info!(openai_base_url = %base_url, "OpenAI client configured for entity ingestion");
|
||||
}
|
||||
|
||||
let (mut settings, settings_missing) =
|
||||
load_or_init_system_settings(&db, provider_dimension).await?;
|
||||
|
||||
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
|
||||
if config.embedding_backend == EmbeddingBackend::FastEmbed {
|
||||
if let Some(model_code) = embedding_provider.model_code() {
|
||||
let sanitized = sanitize_model_code(&model_code);
|
||||
let path = config.cache_dir.join(format!("{sanitized}.json"));
|
||||
@@ -83,15 +75,8 @@ pub(crate) async fn prepare_db(
|
||||
.with_context(|| format!("removing stale cache {}", path.display()))
|
||||
.ok();
|
||||
}
|
||||
let cache = EmbeddingCache::load(&path).await?;
|
||||
info!(path = %path.display(), "Embedding cache ready");
|
||||
Some(cache)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
let must_reapply_settings = settings_missing;
|
||||
let defer_initial_enforce = settings_missing && !config.reseed_slice;
|
||||
@@ -104,9 +89,8 @@ pub(crate) async fn prepare_db(
|
||||
ctx.must_reapply_settings = must_reapply_settings;
|
||||
ctx.settings = Some(settings);
|
||||
ctx.embedding_provider = Some(embedding_provider);
|
||||
ctx.embedding_cache = embedding_cache;
|
||||
ctx.openai_client = Some(openai_client);
|
||||
ctx.openai_base_url = Some(openai_base_url);
|
||||
ctx.openai_base_url = openai_base_url;
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
@@ -116,7 +100,5 @@ pub(crate) async fn prepare_db(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_db()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_db", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,25 +5,19 @@ use common::storage::types::system_settings::SystemSettings;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
cases::cases_from_manifest,
|
||||
corpus,
|
||||
db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace},
|
||||
eval::{
|
||||
can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user,
|
||||
record_namespace_state, warm_hnsw_cache,
|
||||
db::{
|
||||
can_reuse_namespace, ensure_eval_user, record_namespace_seed, recreate_indexes,
|
||||
reset_namespace, warm_hnsw_cache,
|
||||
},
|
||||
settings::enforce_system_settings,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{CorpusReady, EvaluationMachine, NamespaceReady},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn prepare_namespace(
|
||||
machine: EvaluationMachine<(), CorpusReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<NamespaceReady> {
|
||||
pub(crate) async fn prepare_namespace(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::PrepareNamespace;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -32,7 +26,6 @@ pub(crate) async fn prepare_namespace(
|
||||
let started = Instant::now();
|
||||
|
||||
let config = ctx.config();
|
||||
let dataset = ctx.dataset();
|
||||
let expected_fingerprint = ctx
|
||||
.expected_fingerprint
|
||||
.as_deref()
|
||||
@@ -60,20 +53,16 @@ pub(crate) async fn prepare_namespace(
|
||||
|
||||
let mut namespace_reused = false;
|
||||
if !config.reseed_slice {
|
||||
namespace_reused = {
|
||||
let slice = ctx.slice()?;
|
||||
can_reuse_namespace(
|
||||
ctx.db()?,
|
||||
ctx.descriptor()?,
|
||||
&namespace,
|
||||
&database,
|
||||
dataset.metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
namespace_reused = can_reuse_namespace(
|
||||
ctx.db()?,
|
||||
base_manifest,
|
||||
&embedding_provider,
|
||||
&namespace,
|
||||
&database,
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let mut namespace_seed_ms = None;
|
||||
@@ -114,34 +103,20 @@ pub(crate) async fn prepare_namespace(
|
||||
"Seeding ingestion corpus into SurrealDB"
|
||||
);
|
||||
}
|
||||
let indexes_disabled = remove_all_indexes(ctx.db()?).await.is_ok();
|
||||
|
||||
let seed_start = Instant::now();
|
||||
corpus::seed_manifest_into_db(ctx.db()?, &manifest_for_seed)
|
||||
.await
|
||||
.context("seeding ingestion corpus from manifest")?;
|
||||
namespace_seed_ms = Some(seed_start.elapsed().as_millis());
|
||||
|
||||
// Recreate indexes AFTER data is loaded (correct bulk loading pattern)
|
||||
if indexes_disabled {
|
||||
info!("Recreating indexes after seeding data");
|
||||
recreate_indexes(ctx.db()?, embedding_provider.dimension())
|
||||
.await
|
||||
.context("recreating indexes with correct dimension")?;
|
||||
warm_hnsw_cache(ctx.db()?, embedding_provider.dimension()).await?;
|
||||
}
|
||||
{
|
||||
let slice = ctx.slice()?;
|
||||
record_namespace_state(
|
||||
ctx.descriptor()?,
|
||||
dataset.metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
&namespace,
|
||||
&database,
|
||||
requested_cases,
|
||||
)
|
||||
.await;
|
||||
info!("Recreating indexes after seeding data");
|
||||
recreate_indexes(ctx.db()?, embedding_provider.dimension())
|
||||
.await
|
||||
.context("recreating indexes with correct dimension")?;
|
||||
warm_hnsw_cache(ctx.db()?, embedding_provider.dimension()).await?;
|
||||
|
||||
if let Some(handle) = ctx.corpus_handle.as_mut() {
|
||||
record_namespace_seed(handle, &namespace, &database, requested_cases).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +173,5 @@ pub(crate) async fn prepare_namespace(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_namespace()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_namespace", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,20 +4,13 @@ use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
eval::{default_database, default_namespace, ledger_target},
|
||||
db::{default_database, default_namespace},
|
||||
slice,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, Ready, SlicePrepared},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
pub(crate) async fn prepare_slice(
|
||||
machine: EvaluationMachine<(), Ready>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<SlicePrepared> {
|
||||
pub(crate) async fn prepare_slice(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::PrepareSlice;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -25,7 +18,7 @@ pub(crate) async fn prepare_slice(
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
let ledger_limit = ledger_target(ctx.config());
|
||||
let ledger_limit = slice::ledger_target(ctx.config());
|
||||
let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit);
|
||||
let resolved_slice =
|
||||
slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?;
|
||||
@@ -49,7 +42,11 @@ pub(crate) async fn prepare_slice(
|
||||
.db_namespace
|
||||
.clone()
|
||||
.unwrap_or_else(|| {
|
||||
default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit)
|
||||
default_namespace(
|
||||
ctx.dataset().metadata.id.as_str(),
|
||||
ctx.config().limit,
|
||||
ctx.config().slice.as_deref(),
|
||||
)
|
||||
});
|
||||
ctx.database = ctx
|
||||
.config()
|
||||
@@ -66,7 +63,5 @@ pub(crate) async fn prepare_slice(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_slice()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_slice", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,9 +5,13 @@ use common::storage::types::StoredObject;
|
||||
use futures::stream::{self, StreamExt};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::eval::{
|
||||
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||
CaseSummary, RetrievedSummary,
|
||||
use crate::{
|
||||
cases::SeededCase,
|
||||
context_stats,
|
||||
types::{
|
||||
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||
CaseSummary, RetrievedSummary,
|
||||
},
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
pipeline::{self, RetrievalConfig, StageTimings},
|
||||
@@ -15,17 +19,10 @@ use retrieval_pipeline::{
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, NamespaceReady, QueriesFinished},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
||||
pub(crate) async fn run_queries(
|
||||
machine: EvaluationMachine<(), NamespaceReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<QueriesFinished> {
|
||||
pub(crate) async fn run_queries(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::RunQueries;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -153,7 +150,7 @@ pub(crate) async fn run_queries(
|
||||
.await
|
||||
.context("acquiring query semaphore permit")?;
|
||||
|
||||
let crate::eval::SeededCase {
|
||||
let SeededCase {
|
||||
question_id,
|
||||
question,
|
||||
expected_source,
|
||||
@@ -197,6 +194,7 @@ pub(crate) async fn run_queries(
|
||||
let query_latency = query_start.elapsed().as_millis();
|
||||
|
||||
let candidates = adapt_retrieval_output(result_output);
|
||||
let retrieved_context = context_stats::stats_for_candidates(&candidates);
|
||||
let mut retrieved = Vec::new();
|
||||
let mut match_rank = None;
|
||||
let answers_lower: Vec<String> =
|
||||
@@ -288,6 +286,7 @@ pub(crate) async fn run_queries(
|
||||
reciprocal_rank: Some(reciprocal_rank),
|
||||
ndcg: Some(ndcg),
|
||||
latency_ms: query_latency,
|
||||
retrieved_context,
|
||||
retrieved,
|
||||
};
|
||||
|
||||
@@ -353,9 +352,7 @@ pub(crate) async fn run_queries(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.run_queries()
|
||||
.map_err(|(_, guard)| map_guard_error("run_queries", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::cast_precision_loss)]
|
||||
|
||||
@@ -3,25 +3,19 @@ use std::time::Instant;
|
||||
use chrono::Utc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::eval::{
|
||||
use crate::types::{
|
||||
build_stage_latency_breakdown, compute_latency_stats, EvaluationSummary, PerformanceTimings,
|
||||
RetrievedContextStats,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, QueriesFinished, Summarized},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
#[allow(
|
||||
clippy::too_many_lines,
|
||||
clippy::arithmetic_side_effects,
|
||||
clippy::cast_precision_loss
|
||||
)]
|
||||
pub(crate) async fn summarize(
|
||||
machine: EvaluationMachine<(), QueriesFinished>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<Summarized> {
|
||||
pub(crate) async fn summarize(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::Summarize;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -123,6 +117,12 @@ pub(crate) async fn summarize(
|
||||
sum_ndcg / (retrieval_cases as f64)
|
||||
};
|
||||
|
||||
let per_query_context: Vec<RetrievedContextStats> = summaries
|
||||
.iter()
|
||||
.map(|summary| summary.retrieved_context)
|
||||
.collect();
|
||||
let retrieved_context = crate::context_stats::aggregate_context_stats(&per_query_context);
|
||||
|
||||
let active_tuning = ctx
|
||||
.retrieval_config
|
||||
.as_ref()
|
||||
@@ -133,7 +133,7 @@ pub(crate) async fn summarize(
|
||||
openai_base_url: ctx
|
||||
.openai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "<unknown>".to_string()),
|
||||
.unwrap_or_else(|| "n/a (chunk-only ingestion)".to_string()),
|
||||
ingestion_ms: ctx.ingestion_duration_ms,
|
||||
namespace_seed_ms: ctx.namespace_seed_ms,
|
||||
evaluation_stage_ms: ctx.stage_timings.clone(),
|
||||
@@ -217,11 +217,12 @@ pub(crate) async fn summarize(
|
||||
chunk_rrf_use_fts: active_tuning.flags.chunk_rrf_use_fts.as_bool(),
|
||||
ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
|
||||
ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
|
||||
ingest_chunks_only: config.ingest.ingest_chunks_only,
|
||||
ingest_chunks_only: !config.ingest.include_entities,
|
||||
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
||||
chunk_vector_take: active_tuning.chunk_vector_take,
|
||||
chunk_fts_take: active_tuning.chunk_fts_take,
|
||||
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
|
||||
retrieved_context,
|
||||
cases: summaries,
|
||||
});
|
||||
|
||||
@@ -233,7 +234,5 @@ pub(crate) async fn summarize(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.summarize()
|
||||
.map_err(|(_, guard)| map_guard_error("summarize", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
use state_machines::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: EvaluationMachine,
|
||||
state: EvaluationState,
|
||||
initial: Ready,
|
||||
states: [Ready, SlicePrepared, DbReady, CorpusReady, NamespaceReady, QueriesFinished, Summarized, Completed, Failed],
|
||||
events {
|
||||
prepare_slice { transition: { from: Ready, to: SlicePrepared } }
|
||||
prepare_db { transition: { from: SlicePrepared, to: DbReady } }
|
||||
prepare_corpus { transition: { from: DbReady, to: CorpusReady } }
|
||||
prepare_namespace { transition: { from: CorpusReady, to: NamespaceReady } }
|
||||
run_queries { transition: { from: NamespaceReady, to: QueriesFinished } }
|
||||
summarize { transition: { from: QueriesFinished, to: Summarized } }
|
||||
finalize { transition: { from: Summarized, to: Completed } }
|
||||
abort {
|
||||
transition: { from: Ready, to: Failed }
|
||||
transition: { from: SlicePrepared, to: Failed }
|
||||
transition: { from: DbReady, to: Failed }
|
||||
transition: { from: CorpusReady, to: Failed }
|
||||
transition: { from: NamespaceReady, to: Failed }
|
||||
transition: { from: QueriesFinished, to: Failed }
|
||||
transition: { from: Summarized, to: Failed }
|
||||
transition: { from: Completed, to: Failed }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ready() -> EvaluationMachine<(), Ready> {
|
||||
EvaluationMachine::new(())
|
||||
}
|
||||
+81
-212
@@ -7,12 +7,10 @@ use std::{
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::eval::{
|
||||
use crate::types::{
|
||||
format_timestamp, CaseSummary, EvaluationStageTimings, EvaluationSummary, LatencyStats,
|
||||
StageLatencyBreakdown,
|
||||
RetrievalContextStats, StageLatencyBreakdown,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportPaths {
|
||||
@@ -108,6 +106,7 @@ pub struct RetrievalSection {
|
||||
pub ingest_chunk_max_tokens: usize,
|
||||
pub ingest_chunk_overlap_tokens: usize,
|
||||
pub ingest_chunks_only: bool,
|
||||
pub retrieved_context: RetrievalContextStats,
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_k() -> f32 {
|
||||
@@ -242,6 +241,7 @@ impl EvaluationReport {
|
||||
ingest_chunk_max_tokens: summary.ingest_chunk_max_tokens,
|
||||
ingest_chunk_overlap_tokens: summary.ingest_chunk_overlap_tokens,
|
||||
ingest_chunks_only: summary.ingest_chunks_only,
|
||||
retrieved_context: summary.retrieved_context.clone(),
|
||||
};
|
||||
|
||||
let llm = if summary.llm_cases > 0 {
|
||||
@@ -345,7 +345,7 @@ impl LlmCaseEntry {
|
||||
}
|
||||
|
||||
impl RetrievedSnippet {
|
||||
fn from_summary(entry: &crate::eval::RetrievedSummary) -> Self {
|
||||
fn from_summary(entry: &crate::types::RetrievedSummary) -> Self {
|
||||
Self {
|
||||
rank: entry.rank,
|
||||
source_id: entry.source_id.clone(),
|
||||
@@ -558,6 +558,65 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
||||
} else {
|
||||
md.push_str("| Rerank | disabled |\\n");
|
||||
}
|
||||
write!(
|
||||
md,
|
||||
"| Chunk result cap | {} |\\n",
|
||||
report.retrieval.chunk_result_cap
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
md.push_str("\\n## Retrieved Context Volume\\n\\n");
|
||||
md.push_str("| Metric | Value |\\n| --- | --- |\\n");
|
||||
write!(
|
||||
md,
|
||||
"| Tokenizer | {} |\\n",
|
||||
report.retrieval.retrieved_context.tokenizer
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Queries measured | {} |\\n",
|
||||
report.retrieval.retrieved_context.queries
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Total chunks returned | {} |\\n",
|
||||
report.retrieval.retrieved_context.total_chunks
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Total characters | {} |\\n",
|
||||
report.retrieval.retrieved_context.total_chars
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Total tokens | {} |\\n",
|
||||
report.retrieval.retrieved_context.total_tokens
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Avg chunks / query | {:.1} |\\n",
|
||||
report.retrieval.retrieved_context.avg_chunks_per_query
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Avg tokens / query | {:.1} |\\n",
|
||||
report.retrieval.retrieved_context.avg_tokens_per_query
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| P50 / P95 / max tokens / query | {} / {} / {} |\\n",
|
||||
report.retrieval.retrieved_context.p50_tokens_per_query,
|
||||
report.retrieval.retrieved_context.p95_tokens_per_query,
|
||||
report.retrieval.retrieved_context.max_tokens_per_query
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
if let Some(llm) = &report.llm {
|
||||
md.push_str("\\n## LLM Mode Metrics\\n\\n");
|
||||
@@ -797,182 +856,6 @@ pub fn dataset_report_dir(report_dir: &Path, dataset_id: &str) -> PathBuf {
|
||||
report_dir.join(sanitize_component(dataset_id))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LegacyHistoryEntry {
|
||||
generated_at: String,
|
||||
run_label: Option<String>,
|
||||
dataset_id: String,
|
||||
dataset_label: String,
|
||||
slice_id: String,
|
||||
slice_seed: u64,
|
||||
slice_window_offset: usize,
|
||||
slice_window_length: usize,
|
||||
slice_cases: usize,
|
||||
slice_total_cases: usize,
|
||||
k: usize,
|
||||
limit: Option<usize>,
|
||||
precision: f64,
|
||||
precision_at_1: f64,
|
||||
precision_at_2: f64,
|
||||
precision_at_3: f64,
|
||||
#[serde(default)]
|
||||
mrr: f64,
|
||||
#[serde(default)]
|
||||
average_ndcg: f64,
|
||||
#[serde(default)]
|
||||
retrieval_cases: usize,
|
||||
#[serde(default)]
|
||||
retrieval_precision: f64,
|
||||
#[serde(default)]
|
||||
llm_cases: usize,
|
||||
#[serde(default)]
|
||||
llm_precision: f64,
|
||||
duration_ms: u128,
|
||||
latency_ms: LatencyStats,
|
||||
embedding_backend: String,
|
||||
embedding_model: Option<String>,
|
||||
ingestion_reused: bool,
|
||||
ingestion_embeddings_reused: bool,
|
||||
rerank_enabled: bool,
|
||||
rerank_keep_top: usize,
|
||||
rerank_pool_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
chunk_result_cap: Option<usize>,
|
||||
#[serde(default)]
|
||||
ingest_chunk_min_tokens: Option<usize>,
|
||||
#[serde(default)]
|
||||
ingest_chunk_max_tokens: Option<usize>,
|
||||
#[serde(default)]
|
||||
ingest_chunk_overlap_tokens: Option<usize>,
|
||||
#[serde(default)]
|
||||
ingest_chunks_only: Option<bool>,
|
||||
#[serde(default)]
|
||||
delta: Option<LegacyHistoryDelta>,
|
||||
openai_base_url: String,
|
||||
ingestion_ms: u128,
|
||||
#[serde(default)]
|
||||
namespace_seed_ms: Option<u128>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LegacyHistoryDelta {
|
||||
precision: f64,
|
||||
precision_at_1: f64,
|
||||
latency_avg_ms: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
||||
let overview = OverviewSection {
|
||||
generated_at: entry.generated_at,
|
||||
run_label: entry.run_label,
|
||||
total_cases: entry.slice_cases,
|
||||
filtered_questions: 0,
|
||||
};
|
||||
|
||||
let dataset = DatasetSection {
|
||||
id: entry.dataset_id,
|
||||
label: entry.dataset_label,
|
||||
source: String::new(),
|
||||
includes_unanswerable: entry.llm_cases > 0,
|
||||
require_verified_chunks: true,
|
||||
embedding_backend: entry.embedding_backend,
|
||||
embedding_model: entry.embedding_model,
|
||||
embedding_dimension: 0,
|
||||
};
|
||||
|
||||
let slice = SliceSection {
|
||||
id: entry.slice_id,
|
||||
seed: entry.slice_seed,
|
||||
window_offset: entry.slice_window_offset,
|
||||
window_length: entry.slice_window_length,
|
||||
slice_cases: entry.slice_cases,
|
||||
ledger_total_cases: entry.slice_total_cases,
|
||||
positives: 0,
|
||||
negatives: 0,
|
||||
total_paragraphs: 0,
|
||||
negative_multiplier: 0.0,
|
||||
};
|
||||
|
||||
let retrieval_cases = if entry.retrieval_cases > 0 {
|
||||
entry.retrieval_cases
|
||||
} else {
|
||||
entry.slice_cases.saturating_sub(entry.llm_cases)
|
||||
};
|
||||
let retrieval_precision = if entry.retrieval_precision > 0.0 {
|
||||
entry.retrieval_precision
|
||||
} else {
|
||||
entry.precision
|
||||
};
|
||||
|
||||
let retrieval = RetrievalSection {
|
||||
k: entry.k,
|
||||
cases: retrieval_cases,
|
||||
correct: 0,
|
||||
precision: retrieval_precision,
|
||||
precision_at_1: entry.precision_at_1,
|
||||
precision_at_2: entry.precision_at_2,
|
||||
precision_at_3: entry.precision_at_3,
|
||||
mrr: entry.mrr,
|
||||
average_ndcg: entry.average_ndcg,
|
||||
latency: entry.latency_ms,
|
||||
concurrency: 0,
|
||||
resolve_entities: false,
|
||||
rerank_enabled: entry.rerank_enabled,
|
||||
rerank_pool_size: entry.rerank_pool_size,
|
||||
rerank_keep_top: entry.rerank_keep_top,
|
||||
chunk_result_cap: entry.chunk_result_cap.unwrap_or(5),
|
||||
chunk_rrf_k: default_chunk_rrf_k(),
|
||||
chunk_rrf_vector_weight: default_chunk_rrf_weight(),
|
||||
chunk_rrf_fts_weight: default_chunk_rrf_weight(),
|
||||
chunk_rrf_use_vector: default_chunk_rrf_use(),
|
||||
chunk_rrf_use_fts: default_chunk_rrf_use(),
|
||||
chunk_vector_take: 0,
|
||||
chunk_fts_take: 0,
|
||||
ingest_chunk_min_tokens: entry.ingest_chunk_min_tokens.unwrap_or(256),
|
||||
ingest_chunk_max_tokens: entry.ingest_chunk_max_tokens.unwrap_or(512),
|
||||
ingest_chunk_overlap_tokens: entry.ingest_chunk_overlap_tokens.unwrap_or(50),
|
||||
ingest_chunks_only: entry.ingest_chunks_only.unwrap_or(false),
|
||||
};
|
||||
|
||||
let llm = if entry.llm_cases > 0 {
|
||||
Some(LlmSection {
|
||||
cases: entry.llm_cases,
|
||||
answered: 0,
|
||||
precision: entry.llm_precision,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let performance = PerformanceSection {
|
||||
openai_base_url: entry.openai_base_url,
|
||||
ingestion_ms: entry.ingestion_ms,
|
||||
namespace_seed_ms: entry.namespace_seed_ms,
|
||||
evaluation_stages_ms: EvaluationStageTimings::default(),
|
||||
stage_latency: StageLatencyBreakdown::default(),
|
||||
namespace_reused: false,
|
||||
ingestion_reused: entry.ingestion_reused,
|
||||
embeddings_reused: entry.ingestion_embeddings_reused,
|
||||
ingestion_cache_path: String::new(),
|
||||
corpus_paragraphs: 0,
|
||||
positive_paragraphs_reused: 0,
|
||||
negative_paragraphs_reused: 0,
|
||||
};
|
||||
|
||||
EvaluationReport {
|
||||
overview,
|
||||
dataset,
|
||||
slice,
|
||||
retrieval,
|
||||
llm,
|
||||
performance,
|
||||
misses: Vec::new(),
|
||||
llm_cases: Vec::new(),
|
||||
detailed_report: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_history(path: &Path) -> Result<Vec<EvaluationReport>> {
|
||||
if !path.exists() {
|
||||
return Ok(Vec::new());
|
||||
@@ -981,34 +864,12 @@ fn load_history(path: &Path) -> Result<Vec<EvaluationReport>> {
|
||||
let contents =
|
||||
fs::read(path).with_context(|| format!("reading evaluation log {}", path.display()))?;
|
||||
|
||||
if let Ok(entries) = serde_json::from_slice::<Vec<EvaluationReport>>(&contents) {
|
||||
return Ok(entries);
|
||||
}
|
||||
|
||||
match serde_json::from_slice::<Vec<LegacyHistoryEntry>>(&contents) {
|
||||
Ok(entries) => Ok(entries.into_iter().map(convert_legacy_entry).collect()),
|
||||
Err(err) => {
|
||||
let timestamp = Utc::now().format("%Y%m%dT%H%M%S");
|
||||
let backup_path = path
|
||||
.parent()
|
||||
.unwrap_or_else(|| Path::new("."))
|
||||
.join(format!("evaluations.json.corrupted.{timestamp}"));
|
||||
warn!(
|
||||
path = %path.display(),
|
||||
backup = %backup_path.display(),
|
||||
error = %err,
|
||||
"Evaluation history file is corrupted; backing up and starting fresh"
|
||||
);
|
||||
if let Err(e) = fs::rename(path, &backup_path) {
|
||||
warn!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"Failed to backup corrupted evaluation history"
|
||||
);
|
||||
}
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
serde_json::from_slice(&contents).with_context(|| {
|
||||
format!(
|
||||
"parsing evaluation history at {}; delete the file and re-run if upgrading from an older format",
|
||||
path.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBuf> {
|
||||
@@ -1024,9 +885,9 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBu
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::eval::{
|
||||
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatency,
|
||||
StageLatencyBreakdown,
|
||||
use crate::types::{
|
||||
EvaluationStageTimings, PerformanceTimings, RetrievedContextStats, RetrievedSummary,
|
||||
StageLatency, StageLatencyBreakdown,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use tempfile::tempdir;
|
||||
@@ -1101,6 +962,7 @@ mod tests {
|
||||
has_verified_chunks: !is_impossible,
|
||||
match_rank: if matched { Some(1) } else { None },
|
||||
latency_ms: 42,
|
||||
retrieved_context: RetrievedContextStats::default(),
|
||||
retrieved: vec![RetrievedSummary {
|
||||
rank: 1,
|
||||
entity_id: "entity1".into(),
|
||||
@@ -1199,6 +1061,13 @@ mod tests {
|
||||
chunk_vector_take: 50,
|
||||
chunk_fts_take: 50,
|
||||
max_chunks_per_entity: 4,
|
||||
retrieved_context: crate::context_stats::aggregate_context_stats(&[
|
||||
RetrievedContextStats {
|
||||
chunk_count: 1,
|
||||
char_count: 10,
|
||||
token_count: 3,
|
||||
},
|
||||
]),
|
||||
cases,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::datasets::{ConvertedDataset, BEIR_DATASETS};
|
||||
|
||||
use super::build::{mix_seed, BuildParams};
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
||||
pub(super) fn ordered_question_refs_beir(
|
||||
dataset: &ConvertedDataset,
|
||||
params: &BuildParams,
|
||||
target_cases: usize,
|
||||
) -> Result<Vec<(usize, usize)>> {
|
||||
let prefixes: Vec<&str> = BEIR_DATASETS
|
||||
.iter()
|
||||
.map(|kind| kind.source_prefix())
|
||||
.collect();
|
||||
|
||||
let mut grouped: HashMap<&str, Vec<(usize, usize)>> = HashMap::new();
|
||||
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||
let include = if params.include_impossible {
|
||||
true
|
||||
} else {
|
||||
!question.is_impossible && !question.answers.is_empty()
|
||||
};
|
||||
if !include {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(prefix) = question_prefix(&question.id) else {
|
||||
warn!(
|
||||
question_id = %question.id,
|
||||
"Skipping BEIR question without expected prefix"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
if !prefixes.contains(&prefix) {
|
||||
warn!(
|
||||
question_id = %question.id,
|
||||
prefix = %prefix,
|
||||
"Skipping BEIR question with unknown subset prefix"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
grouped.entry(prefix).or_default().push((p_idx, q_idx));
|
||||
}
|
||||
}
|
||||
|
||||
if grouped.values().all(std::vec::Vec::is_empty) {
|
||||
return Err(anyhow!(
|
||||
"no eligible BEIR questions found; cannot build slice"
|
||||
));
|
||||
}
|
||||
|
||||
for prefix in &prefixes {
|
||||
if let Some(entries) = grouped.get_mut(prefix) {
|
||||
let seed = mix_seed(
|
||||
&format!("{}::{prefix}", dataset.metadata.id),
|
||||
params.base_seed,
|
||||
);
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
entries.shuffle(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
let dataset_count = prefixes.len().max(1);
|
||||
let base_quota = target_cases / dataset_count;
|
||||
let mut remainder = target_cases % dataset_count;
|
||||
|
||||
let mut quotas: HashMap<&str, usize> = HashMap::new();
|
||||
for prefix in &prefixes {
|
||||
let mut quota = base_quota;
|
||||
if remainder > 0 {
|
||||
quota += 1;
|
||||
remainder -= 1;
|
||||
}
|
||||
quotas.insert(*prefix, quota);
|
||||
}
|
||||
|
||||
let mut take_counts: HashMap<&str, usize> = HashMap::new();
|
||||
let mut spare_slots: HashMap<&str, usize> = HashMap::new();
|
||||
let mut shortfall = 0usize;
|
||||
|
||||
for prefix in &prefixes {
|
||||
let available = grouped.get(prefix).map_or(0, std::vec::Vec::len);
|
||||
let quota = *quotas.get(prefix).unwrap_or(&0);
|
||||
let take = quota.min(available);
|
||||
let missing = quota.saturating_sub(take);
|
||||
shortfall += missing;
|
||||
take_counts.insert(*prefix, take);
|
||||
spare_slots.insert(*prefix, available.saturating_sub(take));
|
||||
}
|
||||
|
||||
while shortfall > 0 {
|
||||
let mut allocated = false;
|
||||
for prefix in &prefixes {
|
||||
if shortfall == 0 {
|
||||
break;
|
||||
}
|
||||
let spare = spare_slots.get(prefix).copied().unwrap_or(0);
|
||||
if spare == 0 {
|
||||
continue;
|
||||
}
|
||||
if let Some(count) = take_counts.get_mut(prefix) {
|
||||
*count += 1;
|
||||
}
|
||||
spare_slots.insert(*prefix, spare - 1);
|
||||
shortfall = shortfall.saturating_sub(1);
|
||||
allocated = true;
|
||||
}
|
||||
if !allocated {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut queues: Vec<VecDeque<(usize, usize)>> = Vec::new();
|
||||
let mut total_selected = 0usize;
|
||||
for prefix in &prefixes {
|
||||
let take = *take_counts.get(prefix).unwrap_or(&0);
|
||||
let mut deque = VecDeque::new();
|
||||
if let Some(entries) = grouped.get(prefix) {
|
||||
for item in entries.iter().take(take) {
|
||||
deque.push_back(*item);
|
||||
total_selected += 1;
|
||||
}
|
||||
}
|
||||
queues.push(deque);
|
||||
}
|
||||
|
||||
if total_selected < target_cases {
|
||||
warn!(
|
||||
requested = target_cases,
|
||||
available = total_selected,
|
||||
"BEIR mix requested more questions than available after balancing; continuing with capped set"
|
||||
);
|
||||
}
|
||||
|
||||
let mut output = Vec::with_capacity(total_selected);
|
||||
loop {
|
||||
let mut progressed = false;
|
||||
for queue in &mut queues {
|
||||
if let Some(item) = queue.pop_front() {
|
||||
output.push(item);
|
||||
progressed = true;
|
||||
}
|
||||
}
|
||||
if !progressed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no eligible BEIR questions found; cannot build slice"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub(super) fn question_prefix(question_id: &str) -> Option<&'static str> {
|
||||
for prefix in BEIR_DATASETS.iter().map(|kind| kind.source_prefix()) {
|
||||
if let Some(rest) = question_id.strip_prefix(prefix) {
|
||||
if rest.starts_with('-') {
|
||||
return Some(prefix);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct BuildParams {
|
||||
pub include_impossible: bool,
|
||||
pub base_seed: u64,
|
||||
pub rng_seed: u64,
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
pub(super) fn mix_seed(dataset_id: &str, seed: u64) -> u64 {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(dataset_id.as_bytes());
|
||||
hasher.update(seed.to_le_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let mut bytes = [0u8; 8];
|
||||
bytes.copy_from_slice(&digest[..8]);
|
||||
u64::from_le_bytes(bytes)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet, VecDeque},
|
||||
collections::{HashMap, HashSet},
|
||||
fmt::Write,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
@@ -12,10 +12,16 @@ use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::datasets::{
|
||||
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, BEIR_DATASETS,
|
||||
use crate::{
|
||||
args::Config,
|
||||
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind},
|
||||
};
|
||||
|
||||
mod beir;
|
||||
mod build;
|
||||
|
||||
use build::{mix_seed, BuildParams};
|
||||
|
||||
const SLICE_VERSION: u32 = 2;
|
||||
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
|
||||
|
||||
@@ -80,8 +86,12 @@ pub enum SliceParagraphKind {
|
||||
Negative,
|
||||
}
|
||||
|
||||
pub fn paragraph_storage_key(paragraph_id: &str) -> String {
|
||||
sanitize_identifier(paragraph_id)
|
||||
}
|
||||
|
||||
pub(crate) fn default_shard_path(paragraph_id: &str) -> String {
|
||||
let sanitized = sanitize_identifier(paragraph_id);
|
||||
let sanitized = paragraph_storage_key(paragraph_id);
|
||||
format!("paragraphs/{sanitized}.json")
|
||||
}
|
||||
|
||||
@@ -210,13 +220,6 @@ struct SliceKey<'a> {
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BuildParams {
|
||||
include_impossible: bool,
|
||||
base_seed: u64,
|
||||
rng_seed: u64,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn resolve_slice<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
@@ -225,15 +228,28 @@ pub fn resolve_slice<'a>(
|
||||
let index = DatasetIndex::build(dataset);
|
||||
|
||||
if let Some(slice_arg) = config.explicit_slice {
|
||||
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
|
||||
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
|
||||
let path = explicit_slice_path(dataset, config, slice_arg);
|
||||
if path.exists() {
|
||||
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
|
||||
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
|
||||
info!(
|
||||
slice = %resolved.manifest.slice_id,
|
||||
path = %resolved.path.display(),
|
||||
cases = resolved.manifest.case_count,
|
||||
positives = resolved.manifest.positive_paragraphs,
|
||||
negatives = resolved.manifest.negative_paragraphs,
|
||||
"Using explicitly selected slice"
|
||||
);
|
||||
return Ok(resolved);
|
||||
}
|
||||
let resolved = materialize_slice_ledger(dataset, config, &index, slice_arg, path)?;
|
||||
info!(
|
||||
slice = %resolved.manifest.slice_id,
|
||||
path = %resolved.path.display(),
|
||||
cases = resolved.manifest.case_count,
|
||||
positives = resolved.manifest.positive_paragraphs,
|
||||
negatives = resolved.manifest.negative_paragraphs,
|
||||
"Using explicitly selected slice"
|
||||
"Built catalog slice ledger"
|
||||
);
|
||||
return Ok(resolved);
|
||||
}
|
||||
@@ -256,6 +272,82 @@ pub fn resolve_slice<'a>(
|
||||
.join("slices")
|
||||
.join(dataset.metadata.id.as_str());
|
||||
let path = base.join(format!("{slice_id}.json"));
|
||||
materialize_slice_ledger(dataset, config, &index, &slice_id, path)
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||
pub fn select_window<'a>(
|
||||
resolved: &'a ResolvedSlice<'a>,
|
||||
offset: usize,
|
||||
limit: Option<usize>,
|
||||
) -> Result<SliceWindow<'a>> {
|
||||
let total = resolved.manifest.case_count;
|
||||
if total == 0 {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' contains no cases",
|
||||
resolved.manifest.slice_id
|
||||
));
|
||||
}
|
||||
if offset >= total {
|
||||
return Err(anyhow!(
|
||||
"slice offset {offset} exceeds available cases ({total})",
|
||||
));
|
||||
}
|
||||
let available = total - offset;
|
||||
let requested = limit.unwrap_or(available).max(1);
|
||||
let length = requested.min(available);
|
||||
let cases = resolved.cases[offset..offset + length].to_vec();
|
||||
let mut seen = HashSet::new();
|
||||
let mut positive_ids = Vec::new();
|
||||
for case in &cases {
|
||||
if seen.insert(case.paragraph.id.as_str()) {
|
||||
positive_ids.push(case.paragraph.id.clone());
|
||||
}
|
||||
}
|
||||
Ok(SliceWindow {
|
||||
offset,
|
||||
length,
|
||||
total_cases: total,
|
||||
cases,
|
||||
positive_paragraph_ids: positive_ids,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
|
||||
select_window(resolved, 0, None)
|
||||
}
|
||||
|
||||
fn explicit_slice_path(
|
||||
dataset: &ConvertedDataset,
|
||||
config: &SliceConfig<'_>,
|
||||
slice_arg: &str,
|
||||
) -> PathBuf {
|
||||
let explicit_path = Path::new(slice_arg);
|
||||
if explicit_path.exists() {
|
||||
explicit_path.to_path_buf()
|
||||
} else {
|
||||
config
|
||||
.cache_dir
|
||||
.join("slices")
|
||||
.join(dataset.metadata.id.as_str())
|
||||
.join(format!("{slice_arg}.json"))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn materialize_slice_ledger<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
config: &SliceConfig<'_>,
|
||||
index: &DatasetIndex,
|
||||
slice_id: &str,
|
||||
path: PathBuf,
|
||||
) -> Result<ResolvedSlice<'a>> {
|
||||
let requested_corpus = config
|
||||
.corpus_limit
|
||||
.unwrap_or(dataset.paragraphs.len())
|
||||
.min(dataset.paragraphs.len())
|
||||
.max(1);
|
||||
|
||||
let total_questions = dataset
|
||||
.paragraphs
|
||||
@@ -339,7 +431,7 @@ pub fn resolve_slice<'a>(
|
||||
let mut manifest = manifest.unwrap_or_else(|| {
|
||||
empty_manifest(
|
||||
dataset,
|
||||
slice_id.clone(),
|
||||
slice_id.to_string(),
|
||||
¶ms,
|
||||
requested_corpus,
|
||||
config.negative_multiplier,
|
||||
@@ -396,52 +488,7 @@ pub fn resolve_slice<'a>(
|
||||
);
|
||||
}
|
||||
|
||||
let resolved = manifest_to_resolved(dataset, &index, manifest.clone(), path)?;
|
||||
|
||||
Ok(resolved)
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||
pub fn select_window<'a>(
|
||||
resolved: &'a ResolvedSlice<'a>,
|
||||
offset: usize,
|
||||
limit: Option<usize>,
|
||||
) -> Result<SliceWindow<'a>> {
|
||||
let total = resolved.manifest.case_count;
|
||||
if total == 0 {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' contains no cases",
|
||||
resolved.manifest.slice_id
|
||||
));
|
||||
}
|
||||
if offset >= total {
|
||||
return Err(anyhow!(
|
||||
"slice offset {offset} exceeds available cases ({total})",
|
||||
));
|
||||
}
|
||||
let available = total - offset;
|
||||
let requested = limit.unwrap_or(available).max(1);
|
||||
let length = requested.min(available);
|
||||
let cases = resolved.cases[offset..offset + length].to_vec();
|
||||
let mut seen = HashSet::new();
|
||||
let mut positive_ids = Vec::new();
|
||||
for case in &cases {
|
||||
if seen.insert(case.paragraph.id.as_str()) {
|
||||
positive_ids.push(case.paragraph.id.clone());
|
||||
}
|
||||
}
|
||||
Ok(SliceWindow {
|
||||
offset,
|
||||
length,
|
||||
total_cases: total,
|
||||
cases,
|
||||
positive_paragraph_ids: positive_ids,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
|
||||
select_window(resolved, 0, None)
|
||||
manifest_to_resolved(dataset, index, manifest, path)
|
||||
}
|
||||
|
||||
fn load_explicit_slice(
|
||||
@@ -450,16 +497,7 @@ fn load_explicit_slice(
|
||||
config: &SliceConfig<'_>,
|
||||
slice_arg: &str,
|
||||
) -> Result<(PathBuf, SliceManifest)> {
|
||||
let explicit_path = Path::new(slice_arg);
|
||||
let candidate_path = if explicit_path.exists() {
|
||||
explicit_path.to_path_buf()
|
||||
} else {
|
||||
config
|
||||
.cache_dir
|
||||
.join("slices")
|
||||
.join(dataset.metadata.id.as_str())
|
||||
.join(format!("{slice_arg}.json"))
|
||||
};
|
||||
let candidate_path = explicit_slice_path(dataset, config, slice_arg);
|
||||
|
||||
let manifest = read_manifest(&candidate_path)
|
||||
.with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?;
|
||||
@@ -613,7 +651,7 @@ fn ordered_question_refs(
|
||||
target_cases: usize,
|
||||
) -> Result<Vec<(usize, usize)>> {
|
||||
if dataset.metadata.id == DatasetKind::Beir.id() {
|
||||
return ordered_question_refs_beir(dataset, params, target_cases);
|
||||
return beir::ordered_question_refs_beir(dataset, params, target_cases);
|
||||
}
|
||||
|
||||
let mut question_refs = Vec::new();
|
||||
@@ -642,171 +680,6 @@ fn ordered_question_refs(
|
||||
Ok(question_refs)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
||||
fn ordered_question_refs_beir(
|
||||
dataset: &ConvertedDataset,
|
||||
params: &BuildParams,
|
||||
target_cases: usize,
|
||||
) -> Result<Vec<(usize, usize)>> {
|
||||
let prefixes: Vec<&str> = BEIR_DATASETS
|
||||
.iter()
|
||||
.map(|kind| kind.source_prefix())
|
||||
.collect();
|
||||
|
||||
let mut grouped: HashMap<&str, Vec<(usize, usize)>> = HashMap::new();
|
||||
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||
let include = if params.include_impossible {
|
||||
true
|
||||
} else {
|
||||
!question.is_impossible && !question.answers.is_empty()
|
||||
};
|
||||
if !include {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(prefix) = question_prefix(&question.id) else {
|
||||
warn!(
|
||||
question_id = %question.id,
|
||||
"Skipping BEIR question without expected prefix"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
if !prefixes.contains(&prefix) {
|
||||
warn!(
|
||||
question_id = %question.id,
|
||||
prefix = %prefix,
|
||||
"Skipping BEIR question with unknown subset prefix"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
grouped.entry(prefix).or_default().push((p_idx, q_idx));
|
||||
}
|
||||
}
|
||||
|
||||
if grouped.values().all(std::vec::Vec::is_empty) {
|
||||
return Err(anyhow!(
|
||||
"no eligible BEIR questions found; cannot build slice"
|
||||
));
|
||||
}
|
||||
|
||||
for prefix in &prefixes {
|
||||
if let Some(entries) = grouped.get_mut(prefix) {
|
||||
let seed = mix_seed(
|
||||
&format!("{}::{prefix}", dataset.metadata.id),
|
||||
params.base_seed,
|
||||
);
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
entries.shuffle(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
let dataset_count = prefixes.len().max(1);
|
||||
let base_quota = target_cases / dataset_count;
|
||||
let mut remainder = target_cases % dataset_count;
|
||||
|
||||
let mut quotas: HashMap<&str, usize> = HashMap::new();
|
||||
for prefix in &prefixes {
|
||||
let mut quota = base_quota;
|
||||
if remainder > 0 {
|
||||
quota += 1;
|
||||
remainder -= 1;
|
||||
}
|
||||
quotas.insert(*prefix, quota);
|
||||
}
|
||||
|
||||
let mut take_counts: HashMap<&str, usize> = HashMap::new();
|
||||
let mut spare_slots: HashMap<&str, usize> = HashMap::new();
|
||||
let mut shortfall = 0usize;
|
||||
|
||||
for prefix in &prefixes {
|
||||
let available = grouped.get(prefix).map_or(0, std::vec::Vec::len);
|
||||
let quota = *quotas.get(prefix).unwrap_or(&0);
|
||||
let take = quota.min(available);
|
||||
let missing = quota.saturating_sub(take);
|
||||
shortfall += missing;
|
||||
take_counts.insert(*prefix, take);
|
||||
spare_slots.insert(*prefix, available.saturating_sub(take));
|
||||
}
|
||||
|
||||
while shortfall > 0 {
|
||||
let mut allocated = false;
|
||||
for prefix in &prefixes {
|
||||
if shortfall == 0 {
|
||||
break;
|
||||
}
|
||||
let spare = spare_slots.get(prefix).copied().unwrap_or(0);
|
||||
if spare == 0 {
|
||||
continue;
|
||||
}
|
||||
if let Some(count) = take_counts.get_mut(prefix) {
|
||||
*count += 1;
|
||||
}
|
||||
spare_slots.insert(*prefix, spare - 1);
|
||||
shortfall = shortfall.saturating_sub(1);
|
||||
allocated = true;
|
||||
}
|
||||
if !allocated {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut queues: Vec<VecDeque<(usize, usize)>> = Vec::new();
|
||||
let mut total_selected = 0usize;
|
||||
for prefix in &prefixes {
|
||||
let take = *take_counts.get(prefix).unwrap_or(&0);
|
||||
let mut deque = VecDeque::new();
|
||||
if let Some(entries) = grouped.get(prefix) {
|
||||
for item in entries.iter().take(take) {
|
||||
deque.push_back(*item);
|
||||
total_selected += 1;
|
||||
}
|
||||
}
|
||||
queues.push(deque);
|
||||
}
|
||||
|
||||
if total_selected < target_cases {
|
||||
warn!(
|
||||
requested = target_cases,
|
||||
available = total_selected,
|
||||
"BEIR mix requested more questions than available after balancing; continuing with capped set"
|
||||
);
|
||||
}
|
||||
|
||||
let mut output = Vec::with_capacity(total_selected);
|
||||
loop {
|
||||
let mut progressed = false;
|
||||
for queue in &mut queues {
|
||||
if let Some(item) = queue.pop_front() {
|
||||
output.push(item);
|
||||
progressed = true;
|
||||
}
|
||||
}
|
||||
if !progressed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no eligible BEIR questions found; cannot build slice"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn question_prefix(question_id: &str) -> Option<&'static str> {
|
||||
for prefix in BEIR_DATASETS.iter().map(|kind| kind.source_prefix()) {
|
||||
if let Some(rest) = question_id.strip_prefix(prefix) {
|
||||
if rest.starts_with('-') {
|
||||
return Some(prefix);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
fn ensure_negative_pool(
|
||||
dataset: &ConvertedDataset,
|
||||
@@ -1028,15 +901,47 @@ fn compute_slice_id(key: &SliceKey<'_>) -> Result<String> {
|
||||
}))
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
fn mix_seed(dataset_id: &str, seed: u64) -> u64 {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(dataset_id.as_bytes());
|
||||
hasher.update(seed.to_le_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let mut bytes = [0u8; 8];
|
||||
bytes.copy_from_slice(&digest[..8]);
|
||||
u64::from_le_bytes(bytes)
|
||||
pub fn read_manifest_if_exists(path: &Path) -> Result<Option<SliceManifest>> {
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
read_manifest(path).map(Some)
|
||||
}
|
||||
|
||||
pub fn cached_manifest_path(config: &crate::args::Config) -> Option<PathBuf> {
|
||||
let slice_arg = config.slice.as_deref()?;
|
||||
let explicit_path = Path::new(slice_arg);
|
||||
if explicit_path.exists() {
|
||||
return Some(explicit_path.to_path_buf());
|
||||
}
|
||||
Some(
|
||||
config
|
||||
.cache_dir
|
||||
.join("slices")
|
||||
.join(config.dataset.id())
|
||||
.join(format!("{slice_arg}.json")),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn manifest_is_complete(manifest: &SliceManifest, config: &SliceConfig<'_>) -> bool {
|
||||
let requested_limit = config.limit.unwrap_or(manifest.case_count.max(1)).max(1);
|
||||
if manifest.case_count < requested_limit {
|
||||
return false;
|
||||
}
|
||||
|
||||
let requested_corpus = config
|
||||
.corpus_limit
|
||||
.unwrap_or(manifest.total_paragraphs.max(1))
|
||||
.max(1);
|
||||
let desired_negatives = desired_negative_target(
|
||||
manifest.positive_paragraphs,
|
||||
requested_corpus,
|
||||
manifest
|
||||
.total_paragraphs
|
||||
.max(manifest.positive_paragraphs.max(1)),
|
||||
config.negative_multiplier,
|
||||
);
|
||||
manifest.negative_paragraphs >= desired_negatives
|
||||
}
|
||||
|
||||
fn read_manifest(path: &Path) -> Result<SliceManifest> {
|
||||
@@ -1057,14 +962,37 @@ fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use crate::args::Config;
|
||||
|
||||
impl<'a> From<&'a Config> for SliceConfig<'a> {
|
||||
fn from(config: &'a Config) -> Self {
|
||||
slice_config_with_limit(config, None)
|
||||
pub fn ledger_target(config: &Config) -> Option<usize> {
|
||||
match (config.slice_grow, config.limit) {
|
||||
(Some(grow), Some(limit)) => Some(limit.max(grow)),
|
||||
(Some(grow), None) => Some(grow),
|
||||
(None, limit) => limit,
|
||||
}
|
||||
}
|
||||
|
||||
/// Grow the slice ledger to contain the target number of cases.
|
||||
pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||
let ledger_limit = ledger_target(config);
|
||||
let slice_settings = slice_config_with_limit(config, ledger_limit);
|
||||
let slice = resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
|
||||
info!(
|
||||
slice = slice.manifest.slice_id.as_str(),
|
||||
cases = slice.manifest.case_count,
|
||||
positives = slice.manifest.positive_paragraphs,
|
||||
negatives = slice.manifest.negative_paragraphs,
|
||||
total_paragraphs = slice.manifest.total_paragraphs,
|
||||
"Slice ledger ready"
|
||||
);
|
||||
println!(
|
||||
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
|
||||
slice.manifest.slice_id,
|
||||
slice.manifest.case_count,
|
||||
slice.manifest.positive_paragraphs,
|
||||
slice.manifest.negative_paragraphs
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn slice_config_with_limit(config: &Config, limit_override: Option<usize>) -> SliceConfig<'_> {
|
||||
SliceConfig {
|
||||
cache_dir: config.cache_dir.as_path(),
|
||||
@@ -1088,7 +1016,7 @@ mod tests {
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn sample_dataset() -> ConvertedDataset {
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false, None);
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false);
|
||||
ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata,
|
||||
@@ -1226,7 +1154,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false, None);
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false);
|
||||
let dataset = ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata,
|
||||
@@ -1240,11 +1168,11 @@ mod tests {
|
||||
rng_seed: 0xBB,
|
||||
};
|
||||
|
||||
let refs = ordered_question_refs_beir(&dataset, ¶ms, 8)?;
|
||||
let refs = beir::ordered_question_refs_beir(&dataset, ¶ms, 8)?;
|
||||
let mut per_prefix: HashMap<String, usize> = HashMap::new();
|
||||
for (p_idx, q_idx) in refs {
|
||||
let question = &dataset.paragraphs[p_idx].questions[q_idx];
|
||||
let prefix = question_prefix(&question.id).unwrap_or("unknown");
|
||||
let prefix = beir::question_prefix(&question.id).unwrap_or("unknown");
|
||||
*per_prefix.entry(prefix.to_string()).or_default() += 1;
|
||||
}
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::fs;
|
||||
|
||||
use crate::{args::Config, slice};
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SnapshotMetadata {
|
||||
pub dataset_id: String,
|
||||
pub slice_id: String,
|
||||
pub embedding_backend: String,
|
||||
pub embedding_model: Option<String>,
|
||||
pub embedding_dimension: usize,
|
||||
pub rerank_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbSnapshotState {
|
||||
pub dataset_id: String,
|
||||
pub slice_id: String,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub snapshot_hash: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub namespace: Option<String>,
|
||||
#[serde(default)]
|
||||
pub database: Option<String>,
|
||||
#[serde(default)]
|
||||
pub slice_case_count: usize,
|
||||
}
|
||||
|
||||
pub struct Descriptor {
|
||||
#[allow(dead_code)]
|
||||
metadata: SnapshotMetadata,
|
||||
dir: PathBuf,
|
||||
metadata_hash: String,
|
||||
}
|
||||
|
||||
impl Descriptor {
|
||||
pub fn new(
|
||||
config: &Config,
|
||||
slice: &slice::ResolvedSlice<'_>,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
) -> Self {
|
||||
let metadata = SnapshotMetadata {
|
||||
dataset_id: slice.manifest.dataset_id.clone(),
|
||||
slice_id: slice.manifest.slice_id.clone(),
|
||||
embedding_backend: embedding_provider.backend_label().to_string(),
|
||||
embedding_model: embedding_provider.model_code(),
|
||||
embedding_dimension: embedding_provider.dimension(),
|
||||
rerank_enabled: config.retrieval.rerank,
|
||||
};
|
||||
|
||||
let dir = config
|
||||
.cache_dir
|
||||
.join("snapshots")
|
||||
.join(&metadata.dataset_id)
|
||||
.join(&metadata.slice_id);
|
||||
let metadata_hash = compute_hash(&metadata);
|
||||
|
||||
Self {
|
||||
metadata,
|
||||
dir,
|
||||
metadata_hash,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn metadata_hash(&self) -> &str {
|
||||
&self.metadata_hash
|
||||
}
|
||||
|
||||
pub async fn load_db_state(&self) -> Result<Option<DbSnapshotState>> {
|
||||
let path = self.db_state_path();
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let bytes = fs::read(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading namespace state {}", path.display()))?;
|
||||
let state = serde_json::from_slice(&bytes)
|
||||
.with_context(|| format!("deserialising namespace state {}", path.display()))?;
|
||||
Ok(Some(state))
|
||||
}
|
||||
|
||||
pub async fn store_db_state(&self, state: &DbSnapshotState) -> Result<()> {
|
||||
let path = self.db_state_path();
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await.with_context(|| {
|
||||
format!("creating namespace state directory {}", parent.display())
|
||||
})?;
|
||||
}
|
||||
let blob =
|
||||
serde_json::to_vec_pretty(state).context("serialising namespace state payload")?;
|
||||
fs::write(&path, blob)
|
||||
.await
|
||||
.with_context(|| format!("writing namespace state {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn db_dir(&self) -> PathBuf {
|
||||
self.dir.join("db")
|
||||
}
|
||||
|
||||
fn db_state_path(&self) -> PathBuf {
|
||||
self.db_dir().join("state.json")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn from_parts(metadata: SnapshotMetadata, dir: PathBuf) -> Self {
|
||||
let metadata_hash = compute_hash(&metadata);
|
||||
Self {
|
||||
metadata,
|
||||
dir,
|
||||
metadata_hash,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
fn compute_hash(metadata: &SnapshotMetadata) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(
|
||||
serde_json::to_vec(metadata).expect("snapshot metadata serialisation should succeed"),
|
||||
);
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
async fn state_round_trip() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let metadata = SnapshotMetadata {
|
||||
dataset_id: "dataset".into(),
|
||||
slice_id: "slice".into(),
|
||||
embedding_backend: "hashed".into(),
|
||||
embedding_model: None,
|
||||
embedding_dimension: 128,
|
||||
rerank_enabled: true,
|
||||
};
|
||||
let descriptor = Descriptor::from_parts(
|
||||
metadata,
|
||||
temp_dir
|
||||
.path()
|
||||
.join("snapshots")
|
||||
.join("dataset")
|
||||
.join("slice"),
|
||||
);
|
||||
|
||||
let state = DbSnapshotState {
|
||||
dataset_id: "dataset".into(),
|
||||
slice_id: "slice".into(),
|
||||
ingestion_fingerprint: "fingerprint".into(),
|
||||
snapshot_hash: descriptor.metadata_hash().to_string(),
|
||||
updated_at: Utc::now(),
|
||||
namespace: Some("ns".into()),
|
||||
database: Some("db".into()),
|
||||
slice_case_count: 42,
|
||||
};
|
||||
descriptor.store_db_state(&state).await.unwrap();
|
||||
|
||||
let loaded = descriptor.load_db_state().await.unwrap().unwrap();
|
||||
assert_eq!(loaded.dataset_id, state.dataset_id);
|
||||
assert_eq!(loaded.slice_id, state.slice_id);
|
||||
assert_eq!(loaded.ingestion_fingerprint, state.ingestion_fingerprint);
|
||||
assert_eq!(loaded.snapshot_hash, state.snapshot_hash);
|
||||
assert_eq!(loaded.namespace, state.namespace);
|
||||
assert_eq!(loaded.database, state.database);
|
||||
assert_eq!(loaded.slice_case_count, state.slice_case_count);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
use common::storage::types::StoredObject;
|
||||
use retrieval_pipeline::{
|
||||
Diagnostics, RetrievalOutput, RetrievedChunk, RetrievedEntity, StageKind, StageTimings,
|
||||
@@ -8,6 +8,8 @@ use retrieval_pipeline::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use unicode_normalization::UnicodeNormalization;
|
||||
|
||||
pub use crate::context_stats::{RetrievalContextStats, RetrievedContextStats};
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EvaluationSummary {
|
||||
@@ -83,6 +85,7 @@ pub struct EvaluationSummary {
|
||||
pub chunk_vector_take: usize,
|
||||
pub chunk_fts_take: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub retrieved_context: RetrievalContextStats,
|
||||
pub cases: Vec<CaseSummary>,
|
||||
}
|
||||
|
||||
@@ -108,6 +111,7 @@ pub struct CaseSummary {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ndcg: Option<f64>,
|
||||
pub latency_ms: u128,
|
||||
pub retrieved_context: RetrievedContextStats,
|
||||
pub retrieved: Vec<RetrievedSummary>,
|
||||
}
|
||||
|
||||
@@ -483,3 +487,7 @@ pub fn build_case_diagnostics(
|
||||
pipeline: pipeline_stats,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
|
||||
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
crane,
|
||||
}: let
|
||||
inherit (nixpkgs.legacyPackages.x86_64-linux) lib;
|
||||
ortVersion = lib.removeSuffix "\n" (builtins.readFile "${self}/ort-version");
|
||||
ortVersion = "1.23.2";
|
||||
in
|
||||
flake-utils.lib.eachDefaultSystem (system: let
|
||||
pkgs = nixpkgs.legacyPackages.${system};
|
||||
@@ -24,83 +24,182 @@
|
||||
if pkgs.stdenv.isDarwin
|
||||
then "dylib"
|
||||
else "so";
|
||||
minne-pkg =
|
||||
if pkgs.onnxruntime.version == ortVersion then
|
||||
craneLib.buildPackage {
|
||||
minneVersion = "1.0.4";
|
||||
|
||||
# Pre-download mozjs binary archive for mozjs_sys (servo dep).
|
||||
# When updating mozjs_sys version in Cargo.lock, update this URL too.
|
||||
mozjsArchive = pkgs.fetchurl {
|
||||
url = "https://github.com/servo/mozjs/releases/download/mozjs-sys-v140.10.1-0/libmozjs-x86_64-unknown-linux-gnu.tar.gz";
|
||||
hash = "sha256-e5kW8HTg6Hrd3sGgU9bqFNTTf7wJCChFOwKE3xyYT4Q=";
|
||||
};
|
||||
|
||||
# Extra paths (common/db, html-router/templates, html-router/assets) are
|
||||
# embedded at compile time via include_dir! / minijinja_embed.
|
||||
commonArgs = {
|
||||
version = minneVersion;
|
||||
src = lib.cleanSourceWith {
|
||||
src = ./.;
|
||||
filter = let
|
||||
extraPaths = [
|
||||
filter = path: type:
|
||||
craneLib.filterCargoSources path type
|
||||
|| lib.any (x: lib.hasPrefix (toString x) (toString path)) [
|
||||
(toString ./Cargo.lock)
|
||||
(toString ./common/db)
|
||||
(toString ./html-router/templates)
|
||||
(toString ./html-router/assets)
|
||||
];
|
||||
in
|
||||
path: type: let
|
||||
p = toString path;
|
||||
in
|
||||
craneLib.filterCargoSources path type
|
||||
|| lib.any (x: lib.hasPrefix x p) extraPaths;
|
||||
};
|
||||
strictDeps = true;
|
||||
|
||||
pname = "minne";
|
||||
version = "1.0.3";
|
||||
# Uses nixpkgs rustc (stable). Release/Docker pin: rust-toolchain.toml (1.91.1).
|
||||
doCheck = false;
|
||||
buildInputs = [
|
||||
pkgs.openssl
|
||||
pkgs.libglvnd
|
||||
pkgs.onnxruntime
|
||||
pkgs.fontconfig # .pc for yeslogic-fontconfig-sys (servo dep)
|
||||
pkgs.libclang.lib # libclang.so for bindgen (servo dep)
|
||||
];
|
||||
|
||||
nativeBuildInputs = [pkgs.pkg-config pkgs.rustfmt pkgs.makeWrapper];
|
||||
buildInputs = [pkgs.openssl pkgs.chromium pkgs.onnxruntime];
|
||||
nativeBuildInputs = [
|
||||
pkgs.pkg-config
|
||||
pkgs.rustfmt
|
||||
pkgs.makeWrapper
|
||||
pkgs.python3 # needed by servo's stylo crate build.rs
|
||||
pkgs.llvmPackages.llvm # llvm-objdump for mozjs_sys (servo dep)
|
||||
pkgs.rustPlatform.bindgenHook # configures bindgen (servo deps)
|
||||
];
|
||||
|
||||
postInstall = ''
|
||||
wrapProgram $out/bin/main \
|
||||
--set CHROME ${pkgs.chromium}/bin/chromium \
|
||||
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
||||
for b in worker server; do
|
||||
if [ -x "$out/bin/$b" ]; then
|
||||
wrapProgram $out/bin/$b \
|
||||
--set CHROME ${pkgs.chromium}/bin/chromium \
|
||||
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
||||
fi
|
||||
done
|
||||
'';
|
||||
}
|
||||
else
|
||||
throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ort-version (${ortVersion})";
|
||||
# Provide pre-downloaded mozjs archive so it doesn't need network
|
||||
MOZJS_ARCHIVE = "${mozjsArchive}";
|
||||
};
|
||||
|
||||
# cargoBuild (not buildDepsOnly) avoids mkDummySrc breaking native build scripts.
|
||||
cargoArtifacts = craneLib.cargoBuild (commonArgs
|
||||
// {
|
||||
cargoArtifacts = null;
|
||||
pname = "minne-deps";
|
||||
cargoExtraArgs = "--workspace";
|
||||
doCheck = false;
|
||||
doInstallCargoArtifacts = true;
|
||||
installPhaseCommand = "";
|
||||
});
|
||||
|
||||
minne-pkg =
|
||||
if pkgs.onnxruntime.version == ortVersion
|
||||
then
|
||||
craneLib.buildPackage (commonArgs
|
||||
// {
|
||||
pname = "minne";
|
||||
version = minneVersion;
|
||||
inherit cargoArtifacts;
|
||||
doCheck = false; # checks are in separate derivations
|
||||
doInstallCargoArtifacts = true; # for reuse by check derivations
|
||||
|
||||
postInstall = ''
|
||||
wrapProgram $out/bin/main \
|
||||
--prefix LD_LIBRARY_PATH : ${pkgs.libglvnd}/lib \
|
||||
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
||||
for b in worker server; do
|
||||
if [ -x "$out/bin/$b" ]; then
|
||||
wrapProgram $out/bin/$b \
|
||||
--prefix LD_LIBRARY_PATH : ${pkgs.libglvnd}/lib \
|
||||
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
||||
fi
|
||||
done
|
||||
'';
|
||||
})
|
||||
else throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ortVersion in flake.nix (${ortVersion})";
|
||||
|
||||
dockerImage = pkgs.dockerTools.buildLayeredImage {
|
||||
name = "minne";
|
||||
tag = minneVersion;
|
||||
created = "now";
|
||||
|
||||
contents = [
|
||||
minne-pkg
|
||||
pkgs.cacert
|
||||
pkgs.bashInteractive
|
||||
pkgs.libglvnd
|
||||
pkgs.fontconfig.lib
|
||||
pkgs.freetype
|
||||
pkgs.stdenv.cc.cc.lib # libgomp (OpenMP) for ONNX Runtime
|
||||
];
|
||||
|
||||
maxLayers = 25;
|
||||
|
||||
config = {
|
||||
Cmd = ["${minne-pkg}/bin/main"];
|
||||
Env = [
|
||||
"SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-certificates.crt"
|
||||
"ORT_DYLIB_PATH=${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}"
|
||||
];
|
||||
ExposedPorts = {"3000/tcp" = {};};
|
||||
User = "appuser";
|
||||
};
|
||||
};
|
||||
in {
|
||||
packages = {
|
||||
minne-pkg = minne-pkg;
|
||||
inherit minne-pkg dockerImage;
|
||||
default = minne-pkg;
|
||||
};
|
||||
|
||||
apps = {
|
||||
main = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "main";
|
||||
main = {
|
||||
type = "app";
|
||||
program = "${minne-pkg}/bin/main";
|
||||
meta.description = "Minne main server — API, web UI, and background worker";
|
||||
};
|
||||
worker = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "worker";
|
||||
worker = {
|
||||
type = "app";
|
||||
program = "${minne-pkg}/bin/worker";
|
||||
meta.description = "Minne standalone background worker (ingestion, indexing, maintenance)";
|
||||
};
|
||||
server = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "server";
|
||||
server = {
|
||||
type = "app";
|
||||
program = "${minne-pkg}/bin/server";
|
||||
meta.description = "Minne API-only server (no background worker)";
|
||||
};
|
||||
default = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "main";
|
||||
default = {
|
||||
type = "app";
|
||||
program = "${minne-pkg}/bin/main";
|
||||
meta.description = "Minne main server — API, web UI, and background worker";
|
||||
};
|
||||
};
|
||||
|
||||
checks = {
|
||||
ortVersion = pkgs.runCommand "ort-version-check" {} ''
|
||||
if [ "${pkgs.onnxruntime.version}" != "${ortVersion}" ]; then
|
||||
echo "pkgs.onnxruntime.version is ${pkgs.onnxruntime.version}, but ort-version pins ${ortVersion}" >&2
|
||||
echo "Update ort-version or wait for nixpkgs to catch up." >&2
|
||||
echo "pkgs.onnxruntime.version is ${pkgs.onnxruntime.version}, but flake pins ${ortVersion}" >&2
|
||||
echo "Update ortVersion in flake.nix or wait for nixpkgs to catch up." >&2
|
||||
exit 1
|
||||
fi
|
||||
touch $out
|
||||
'';
|
||||
|
||||
minne-clippy = craneLib.cargoClippy (commonArgs
|
||||
// {
|
||||
cargoArtifacts = minne-pkg;
|
||||
pname = "minne";
|
||||
cargoClippyExtraArgs = "--all-targets -- --deny warnings";
|
||||
});
|
||||
|
||||
minne-test = craneLib.cargoTest (commonArgs
|
||||
// {
|
||||
cargoArtifacts = minne-pkg;
|
||||
pname = "minne";
|
||||
buildInputs = commonArgs.buildInputs ++ [pkgs.cacert];
|
||||
SSL_CERT_FILE = "${pkgs.cacert}/etc/ssl/certs/ca-certificates.crt";
|
||||
cargoTestExtraArgs = "--lib --bins";
|
||||
});
|
||||
|
||||
minne-fmt = craneLib.cargoFmt {
|
||||
pname = "minne-fmt";
|
||||
version = minneVersion;
|
||||
src = craneLib.cleanCargoSource ./.;
|
||||
};
|
||||
};
|
||||
})
|
||||
// {
|
||||
lib = {
|
||||
inherit ortVersion;
|
||||
};
|
||||
}) // {
|
||||
ortVersion = ortVersion;
|
||||
};
|
||||
}
|
||||
|
||||
+182
-159
@@ -1,6 +1,6 @@
|
||||
/**
|
||||
* Design Polishing Pass - Interactive Effects
|
||||
*
|
||||
*
|
||||
* Includes:
|
||||
* - Scroll-Linked Navbar Shadow
|
||||
* - HTMX Swap Animation
|
||||
@@ -8,183 +8,207 @@
|
||||
* - Rubberbanding Scroll
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
(() => {
|
||||
// === SCROLL-LINKED NAVBAR SHADOW ===
|
||||
function initScrollShadow() {
|
||||
const mainContent = document.querySelector("main");
|
||||
const navbar = document.querySelector("nav");
|
||||
if (!mainContent || !navbar) return;
|
||||
|
||||
// === SCROLL-LINKED NAVBAR SHADOW ===
|
||||
function initScrollShadow() {
|
||||
const mainContent = document.querySelector('main');
|
||||
const navbar = document.querySelector('nav');
|
||||
if (!mainContent || !navbar) return;
|
||||
mainContent.addEventListener(
|
||||
"scroll",
|
||||
() => {
|
||||
const scrollTop = mainContent.scrollTop;
|
||||
const scrollHeight =
|
||||
mainContent.scrollHeight - mainContent.clientHeight;
|
||||
const scrollDepth = scrollHeight > 0 ? Math.min(scrollTop / 200, 1) : 0;
|
||||
navbar.style.setProperty("--scroll-depth", scrollDepth.toFixed(2));
|
||||
},
|
||||
{ passive: true },
|
||||
);
|
||||
}
|
||||
|
||||
mainContent.addEventListener('scroll', () => {
|
||||
const scrollTop = mainContent.scrollTop;
|
||||
const scrollHeight = mainContent.scrollHeight - mainContent.clientHeight;
|
||||
const scrollDepth = scrollHeight > 0 ? Math.min(scrollTop / 200, 1) : 0;
|
||||
navbar.style.setProperty('--scroll-depth', scrollDepth.toFixed(2));
|
||||
}, { passive: true });
|
||||
}
|
||||
// === HTMX SWAP ANIMATION ===
|
||||
function initHtmxSwapAnimation() {
|
||||
document.body.addEventListener("htmx:afterSwap", (event) => {
|
||||
let target = event.detail.target;
|
||||
if (!target) return;
|
||||
|
||||
// === HTMX SWAP ANIMATION ===
|
||||
function initHtmxSwapAnimation() {
|
||||
document.body.addEventListener('htmx:afterSwap', (event) => {
|
||||
let target = event.detail.target;
|
||||
if (!target) return;
|
||||
// If full body swap (hx-boost), animate only the main content
|
||||
if (target.tagName === "BODY") {
|
||||
const main = document.querySelector("main");
|
||||
if (main) target = main;
|
||||
}
|
||||
|
||||
// If full body swap (hx-boost), animate only the main content
|
||||
if (target.tagName === 'BODY') {
|
||||
const main = document.querySelector('main');
|
||||
if (main) target = main;
|
||||
}
|
||||
// Only animate if target is valid and inside/is main content or a card/panel
|
||||
// Avoid animating sidebar or navbar updates
|
||||
if (target && (target.tagName === "MAIN" || target.closest("main"))) {
|
||||
if (!target.classList.contains("animate-fade-up")) {
|
||||
target.classList.add("animate-fade-up");
|
||||
// Remove class after animation completes to allow re-animation
|
||||
setTimeout(() => {
|
||||
target.classList.remove("animate-fade-up");
|
||||
}, 250);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Only animate if target is valid and inside/is main content or a card/panel
|
||||
// Avoid animating sidebar or navbar updates
|
||||
if (target && (target.tagName === 'MAIN' || target.closest('main'))) {
|
||||
if (!target.classList.contains('animate-fade-up')) {
|
||||
target.classList.add('animate-fade-up');
|
||||
// Remove class after animation completes to allow re-animation
|
||||
setTimeout(() => {
|
||||
target.classList.remove('animate-fade-up');
|
||||
}, 250);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// === TYPEWRITER AI RESPONSE ===
|
||||
// Works with SSE streaming - buffers text and reveals character by character
|
||||
window.initTypewriter = (element, options = {}) => {
|
||||
const { minDelay = 5, maxDelay = 15, showCursor = true } = options;
|
||||
|
||||
// === TYPEWRITER AI RESPONSE ===
|
||||
// Works with SSE streaming - buffers text and reveals character by character
|
||||
window.initTypewriter = function(element, options = {}) {
|
||||
const {
|
||||
minDelay = 5,
|
||||
maxDelay = 15,
|
||||
showCursor = true
|
||||
} = options;
|
||||
let buffer = "";
|
||||
let isTyping = false;
|
||||
let cursorElement = null;
|
||||
|
||||
let buffer = '';
|
||||
let isTyping = false;
|
||||
let cursorElement = null;
|
||||
if (showCursor) {
|
||||
cursorElement = document.createElement("span");
|
||||
cursorElement.className = "typewriter-cursor";
|
||||
cursorElement.textContent = "▌";
|
||||
cursorElement.style.animation = "blink 1s step-end infinite";
|
||||
element.appendChild(cursorElement);
|
||||
}
|
||||
|
||||
if (showCursor) {
|
||||
cursorElement = document.createElement('span');
|
||||
cursorElement.className = 'typewriter-cursor';
|
||||
cursorElement.textContent = '▌';
|
||||
cursorElement.style.animation = 'blink 1s step-end infinite';
|
||||
element.appendChild(cursorElement);
|
||||
}
|
||||
function typeNextChar() {
|
||||
if (buffer.length === 0) {
|
||||
isTyping = false;
|
||||
return;
|
||||
}
|
||||
|
||||
function typeNextChar() {
|
||||
if (buffer.length === 0) {
|
||||
isTyping = false;
|
||||
return;
|
||||
}
|
||||
isTyping = true;
|
||||
const char = buffer.charAt(0);
|
||||
buffer = buffer.slice(1);
|
||||
|
||||
isTyping = true;
|
||||
const char = buffer.charAt(0);
|
||||
buffer = buffer.slice(1);
|
||||
// Insert before cursor
|
||||
if (cursorElement && cursorElement.parentNode) {
|
||||
const textNode = document.createTextNode(char);
|
||||
element.insertBefore(textNode, cursorElement);
|
||||
} else {
|
||||
element.textContent += char;
|
||||
}
|
||||
|
||||
// Insert before cursor
|
||||
if (cursorElement && cursorElement.parentNode) {
|
||||
const textNode = document.createTextNode(char);
|
||||
element.insertBefore(textNode, cursorElement);
|
||||
} else {
|
||||
element.textContent += char;
|
||||
}
|
||||
const delay = minDelay + Math.random() * (maxDelay - minDelay);
|
||||
setTimeout(typeNextChar, delay);
|
||||
}
|
||||
|
||||
const delay = minDelay + Math.random() * (maxDelay - minDelay);
|
||||
setTimeout(typeNextChar, delay);
|
||||
}
|
||||
return {
|
||||
append: (text) => {
|
||||
buffer += text;
|
||||
if (!isTyping) {
|
||||
typeNextChar();
|
||||
}
|
||||
},
|
||||
complete: () => {
|
||||
// Flush remaining buffer immediately
|
||||
if (cursorElement && cursorElement.parentNode) {
|
||||
const textNode = document.createTextNode(buffer);
|
||||
element.insertBefore(textNode, cursorElement);
|
||||
cursorElement.remove();
|
||||
} else {
|
||||
element.textContent += buffer;
|
||||
}
|
||||
buffer = "";
|
||||
isTyping = false;
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
return {
|
||||
append: function(text) {
|
||||
buffer += text;
|
||||
if (!isTyping) {
|
||||
typeNextChar();
|
||||
}
|
||||
},
|
||||
complete: function() {
|
||||
// Flush remaining buffer immediately
|
||||
if (cursorElement && cursorElement.parentNode) {
|
||||
const textNode = document.createTextNode(buffer);
|
||||
element.insertBefore(textNode, cursorElement);
|
||||
cursorElement.remove();
|
||||
} else {
|
||||
element.textContent += buffer;
|
||||
}
|
||||
buffer = '';
|
||||
isTyping = false;
|
||||
}
|
||||
};
|
||||
};
|
||||
// === RUBBERBANDING SCROLL ===
|
||||
function attachRubberbanding(
|
||||
container,
|
||||
{ maxPull = 60, resistance = 0.4 } = {},
|
||||
) {
|
||||
let startY = 0;
|
||||
let pulling = false;
|
||||
|
||||
// === RUBBERBANDING SCROLL ===
|
||||
function initRubberbanding() {
|
||||
const containers = document.querySelectorAll('#chat-scroll-container, .content-scroll-container');
|
||||
|
||||
containers.forEach(container => {
|
||||
let startY = 0;
|
||||
let pulling = false;
|
||||
let pullDistance = 0;
|
||||
const maxPull = 60;
|
||||
const resistance = 0.4;
|
||||
function applyPull(distance) {
|
||||
container.style.transform = `translateY(${distance}px)`;
|
||||
}
|
||||
|
||||
container.addEventListener('touchstart', (e) => {
|
||||
startY = e.touches[0].clientY;
|
||||
}, { passive: true });
|
||||
function release() {
|
||||
container.style.transition =
|
||||
"transform 300ms cubic-bezier(0.25, 1, 0.5, 1)";
|
||||
container.style.transform = "translateY(0)";
|
||||
setTimeout(() => {
|
||||
container.style.transition = "";
|
||||
}, 300);
|
||||
pulling = false;
|
||||
}
|
||||
|
||||
container.addEventListener('touchmove', (e) => {
|
||||
const currentY = e.touches[0].clientY;
|
||||
const diff = currentY - startY;
|
||||
|
||||
// At top boundary, pulling down
|
||||
if (container.scrollTop <= 0 && diff > 0) {
|
||||
pulling = true;
|
||||
pullDistance = Math.min(diff * resistance, maxPull);
|
||||
container.style.transform = `translateY(${pullDistance}px)`;
|
||||
}
|
||||
// At bottom boundary, pulling up
|
||||
else if (container.scrollTop + container.clientHeight >= container.scrollHeight && diff < 0) {
|
||||
pulling = true;
|
||||
pullDistance = Math.max(diff * resistance, -maxPull);
|
||||
container.style.transform = `translateY(${pullDistance}px)`;
|
||||
}
|
||||
}, { passive: true });
|
||||
function isAtTop() {
|
||||
return container.scrollTop <= 0;
|
||||
}
|
||||
function isAtBottom() {
|
||||
return (
|
||||
container.scrollTop + container.clientHeight >= container.scrollHeight
|
||||
);
|
||||
}
|
||||
|
||||
container.addEventListener('touchend', () => {
|
||||
if (pulling) {
|
||||
container.style.transition = 'transform 300ms cubic-bezier(0.25, 1, 0.5, 1)';
|
||||
container.style.transform = 'translateY(0)';
|
||||
setTimeout(() => {
|
||||
container.style.transition = '';
|
||||
}, 300);
|
||||
pulling = false;
|
||||
pullDistance = 0;
|
||||
}
|
||||
}, { passive: true });
|
||||
});
|
||||
}
|
||||
container.addEventListener(
|
||||
"touchstart",
|
||||
(e) => {
|
||||
startY = e.touches[0].clientY;
|
||||
},
|
||||
{ passive: true },
|
||||
);
|
||||
|
||||
// === INITIALIZATION ===
|
||||
function init() {
|
||||
initScrollShadow();
|
||||
initHtmxSwapAnimation();
|
||||
initRubberbanding();
|
||||
}
|
||||
container.addEventListener(
|
||||
"touchmove",
|
||||
(e) => {
|
||||
const diff = e.touches[0].clientY - startY;
|
||||
const isPullingDown = diff > 0 && isAtTop();
|
||||
const isPullingUp = diff < 0 && isAtBottom();
|
||||
|
||||
// Run on DOMContentLoaded
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', init);
|
||||
} else {
|
||||
init();
|
||||
}
|
||||
if (isPullingDown) {
|
||||
pulling = true;
|
||||
applyPull(Math.min(diff * resistance, maxPull));
|
||||
} else if (isPullingUp) {
|
||||
pulling = true;
|
||||
applyPull(Math.max(diff * resistance, -maxPull));
|
||||
}
|
||||
},
|
||||
{ passive: true },
|
||||
);
|
||||
|
||||
// Re-init rubberbanding after HTMX navigations
|
||||
document.body.addEventListener('htmx:afterSettle', () => {
|
||||
initRubberbanding();
|
||||
});
|
||||
container.addEventListener(
|
||||
"touchend",
|
||||
() => {
|
||||
if (pulling) release();
|
||||
},
|
||||
{ passive: true },
|
||||
);
|
||||
}
|
||||
|
||||
// Add typewriter cursor blink animation
|
||||
const style = document.createElement('style');
|
||||
style.textContent = `
|
||||
function initRubberbanding() {
|
||||
document
|
||||
.querySelectorAll("#chat-scroll-container, .content-scroll-container")
|
||||
.forEach((container) => attachRubberbanding(container));
|
||||
}
|
||||
|
||||
// === INITIALIZATION ===
|
||||
function init() {
|
||||
initScrollShadow();
|
||||
initHtmxSwapAnimation();
|
||||
initRubberbanding();
|
||||
}
|
||||
|
||||
// Run on DOMContentLoaded
|
||||
if (document.readyState === "loading") {
|
||||
document.addEventListener("DOMContentLoaded", init);
|
||||
} else {
|
||||
init();
|
||||
}
|
||||
|
||||
// Re-init rubberbanding after HTMX navigations
|
||||
document.body.addEventListener("htmx:afterSettle", () => {
|
||||
initRubberbanding();
|
||||
});
|
||||
|
||||
// Add typewriter cursor blink animation
|
||||
const style = document.createElement("style");
|
||||
style.textContent = `
|
||||
@keyframes blink {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0; }
|
||||
@@ -194,6 +218,5 @@
|
||||
font-weight: bold;
|
||||
}
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
|
||||
document.head.appendChild(style);
|
||||
})();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,7 +44,6 @@
|
||||
--leading-snug: 1.375;
|
||||
--leading-relaxed: 1.625;
|
||||
--ease-out: cubic-bezier(0, 0, 0.2, 1);
|
||||
--ease-in-out: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
--animate-pulse: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
--default-transition-duration: 150ms;
|
||||
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
@@ -285,37 +284,6 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
.drawer-open {
|
||||
> .drawer-side {
|
||||
overflow-y: auto;
|
||||
}
|
||||
> .drawer-toggle {
|
||||
display: none;
|
||||
& ~ .drawer-side {
|
||||
pointer-events: auto;
|
||||
visibility: visible;
|
||||
position: sticky;
|
||||
display: block;
|
||||
width: auto;
|
||||
overscroll-behavior: auto;
|
||||
opacity: 100%;
|
||||
& > .drawer-overlay {
|
||||
cursor: default;
|
||||
background-color: transparent;
|
||||
}
|
||||
& > *:not(.drawer-overlay) {
|
||||
translate: 0%;
|
||||
[dir="rtl"] & {
|
||||
translate: 0%;
|
||||
}
|
||||
}
|
||||
}
|
||||
&:checked ~ .drawer-side {
|
||||
pointer-events: auto;
|
||||
visibility: visible;
|
||||
}
|
||||
}
|
||||
}
|
||||
.drawer-toggle {
|
||||
position: fixed;
|
||||
height: calc(0.25rem * 0);
|
||||
@@ -1074,22 +1042,6 @@
|
||||
grid-row-start: 1;
|
||||
min-width: calc(0.25rem * 0);
|
||||
}
|
||||
.chat-image {
|
||||
grid-row: span 2 / span 2;
|
||||
align-self: flex-end;
|
||||
}
|
||||
.chat-footer {
|
||||
grid-row-start: 3;
|
||||
display: flex;
|
||||
gap: calc(0.25rem * 1);
|
||||
font-size: 0.6875rem;
|
||||
}
|
||||
.chat-header {
|
||||
grid-row-start: 1;
|
||||
display: flex;
|
||||
gap: calc(0.25rem * 1);
|
||||
font-size: 0.6875rem;
|
||||
}
|
||||
.container {
|
||||
width: 100%;
|
||||
@media (width >= 40rem) {
|
||||
@@ -1796,9 +1748,6 @@
|
||||
.w-10 {
|
||||
width: calc(var(--spacing) * 10);
|
||||
}
|
||||
.w-11 {
|
||||
width: calc(var(--spacing) * 11);
|
||||
}
|
||||
.w-11\/12 {
|
||||
width: calc(11/12 * 100%);
|
||||
}
|
||||
@@ -1862,9 +1811,6 @@
|
||||
.flex-none {
|
||||
flex: none;
|
||||
}
|
||||
.flex-shrink {
|
||||
flex-shrink: 1;
|
||||
}
|
||||
.flex-shrink-0 {
|
||||
flex-shrink: 0;
|
||||
}
|
||||
@@ -1877,13 +1823,6 @@
|
||||
.grow {
|
||||
flex-grow: 1;
|
||||
}
|
||||
.border-collapse {
|
||||
border-collapse: collapse;
|
||||
}
|
||||
.-translate-y-1 {
|
||||
--tw-translate-y: calc(var(--spacing) * -1);
|
||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
||||
}
|
||||
.-translate-y-1\/2 {
|
||||
--tw-translate-y: calc(calc(1/2 * 100%) * -1);
|
||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
||||
@@ -1956,9 +1895,6 @@
|
||||
.justify-start {
|
||||
justify-content: flex-start;
|
||||
}
|
||||
.gap-0 {
|
||||
gap: calc(var(--spacing) * 0);
|
||||
}
|
||||
.gap-0\.5 {
|
||||
gap: calc(var(--spacing) * 0.5);
|
||||
}
|
||||
@@ -2091,9 +2027,6 @@
|
||||
.border-base-200 {
|
||||
border-color: var(--color-base-200);
|
||||
}
|
||||
.border-base-content {
|
||||
border-color: var(--color-base-content);
|
||||
}
|
||||
.border-base-content\/10 {
|
||||
border-color: var(--color-base-content);
|
||||
@supports (color: color-mix(in lab, red, red)) {
|
||||
@@ -2130,9 +2063,6 @@
|
||||
.bg-transparent {
|
||||
background-color: transparent;
|
||||
}
|
||||
.bg-warning {
|
||||
background-color: var(--color-warning);
|
||||
}
|
||||
.bg-warning\/10 {
|
||||
background-color: var(--color-warning);
|
||||
@supports (color: color-mix(in lab, red, red)) {
|
||||
@@ -2151,9 +2081,6 @@
|
||||
.loading-spinner {
|
||||
mask-image: url("data:image/svg+xml,%3Csvg width='24' height='24' stroke='black' viewBox='0 0 24 24' xmlns='http://www.w3.org/2000/svg'%3E%3Cg transform-origin='center'%3E%3Ccircle cx='12' cy='12' r='9.5' fill='none' stroke-width='3' stroke-linecap='round'%3E%3CanimateTransform attributeName='transform' type='rotate' from='0 12 12' to='360 12 12' dur='2s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dasharray' values='0,150;42,150;42,150' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3Canimate attributeName='stroke-dashoffset' values='0;-16;-59' keyTimes='0;0.475;1' dur='1.5s' repeatCount='indefinite'/%3E%3C/circle%3E%3C/g%3E%3C/svg%3E");
|
||||
}
|
||||
.mask-repeat {
|
||||
mask-repeat: repeat;
|
||||
}
|
||||
.fill-current {
|
||||
fill: currentcolor;
|
||||
}
|
||||
@@ -2184,9 +2111,6 @@
|
||||
.p-8 {
|
||||
padding: calc(var(--spacing) * 8);
|
||||
}
|
||||
.px-1 {
|
||||
padding-inline: calc(var(--spacing) * 1);
|
||||
}
|
||||
.px-1\.5 {
|
||||
padding-inline: calc(var(--spacing) * 1.5);
|
||||
}
|
||||
@@ -2341,9 +2265,6 @@
|
||||
--tw-tracking: var(--tracking-widest);
|
||||
letter-spacing: var(--tracking-widest);
|
||||
}
|
||||
.text-wrap {
|
||||
text-wrap: wrap;
|
||||
}
|
||||
.break-words {
|
||||
overflow-wrap: break-word;
|
||||
}
|
||||
@@ -2410,17 +2331,6 @@
|
||||
.italic {
|
||||
font-style: italic;
|
||||
}
|
||||
.underline {
|
||||
text-decoration-line: underline;
|
||||
}
|
||||
.swap-active {
|
||||
.swap-off {
|
||||
opacity: 0%;
|
||||
}
|
||||
.swap-on {
|
||||
opacity: 100%;
|
||||
}
|
||||
}
|
||||
.opacity-0 {
|
||||
opacity: 0%;
|
||||
}
|
||||
@@ -2514,10 +2424,6 @@
|
||||
--tw-duration: 300ms;
|
||||
transition-duration: 300ms;
|
||||
}
|
||||
.ease-in-out {
|
||||
--tw-ease: var(--ease-in-out);
|
||||
transition-timing-function: var(--ease-in-out);
|
||||
}
|
||||
.ease-out {
|
||||
--tw-ease: var(--ease-out);
|
||||
transition-timing-function: var(--ease-out);
|
||||
|
||||
Generated
+3
-3
@@ -958,9 +958,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
|
||||
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
//! the template middleware renders them with shared layout context. Route composition
|
||||
//! and middleware layering are handled by [`router_factory::RouterFactory`].
|
||||
|
||||
// minijinja_embed output (release builds) triggers these lints.
|
||||
#![allow(unused_variables, clippy::expect_used, clippy::missing_panics_doc)]
|
||||
|
||||
pub mod html_state;
|
||||
pub mod middlewares;
|
||||
pub mod router_factory;
|
||||
|
||||
@@ -22,13 +22,13 @@ macro_rules! create_asset_service {
|
||||
let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
let assets_path = crate_dir.join($relative_path);
|
||||
tracing::debug!("Assets: Serving from filesystem: {:?}", assets_path);
|
||||
tower_http::services::ServeDir::new(assets_path)
|
||||
tower_http::services::ServeDir::new(&assets_path)
|
||||
}
|
||||
#[cfg(not(debug_assertions))]
|
||||
{
|
||||
tracing::debug!("Assets: Serving embedded directory");
|
||||
static ASSETS_DIR: include_dir::Dir<'static> =
|
||||
include_dir::include_dir!("$CARGO_MANIFEST_DIR/assets");
|
||||
tracing::debug!(directory = %$relative_path, "Assets: Serving embedded directory");
|
||||
tower_serve_static::ServeDir::new(&ASSETS_DIR)
|
||||
}
|
||||
}};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use async_openai::types::ListModelResponse;
|
||||
use async_openai::types::models::ListModelResponse;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
Form,
|
||||
@@ -350,6 +350,9 @@ mod tests {
|
||||
image_processing_model: "gpt-4o-mini".into(),
|
||||
image_processing_prompt: "p".into(),
|
||||
voice_processing_model: "whisper-1".into(),
|
||||
last_index_rebuild_at: None,
|
||||
index_rebuild_lease_owner: None,
|
||||
index_rebuild_lease_expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -234,7 +234,7 @@ fn build_chat_event_stream(
|
||||
state: HtmlState,
|
||||
openai_stream: impl Stream<
|
||||
Item = Result<
|
||||
async_openai::types::CreateChatCompletionStreamResponse,
|
||||
async_openai::types::chat::CreateChatCompletionStreamResponse,
|
||||
async_openai::error::OpenAIError,
|
||||
>,
|
||||
> + Send
|
||||
@@ -342,7 +342,7 @@ async fn prepare_chat_request(
|
||||
history: &[Message],
|
||||
) -> Result<
|
||||
(
|
||||
async_openai::types::CreateChatCompletionRequest,
|
||||
async_openai::types::chat::CreateChatCompletionRequest,
|
||||
Vec<String>,
|
||||
),
|
||||
SseResponse,
|
||||
|
||||
@@ -5,10 +5,7 @@ use axum::{
|
||||
use axum_htmx::{HxBoosted, HxRequest, HxTarget};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use common::storage::types::{
|
||||
file_info::FileInfo, knowledge_entity::KnowledgeEntity, text_chunk::TextChunk,
|
||||
text_content::TextContent, user::User,
|
||||
};
|
||||
use common::storage::types::{file_info::FileInfo, text_content::TextContent, user::User};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
@@ -180,9 +177,7 @@ pub async fn delete_text_content(
|
||||
}
|
||||
}
|
||||
|
||||
// Delete related knowledge entities and text chunks
|
||||
KnowledgeEntity::delete_by_source_id(&id, &state.db).await?;
|
||||
TextChunk::delete_by_source_id(&id, &state.db).await?;
|
||||
TextContent::clear_ingested_children(&id, &user.id, &state.db).await?;
|
||||
|
||||
// Delete the text content
|
||||
state.db.delete_item::<TextContent>(&id).await?;
|
||||
|
||||
@@ -23,9 +23,7 @@ use common::storage::types::user::DashboardStats;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
file_info::FileInfo, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
||||
text_content::TextContent, user::User,
|
||||
file_info::FileInfo, ingestion_task::IngestionTask, text_content::TextContent, user::User,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -81,11 +79,7 @@ pub async fn delete_text_content(
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the text content and any related data
|
||||
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
|
||||
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &user.id, &state.db)
|
||||
.await?;
|
||||
TextContent::clear_ingested_children(&text_content.id, &user.id, &state.db).await?;
|
||||
state
|
||||
.db
|
||||
.delete_item::<TextContent>(&text_content.id)
|
||||
|
||||
@@ -203,7 +203,13 @@ pub async fn create_knowledge_entity(
|
||||
);
|
||||
let new_entity_id = new_entity.id.clone();
|
||||
|
||||
KnowledgeEntity::store_with_embedding(new_entity, embedding, &state.db).await?;
|
||||
KnowledgeEntity::store_with_embedding(
|
||||
new_entity,
|
||||
embedding,
|
||||
state.embedding_provider.dimension(),
|
||||
&state.db,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let relationship_type = relationship_type_or_default(form.relationship_type.as_deref());
|
||||
let user_id = user.id.clone();
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
|
||||
{# Default: one outer #modal_form. Modals with multiple forms (scratchpad editor)
|
||||
override modal_form_open / modal_form_close — nested <form> is invalid HTML. #}
|
||||
{% block modal_form_open %}<form id="modal_form" hx-on::after-request="if(event.detail.successful) document.getElementById('body_modal').close()" {% block form_attributes %}{% endblock %}>{% endblock %}
|
||||
{% block modal_form_open %}<form id="modal_form" hx-on::after-request="if(event.detail.successful && event.detail.elt === event.currentTarget) document.getElementById('body_modal').close()" {% block form_attributes %}{% endblock %}>{% endblock %}
|
||||
<div class="flex flex-col flex-1 gap-5">
|
||||
{% block modal_content %}{% endblock %}
|
||||
</div>
|
||||
|
||||
@@ -333,6 +333,22 @@ async fn snapshot_new_entity_modal() {
|
||||
snapshot_settings().bind(|| insta::assert_snapshot!("new_entity_modal", body));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn modal_form_after_request_ignores_inner_htmx_requests() {
|
||||
let (app, db) = build_test_app().await;
|
||||
let cookie = seeded_cookie(&app, &db).await;
|
||||
let modal = get_html(&app, "/knowledge-entity/new", Some(&cookie)).await;
|
||||
|
||||
// Inner buttons (e.g. Suggest Relationships) bubble htmx:afterRequest to
|
||||
// #modal_form; closing must only run when the form itself submitted.
|
||||
assert!(
|
||||
modal.contains(
|
||||
r#"hx-on::after-request="if(event.detail.successful && event.detail.elt === event.currentTarget) document.getElementById('body_modal').close()"#
|
||||
),
|
||||
"#modal_form should ignore bubbled after-request events from child elements"
|
||||
);
|
||||
}
|
||||
|
||||
async fn sign_in(app: &Router, email: &str, password: &str) -> String {
|
||||
let response = app
|
||||
.clone()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
---
|
||||
source: html-router/tests/router_integration.rs
|
||||
assertion_line: 333
|
||||
expression: body
|
||||
---
|
||||
<dialog id="body_modal" class="modal">
|
||||
@@ -18,7 +19,7 @@ expression: body
|
||||
</button>
|
||||
|
||||
|
||||
<form id="modal_form" hx-on::after-request="if(event.detail.successful) document.getElementById('body_modal').close()"
|
||||
<form id="modal_form" hx-on::after-request="if(event.detail.successful && event.detail.elt === event.currentTarget) document.getElementById('body_modal').close()"
|
||||
hx-post="/knowledge-entity"
|
||||
hx-target="#knowledge_pane"
|
||||
hx-swap="outerHTML"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user