mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-26 03:46:24 +02:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d273390de8 | |||
| ba3fd6ed46 | |||
| bddb603bd9 | |||
| aee9136f1e | |||
| 253ac9819d | |||
| 511e42a078 | |||
| d8e839bf46 | |||
| 588e616baf | |||
| 87e6fa14b2 | |||
| 09e545816e | |||
| 01ef1bcb7a | |||
| 530cd0a8f1 | |||
| 3b20adc50f | |||
| b3d42d2586 | |||
| fb51a8b55f | |||
| adc04d8c6d | |||
| 1013035731 | |||
| cf69cb7b05 | |||
| ead17530bd | |||
| 4e8a58fff1 |
+1
-1
@@ -1,2 +1,2 @@
|
||||
[alias]
|
||||
eval = "run -p evaluations --"
|
||||
eval = "run -p evaluations --release --"
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
# Git stuff
|
||||
.git/
|
||||
.gitignore
|
||||
.github
|
||||
|
||||
# Node build artifacts
|
||||
**/node_modules/
|
||||
|
||||
# Nix/Devenv environment files
|
||||
.direnv/
|
||||
.devenv/
|
||||
devenv.lock
|
||||
devenv.nix
|
||||
devenv.yaml
|
||||
docker-compose.yml
|
||||
.envrc
|
||||
.devenv.flake.nix
|
||||
flake.lock
|
||||
flake.nix
|
||||
|
||||
# Rust build artifacts (crucial for multi-stage builds)
|
||||
**/target/
|
||||
|
||||
# Runtime data directories
|
||||
data/
|
||||
database/
|
||||
|
||||
# Local environment config (sensitive)
|
||||
.env
|
||||
|
||||
# IDE specific
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# OS specific
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs / Temporary files
|
||||
*.log
|
||||
@@ -0,0 +1,31 @@
|
||||
name: CI
|
||||
permissions:
|
||||
contents: read
|
||||
actions: write
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Format, lint, build & test
|
||||
runs-on: ubuntu-24.04
|
||||
if: ${{ github.event_name == 'workflow_dispatch' || !startsWith(github.event.head_commit.message, 'release:') }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- uses: DeterminateSystems/determinate-nix-action@v3
|
||||
|
||||
- uses: nix-community/cache-nix-action@v7
|
||||
with:
|
||||
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||
gc-max-store-size-linux: 10G
|
||||
|
||||
- name: Check formatting, clippy lint, unit tests & ort version
|
||||
run: nix flake check --show-trace
|
||||
+156
-267
@@ -7,200 +7,138 @@ on:
|
||||
pull_request:
|
||||
push:
|
||||
tags:
|
||||
- '**[0-9]+.[0-9]+.[0-9]+*'
|
||||
- "**[0-9]+.[0-9]+.[0-9]+*"
|
||||
|
||||
jobs:
|
||||
plan:
|
||||
runs-on: ubuntu-22.04
|
||||
ci:
|
||||
runs-on: ubuntu-24.04
|
||||
outputs:
|
||||
val: ${{ steps.plan.outputs.manifest }}
|
||||
tag: ${{ !github.event.pull_request && github.ref_name || '' }}
|
||||
tag-flag: ${{ !github.event.pull_request && format('--tag={0}', github.ref_name) || '' }}
|
||||
publishing: ${{ !github.event.pull_request }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
ort-version: ${{ steps.ort_version.outputs.value }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install Nix
|
||||
uses: cachix/install-nix-action@v27
|
||||
uses: DeterminateSystems/determinate-nix-action@v3
|
||||
|
||||
- uses: nix-community/cache-nix-action@v7
|
||||
with:
|
||||
extra_nix_config: |
|
||||
experimental-features = nix-command flakes
|
||||
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||
gc-max-store-size-linux: 10G
|
||||
|
||||
- name: Verify ort-version matches nixpkgs onnxruntime
|
||||
run: nix flake check --system x86_64-linux -L
|
||||
- name: Read ORT version from flake
|
||||
id: ort_version
|
||||
run: echo "value=$(nix eval .#lib.ortVersion --raw)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Install dist
|
||||
shell: bash
|
||||
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.30.3/cargo-dist-installer.sh | sh"
|
||||
- name: Run nix flake check
|
||||
run: nix flake check --system x86_64-linux
|
||||
|
||||
- name: Cache dist
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/dist
|
||||
- name: Warm Linux release outputs for cache
|
||||
if: ${{ !github.event.pull_request }}
|
||||
run: nix build .#minne-release .#minne-release-windows --no-link -L
|
||||
|
||||
- id: plan
|
||||
run: |
|
||||
dist ${{ (!github.event.pull_request && format('host --steps=create --tag={0}', github.ref_name)) || 'plan' }} --output-format=json > plan-dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
cat plan-dist-manifest.json
|
||||
echo "manifest=$(jq -c . plan-dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Upload dist-manifest.json
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-plan-dist-manifest
|
||||
path: plan-dist-manifest.json
|
||||
|
||||
build-local-artifacts:
|
||||
name: build-local-artifacts (${{ join(matrix.targets, ', ') }})
|
||||
needs: [plan]
|
||||
if: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix.include != null && (needs.plan.outputs.publishing == 'true' || fromJson(needs.plan.outputs.val).ci.github.pr_run_mode == 'upload') }}
|
||||
build-nix-artifacts:
|
||||
name: build (${{ matrix.triple }})
|
||||
needs: [ci]
|
||||
if: ${{ needs.ci.outputs.publishing == 'true' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix }}
|
||||
matrix:
|
||||
include:
|
||||
- runner: ubuntu-24.04
|
||||
triple: x86_64-unknown-linux-gnu
|
||||
nix_package: minne-release
|
||||
cache_save: false
|
||||
- runner: macos-14
|
||||
triple: aarch64-apple-darwin
|
||||
nix_package: minne-release
|
||||
cache_save: false
|
||||
- runner: ubuntu-24.04
|
||||
triple: x86_64-pc-windows-msvc
|
||||
nix_package: minne-release-windows
|
||||
cache_save: false
|
||||
runs-on: ${{ matrix.runner }}
|
||||
container: ${{ matrix.container && matrix.container.image || null }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
BUILD_MANIFEST_NAME: target/distrib/${{ join(matrix.targets, '-') }}-dist-manifest.json
|
||||
steps:
|
||||
- name: enable windows longpaths
|
||||
run: git config --global core.longpaths true
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Load ONNX Runtime version
|
||||
shell: bash
|
||||
run: echo "ORT_VER=$(tr -d '[:space:]' < ort-version)" >> "$GITHUB_ENV"
|
||||
- name: Install Nix
|
||||
uses: DeterminateSystems/determinate-nix-action@v3
|
||||
|
||||
- name: Install Rust non-interactively if not already installed
|
||||
if: ${{ matrix.container }}
|
||||
- uses: nix-community/cache-nix-action@v7
|
||||
with:
|
||||
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||
gc-max-store-size-linux: 10G
|
||||
save: ${{ matrix.cache_save }}
|
||||
|
||||
- name: Build release archive (Nix)
|
||||
run: nix build .#${{ matrix.nix_package }} -L --out-link minne-release
|
||||
|
||||
- name: Stage artifact
|
||||
shell: bash
|
||||
run: |
|
||||
if ! command -v cargo > /dev/null 2>&1; then
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
set -euo pipefail
|
||||
TRIPLE="${{ matrix.triple }}"
|
||||
if [[ "$TRIPLE" == *windows* ]]; then
|
||||
ARTIFACT="main-${TRIPLE}.zip"
|
||||
else
|
||||
ARTIFACT="main-${TRIPLE}.tar.xz"
|
||||
fi
|
||||
RELEASE="$(nix path-info ./minne-release)"
|
||||
echo "Release output at $RELEASE:"
|
||||
ls -la "$RELEASE"
|
||||
mapfile -t BUILT < <(find "$RELEASE" -maxdepth 1 \( -name 'main-*.tar.xz' -o -name 'main-*.zip' \) -print)
|
||||
if [ "${#BUILT[@]}" -ne 1 ]; then
|
||||
echo "Expected exactly one release archive in $RELEASE" >&2
|
||||
exit 1
|
||||
fi
|
||||
cp "${BUILT[0]}" "$ARTIFACT"
|
||||
if command -v sha256sum >/dev/null; then
|
||||
sha256sum "$ARTIFACT" > "${ARTIFACT}.sha256"
|
||||
else
|
||||
shasum -a 256 "$ARTIFACT" > "${ARTIFACT}.sha256"
|
||||
fi
|
||||
|
||||
- name: Install dist
|
||||
run: ${{ matrix.install_dist.run }}
|
||||
|
||||
- name: Fetch local artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
# ===== BEGIN: Injected ORT staging for cargo-dist bundling =====
|
||||
- run: echo "=== BUILD-SETUP START ==="
|
||||
|
||||
# Unix shells
|
||||
- name: Prepare lib dir (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p lib
|
||||
rm -f lib/*
|
||||
|
||||
# Windows PowerShell
|
||||
- name: Prepare lib dir (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
New-Item -ItemType Directory -Force -Path lib | Out-Null
|
||||
# remove contents if any
|
||||
Get-ChildItem -Path lib -Force | Remove-Item -Force -Recurse -ErrorAction SilentlyContinue
|
||||
|
||||
- name: Fetch ONNX Runtime (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
set -euo pipefail
|
||||
ARCH="$(uname -m)"
|
||||
case "$ARCH" in
|
||||
x86_64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-x64-${ORT_VER}.tgz" ;;
|
||||
aarch64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-aarch64-${ORT_VER}.tgz" ;;
|
||||
*) echo "Unsupported arch $ARCH"; exit 1 ;;
|
||||
esac
|
||||
curl -fsSL -o ort.tgz "$URL"
|
||||
tar -xzf ort.tgz
|
||||
cp -v onnxruntime-*/lib/libonnxruntime.so* lib/
|
||||
# normalize to stable name if needed
|
||||
[ -f lib/libonnxruntime.so ] || cp -v lib/libonnxruntime.so.* lib/libonnxruntime.so
|
||||
|
||||
- name: Fetch ONNX Runtime (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
set -euo pipefail
|
||||
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
||||
tar -xzf ort.tgz
|
||||
cp -v onnxruntime-*/lib/libonnxruntime*.dylib lib/
|
||||
[ -f lib/libonnxruntime.dylib ] || cp -v lib/libonnxruntime*.dylib lib/libonnxruntime.dylib
|
||||
|
||||
- name: Fetch ONNX Runtime (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
||||
Invoke-WebRequest $url -OutFile ort.zip
|
||||
Expand-Archive ort.zip -DestinationPath ort
|
||||
$dll = Get-ChildItem -Recurse -Path ort -Filter onnxruntime.dll | Select-Object -First 1
|
||||
Copy-Item $dll.FullName lib\onnxruntime.dll
|
||||
|
||||
- run: |
|
||||
echo "=== BUILD-SETUP END ==="
|
||||
echo "lib/ contents:"
|
||||
ls -l lib || dir lib
|
||||
# ===== END: Injected ORT staging =====
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
${{ matrix.packages_install }}
|
||||
|
||||
- name: Build artifacts
|
||||
run: |
|
||||
dist build ${{ needs.plan.outputs.tag-flag }} --print=linkage --output-format=json ${{ matrix.dist_args }} > dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
|
||||
- id: cargo-dist
|
||||
name: Post-build
|
||||
shell: bash
|
||||
run: |
|
||||
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
||||
dist print-upload-files-from-manifest --manifest dist-manifest.json >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
||||
|
||||
- name: Upload artifacts
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-build-local-${{ join(matrix.targets, '_') }}
|
||||
name: release-${{ matrix.triple }}
|
||||
path: |
|
||||
${{ steps.cargo-dist.outputs.paths }}
|
||||
${{ env.BUILD_MANIFEST_NAME }}
|
||||
main-${{ matrix.triple }}.*
|
||||
|
||||
build_and_push_docker_image:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
needs: [plan]
|
||||
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
||||
name: Build and Push Docker Image (Nix)
|
||||
runs-on: ubuntu-24.04
|
||||
needs: [ci]
|
||||
if: ${{ needs.ci.outputs.publishing == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
packages: write
|
||||
actions: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Install Nix
|
||||
uses: DeterminateSystems/determinate-nix-action@v3
|
||||
|
||||
- uses: nix-community/cache-nix-action@v7
|
||||
with:
|
||||
primary-key: nix-${{ runner.os }}-${{ hashFiles('**/*.nix', '**/flake.lock', 'Cargo.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-
|
||||
gc-max-store-size-linux: 10G
|
||||
save: false
|
||||
|
||||
- name: Build Docker image with Nix
|
||||
run: nix build .#dockerImage -L --show-trace
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
@@ -209,133 +147,84 @@ jobs:
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract Docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/${{ github.repository }}
|
||||
- name: Load and push Docker image
|
||||
env:
|
||||
IMAGE_NAME: ghcr.io/${{ github.repository }}
|
||||
IMAGE_TAG: ${{ needs.ci.outputs.tag }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
LOADED_IMAGE="$(docker load < result | awk '/Loaded image:/ {print $3; exit}')"
|
||||
if [ -z "$LOADED_IMAGE" ]; then
|
||||
echo "failed to load docker image from nix result" >&2
|
||||
exit 1
|
||||
fi
|
||||
docker tag "$LOADED_IMAGE" "$IMAGE_NAME:$IMAGE_TAG"
|
||||
docker tag "$LOADED_IMAGE" "$IMAGE_NAME:latest"
|
||||
docker push "$IMAGE_NAME:$IMAGE_TAG"
|
||||
docker push "$IMAGE_NAME:latest"
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
build-global-artifacts:
|
||||
needs: [plan, build-local-artifacts]
|
||||
runs-on: ubuntu-22.04
|
||||
release:
|
||||
name: Create GitHub Release
|
||||
needs: [ci, build-nix-artifacts, build_and_push_docker_image]
|
||||
if: ${{ needs.ci.outputs.publishing == 'true' }}
|
||||
runs-on: ubuntu-24.04
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
BUILD_MANIFEST_NAME: target/distrib/global-dist-manifest.json
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install cached dist
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/
|
||||
- run: chmod +x ~/.cargo/bin/dist
|
||||
|
||||
- name: Fetch local artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
- id: cargo-dist
|
||||
shell: bash
|
||||
run: |
|
||||
dist build ${{ needs.plan.outputs.tag-flag }} --output-format=json "--artifacts=global" > dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
||||
jq --raw-output ".upload_files[]" dist-manifest.json >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-build-global
|
||||
path: |
|
||||
${{ steps.cargo-dist.outputs.paths }}
|
||||
${{ env.BUILD_MANIFEST_NAME }}
|
||||
|
||||
host:
|
||||
needs: [plan, build-local-artifacts, build-global-artifacts]
|
||||
if: ${{ always() && needs.plan.outputs.publishing == 'true' && (needs.build-global-artifacts.result == 'skipped' || needs.build-global-artifacts.result == 'success') && (needs.build-local-artifacts.result == 'skipped' || needs.build-local-artifacts.result == 'success') }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
val: ${{ steps.host.outputs.manifest }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install cached dist
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/
|
||||
- run: chmod +x ~/.cargo/bin/dist
|
||||
|
||||
- name: Fetch artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
- id: host
|
||||
shell: bash
|
||||
run: |
|
||||
dist host ${{ needs.plan.outputs.tag-flag }} --steps=upload --steps=release --output-format=json > dist-manifest.json
|
||||
echo "artifacts uploaded and released successfully"
|
||||
cat dist-manifest.json
|
||||
echo "manifest=$(jq -c . dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Upload dist-manifest.json
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-dist-manifest
|
||||
path: dist-manifest.json
|
||||
|
||||
- name: Download GitHub Artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
pattern: release-*
|
||||
path: artifacts
|
||||
merge-multiple: true
|
||||
|
||||
- name: Cleanup
|
||||
run: rm -f artifacts/*-dist-manifest.json
|
||||
- name: Flatten artifacts
|
||||
run: find artifacts -type f -exec mv {} . \;
|
||||
|
||||
- name: Prepare release notes
|
||||
env:
|
||||
VERSION: ${{ needs.ci.outputs.tag }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if grep -q "^## ${VERSION} (" CHANGELOG.md; then
|
||||
awk -v ver="$VERSION" '
|
||||
/^## / { if (found) exit; if ($0 ~ "^## " ver " \\(") found=1; next }
|
||||
found { print }
|
||||
' CHANGELOG.md > "$RUNNER_TEMP/notes.txt"
|
||||
else
|
||||
awk '
|
||||
/^## Unreleased/ { found=1; next }
|
||||
found && /^## [0-9]/ { exit }
|
||||
found { print }
|
||||
' CHANGELOG.md > "$RUNNER_TEMP/notes.txt"
|
||||
fi
|
||||
if [ ! -s "$RUNNER_TEMP/notes.txt" ]; then
|
||||
echo "Release ${VERSION}" > "$RUNNER_TEMP/notes.txt"
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release
|
||||
env:
|
||||
PRERELEASE_FLAG: "${{ fromJson(steps.host.outputs.manifest).announcement_is_prerelease && '--prerelease' || '' }}"
|
||||
ANNOUNCEMENT_TITLE: "${{ fromJson(steps.host.outputs.manifest).announcement_title }}"
|
||||
ANNOUNCEMENT_BODY: "${{ fromJson(steps.host.outputs.manifest).announcement_github_body }}"
|
||||
RELEASE_COMMIT: "${{ github.sha }}"
|
||||
TAG: ${{ needs.ci.outputs.tag }}
|
||||
PRERELEASE_FLAG: ${{ contains(needs.ci.outputs.tag, 'alpha') || contains(needs.ci.outputs.tag, 'beta') || contains(needs.ci.outputs.tag, 'rc') && '--prerelease' || '' }}
|
||||
run: |
|
||||
echo "$ANNOUNCEMENT_BODY" > $RUNNER_TEMP/notes.txt
|
||||
gh release create "${{ needs.plan.outputs.tag }}" --target "$RELEASE_COMMIT" $PRERELEASE_FLAG --title "$ANNOUNCEMENT_TITLE" --notes-file "$RUNNER_TEMP/notes.txt" artifacts/*
|
||||
|
||||
announce:
|
||||
needs: [plan, host]
|
||||
if: ${{ always() && needs.host.result == 'success' }}
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
set -euo pipefail
|
||||
FILES=()
|
||||
for f in main-*; do
|
||||
[ -f "$f" ] || continue
|
||||
FILES+=("$f")
|
||||
done
|
||||
if [ "${#FILES[@]}" -eq 0 ]; then
|
||||
echo "no release artifacts found" >&2
|
||||
ls -la
|
||||
exit 1
|
||||
fi
|
||||
gh release create "$TAG" \
|
||||
--target "${{ github.sha }}" \
|
||||
--title "minne $TAG" \
|
||||
--notes-file "$RUNNER_TEMP/notes.txt" \
|
||||
$PRERELEASE_FLAG \
|
||||
"${FILES[@]}"
|
||||
|
||||
+43
-1
@@ -1,7 +1,33 @@
|
||||
# Changelog
|
||||
|
||||
## Unreleased
|
||||
|
||||
|
||||
## 1.0.5 (2026-06-24)
|
||||
|
||||
- Infra: CI workflow fixes. CI is now a nix flake check which includes compilation, caching and running tests, clippy, fmt, validation for ort version.
|
||||
- Docker-compose: The example now references the ghcr image, this is so we can remove the Dockerfile and reducing maintenance scope.
|
||||
- Refactor: web scraping now uses `servo-fetch` (pure-Rust Servo engine) and PDF rendering uses `pdfium-render` (direct PDFium bindings) — reduces Docker image size by ~300MB, improves startup latency by ~100× for PDF rendering, and provides more stable output
|
||||
- Fix: added `pkgs.libglvnd` to `LD_LIBRARY_PATH` in devenv so Servo engine can find `libEGL.so` at runtime
|
||||
- Fix: updated Dockerfile to add `libegl1 libegl-mesa0 libgles2 libfontconfig1 libfreetype6` runtime dependencies for servo-fetch
|
||||
- Docs: updated architecture, features, and installation docs to reflect the new web processing stack
|
||||
- Fix: added pre-commit hooks to further maintain code consistency.
|
||||
- Security: updated some deps because dependabot told me, good bot.
|
||||
- Refactor: deduplicated test database setup across common/src/storage/.
|
||||
- Refactor: split knowledge-graph.js monolith into focused functions.
|
||||
- Evaluations: simplified crate layout — linear pipeline, sharded-only converted store, in-memory ingestion, `db/` and `cli/` modules; namespace reuse state in corpus manifest (removed `cache/snapshots/`); no legacy JSON/history compatibility (re-run `--warm` after upgrade)
|
||||
- Performance: ingestion skips per-task index rebuild; worker runs scheduled `REBUILD INDEX` (default every 24h via `index_rebuild_interval_secs`, `0` disables)
|
||||
- Performance: ingestion persists all artifacts in a single SurrealDB transaction per task (atomic replace by task id)
|
||||
- Performance: entity embeddings during ingestion use batched `embed_batch`, matching chunk embedding
|
||||
- Fix: ingestion reclaims tasks after a successful persist without re-running the pipeline when `mark_succeeded` failed
|
||||
- Fix: content deletion clears graph relationships via shared `TextContent::clear_ingested_children`
|
||||
- Fix: regression re suggestion of relationships
|
||||
- Internal: extracted duplicate entity+embedding patterns into `HasEmbedding` and `EmbeddingRecord` traits with generic `store_with_embedding`, `delete_by_source_id`, and `vector_search` on `SurrealDbClient`.
|
||||
- Infra: `ort-version` file removed — version inlined in `flake.nix` and `devenv.nix`; `release.yml` reads it via `nix eval .#lib.ortVersion` from the plan job
|
||||
- Infra: `screenshot-graph.webp` and `.dockerignore` deleted — stale artifacts from Dockerfile era
|
||||
|
||||
## 1.0.3 (2026-06-12)
|
||||
|
||||
- Search: filter results by type — knowledge entities, ingested content, or both
|
||||
- Admin: choose the local FastEmbed model from the admin UI; changes save immediately and apply after restart (re-embeds when the vector dimension changes)
|
||||
- Performance: pooled FastEmbed workers and batched embedding generation for faster ingestion and search
|
||||
@@ -11,6 +37,7 @@
|
||||
- Fix: API key revocation now correctly clears the stored key
|
||||
|
||||
## 1.0.2 (2026-02-15)
|
||||
|
||||
- Fix: edge case where navigation back to a chat page could trigger a new response generation
|
||||
- Fix: chat references now validate and render more reliably
|
||||
- Fix: improved admin access checks for restricted routes
|
||||
@@ -19,73 +46,88 @@
|
||||
- Security: hardened query handling and ingestion logging to reduce injection and data exposure risk
|
||||
|
||||
## 1.0.1 (2026-02-11)
|
||||
|
||||
- Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments.
|
||||
- Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling.
|
||||
- Fixed edge cases, including content deletion behavior and compatibility for older user records.
|
||||
|
||||
## 1.0.0 (2026-01-02)
|
||||
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
|
||||
|
||||
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
|
||||
- Added a benchmarks create for evaluating the retrieval process
|
||||
- Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms.
|
||||
- Embeddings stored on own table.
|
||||
- Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details.
|
||||
|
||||
## Version 0.2.7 (2025-12-04)
|
||||
|
||||
- Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features.
|
||||
- Fix: timezone aware info in scratchpad
|
||||
|
||||
## Version 0.2.6 (2025-10-29)
|
||||
|
||||
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||
- Fix: default name for relationships harmonized across application
|
||||
|
||||
## Version 0.2.5 (2025-10-24)
|
||||
|
||||
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
||||
- Scratchpad feature, with the feature to convert scratchpads to content.
|
||||
- Added knowledge entity search results to the global search
|
||||
- Backend fixes for improved performance when ingesting and retrieval
|
||||
|
||||
## Version 0.2.4 (2025-10-15)
|
||||
|
||||
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
|
||||
- Ingestion task archive
|
||||
|
||||
## Version 0.2.3 (2025-10-12)
|
||||
|
||||
- Fix changing vector dimensions on a fresh database (#3)
|
||||
|
||||
## Version 0.2.2 (2025-10-07)
|
||||
|
||||
- Support for ingestion of PDF files
|
||||
- Improved ingestion speed
|
||||
- Fix deletion of items work as expected
|
||||
- Fix enabling GPT-5 use via OpenAI API
|
||||
|
||||
## Version 0.2.1 (2025-09-24)
|
||||
|
||||
- Fixed API JSON responses so iOS Shortcuts integrations keep working.
|
||||
|
||||
## Version 0.2.0 (2025-09-23)
|
||||
|
||||
- Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph.
|
||||
- Added pagination for entities and content plus new observability metrics on the dashboard.
|
||||
- Enabled audio ingestion and merged the new storage backend.
|
||||
- Improved performance, request filtering, and journalctl/systemd compatibility.
|
||||
|
||||
## Version 0.1.4 (2025-07-01)
|
||||
|
||||
- Added image ingestion with configurable system settings and updated Docker Compose docs.
|
||||
- Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses.
|
||||
|
||||
## Version 0.1.3 (2025-06-08)
|
||||
|
||||
- Added support for AI providers beyond OpenAI.
|
||||
- Made the HTTP port configurable for deployments.
|
||||
- Smoothed graph mapper failures, long content tiles, and refreshed project documentation.
|
||||
|
||||
## Version 0.1.2 (2025-05-26)
|
||||
|
||||
- Introduced full-text search across indexed knowledge.
|
||||
- Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling.
|
||||
- Fixed search result links and SurrealDB vector formatting glitches.
|
||||
|
||||
## Version 0.1.1 (2025-05-13)
|
||||
|
||||
- Added streaming feedback to ingestion tasks for clearer progress updates.
|
||||
- Made the data storage path configurable.
|
||||
- Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes.
|
||||
|
||||
## Version 0.1.0 (2025-05-06)
|
||||
|
||||
- Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage.
|
||||
- Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts.
|
||||
- Introduced an admin console with analytics, registration and timezone controls, and job monitoring.
|
||||
|
||||
Generated
+6203
-364
File diff suppressed because it is too large
Load Diff
+26
-13
@@ -7,19 +7,24 @@ members = [
|
||||
"ingestion-pipeline",
|
||||
"retrieval-pipeline",
|
||||
"json-stream-parser",
|
||||
"evaluations"
|
||||
"evaluations",
|
||||
]
|
||||
resolver = "2"
|
||||
resolver = "3"
|
||||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1.0.94"
|
||||
async-openai = "0.29.3"
|
||||
async-openai = { version = "0.41.1", features = [
|
||||
"chat-completion",
|
||||
"embedding",
|
||||
"audio",
|
||||
"model",
|
||||
] }
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1.88"
|
||||
axum-htmx = "0.7.0"
|
||||
axum_session = "0.16"
|
||||
axum_session_auth = "0.16"
|
||||
axum_session_surreal = "0.4"
|
||||
axum_session = "0.18"
|
||||
axum_session_auth = "0.18"
|
||||
axum_session_surreal = "0.6"
|
||||
axum_typed_multipart = "0.16"
|
||||
axum = { version = "0.8", features = ["multipart", "macros"] }
|
||||
chrono-tz = "0.10.1"
|
||||
@@ -27,7 +32,6 @@ chrono = { version = "0.4.39", features = ["serde"] }
|
||||
config = "0.15.4"
|
||||
dom_smoothie = "0.10.0"
|
||||
futures = "0.3.31"
|
||||
headless_chrome = "1.0.17"
|
||||
include_dir = "0.7.4"
|
||||
mime = "0.3.17"
|
||||
mime_guess = "2.0.5"
|
||||
@@ -35,12 +39,12 @@ minijinja-autoreload = "2.5.0"
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
minijinja-embed = { version = "2.8.0" }
|
||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
reqwest = {version = "0.12.12", features = ["charset", "json"]}
|
||||
reqwest = { version = "0.12.12", features = ["charset", "json"] }
|
||||
serde_json = "1.0.128"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
sha2 = "0.10.8"
|
||||
surrealdb-migrations = "2.2.2"
|
||||
surrealdb = { version = "2" }
|
||||
surrealdb-migrations = "2.4.0"
|
||||
surrealdb = { version = "2.6" }
|
||||
tempfile = "3.12.0"
|
||||
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
|
||||
tokenizers = { version = "0.20.4", features = ["http"] }
|
||||
@@ -61,14 +65,24 @@ bytes = "1.7.1"
|
||||
state-machines = "0.9"
|
||||
pdf-extract = "0.9"
|
||||
lopdf = "0.32"
|
||||
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
|
||||
pdfium-auto = "0.3"
|
||||
pdfium-render = "0.8"
|
||||
servo-fetch = "0.13"
|
||||
tendril = "0.4"
|
||||
image = { version = "0.25", default-features = false, features = ["png"] }
|
||||
fastembed = { version = "5.2.0", default-features = false, features = [
|
||||
"hf-hub-native-tls",
|
||||
"ort-load-dynamic",
|
||||
] }
|
||||
|
||||
[profile.dist]
|
||||
inherits = "release"
|
||||
lto = "thin"
|
||||
|
||||
[workspace.lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ["cfg(feature, values(\"inspect\"))"] }
|
||||
unexpected_cfgs = { level = "warn", check-cfg = [
|
||||
"cfg(feature, values(\"inspect\"))",
|
||||
] }
|
||||
|
||||
[workspace.lints.clippy]
|
||||
# Performance-focused lints
|
||||
@@ -118,4 +132,3 @@ needless_raw_string_hashes = "allow"
|
||||
multiple_bound_locations = "allow"
|
||||
cargo_common_metadata = "allow"
|
||||
multiple-crate-versions = "allow"
|
||||
|
||||
|
||||
-53
@@ -1,53 +0,0 @@
|
||||
# === Builder ===
|
||||
FROM rust:1.91.1-bookworm AS builder
|
||||
WORKDIR /usr/src/minne
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cache deps
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
RUN mkdir -p api-router common retrieval-pipeline html-router ingestion-pipeline json-stream-parser main worker
|
||||
COPY api-router/Cargo.toml ./api-router/
|
||||
COPY common/Cargo.toml ./common/
|
||||
COPY retrieval-pipeline/Cargo.toml ./retrieval-pipeline/
|
||||
COPY html-router/Cargo.toml ./html-router/
|
||||
COPY ingestion-pipeline/Cargo.toml ./ingestion-pipeline/
|
||||
COPY json-stream-parser/Cargo.toml ./json-stream-parser/
|
||||
COPY main/Cargo.toml ./main/
|
||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker || true
|
||||
|
||||
# Build
|
||||
COPY . .
|
||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker
|
||||
|
||||
# === Runtime ===
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Chromium + runtime deps + OpenMP for ORT
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
chromium libnss3 libasound2 libgbm1 libxshmfence1 \
|
||||
ca-certificates fonts-dejavu fonts-noto-color-emoji \
|
||||
libgomp1 libstdc++6 curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ONNX Runtime (CPU). Version is read from ort-version (override with --build-arg ORT_VERSION=...).
|
||||
COPY ort-version /tmp/ort-version
|
||||
ARG ORT_VERSION
|
||||
RUN ORT_VERSION="${ORT_VERSION:-$(tr -d '[:space:]' < /tmp/ort-version)}" && \
|
||||
mkdir -p /opt/onnxruntime && \
|
||||
curl -fsSL -o /tmp/ort.tgz \
|
||||
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
||||
tar -xzf /tmp/ort.tgz -C /opt/onnxruntime --strip-components=1 && rm /tmp/ort.tgz
|
||||
|
||||
ENV CHROME_BIN=/usr/bin/chromium \
|
||||
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \
|
||||
ORT_DYLIB_PATH=/opt/onnxruntime/lib/libonnxruntime.so
|
||||
|
||||
# Non-root
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
WORKDIR /home/appuser
|
||||
|
||||
COPY --from=builder /usr/src/minne/target/release/main /usr/local/bin/main
|
||||
EXPOSE 3000
|
||||
CMD ["main"]
|
||||
@@ -121,7 +121,7 @@ fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults t
|
||||
- **Frontend:** HTML with HTMX and minimal JavaScript for interactivity
|
||||
- **Database:** SurrealDB (graph, document, and vector search)
|
||||
- **AI Integration:** OpenAI-compatible API with structured outputs
|
||||
- **Web Processing:** Headless Chrome for robust webpage content extraction
|
||||
- **Web Processing:** Embedded Servo engine (servo-fetch) for webpage content extraction + PDFium for PDF rendering
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -172,7 +172,7 @@ cd minne
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
The included `docker-compose.yml` handles SurrealDB and Chromium dependencies automatically.
|
||||
The included `docker-compose.yml` handles SurrealDB automatically.
|
||||
|
||||
### 2. Nix
|
||||
|
||||
@@ -180,13 +180,13 @@ The included `docker-compose.yml` handles SurrealDB and Chromium dependencies au
|
||||
nix run 'github:perstarkse/minne#main'
|
||||
```
|
||||
|
||||
This fetches Minne and all dependencies, including Chromium.
|
||||
This fetches Minne and all dependencies.
|
||||
|
||||
### 3. Pre-built Binaries
|
||||
|
||||
Download binaries for Windows, macOS, and Linux from the [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||
|
||||
**Requirements:** You'll need to provide SurrealDB and Chromium separately.
|
||||
**Requirements:** You'll need to provide SurrealDB separately.
|
||||
|
||||
### 4. Build from Source
|
||||
|
||||
@@ -196,7 +196,7 @@ cd minne
|
||||
cargo run --release --bin main
|
||||
```
|
||||
|
||||
**Requirements:** SurrealDB and Chromium must be installed and accessible in your PATH.
|
||||
**Requirements:** SurrealDB must be installed and accessible in your PATH.
|
||||
|
||||
## Application Architecture
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "api-router"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use axum::{
|
||||
Json,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use common::error::AppError;
|
||||
use serde::Serialize;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use api_state::ApiState;
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{DefaultBodyLimit, FromRef},
|
||||
middleware::from_fn_with_state,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use middleware_api_auth::api_auth;
|
||||
use routes::{categories::list, ingest::handle, liveness::live, readiness::ready};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, response::IntoResponse, Extension, Json};
|
||||
use axum::{Extension, Json, extract::State, response::IntoResponse};
|
||||
use common::storage::types::user::User;
|
||||
|
||||
use crate::{api_state::ApiState, error::ApiErr};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
|
||||
use axum::{Extension, Json, extract::State, http::StatusCode, response::IntoResponse};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
use common::{
|
||||
error::AppError,
|
||||
@@ -6,9 +6,9 @@ use common::{
|
||||
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
|
||||
user::User,
|
||||
},
|
||||
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
|
||||
utils::ingest_limits::{IngestValidationError, validate_ingest_input},
|
||||
};
|
||||
use futures::{future::try_join_all, TryFutureExt};
|
||||
use futures::{TryFutureExt, future::try_join_all};
|
||||
use serde_json::json;
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::info;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{http::StatusCode, response::IntoResponse, Json};
|
||||
use axum::{Json, http::StatusCode, response::IntoResponse};
|
||||
use serde_json::json;
|
||||
|
||||
/// Liveness probe: always returns 200 to indicate the process is running.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||
use axum::{Json, extract::State, http::StatusCode, response::IntoResponse};
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ use std::sync::Arc;
|
||||
|
||||
use api_router::{api_routes_v1, api_state::ApiState};
|
||||
use axum::{
|
||||
body::{to_bytes, Body},
|
||||
http::{Request, StatusCode},
|
||||
Router,
|
||||
body::{Body, to_bytes},
|
||||
http::{Request, StatusCode},
|
||||
};
|
||||
use common::{
|
||||
storage::{db::SurrealDbClient, store::StorageManager, types::user::User},
|
||||
|
||||
+2
-2
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "common"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
@@ -24,7 +24,7 @@ dom_smoothie = { workspace = true }
|
||||
axum_session = { workspace = true }
|
||||
axum_session_auth = { workspace = true }
|
||||
axum_session_surreal = { workspace = true}
|
||||
axum_typed_multipart = { workspace = true}
|
||||
axum_typed_multipart = { workspace = true}
|
||||
include_dir = { workspace = true }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-autoreload = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Track scheduled runtime index rebuild state on the system_settings singleton.
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -201,6 +201,10 @@\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;\n+DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;\n+DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;\n+DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}
|
||||
@@ -14,3 +14,7 @@ DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS last_index_rebuild_at ON system_settings TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_owner ON system_settings TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS index_rebuild_lease_expires_at ON system_settings TYPE option<datetime>;
|
||||
|
||||
+160
-122
@@ -1,14 +1,16 @@
|
||||
use super::types::StoredObject;
|
||||
use super::types::{EmbeddingRecord, HasEmbedding, StoredObject};
|
||||
use crate::error::AppError;
|
||||
use axum_session::{SessionConfig, SessionError, SessionStore};
|
||||
use axum_session_surreal::SessionSurrealPool;
|
||||
use futures::Stream;
|
||||
use include_dir::{include_dir, Dir};
|
||||
use include_dir::{Dir, include_dir};
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
use surrealdb::{
|
||||
engine::any::{connect, Any},
|
||||
opt::auth::{Namespace, Root},
|
||||
Error, Notification, Surreal,
|
||||
engine::any::{Any, connect},
|
||||
opt::auth::{Namespace, Root},
|
||||
};
|
||||
use surrealdb_migrations::MigrationRunner;
|
||||
use tracing::debug;
|
||||
@@ -26,20 +28,6 @@ pub trait ProvidesDb {
|
||||
}
|
||||
|
||||
impl SurrealDbClient {
|
||||
/// Initialize a new database client.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
|
||||
/// * `username` — Root username for authentication.
|
||||
/// * `password` — Root password for authentication.
|
||||
/// * `namespace` — SurrealDB namespace to use.
|
||||
/// * `database` — SurrealDB database to use.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
|
||||
/// In-memory (`mem://`) connections skip authentication.
|
||||
pub async fn new(
|
||||
address: &str,
|
||||
username: &str,
|
||||
@@ -49,30 +37,15 @@ impl SurrealDbClient {
|
||||
) -> Result<Self, Error> {
|
||||
let db = connect(address).await?;
|
||||
|
||||
// Skip sign-in for in-memory engine (no auth support)
|
||||
if !address.starts_with("mem://") {
|
||||
db.signin(Root { username, password }).await?;
|
||||
}
|
||||
|
||||
// Set namespace
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
/// Initialize a new database client using namespace-level authentication.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` — Database connection string.
|
||||
/// * `namespace` — SurrealDB namespace to use (also used for auth).
|
||||
/// * `username` — Namespace username for authentication.
|
||||
/// * `password` — Namespace password for authentication.
|
||||
/// * `database` — SurrealDB database to use.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
|
||||
pub async fn new_with_namespace_user(
|
||||
address: &str,
|
||||
namespace: &str,
|
||||
@@ -91,11 +64,6 @@ impl SurrealDbClient {
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
/// Create an Axum session store backed by SurrealDB.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `SessionError` if the session store configuration or table creation fails.
|
||||
pub async fn create_session_store(
|
||||
&self,
|
||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||
@@ -109,15 +77,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Applies all pending database migrations found in the embedded MIGRATIONS_DIR.
|
||||
///
|
||||
/// This function should be called during application startup, after connecting to
|
||||
/// the database and selecting the appropriate namespace and database, but before
|
||||
/// the application starts performing operations that rely on the schema.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
|
||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||
debug!("Applying migrations");
|
||||
MigrationRunner::new(&self.client)
|
||||
@@ -129,15 +88,6 @@ impl SurrealDbClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Store an object in SurrealDB.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` — The item to store. Must implement `StoredObject`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database create operation fails.
|
||||
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
@@ -148,13 +98,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
|
||||
///
|
||||
/// Useful for idempotent ingestion flows.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database upsert operation fails.
|
||||
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
@@ -166,11 +109,6 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retrieve all objects from a table.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database select operation fails.
|
||||
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -178,16 +116,6 @@ impl SurrealDbClient {
|
||||
self.client.select(T::table_name()).await
|
||||
}
|
||||
|
||||
/// Retrieve a single object by its ID.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` — The ID of the item to retrieve.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database select operation fails.
|
||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -195,16 +123,6 @@ impl SurrealDbClient {
|
||||
self.client.select((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Delete a single object by its ID.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` — The ID of the item to delete.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database delete operation fails.
|
||||
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: for<'de> StoredObject,
|
||||
@@ -212,11 +130,6 @@ impl SurrealDbClient {
|
||||
self.client.delete((T::table_name(), id)).await
|
||||
}
|
||||
|
||||
/// Listen to a table for real-time updates via a live query stream.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the database live query subscription fails.
|
||||
pub async fn listen<T>(
|
||||
&self,
|
||||
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
||||
@@ -225,6 +138,156 @@ impl SurrealDbClient {
|
||||
{
|
||||
self.client.select(T::table_name()).live().await
|
||||
}
|
||||
|
||||
/// Atomically store an entity and its embedding vector in a single
|
||||
/// SurrealDB transaction.
|
||||
///
|
||||
/// Creates (or overwrites) the entity row and upserts the linked
|
||||
/// embedding record. The embedding dimension is validated against
|
||||
/// `embedding_dimensions` before the query is issued.
|
||||
pub async fn store_with_embedding<E>(
|
||||
&self,
|
||||
entity: E,
|
||||
embedding: Vec<f32>,
|
||||
embedding_dimensions: usize,
|
||||
) -> Result<(), AppError>
|
||||
where
|
||||
E: HasEmbedding + Serialize + Send + Sync + 'static,
|
||||
<E as HasEmbedding>::Embedding: Serialize + Send + Sync,
|
||||
{
|
||||
E::Embedding::validate_dimension(&embedding, embedding_dimensions)?;
|
||||
|
||||
let entity_id = entity.id().to_string();
|
||||
let emb = <E as HasEmbedding>::Embedding::new(
|
||||
&entity_id,
|
||||
entity.source_id().to_string(),
|
||||
embedding,
|
||||
entity.user_id().to_string(),
|
||||
E::table_name(),
|
||||
);
|
||||
|
||||
let sql = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{et}', $id) CONTENT $entity;
|
||||
UPSERT type::thing('{emt}', $id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
et = E::table_name(),
|
||||
emt = <E as HasEmbedding>::Embedding::table_name(),
|
||||
);
|
||||
|
||||
self.client
|
||||
.query(sql)
|
||||
.bind(("id", entity_id))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb", emb))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all entity and embedding rows matching a given `source_id`.
|
||||
///
|
||||
/// Runs inside a SurrealDB transaction so that entity and embedding
|
||||
/// deletes are atomic.
|
||||
pub async fn delete_by_source_id<E>(&self, source_id: &str) -> Result<(), AppError>
|
||||
where
|
||||
E: HasEmbedding,
|
||||
E::Embedding: Send + Sync,
|
||||
{
|
||||
self.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
E::Embedding::table_name()
|
||||
))
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
E::table_name()
|
||||
))
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector similarity search over entities using HNSW index.
|
||||
///
|
||||
/// Performs a cosine-similarity search against the embedding table,
|
||||
/// fetches the corresponding entity rows server-side via `FETCH`,
|
||||
/// and returns `(entity, score)` pairs ordered by descending
|
||||
/// similarity. Orphaned embeddings (entity deleted but its
|
||||
/// embedding row remains) are logged as a warning and dropped.
|
||||
///
|
||||
/// This is a single round-trip — SurrealDB resolves the link field
|
||||
/// (`entity_id` or `chunk_id`) inside the query engine.
|
||||
pub async fn vector_search<E, Emb>(
|
||||
&self,
|
||||
take: usize,
|
||||
query_embedding: &[f32],
|
||||
user_id: &str,
|
||||
) -> Result<Vec<(E, f32)>, AppError>
|
||||
where
|
||||
E: StoredObject + DeserializeOwned + Clone + Send + Sync,
|
||||
Emb: EmbeddingRecord + Send + Sync,
|
||||
{
|
||||
// Generic row that works with both `entity_id` and `chunk_id` link
|
||||
// fields via `#[serde(alias)]`. SurrealDB's `FETCH` resolves the link
|
||||
// server-side so we get the full entity in a single round-trip.
|
||||
#[derive(serde::Deserialize)]
|
||||
struct FetchRow<Ent> {
|
||||
score: f32,
|
||||
#[serde(alias = "entity_id", alias = "chunk_id")]
|
||||
entity: Option<Ent>,
|
||||
}
|
||||
|
||||
let link_field = Emb::link_field();
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
{link_field},
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH {link_field}
|
||||
"#,
|
||||
link_field = link_field,
|
||||
emb_table = Emb::table_name(),
|
||||
take = take,
|
||||
);
|
||||
|
||||
let mut response = self
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?;
|
||||
|
||||
response = response.check()?;
|
||||
|
||||
let rows: Vec<FetchRow<E>> = response.take(0)?;
|
||||
|
||||
let mut results = Vec::with_capacity(rows.len());
|
||||
for r in rows {
|
||||
if let Some(entity) = r.entity {
|
||||
results.push((entity, r.score));
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Vector search hit orphaned {} row with missing {link_field}",
|
||||
Emb::table_name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SurrealDbClient {
|
||||
@@ -237,12 +300,9 @@ impl Deref for SurrealDbClient {
|
||||
|
||||
#[cfg(any(test, feature = "test-utils"))]
|
||||
impl SurrealDbClient {
|
||||
/// Create an in-memory SurrealDB client for testing.
|
||||
pub async fn memory(namespace: &str, database: &str) -> Result<Self, Error> {
|
||||
let db = connect("mem://").await?;
|
||||
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
}
|
||||
@@ -253,8 +313,7 @@ mod tests {
|
||||
use crate::stored_object;
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
stored_object!(Dummy, "dummy", {
|
||||
name: String
|
||||
@@ -262,15 +321,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialization_and_crud() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
@@ -314,15 +365,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
@@ -371,12 +414,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_applying_migrations() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to build indexes".to_string())?;
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::future::try_join_all;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Map, Value};
|
||||
use tracing::{debug, info, warn};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
|
||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||
const INDEX_BUILD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
||||
@@ -204,6 +208,9 @@ pub async fn ensure_runtime(
|
||||
|
||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||
///
|
||||
/// Uses `DEFINE INDEX OVERWRITE` and is reserved for dimension migrations, re-embed
|
||||
/// flows, and tests. Routine optimization should use [`rebuild_runtime`].
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if any index rebuild operation fails.
|
||||
@@ -211,6 +218,115 @@ pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_inner(db).await.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Rebuilds existing runtime FTS and HNSW indexes in place via SurrealQL `REBUILD INDEX`.
|
||||
///
|
||||
/// SurrealDB maintains ready indexes incrementally on writes; this is for periodic
|
||||
/// optimization (for example a nightly maintainer job), not ingest correctness.
|
||||
/// On SurrealDB 2.6 this runs synchronously (`CONCURRENTLY` is not supported on `REBUILD`).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `AppError::InternalError` if any rebuild operation fails.
|
||||
pub async fn rebuild_runtime(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_runtime_inner(db).await.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
/// Returns whether a scheduled index rebuild is due based on the persisted last-run time.
|
||||
#[must_use]
|
||||
pub fn scheduled_index_rebuild_due(
|
||||
last_run: Option<DateTime<Utc>>,
|
||||
interval_secs: u64,
|
||||
now: DateTime<Utc>,
|
||||
) -> bool {
|
||||
if interval_secs == 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let Some(last_run) = last_run else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let elapsed = now.signed_duration_since(last_run);
|
||||
elapsed.num_seconds() >= i64::try_from(interval_secs).unwrap_or(i64::MAX)
|
||||
}
|
||||
|
||||
/// Runs a scheduled native `REBUILD INDEX` pass when due, using a DB lock so only one
|
||||
/// maintainer rebuilds at a time. Seeds a checkpoint on first run so the initial rebuild
|
||||
/// waits one full interval after worker startup.
|
||||
pub async fn maybe_run_scheduled_index_rebuild(
|
||||
db: &SurrealDbClient,
|
||||
worker_id: &str,
|
||||
interval_secs: u64,
|
||||
) {
|
||||
if interval_secs == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
let settings = match SystemSettings::get_current(db).await {
|
||||
Ok(settings) => settings,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to load system settings for index rebuild schedule");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let last_run = settings.last_index_rebuild_at;
|
||||
|
||||
if last_run.is_none() {
|
||||
match SystemSettings::seed_index_rebuild_checkpoint(db).await {
|
||||
Ok(true) => debug!("seeded index rebuild checkpoint; first rebuild deferred"),
|
||||
Ok(false) => {}
|
||||
Err(err) => warn!(error = %err, "failed to seed index rebuild checkpoint"),
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !scheduled_index_rebuild_due(last_run, interval_secs, now) {
|
||||
return;
|
||||
}
|
||||
|
||||
let lock_owner = format!("{worker_id}-index-rebuild");
|
||||
let acquired = match SystemSettings::try_acquire_index_rebuild_lease(db, &lock_owner).await {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to acquire index rebuild lease");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !acquired {
|
||||
debug!("another maintainer is rebuilding indexes");
|
||||
return;
|
||||
}
|
||||
|
||||
let started = Instant::now();
|
||||
info!(interval_secs, "starting scheduled runtime index rebuild");
|
||||
let rebuild_result = rebuild_runtime(db).await;
|
||||
|
||||
match rebuild_result {
|
||||
Ok(()) => {
|
||||
if let Err(err) = SystemSettings::record_index_rebuild_completed(db, &lock_owner).await
|
||||
{
|
||||
warn!(error = %err, "failed to persist index rebuild checkpoint");
|
||||
SystemSettings::release_index_rebuild_lease(db, &lock_owner).await;
|
||||
}
|
||||
info!(
|
||||
elapsed_ms = started.elapsed().as_millis(),
|
||||
"scheduled runtime index rebuild completed"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
SystemSettings::release_index_rebuild_lease(db, &lock_owner).await;
|
||||
error!(
|
||||
error = %err,
|
||||
elapsed_ms = started.elapsed().as_millis(),
|
||||
"scheduled runtime index rebuild failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any.
|
||||
///
|
||||
/// Stored embeddings always share this index's dimension because re-embedding rewrites the
|
||||
@@ -382,6 +498,45 @@ async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||
}
|
||||
|
||||
async fn rebuild_runtime_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
debug!("Rebuilding runtime indexes with REBUILD INDEX");
|
||||
|
||||
for spec in fts_index_specs() {
|
||||
rebuild_existing_index_in_place(db, spec.index_name, spec.table).await?;
|
||||
}
|
||||
|
||||
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||
rebuild_existing_index_in_place(db, spec.index_name, spec.table).await
|
||||
});
|
||||
|
||||
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||
}
|
||||
|
||||
async fn rebuild_existing_index_in_place(
|
||||
db: &SurrealDbClient,
|
||||
index_name: &str,
|
||||
table: &str,
|
||||
) -> Result<()> {
|
||||
if !index_exists(db, table, index_name).await? {
|
||||
debug!(
|
||||
index = index_name,
|
||||
table, "Skipping in-place rebuild because index is missing"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let query = format!("REBUILD INDEX IF EXISTS {index_name} ON {table};");
|
||||
let res = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.with_context(|| format!("rebuilding index {index_name} on table {table}"))?;
|
||||
res.check()
|
||||
.with_context(|| format!("rebuild index {index_name} on table {table} failed"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn existing_hnsw_dimension(
|
||||
db: &SurrealDbClient,
|
||||
spec: &HnswIndexSpec,
|
||||
@@ -906,6 +1061,43 @@ mod tests {
|
||||
assert_eq!(extract_dimension(definition), Some(1536));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scheduled_index_rebuild_due_respects_interval_and_disabled() {
|
||||
let now = Utc::now();
|
||||
let last = now - chrono::Duration::hours(25);
|
||||
|
||||
assert!(!scheduled_index_rebuild_due(None, 86_400, now));
|
||||
assert!(!scheduled_index_rebuild_due(Some(last), 0, now));
|
||||
assert!(!scheduled_index_rebuild_due(
|
||||
Some(now - chrono::Duration::hours(1)),
|
||||
86_400,
|
||||
now
|
||||
));
|
||||
assert!(scheduled_index_rebuild_due(Some(last), 86_400, now));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rebuild_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_in_place_rebuild";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.context("in-memory db")?;
|
||||
|
||||
db.apply_migrations().await.context("migrations")?;
|
||||
ensure_runtime(&db, 8)
|
||||
.await
|
||||
.context("ensure runtime indexes")?;
|
||||
|
||||
rebuild_runtime(&db)
|
||||
.await
|
||||
.context("first in-place rebuild")?;
|
||||
rebuild_runtime(&db)
|
||||
.await
|
||||
.context("second in-place rebuild")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||
let namespace = "indexes_ns";
|
||||
|
||||
+82
-55
@@ -2,14 +2,14 @@ use std::io::ErrorKind;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Context, Result as AnyResult};
|
||||
use anyhow::{Context, Result as AnyResult, anyhow};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use object_store::aws::AmazonS3Builder;
|
||||
use object_store::local::LocalFileSystem;
|
||||
use object_store::memory::InMemory;
|
||||
use object_store::{path::Path as ObjPath, ObjectStore};
|
||||
use object_store::{ObjectStore, path::Path as ObjPath};
|
||||
|
||||
use crate::utils::config::{AppConfig, StorageKind};
|
||||
|
||||
@@ -461,9 +461,12 @@ pub mod testing {
|
||||
pub async fn new_s3() -> object_store::Result<Self> {
|
||||
// Ensure credentials are set for MinIO
|
||||
// We set these env vars for the process, which AmazonS3Builder will pick up
|
||||
std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
|
||||
std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
|
||||
std::env::set_var("AWS_REGION", "us-east-1");
|
||||
// SAFETY: test setup runs before concurrent S3 client use in this process.
|
||||
unsafe {
|
||||
std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
|
||||
std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
|
||||
std::env::set_var("AWS_REGION", "us-east-1");
|
||||
}
|
||||
|
||||
let cfg = test_config_s3();
|
||||
let storage = StorageManager::new(&cfg).await?;
|
||||
@@ -543,10 +546,10 @@ pub mod testing {
|
||||
impl Drop for TestStorageManager {
|
||||
fn drop(&mut self) {
|
||||
// Clean up temporary directories for local storage
|
||||
if let Some((_, path)) = &self.temp_dir {
|
||||
if path.exists() {
|
||||
let _ = std::fs::remove_dir_all(path);
|
||||
}
|
||||
if let Some((_, path)) = &self.temp_dir
|
||||
&& path.exists()
|
||||
{
|
||||
let _ = std::fs::remove_dir_all(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -690,20 +693,24 @@ mod tests {
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test exists
|
||||
assert!(storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists check".to_string())?);
|
||||
assert!(
|
||||
storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists check".to_string())?
|
||||
);
|
||||
|
||||
// Test delete
|
||||
storage
|
||||
.delete_prefix("test/data/")
|
||||
.await
|
||||
.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists check after delete".to_string())?);
|
||||
assert!(
|
||||
!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists check after delete".to_string())?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -741,20 +748,24 @@ mod tests {
|
||||
.with_context(|| "object directory exists after write".to_string())?;
|
||||
|
||||
// Test exists
|
||||
assert!(storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists check".to_string())?);
|
||||
assert!(
|
||||
storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists check".to_string())?
|
||||
);
|
||||
|
||||
// Test delete
|
||||
storage
|
||||
.delete_prefix("test/data/")
|
||||
.await
|
||||
.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists check after delete".to_string())?);
|
||||
assert!(
|
||||
!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists check after delete".to_string())?
|
||||
);
|
||||
assert!(
|
||||
tokio::fs::metadata(&object_dir).await.is_err(),
|
||||
"object directory should be removed"
|
||||
@@ -846,12 +857,16 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "list dir1".to_string())?;
|
||||
assert_eq!(dir1_files.len(), 2);
|
||||
assert!(dir1_files
|
||||
.iter()
|
||||
.any(|meta| meta.location.as_ref().contains("file1.txt")));
|
||||
assert!(dir1_files
|
||||
.iter()
|
||||
.any(|meta| meta.location.as_ref().contains("file2.txt")));
|
||||
assert!(
|
||||
dir1_files
|
||||
.iter()
|
||||
.any(|meta| meta.location.as_ref().contains("file1.txt"))
|
||||
);
|
||||
assert!(
|
||||
dir1_files
|
||||
.iter()
|
||||
.any(|meta| meta.location.as_ref().contains("file2.txt"))
|
||||
);
|
||||
|
||||
// Test listing non-existent prefix
|
||||
let empty_files = storage
|
||||
@@ -918,10 +933,12 @@ mod tests {
|
||||
.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
assert!(storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists".to_string())?);
|
||||
assert!(
|
||||
storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists".to_string())?
|
||||
);
|
||||
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
|
||||
|
||||
Ok(())
|
||||
@@ -975,10 +992,12 @@ mod tests {
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists".to_string())?);
|
||||
assert!(
|
||||
test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists".to_string())?
|
||||
);
|
||||
|
||||
// Test list
|
||||
let files = test_storage
|
||||
@@ -992,10 +1011,12 @@ mod tests {
|
||||
.delete_prefix("test/storage/")
|
||||
.await
|
||||
.with_context(|| "delete".to_string())?;
|
||||
assert!(!test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists after delete".to_string())?);
|
||||
assert!(
|
||||
!test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists after delete".to_string())?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1019,10 +1040,12 @@ mod tests {
|
||||
.with_context(|| "get".to_string())?;
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
assert!(test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists".to_string())?);
|
||||
assert!(
|
||||
test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.with_context(|| "exists".to_string())?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1119,20 +1142,24 @@ mod tests {
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test exists
|
||||
assert!(storage
|
||||
.exists(&location)
|
||||
.await
|
||||
.with_context(|| "exists".to_string())?);
|
||||
assert!(
|
||||
storage
|
||||
.exists(&location)
|
||||
.await
|
||||
.with_context(|| "exists".to_string())?
|
||||
);
|
||||
|
||||
// Test delete
|
||||
storage
|
||||
.delete_prefix(&format!("{prefix}/"))
|
||||
.await
|
||||
.with_context(|| "delete".to_string())?;
|
||||
assert!(!storage
|
||||
.exists(&location)
|
||||
.await
|
||||
.with_context(|| "exists after delete".to_string())?);
|
||||
assert!(
|
||||
!storage
|
||||
.exists(&location)
|
||||
.await
|
||||
.with_context(|| "exists after delete".to_string())?
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::storage::types::{user::User, StoredObject};
|
||||
use crate::storage::types::{StoredObject, user::User};
|
||||
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -108,6 +108,7 @@ mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use super::*;
|
||||
use crate::stored_object;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use anyhow::{self};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -120,10 +121,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_analytics_initialization() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Test initialization of analytics
|
||||
let analytics = Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
@@ -145,10 +143,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_get_current_analytics() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
@@ -165,10 +160,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
@@ -190,10 +182,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Initialize analytics
|
||||
Analytics::ensure_initialized(&db).await?;
|
||||
|
||||
@@ -214,11 +203,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_users_amount() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||
// Test with no users
|
||||
let count = Analytics::get_users_amount(&db).await?;
|
||||
assert_eq!(count, 0);
|
||||
@@ -246,10 +231,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors_without_prior_init() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let analytics = Analytics::increment_visitors(&db).await?;
|
||||
assert_eq!(analytics.visitors, 1);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
@@ -259,10 +241,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads_without_prior_init() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||
assert_eq!(analytics.page_loads, 1);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
@@ -272,10 +251,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_visitor_and_page_load_increments_are_independent() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let after_visitors = Analytics::increment_visitors(&db).await?;
|
||||
assert_eq!(after_visitors.visitors, 1);
|
||||
assert_eq!(after_visitors.page_loads, 0);
|
||||
@@ -293,10 +269,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_record_page_view() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let first_view = Analytics::record_page_view(&db, true).await?;
|
||||
assert_eq!(first_view.visitors, 1);
|
||||
assert_eq!(first_view.page_loads, 1);
|
||||
@@ -310,11 +283,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||
// Don't initialize analytics and try to get it
|
||||
let result = Analytics::get_current(&db).await;
|
||||
|
||||
|
||||
@@ -157,6 +157,7 @@ impl Conversation {
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use crate::storage::types::message::MessageRole;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
@@ -181,11 +182,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let title = "Test Conversation";
|
||||
@@ -214,11 +211,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let result =
|
||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||
@@ -234,11 +227,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id_1 = "user_1";
|
||||
let conversation =
|
||||
@@ -264,11 +253,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_success() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "user_1";
|
||||
let original_title = "Original Title";
|
||||
@@ -297,11 +282,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
||||
|
||||
@@ -316,11 +297,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user_id = "intruder";
|
||||
@@ -345,11 +322,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_user_sidebar_conversations_filters_and_orders_by_updated_at_desc() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await.expect("setup_test_db");
|
||||
|
||||
let user_id = "sidebar_user";
|
||||
let other_user_id = "other_user";
|
||||
@@ -398,11 +371,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sidebar_projection_reflects_patch_title_and_updated_at_reorder() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await.expect("setup_test_db");
|
||||
|
||||
let user_id = "sidebar_patch_user";
|
||||
let base = Utc::now();
|
||||
@@ -440,11 +409,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id_1 = "user_1";
|
||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||
@@ -527,11 +492,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sidebar_conversation_deserializes_id_from_db_record() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await.expect("setup_test_db");
|
||||
|
||||
let owner = "sidebar_owner";
|
||||
let conversation = Conversation::new(owner.to_string(), "Sidebar title".to_string());
|
||||
@@ -551,9 +512,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_query_filters_by_owner_user_id_in_sql() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner = "owner_user";
|
||||
let intruder = "intruder_user";
|
||||
@@ -590,9 +549,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_orders_messages_by_updated_at() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "order_user";
|
||||
let conversation = Conversation::new(user_id.to_string(), "Ordered".to_string());
|
||||
@@ -637,9 +594,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_not_found_when_conversation_deleted() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner = "owner";
|
||||
let conversation = Conversation::new(owner.to_string(), "To delete".to_string());
|
||||
|
||||
@@ -327,6 +327,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::store::testing::TestStorageManager;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::FieldMetadata;
|
||||
use std::{io::Write, path::Path};
|
||||
@@ -378,15 +379,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"This is a test file for StorageManager operations";
|
||||
let file_name = "storage_manager_test.txt";
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
@@ -435,15 +428,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"filename sanitization";
|
||||
let original_name = "Complex name (1).txt";
|
||||
let expected_sanitized = "Complex_name__1_.txt";
|
||||
@@ -470,15 +455,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_duplicate_detection_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"This is a test file for StorageManager duplicate detection";
|
||||
let file_name = "storage_manager_duplicate.txt";
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
@@ -538,15 +515,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_creation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"This is a test file content";
|
||||
let file_name = "test_file.txt";
|
||||
let field_data = create_test_file(content, file_name)?;
|
||||
@@ -585,15 +554,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_duplicate_detection() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// First, store a file with known content
|
||||
let content = b"This is a test file for duplicate detection";
|
||||
let file_name = "original.txt";
|
||||
@@ -692,12 +653,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", "user123", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
@@ -710,12 +666,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let db = setup_test_db().await.expect("setup test db");
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
@@ -740,15 +691,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_duplicate_detection_is_per_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let content = b"shared content across users";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
@@ -783,10 +726,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found_for_other_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
let now = Utc::now();
|
||||
let sha = "abc123sha";
|
||||
let owner = "owner_user";
|
||||
@@ -816,9 +756,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_with_storage_missing_file_name() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
|
||||
let field_data = create_test_file_without_name(b"data")?;
|
||||
@@ -832,9 +770,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_with_storage_empty_file() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let db = setup_test_db().await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
|
||||
let file_info = FileInfo::new_with_storage(
|
||||
@@ -856,10 +792,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_duplicate_upload_persists_single_row_per_user_sha() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
db.apply_migrations().await?;
|
||||
let db = setup_test_db().await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
let storage = test_storage.storage();
|
||||
let user_id = "dedup_user";
|
||||
@@ -901,12 +834,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manual_file_info_creation() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let db = setup_test_db().await.expect("setup test db");
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
@@ -939,15 +867,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Create and persist a test file via FileInfo::new_with_storage
|
||||
let user_id = "user123";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
@@ -985,12 +905,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_id_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Try to delete a file that doesn't exist
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
@@ -1006,12 +921,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Create a FileInfo instance directly
|
||||
let now = Utc::now();
|
||||
let file_id = Uuid::new_v4().to_string();
|
||||
@@ -1045,12 +955,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_id_not_found() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = setup_test_db().await?;
|
||||
// Try to retrieve a non-existent ID
|
||||
let non_existent_id = "non-existent-file-id";
|
||||
let result = FileInfo::get_by_id(non_existent_id, &db).await;
|
||||
|
||||
@@ -315,9 +315,10 @@ impl IngestionTask {
|
||||
"#;
|
||||
|
||||
debug_assert!(lifecycle::pending().reserve().is_ok());
|
||||
debug_assert!(lifecycle::pending().reserve().is_ok_and(|m| m
|
||||
.start_processing()
|
||||
.is_ok_and(|m| m.fail().is_ok_and(|m| m.reserve().is_ok()))));
|
||||
debug_assert!(lifecycle::pending().reserve().is_ok_and(|m| {
|
||||
m.start_processing()
|
||||
.is_ok_and(|m| m.fail().is_ok_and(|m| m.reserve().is_ok()))
|
||||
}));
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
@@ -630,6 +631,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
fn create_payload(user_id: &str) -> IngestionPayload {
|
||||
IngestionPayload::Text {
|
||||
@@ -641,11 +643,7 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn memory_db() -> anyhow::Result<SurrealDbClient> {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.with_context(|| "in-memory surrealdb".to_string())
|
||||
setup_test_db().await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -4,10 +4,13 @@ use std::fmt::Write;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::db::SurrealDbClient,
|
||||
storage::indexes::hnsw_index_overwrite_sql,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
storage::types::system_settings::SystemSettings,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::hnsw_index_overwrite_sql,
|
||||
types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
types::system_settings::SystemSettings,
|
||||
types::{EmbeddingRecord, HasEmbedding},
|
||||
},
|
||||
stored_object,
|
||||
utils::embedding::{EmbeddingProvider, RE_EMBED_BATCH_SIZE},
|
||||
};
|
||||
@@ -70,6 +73,18 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl HasEmbedding for KnowledgeEntity {
|
||||
type Embedding = KnowledgeEntityEmbedding;
|
||||
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
impl KnowledgeEntity {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
@@ -227,67 +242,22 @@ impl KnowledgeEntity {
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
// Delete embeddings first, while we can still look them up via the entity's source_id
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
db.delete_by_source_id::<Self>(source_id).await
|
||||
}
|
||||
|
||||
/// Atomically store a knowledge entity and its embedding.
|
||||
/// Writes the entity to `knowledge_entity` and the embedding to `knowledge_entity_embedding`.
|
||||
/// Atomically store one knowledge entity and its embedding (single-record path).
|
||||
///
|
||||
/// Bulk ingestion uses `ingestion_pipeline::persist_artifacts` instead.
|
||||
pub async fn store_with_embedding(
|
||||
entity: KnowledgeEntity,
|
||||
embedding: Vec<f32>,
|
||||
embedding_dimensions: usize,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let settings = SystemSettings::get_current(db).await?;
|
||||
KnowledgeEntityEmbedding::validate_dimension(
|
||||
&embedding,
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
|
||||
let entity_id = entity.id.clone();
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
&entity_id,
|
||||
entity.source_id.clone(),
|
||||
embedding,
|
||||
entity.user_id.clone(),
|
||||
);
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
|
||||
UPSERT type::thing('{emb_table}', $entity_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
entity_table = Self::table_name(),
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb", emb))
|
||||
db.store_with_embedding(entity, embedding, embedding_dimensions)
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
||||
@@ -297,48 +267,14 @@ impl KnowledgeEntity {
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: Option<KnowledgeEntity>,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
entity_id,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH entity_id;
|
||||
"#,
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
db.vector_search::<Self, KnowledgeEntityEmbedding>(take, query_embedding, user_id)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
response = response.check().map_err(AppError::from)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
r.entity_id.map(|entity| KnowledgeEntitySearchResult {
|
||||
entity,
|
||||
score: r.score,
|
||||
})
|
||||
.map(|results| {
|
||||
results
|
||||
.into_iter()
|
||||
.map(|(entity, score)| KnowledgeEntitySearchResult { entity, score })
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
@@ -364,7 +300,13 @@ impl KnowledgeEntity {
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
|
||||
let emb = KnowledgeEntityEmbedding::new(id, entity.source_id, embedding, entity.user_id);
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
id,
|
||||
entity.source_id,
|
||||
embedding,
|
||||
entity.user_id,
|
||||
Self::table_name(),
|
||||
);
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
@@ -457,7 +399,9 @@ impl KnowledgeEntity {
|
||||
if embedding.len() != new_dimensions {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
entity.id,
|
||||
embedding.len(),
|
||||
new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
@@ -554,9 +498,8 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::storage::indexes::rebuild;
|
||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||
use crate::test_utils::configure_embedding_dimension;
|
||||
use crate::test_utils::{ensure_fts_index, prepare_knowledge_entity_test_db, setup_test_db};
|
||||
use anyhow::{self, Context};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn embedding_input_text_uses_canonical_type_label() {
|
||||
@@ -568,27 +511,6 @@ mod tests {
|
||||
assert_eq!(text, "name: Alpha, description: Beta, type: TextSnippet");
|
||||
}
|
||||
|
||||
async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||
let snowball_sql = r#"
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
if let Err(err) = db.client.query(snowball_sql).await {
|
||||
let fallback_sql = r#"
|
||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
db.client
|
||||
.query(fallback_sql)
|
||||
.await
|
||||
.with_context(|| format!("define entity fts index fallback: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -675,19 +597,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 5).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
let db = prepare_knowledge_entity_test_db(5).await?;
|
||||
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
@@ -722,13 +632,13 @@ mod tests {
|
||||
);
|
||||
|
||||
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), 5, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity 1".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), 5, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity 2".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), 5, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store different entity".to_string())?;
|
||||
|
||||
@@ -783,21 +693,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
let db = prepare_knowledge_entity_test_db(3)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
configure_embedding_dimension(&db, 3)
|
||||
.await
|
||||
.expect("configure dim");
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.expect("prepare test db");
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
@@ -819,10 +717,10 @@ mod tests {
|
||||
user_id,
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.expect("store entity1");
|
||||
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], 3, &db)
|
||||
.await
|
||||
.expect("store entity2");
|
||||
|
||||
@@ -849,18 +747,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
let db = prepare_knowledge_entity_test_db(3)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
.expect("prepare test db");
|
||||
|
||||
let results = KnowledgeEntity::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
@@ -870,19 +759,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
@@ -895,7 +772,7 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
@@ -918,7 +795,7 @@ mod tests {
|
||||
assert_eq!(stored_embeddings.len(), 1);
|
||||
|
||||
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "fetch embedding".to_string())?;
|
||||
assert!(fetched_emb.is_some());
|
||||
@@ -938,19 +815,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let e1 = KnowledgeEntity::new(
|
||||
@@ -970,10 +835,10 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store e1".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store e2".to_string())?;
|
||||
|
||||
@@ -1001,14 +866,18 @@ mod tests {
|
||||
|
||||
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
||||
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
|
||||
.await
|
||||
.with_context(|| "get embedding e1".to_string())?
|
||||
.is_some());
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
|
||||
.await
|
||||
.with_context(|| "get embedding e2".to_string())?
|
||||
.is_some());
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e1)
|
||||
.await
|
||||
.with_context(|| "get embedding e1".to_string())?
|
||||
.is_some()
|
||||
);
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &rid_e2)
|
||||
.await
|
||||
.with_context(|| "get embedding e2".to_string())?
|
||||
.is_some()
|
||||
);
|
||||
|
||||
let results = KnowledgeEntity::vector_search(2, &[0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
@@ -1037,19 +906,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_orphan";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "Failed to redefine index length".to_string())?;
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
@@ -1062,7 +919,7 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
@@ -1089,15 +946,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_returns_empty_when_no_entities() -> anyhow::Result<()> {
|
||||
let namespace = "fts_entity_ns_empty";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_entity_fts_indexes(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(
|
||||
&db,
|
||||
"knowledge_entity",
|
||||
&[("name", "name"), ("description", "description")],
|
||||
)
|
||||
.await?;
|
||||
rebuild(&db)
|
||||
.await
|
||||
.with_context(|| "rebuild indexes".to_string())?;
|
||||
@@ -1112,15 +967,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "fts_entity_ns_single";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_entity_fts_indexes(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(
|
||||
&db,
|
||||
"knowledge_entity",
|
||||
&[("name", "name"), ("description", "description")],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let user_id = "fts_user";
|
||||
let entity = KnowledgeEntity::new(
|
||||
@@ -1151,15 +1004,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
||||
let namespace = "fts_entity_ns_order";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_entity_fts_indexes(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(
|
||||
&db,
|
||||
"knowledge_entity",
|
||||
&[("name", "name"), ("description", "description")],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let user_id = "fts_user_order";
|
||||
let high_score_entity = KnowledgeEntity::new(
|
||||
|
||||
@@ -4,7 +4,7 @@ use surrealdb::RecordId;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||
storage::{db::SurrealDbClient, types::EmbeddingRecord},
|
||||
stored_object,
|
||||
};
|
||||
|
||||
@@ -17,72 +17,48 @@ stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl KnowledgeEntityEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = hnsw_index_redefine_transaction_sql(
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||
res.check().map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
impl EmbeddingRecord for KnowledgeEntityEmbedding {
|
||||
fn link_field() -> &'static str {
|
||||
"entity_id"
|
||||
}
|
||||
|
||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||
if embedding.len() != expected {
|
||||
return Err(AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
fn index_name() -> &'static str {
|
||||
"idx_embedding_knowledge_entity_embedding"
|
||||
}
|
||||
|
||||
/// Create a new knowledge entity embedding.
|
||||
///
|
||||
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
|
||||
#[must_use]
|
||||
pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
|
||||
fn embedding(&self) -> &[f32] {
|
||||
&self.embedding
|
||||
}
|
||||
|
||||
fn new(
|
||||
entity_id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: entity_id.to_owned(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
||||
entity_id: RecordId::from_table_key(entity_table, entity_id),
|
||||
embedding,
|
||||
source_id,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get embedding by entity ID
|
||||
pub async fn get_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
impl KnowledgeEntityEmbedding {
|
||||
/// Get embeddings for multiple entities in batch
|
||||
pub async fn get_by_entity_ids(
|
||||
entity_ids: &[RecordId],
|
||||
@@ -109,44 +85,6 @@ impl KnowledgeEntityEmbedding {
|
||||
.map(|e| (e.entity_id.key().to_string(), e.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Delete embedding by entity ID
|
||||
pub async fn delete_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE entity_id = $entity_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings with the given denormalized `source_id`.
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -184,6 +122,7 @@ mod tests {
|
||||
"source-1".to_owned(),
|
||||
vec![0.1, 0.2],
|
||||
"user-1".to_owned(),
|
||||
KnowledgeEntity::table_name(),
|
||||
);
|
||||
assert_eq!(emb.id, "entity-abc");
|
||||
}
|
||||
@@ -205,13 +144,13 @@ mod tests {
|
||||
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||
@@ -234,22 +173,22 @@ mod tests {
|
||||
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let existing = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some());
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||
KnowledgeEntityEmbedding::delete_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
||||
|
||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let after = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none());
|
||||
@@ -266,7 +205,7 @@ mod tests {
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
@@ -277,7 +216,7 @@ mod tests {
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_record_id(&db, &entity_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||
let stored_embedding =
|
||||
@@ -295,7 +234,7 @@ mod tests {
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
|
||||
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], 3, &db).await;
|
||||
|
||||
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||
|
||||
@@ -313,15 +252,20 @@ mod tests {
|
||||
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
|
||||
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
KnowledgeEntity::store_with_embedding(
|
||||
entity_other.clone(),
|
||||
vec![3.0_f32, 3.1, 3.2],
|
||||
3,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||
@@ -332,21 +276,23 @@ mod tests {
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity1_rid)
|
||||
.await
|
||||
.with_context(|| "get entity1 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &entity2_rid)
|
||||
.await
|
||||
.with_context(|| "get entity2 embedding after delete".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||
.await
|
||||
.with_context(|| "get other embedding after delete".to_string())?
|
||||
.is_some());
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_record_id(&db, &other_rid)
|
||||
.await
|
||||
.with_context(|| "get other embedding after delete".to_string())?
|
||||
.is_some()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -403,7 +349,7 @@ mod tests {
|
||||
let source_id = "source-fetch";
|
||||
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], 3, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||
|
||||
@@ -441,7 +387,7 @@ mod tests {
|
||||
let source_id = "source-upsert";
|
||||
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db)
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "initial store".to_string())?;
|
||||
|
||||
@@ -450,6 +396,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![0.0, 1.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
KnowledgeEntity::table_name(),
|
||||
);
|
||||
db.upsert_item(replacement)
|
||||
.await
|
||||
|
||||
@@ -575,12 +575,16 @@ mod tests {
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
||||
.await?;
|
||||
|
||||
assert!(get_relationship_by_id(&owner_relationship_id, &db)
|
||||
.await
|
||||
.is_none());
|
||||
assert!(get_relationship_by_id(&other_relationship_id, &db)
|
||||
.await
|
||||
.is_some());
|
||||
assert!(
|
||||
get_relationship_by_id(&owner_relationship_id, &db)
|
||||
.await
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
get_relationship_by_id(&other_relationship_id, &db)
|
||||
.await
|
||||
.is_some()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ pub fn format_history(history: &[Message]) -> String {
|
||||
mod tests {
|
||||
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use anyhow::{self, Context};
|
||||
|
||||
#[tokio::test]
|
||||
@@ -106,11 +106,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_persistence() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &uuid::Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let conversation_id = "test_conversation";
|
||||
let message = Message::new(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#![allow(clippy::unsafe_derive_deserialize)]
|
||||
#![allow(async_fn_in_trait)]
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub mod analytics;
|
||||
pub mod conversation;
|
||||
@@ -22,6 +23,135 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
||||
fn id(&self) -> &str;
|
||||
}
|
||||
|
||||
/// An entity that has an associated embedding record for vector search.
|
||||
pub trait HasEmbedding: StoredObject {
|
||||
/// The embedding record type paired with this entity.
|
||||
type Embedding: EmbeddingRecord;
|
||||
|
||||
fn source_id(&self) -> &str;
|
||||
fn user_id(&self) -> &str;
|
||||
}
|
||||
|
||||
/// An embedding record linked to a `HasEmbedding` entity.
|
||||
pub trait EmbeddingRecord: StoredObject {
|
||||
/// The field name in the embedding table that links back to the entity
|
||||
/// (e.g. `"entity_id"` or `"chunk_id"`). Used in FETCH and WHERE clauses.
|
||||
fn link_field() -> &'static str;
|
||||
|
||||
/// The HNSW index name (e.g. `"idx_embedding_knowledge_entity_embedding"`).
|
||||
fn index_name() -> &'static str;
|
||||
|
||||
fn source_id(&self) -> &str;
|
||||
fn user_id(&self) -> &str;
|
||||
fn embedding(&self) -> &[f32];
|
||||
|
||||
/// Construct a new embedding record.
|
||||
///
|
||||
/// * `id` – shared record id (same as the entity id).
|
||||
/// * `source_id` – denormalised source id for bulk deletes.
|
||||
/// * `embedding` – the embedding vector.
|
||||
/// * `user_id` – denormalised user id for query scoping.
|
||||
/// * `entity_table` – the entity's table name (used to build the link `RecordId`).
|
||||
fn new(
|
||||
id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self;
|
||||
|
||||
/// Validate that an embedding vector matches the expected dimension.
|
||||
fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
if embedding.len() != expected {
|
||||
return Err(crate::error::AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recreate the HNSW vector index with a new dimension.
|
||||
///
|
||||
/// This drops and recreates the index inside a transaction.
|
||||
async fn redefine_hnsw_index(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = crate::storage::indexes::hnsw_index_redefine_transaction_sql(
|
||||
Self::index_name(),
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
db.client.query(query).await?.check()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fetch a single embedding record by its link `RecordId`.
|
||||
async fn get_by_record_id(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
rid: &surrealdb::RecordId,
|
||||
) -> Result<Option<Self>, crate::error::AppError>
|
||||
where
|
||||
Self: Sized + serde::de::DeserializeOwned,
|
||||
{
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE {} = $rid LIMIT 1",
|
||||
Self::table_name(),
|
||||
Self::link_field(),
|
||||
);
|
||||
let mut result = db.client.query(query).bind(("rid", rid.clone())).await?;
|
||||
Ok(result.take(0)?)
|
||||
}
|
||||
|
||||
/// Delete an embedding record by its link `RecordId`.
|
||||
async fn delete_by_record_id(
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
rid: &surrealdb::RecordId,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE {} = $rid",
|
||||
Self::table_name(),
|
||||
Self::link_field(),
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("rid", rid.clone()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embedding records with a given `source_id`.
|
||||
async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &crate::storage::db::SurrealDbClient,
|
||||
) -> Result<(), crate::error::AppError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name(),
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! stored_object {
|
||||
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||
|
||||
@@ -221,19 +221,11 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_scratchpad() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create a new scratchpad
|
||||
let user_id = "test_user";
|
||||
@@ -271,15 +263,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
|
||||
@@ -333,15 +317,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_archive_and_restore() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
@@ -368,15 +344,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
@@ -398,15 +366,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
@@ -428,15 +388,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_scratchpad() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
@@ -461,15 +413,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_unauthorized() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
@@ -498,13 +442,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to create test database".to_string())?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let user_id = "test_user_123";
|
||||
let scratchpad =
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::utils::config::EmbeddingBackend;
|
||||
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -22,6 +25,15 @@ pub struct SystemSettings {
|
||||
pub image_processing_model: String,
|
||||
pub image_processing_prompt: String,
|
||||
pub voice_processing_model: String,
|
||||
/// When the maintainer last completed a scheduled `REBUILD INDEX` pass.
|
||||
#[serde(default)]
|
||||
pub last_index_rebuild_at: Option<DateTime<Utc>>,
|
||||
/// Worker id holding the index-rebuild lease, if any.
|
||||
#[serde(default)]
|
||||
pub index_rebuild_lease_owner: Option<String>,
|
||||
/// Lease expiry for in-flight scheduled index rebuilds.
|
||||
#[serde(default)]
|
||||
pub index_rebuild_lease_expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Partial update for singleton system settings without cloning unchanged fields.
|
||||
@@ -100,6 +112,8 @@ impl SystemSettingsPatch {
|
||||
}
|
||||
}
|
||||
|
||||
const INDEX_REBUILD_LEASE_TTL: &str = "6h";
|
||||
|
||||
impl SystemSettings {
|
||||
pub const RECORD_ID: &'static str = "current";
|
||||
|
||||
@@ -209,16 +223,16 @@ impl SystemSettings {
|
||||
needs_update = true;
|
||||
}
|
||||
|
||||
if let Some(model) = provider_model {
|
||||
if settings.embedding_model != model {
|
||||
tracing::info!(
|
||||
old_model = %settings.embedding_model,
|
||||
new_model = %model,
|
||||
"Embedding model changed, updating SystemSettings"
|
||||
);
|
||||
settings.embedding_model = model;
|
||||
needs_update = true;
|
||||
}
|
||||
if let Some(model) = provider_model
|
||||
&& settings.embedding_model != model
|
||||
{
|
||||
tracing::info!(
|
||||
old_model = %settings.embedding_model,
|
||||
new_model = %model,
|
||||
"Embedding model changed, updating SystemSettings"
|
||||
);
|
||||
settings.embedding_model = model;
|
||||
needs_update = true;
|
||||
}
|
||||
|
||||
if needs_update {
|
||||
@@ -227,6 +241,89 @@ impl SystemSettings {
|
||||
|
||||
Ok((settings, needs_update))
|
||||
}
|
||||
|
||||
/// Seeds the first rebuild checkpoint so the initial scheduled rebuild waits one interval.
|
||||
pub async fn seed_index_rebuild_checkpoint(db: &SurrealDbClient) -> Result<bool, AppError> {
|
||||
let mut response = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('system_settings', $id) SET last_index_rebuild_at = time::now()
|
||||
WHERE last_index_rebuild_at IS NONE
|
||||
RETURN AFTER;",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
let updated: Option<Self> = response.take(0).map_err(AppError::from)?;
|
||||
Ok(updated.is_some())
|
||||
}
|
||||
|
||||
/// Claims the singleton index-rebuild lease when it is free or expired.
|
||||
pub async fn try_acquire_index_rebuild_lease(
|
||||
db: &SurrealDbClient,
|
||||
owner: &str,
|
||||
) -> Result<bool, AppError> {
|
||||
let mut response = db
|
||||
.client
|
||||
.query(format!(
|
||||
"UPDATE type::thing('system_settings', $id) SET
|
||||
index_rebuild_lease_owner = $owner,
|
||||
index_rebuild_lease_expires_at = time::now() + {INDEX_REBUILD_LEASE_TTL}
|
||||
WHERE index_rebuild_lease_expires_at IS NONE
|
||||
OR index_rebuild_lease_expires_at < time::now()
|
||||
RETURN AFTER;"
|
||||
))
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.bind(("owner", owner.to_string()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
let updated: Option<Self> = response.take(0).map_err(AppError::from)?;
|
||||
Ok(updated.is_some())
|
||||
}
|
||||
|
||||
/// Releases the index-rebuild lease when held by `owner`.
|
||||
pub async fn release_index_rebuild_lease(db: &SurrealDbClient, owner: &str) {
|
||||
let released = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('system_settings', $id) SET
|
||||
index_rebuild_lease_owner = NONE,
|
||||
index_rebuild_lease_expires_at = NONE
|
||||
WHERE index_rebuild_lease_owner = $owner;",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.bind(("owner", owner.to_string()))
|
||||
.await
|
||||
.and_then(surrealdb::Response::check);
|
||||
|
||||
if let Err(err) = released {
|
||||
warn!(error = %err, "failed to release index rebuild lease");
|
||||
}
|
||||
}
|
||||
|
||||
/// Records a completed scheduled index rebuild and clears the lease.
|
||||
pub async fn record_index_rebuild_completed(
|
||||
db: &SurrealDbClient,
|
||||
owner: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let response = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('system_settings', $id) SET
|
||||
last_index_rebuild_at = time::now(),
|
||||
index_rebuild_lease_owner = NONE,
|
||||
index_rebuild_lease_expires_at = NONE
|
||||
WHERE index_rebuild_lease_owner = $owner;",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.bind(("owner", owner.to_string()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
response.check().map_err(AppError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -237,6 +334,7 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::setup_test_db;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn get_hnsw_index_dimension(
|
||||
@@ -320,17 +418,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Test initialization of system settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to get system settings".to_string())?;
|
||||
@@ -367,19 +455,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_settings() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Initialize settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
|
||||
// Test get_current method
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to get current settings".to_string())?;
|
||||
@@ -392,17 +469,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_settings() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Initialize settings
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
// Create updated settings
|
||||
let mut updated_settings = SystemSettings::get_current(&db)
|
||||
@@ -435,13 +502,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string()).await?;
|
||||
// Don't initialize settings and try to get them
|
||||
let result = SystemSettings::get_current(&db).await;
|
||||
|
||||
@@ -458,12 +519,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_rejects_zero_embedding_dimensions() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -477,12 +533,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_updates_without_cloning_full_settings() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let updated = SystemSettingsPatch {
|
||||
registrations_enabled: Some(false),
|
||||
@@ -498,12 +549,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_leaves_unmentioned_fields_unchanged() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let original = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -533,12 +579,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_rejects_empty_model_name() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -552,12 +593,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_normalizes_record_id() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -575,12 +611,7 @@ mod tests {
|
||||
async fn test_update_preserves_embedding_backend() -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let provider = EmbeddingProvider::new_hashed(384)
|
||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||
@@ -607,12 +638,7 @@ mod tests {
|
||||
async fn test_sync_from_embedding_provider_updates_mismatched_settings() -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let provider = EmbeddingProvider::new_hashed(384)
|
||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||
@@ -636,12 +662,7 @@ mod tests {
|
||||
async fn test_sync_from_embedding_provider_is_noop_when_already_synced() -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let provider = EmbeddingProvider::new_hashed(384)
|
||||
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||
@@ -660,12 +681,7 @@ mod tests {
|
||||
async fn test_sync_rejects_provider_dimension_above_u32_max() -> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let provider = EmbeddingProvider::new_hashed((u32::MAX as usize) + 1)
|
||||
.with_context(|| "Failed to create oversized hashed provider".to_string())?;
|
||||
@@ -676,14 +692,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Initial migration failed".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let initial_chunk = TextChunk::new(
|
||||
"source1".into(),
|
||||
@@ -691,7 +700,7 @@ mod tests {
|
||||
"user1".into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], 1536, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
|
||||
|
||||
@@ -710,18 +719,11 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length(
|
||||
) -> anyhow::Result<()> {
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length()
|
||||
-> anyhow::Result<()> {
|
||||
use crate::utils::embedding::EmbeddingProvider;
|
||||
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.with_context(|| "Failed to start DB".to_string())?;
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "Initial migration failed".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let mut current_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
@@ -802,4 +804,28 @@ mod tests {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn index_rebuild_lease_is_exclusive_on_system_settings() -> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
assert!(
|
||||
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-a").await?,
|
||||
"first lease claim should succeed"
|
||||
);
|
||||
assert!(
|
||||
!SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?,
|
||||
"second lease claim should fail while lease is held"
|
||||
);
|
||||
|
||||
SystemSettings::release_index_rebuild_lease(&db, "worker-a").await;
|
||||
|
||||
SystemSettings::try_acquire_index_rebuild_lease(&db, "worker-b").await?;
|
||||
SystemSettings::record_index_rebuild_completed(&db, "worker-b").await?;
|
||||
|
||||
let settings = SystemSettings::get_current(&db).await?;
|
||||
assert!(settings.last_index_rebuild_at.is_some());
|
||||
assert!(settings.index_rebuild_lease_owner.is_none());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::storage::indexes::hnsw_index_overwrite_sql;
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::storage::types::{
|
||||
EmbeddingRecord, HasEmbedding, text_chunk_embedding::TextChunkEmbedding,
|
||||
};
|
||||
use crate::utils::embedding::RE_EMBED_BATCH_SIZE;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TextChunk, "text_chunk", {
|
||||
@@ -25,6 +26,18 @@ pub struct TextChunkSearchResult {
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl HasEmbedding for TextChunk {
|
||||
type Embedding = TextChunkEmbedding;
|
||||
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
impl TextChunk {
|
||||
#[must_use]
|
||||
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
|
||||
@@ -41,123 +54,39 @@ impl TextChunk {
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
db_client
|
||||
.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id;",
|
||||
TextChunkEmbedding::table_name()
|
||||
))
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id;")
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.bind(("table", Self::table_name()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
db.delete_by_source_id::<Self>(source_id).await
|
||||
}
|
||||
|
||||
/// Atomically store a text chunk and its embedding.
|
||||
/// Writes the chunk to `text_chunk` and the embedding to `text_chunk_embedding`.
|
||||
/// Atomically store one text chunk and its embedding (single-record path).
|
||||
///
|
||||
/// Bulk ingestion uses `ingestion_pipeline::persist_artifacts` instead.
|
||||
pub async fn store_with_embedding(
|
||||
chunk: TextChunk,
|
||||
embedding: Vec<f32>,
|
||||
embedding_dimensions: usize,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let settings = SystemSettings::get_current(db).await?;
|
||||
TextChunkEmbedding::validate_dimension(&embedding, settings.embedding_dimensions as usize)?;
|
||||
|
||||
let chunk_id = chunk.id.clone();
|
||||
let emb = TextChunkEmbedding::new(
|
||||
&chunk_id,
|
||||
chunk.source_id.clone(),
|
||||
embedding,
|
||||
chunk.user_id.clone(),
|
||||
);
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
|
||||
UPSERT type::thing('{emb_table}', $chunk_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
chunk_table = Self::table_name(),
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id))
|
||||
.bind(("chunk", chunk))
|
||||
.bind(("emb", emb))
|
||||
db.store_with_embedding(chunk, embedding, embedding_dimensions)
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
|
||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and scores.
|
||||
pub async fn vector_search(
|
||||
take: usize,
|
||||
query_embedding: &[f32],
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||
#[allow(clippy::missing_docs_in_private_items)]
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
chunk_id: Option<TextChunk>,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
chunk_id,
|
||||
embedding,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH chunk_id;
|
||||
"#,
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
db.vector_search::<Self, TextChunkEmbedding>(take, query_embedding, user_id)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
response = response.check().map_err(AppError::from)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
r.chunk_id.map(|chunk| TextChunkSearchResult {
|
||||
chunk,
|
||||
score: r.score,
|
||||
}).or_else(|| {
|
||||
warn!("vector search hit orphaned text_chunk_embedding row with missing chunk");
|
||||
None
|
||||
})
|
||||
.map(|results| {
|
||||
results
|
||||
.into_iter()
|
||||
.map(|(chunk, score)| TextChunkSearchResult { chunk, score })
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Full-text search over text chunks using the BM25 FTS index.
|
||||
@@ -287,7 +216,9 @@ impl TextChunk {
|
||||
if embedding.len() != new_dimensions {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
chunk.id,
|
||||
embedding.len(),
|
||||
new_dimensions
|
||||
);
|
||||
error!("{err_msg}");
|
||||
return Err(AppError::internal(err_msg));
|
||||
@@ -393,29 +324,10 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::storage::indexes::{ensure_runtime, rebuild};
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::test_utils::configure_embedding_dimension;
|
||||
use crate::test_utils::{
|
||||
configure_embedding_dimension, ensure_fts_index, prepare_text_chunk_test_db, setup_test_db,
|
||||
};
|
||||
use surrealdb::RecordId;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) -> anyhow::Result<()> {
|
||||
let snowball_sql = r#"
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
if let Err(err) = db.client.query(snowball_sql).await {
|
||||
let fallback_sql = r#"
|
||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
db.client
|
||||
.query(fallback_sql)
|
||||
.await
|
||||
.with_context(|| format!("define chunk fts index fallback: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_chunk_creation() -> anyhow::Result<()> {
|
||||
@@ -434,21 +346,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = prepare_text_chunk_test_db(5).await?;
|
||||
let source_id = "source123".to_string();
|
||||
let user_id = "user123".to_string();
|
||||
configure_embedding_dimension(&db, 5).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let chunk1 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
@@ -466,15 +366,16 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk1".to_string())?;
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
TextChunk::store_with_embedding(
|
||||
different_chunk.clone(),
|
||||
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
5,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
@@ -516,18 +417,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_nonexistent_source_id() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
configure_embedding_dimension(&db, 5).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(5).await?;
|
||||
|
||||
let real_source_id = "real_source".to_string();
|
||||
let chunk = TextChunk::new(
|
||||
@@ -536,7 +426,7 @@ mod tests {
|
||||
"user123".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk".to_string())?;
|
||||
|
||||
@@ -560,18 +450,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
let db = prepare_text_chunk_test_db(5)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
configure_embedding_dimension(&db, 5)
|
||||
.await
|
||||
.expect("configure dim");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
.expect("prepare test db");
|
||||
|
||||
let chunk1 = TextChunk::new(
|
||||
"safe_source".to_string(),
|
||||
@@ -584,10 +465,10 @@ mod tests {
|
||||
"user123".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], 5, &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db)
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], 5, &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
|
||||
@@ -614,25 +495,12 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_both_records() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
let source_id = "store-src".to_string();
|
||||
let user_id = "user_store".to_string();
|
||||
let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone());
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store with embedding".to_string())?;
|
||||
|
||||
@@ -645,7 +513,7 @@ mod tests {
|
||||
assert_eq!(stored_chunk.user_id, user_id);
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "expected embedding".to_string())?;
|
||||
@@ -658,14 +526,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_with_runtime_indexes() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_runtime";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
|
||||
let embedding_dimension = 3usize;
|
||||
configure_embedding_dimension(
|
||||
@@ -683,7 +544,7 @@ mod tests {
|
||||
"runtime_user".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store with embedding".to_string())?;
|
||||
|
||||
@@ -695,7 +556,7 @@ mod tests {
|
||||
assert!(stored_chunk.id == chunk.id, "chunk should be stored");
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
let embedding = TextChunkEmbedding::get_by_record_id(&db, &rid)
|
||||
.await
|
||||
.with_context(|| "get embedding".to_string())?
|
||||
.with_context(|| "embedding should exist".to_string())?;
|
||||
@@ -709,19 +570,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||
@@ -733,19 +582,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let source_id = "src".to_string();
|
||||
let user_id = "user".to_string();
|
||||
@@ -755,7 +592,7 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store".to_string())?;
|
||||
|
||||
@@ -774,28 +611,16 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
|
||||
let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone());
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk1".to_string())?;
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
|
||||
@@ -815,15 +640,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_returns_empty_when_no_chunks() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_empty";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||
rebuild(&db)
|
||||
.await
|
||||
.with_context(|| "rebuild indexes".to_string())?;
|
||||
@@ -838,15 +656,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_single_result() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_single";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||
|
||||
let user_id = "fts_user";
|
||||
let chunk = TextChunk::new(
|
||||
@@ -874,15 +685,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> {
|
||||
let namespace = "fts_chunk_ns_order";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
ensure_chunk_fts_index(&db).await?;
|
||||
let db = setup_test_db().await?;
|
||||
ensure_fts_index(&db, "text_chunk", &[("chunk", "chunk")]).await?;
|
||||
|
||||
let user_id = "fts_user_order";
|
||||
let high_score_chunk = TextChunk::new(
|
||||
@@ -936,19 +740,12 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_dim";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
let db = setup_test_db().await?;
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
|
||||
let chunk = TextChunk::new("src".to_string(), "body".to_string(), "user".to_string());
|
||||
|
||||
let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], &db)
|
||||
let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], 3, &db)
|
||||
.await
|
||||
.expect_err("expected dimension validation failure");
|
||||
assert!(matches!(err, AppError::Validation(_)));
|
||||
@@ -958,18 +755,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_with_orphaned_embedding() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns_orphan_chunk";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
let db = prepare_text_chunk_test_db(3).await?;
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let chunk = TextChunk::new(
|
||||
@@ -978,7 +764,7 @@ mod tests {
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], 3, &db)
|
||||
.await
|
||||
.with_context(|| "store chunk with embedding".to_string())?;
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||
stored_object,
|
||||
};
|
||||
use crate::{storage::types::EmbeddingRecord, stored_object};
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::error::AppError;
|
||||
|
||||
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||
/// Record link to the owning text_chunk
|
||||
@@ -18,123 +16,46 @@ stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextChunkEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
///
|
||||
/// This is useful when the embedding length changes; Surreal requires the
|
||||
/// index definition to be recreated with the updated dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = hnsw_index_redefine_transaction_sql(
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
Self::table_name(),
|
||||
dimension,
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||
res.check().map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
impl EmbeddingRecord for TextChunkEmbedding {
|
||||
fn link_field() -> &'static str {
|
||||
"chunk_id"
|
||||
}
|
||||
|
||||
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||
#[allow(clippy::result_large_err)]
|
||||
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||
if embedding.len() != expected {
|
||||
return Err(AppError::Validation(format!(
|
||||
"embedding dimension mismatch: got {}, expected {expected}",
|
||||
embedding.len()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
fn index_name() -> &'static str {
|
||||
"idx_embedding_text_chunk_embedding"
|
||||
}
|
||||
|
||||
/// Create a new text chunk embedding.
|
||||
///
|
||||
/// The embedding record id equals `chunk_id` so each chunk has at most one embedding row.
|
||||
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid".
|
||||
#[must_use]
|
||||
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
fn source_id(&self) -> &str {
|
||||
&self.source_id
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
|
||||
fn embedding(&self) -> &[f32] {
|
||||
&self.embedding
|
||||
}
|
||||
|
||||
fn new(
|
||||
chunk_id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
entity_table: &str,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
|
||||
Self {
|
||||
id: chunk_id.to_owned(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
||||
chunk_id: RecordId::from_table_key(entity_table, chunk_id),
|
||||
source_id,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a single embedding by its chunk RecordId
|
||||
pub async fn get_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
/// Delete embeddings for a given chunk RecordId
|
||||
pub async fn delete_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE chunk_id = $chunk_id",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings that belong to chunks with a given `source_id`
|
||||
///
|
||||
/// This uses the denormalized `source_id` on the embedding table.
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE source_id = $source_id",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -144,8 +65,31 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
||||
use surrealdb::Value as SurrealValue;
|
||||
|
||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: surrealdb::Value = info_res
|
||||
.take(0)
|
||||
.with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||
.with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
Ok(idx_sql)
|
||||
}
|
||||
|
||||
async fn create_text_chunk_with_id(
|
||||
db: &SurrealDbClient,
|
||||
@@ -169,29 +113,6 @@ mod tests {
|
||||
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
||||
}
|
||||
|
||||
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.with_context(|| "info query failed".to_string())?;
|
||||
let info: SurrealValue = info_res
|
||||
.take(0)
|
||||
.with_context(|| "failed to take info result".to_string())?;
|
||||
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||
.with_context(|| "failed to convert info to json".to_string())?;
|
||||
let idx_sql = info_json
|
||||
.get("Object")
|
||||
.and_then(|v| v.get("indexes"))
|
||||
.and_then(|v| v.get("Object"))
|
||||
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||
.and_then(|v| v.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
Ok(idx_sql)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_chunk_id_as_record_id() {
|
||||
let emb = TextChunkEmbedding::new(
|
||||
@@ -199,6 +120,7 @@ mod tests {
|
||||
"source-1".to_owned(),
|
||||
vec![0.1, 0.2],
|
||||
"user-1".to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
assert_eq!(emb.id, "chunk-abc");
|
||||
}
|
||||
@@ -226,13 +148,14 @@ mod tests {
|
||||
source_id.to_string(),
|
||||
embedding_vec.clone(),
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| "Failed to store embedding".to_string())?;
|
||||
|
||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let fetched = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
||||
.with_context(|| "Expected an embedding to be found".to_string())?;
|
||||
@@ -259,22 +182,23 @@ mod tests {
|
||||
source_id.to_string(),
|
||||
vec![0.4_f32, 0.5, 0.6],
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| "Failed to store embedding".to_string())?;
|
||||
|
||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let existing = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||
|
||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||
TextChunkEmbedding::delete_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
||||
|
||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
let after = TextChunkEmbedding::get_by_record_id(&db, &chunk_rid)
|
||||
.await
|
||||
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||
assert!(after.is_none(), "Embedding should have been deleted");
|
||||
@@ -299,41 +223,59 @@ mod tests {
|
||||
("chunk-s2", source_id, vec![0.2]),
|
||||
("chunk-other", other_source, vec![0.3]),
|
||||
] {
|
||||
let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
|
||||
let emb = TextChunkEmbedding::new(
|
||||
key,
|
||||
src.to_string(),
|
||||
vec,
|
||||
user_id.to_string(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| format!("store embedding for {key}"))?;
|
||||
}
|
||||
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.with_context(|| "get chunk1".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.with_context(|| "get chunk2".to_string())?
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.with_context(|| "get chunk_other".to_string())?
|
||||
.is_some());
|
||||
assert!(
|
||||
TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk1".to_string())?
|
||||
.is_some()
|
||||
);
|
||||
assert!(
|
||||
TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk2".to_string())?
|
||||
.is_some()
|
||||
);
|
||||
assert!(
|
||||
TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
|
||||
.await
|
||||
.with_context(|| "get chunk_other".to_string())?
|
||||
.is_some()
|
||||
);
|
||||
|
||||
TextChunkEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.with_context(|| "check chunk1".to_string())?
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.with_context(|| "check chunk2".to_string())?
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.with_context(|| "check chunk_other".to_string())?
|
||||
.is_some());
|
||||
assert!(
|
||||
TextChunkEmbedding::get_by_record_id(&db, &chunk1_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk1".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
TextChunkEmbedding::get_by_record_id(&db, &chunk2_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk2".to_string())?
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
TextChunkEmbedding::get_by_record_id(&db, &chunk_other_rid)
|
||||
.await
|
||||
.with_context(|| "check chunk_other".to_string())?
|
||||
.is_some()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -352,6 +294,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![1.0_f32, 0.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(initial)
|
||||
.await
|
||||
@@ -362,6 +305,7 @@ mod tests {
|
||||
source_id.to_owned(),
|
||||
vec![0.0, 1.0, 0.0],
|
||||
user_id.to_owned(),
|
||||
TextChunk::table_name(),
|
||||
);
|
||||
db.upsert_item(replacement)
|
||||
.await
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::str::FromStr;
|
||||
|
||||
use surrealdb::opt::PatchOp;
|
||||
use surrealdb::RecordId;
|
||||
use surrealdb::opt::PatchOp;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
@@ -96,6 +96,41 @@ impl TextContent {
|
||||
}
|
||||
}
|
||||
|
||||
/// SurrealQL deletes for ingested child rows keyed by `source_id` (no transaction wrapper).
|
||||
///
|
||||
/// Used inside larger transactions (e.g. ingestion `persist_artifacts`) and mirrored by
|
||||
/// [`Self::clear_ingested_children`].
|
||||
pub const CLEAR_INGESTED_CHILD_ROWS_SURQL: &'static str = r"
|
||||
DELETE relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id;
|
||||
DELETE text_chunk_embedding WHERE source_id = $source_id;
|
||||
DELETE text_chunk WHERE source_id = $source_id;
|
||||
DELETE knowledge_entity_embedding WHERE source_id = $source_id;
|
||||
DELETE knowledge_entity WHERE source_id = $source_id;
|
||||
";
|
||||
|
||||
/// Removes chunks, embeddings, entities, and relationships for one ingested document snapshot.
|
||||
pub async fn clear_ingested_children(
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;\n{} COMMIT TRANSACTION;",
|
||||
Self::CLEAR_INGESTED_CHILD_ROWS_SURQL
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_string()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
.map_err(AppError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
id: &str,
|
||||
context: &str,
|
||||
@@ -364,7 +399,14 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::setup_test_db_with_runtime_indexes;
|
||||
use crate::{
|
||||
storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
},
|
||||
test_utils::{setup_test_db, setup_test_db_with_runtime_indexes},
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||
@@ -638,4 +680,81 @@ mod tests {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_ingested_children_removes_chunks_entities_and_relationships()
|
||||
-> anyhow::Result<()> {
|
||||
let db = setup_test_db().await?;
|
||||
let user_id = "clear-user";
|
||||
let source_id = Uuid::new_v4().to_string();
|
||||
|
||||
let entity_a = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"entity-a".to_string(),
|
||||
"desc-a".to_string(),
|
||||
KnowledgeEntityType::Idea,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
let entity_b = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"entity-b".to_string(),
|
||||
"desc-b".to_string(),
|
||||
KnowledgeEntityType::Idea,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
KnowledgeEntity::store_with_embedding(entity_a.clone(), vec![0.1; 3], 3, &db)
|
||||
.await
|
||||
.context("store entity a")?;
|
||||
KnowledgeEntity::store_with_embedding(entity_b.clone(), vec![0.2; 3], 3, &db)
|
||||
.await
|
||||
.context("store entity b")?;
|
||||
|
||||
let chunk = TextChunk::new(source_id.clone(), "chunk".to_string(), user_id.to_string());
|
||||
TextChunk::store_with_embedding(chunk, vec![0.3; 3], 3, &db)
|
||||
.await
|
||||
.context("store chunk")?;
|
||||
|
||||
KnowledgeRelationship::new(
|
||||
entity_a.id.clone(),
|
||||
entity_b.id,
|
||||
user_id.to_string(),
|
||||
source_id.clone(),
|
||||
"relates_to".to_string(),
|
||||
)
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.context("store relationship")?;
|
||||
|
||||
TextContent::clear_ingested_children(&source_id, user_id, &db)
|
||||
.await
|
||||
.context("clear ingested children")?;
|
||||
|
||||
let chunks: Vec<TextChunk> = db
|
||||
.client
|
||||
.query("SELECT * FROM text_chunk WHERE source_id = $source_id;")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
assert!(chunks.is_empty());
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query("SELECT * FROM knowledge_entity WHERE source_id = $source_id;")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
assert!(entities.is_empty());
|
||||
|
||||
let relationships: Vec<KnowledgeRelationship> = db
|
||||
.client
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id;")
|
||||
.bind(("source_id", source_id))
|
||||
.await?
|
||||
.take(0)?;
|
||||
assert!(relationships.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use anyhow::anyhow;
|
||||
use async_trait::async_trait;
|
||||
use axum_session_auth::Authentication;
|
||||
use chrono_tz::Tz;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use surrealdb::{Surreal, engine::any::Any};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::text_chunk::TextChunk;
|
||||
@@ -729,7 +729,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, MAX_ATTEMPTS, TaskState};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::test_utils::setup_test_db;
|
||||
|
||||
@@ -8,8 +8,8 @@ use crate::storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::{ensure_runtime, rebuild},
|
||||
types::{
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
|
||||
text_chunk_embedding::TextChunkEmbedding,
|
||||
EmbeddingRecord, knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
system_settings::SystemSettings, text_chunk_embedding::TextChunkEmbedding,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -91,3 +91,47 @@ pub async fn setup_test_db_with_runtime_indexes() -> Result<SurrealDbClient> {
|
||||
rebuild(&db).await?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
/// Ensures an FTS analyzer and BM25 indexes exist for a table.
|
||||
///
|
||||
/// Attempts snowball(english) tokenizer first; falls back to basic
|
||||
/// lowercase+ascii when the platform lacks the snowball extension.
|
||||
///
|
||||
/// `indexes` is a slice of `(field_name, index_id_suffix)` pairs —
|
||||
/// e.g. `&[("chunk", "chunk")]` produces index
|
||||
/// `text_chunk_fts_chunk_idx` on column `chunk` of `text_chunk`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the fallback definition fails. The initial
|
||||
/// snowball attempt is allowed to fail silently.
|
||||
pub async fn ensure_fts_index(
|
||||
db: &SurrealDbClient,
|
||||
table: &str,
|
||||
indexes: &[(&str, &str)],
|
||||
) -> Result<()> {
|
||||
use std::fmt::Write;
|
||||
|
||||
let mut define_indexes = String::new();
|
||||
for (field, suffix) in indexes {
|
||||
let _ = writeln!(
|
||||
define_indexes,
|
||||
"DEFINE INDEX IF NOT EXISTS {table}_fts_{suffix}_idx ON TABLE {table} FIELDS {field} SEARCH ANALYZER app_en_fts_analyzer BM25;"
|
||||
);
|
||||
}
|
||||
|
||||
let snowball_sql = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);\n{define_indexes}"
|
||||
);
|
||||
|
||||
if let Err(err) = db.client.query(&snowball_sql).await {
|
||||
let fallback_sql = format!(
|
||||
"DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;\n{define_indexes}"
|
||||
);
|
||||
db.client
|
||||
.query(&fallback_sql)
|
||||
.await
|
||||
.with_context(|| format!("define fts index fallback for {table}: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -135,6 +135,9 @@ pub struct AppConfig {
|
||||
pub ingest_max_context_bytes: usize,
|
||||
#[serde(default = "default_ingest_max_category_bytes")]
|
||||
pub ingest_max_category_bytes: usize,
|
||||
/// Seconds between scheduled `REBUILD INDEX` maintainer runs (`0` disables).
|
||||
#[serde(default = "default_index_rebuild_interval_secs")]
|
||||
pub index_rebuild_interval_secs: u64,
|
||||
}
|
||||
|
||||
/// Default data directory for persisted assets.
|
||||
@@ -172,6 +175,10 @@ fn default_ingest_max_category_bytes() -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
fn default_index_rebuild_interval_secs() -> u64 {
|
||||
86_400
|
||||
}
|
||||
|
||||
static ORT_PATH_INIT: Once = Once::new();
|
||||
|
||||
/// Sets `ORT_DYLIB_PATH` once per process when a bundled ONNX runtime library is found.
|
||||
@@ -191,7 +198,10 @@ pub fn ensure_ort_path() {
|
||||
exe.join("lib").join("onnxruntime.dll"),
|
||||
] {
|
||||
if p.exists() {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
// SAFETY: `Once` ensures this runs on a single thread during startup.
|
||||
unsafe {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -203,7 +213,10 @@ pub fn ensure_ort_path() {
|
||||
};
|
||||
let p = exe.join("lib").join(name);
|
||||
if p.exists() {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
// SAFETY: `Once` ensures this runs on a single thread during startup.
|
||||
unsafe {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -238,6 +251,7 @@ impl Default for AppConfig {
|
||||
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
||||
ingest_max_context_bytes: default_ingest_max_context_bytes(),
|
||||
ingest_max_category_bytes: default_ingest_max_category_bytes(),
|
||||
index_rebuild_interval_secs: default_index_rebuild_interval_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use std::{
|
||||
use serde::Serialize;
|
||||
use tracing::warn;
|
||||
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use async_openai::{Client, types::embeddings::CreateEmbeddingRequestArgs};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
|
||||
@@ -588,9 +588,8 @@ mod tests {
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use super::{
|
||||
align_fastembed_system_settings, fastembed_model_dimension,
|
||||
list_fastembed_embedding_models, resolve_fastembed_model_code, EmbeddingError,
|
||||
DEFAULT_FASTEMBED_MODEL_CODE,
|
||||
DEFAULT_FASTEMBED_MODEL_CODE, EmbeddingError, align_fastembed_system_settings,
|
||||
fastembed_model_dimension, list_fastembed_embedding_models, resolve_fastembed_model_code,
|
||||
};
|
||||
use crate::storage::types::system_settings::SystemSettings;
|
||||
use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError};
|
||||
|
||||
@@ -47,13 +47,13 @@ pub fn validate_ingest_input(
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(content) = content {
|
||||
if content.len() > config.ingest_max_content_bytes {
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"content is too large: maximum allowed is {} bytes",
|
||||
config.ingest_max_content_bytes
|
||||
)));
|
||||
}
|
||||
if let Some(content) = content
|
||||
&& content.len() > config.ingest_max_content_bytes
|
||||
{
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"content is too large: maximum allowed is {} bytes",
|
||||
config.ingest_max_content_bytes
|
||||
)));
|
||||
}
|
||||
|
||||
if ctx.len() > config.ingest_max_context_bytes {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pub use minijinja::{path_loader, Environment, Value};
|
||||
pub use minijinja::{Environment, Value, path_loader};
|
||||
pub use minijinja_autoreload::AutoReloader;
|
||||
pub use minijinja_contrib;
|
||||
pub use minijinja_embed;
|
||||
|
||||
+14
-14
@@ -3,10 +3,10 @@
|
||||
"devenv": {
|
||||
"locked": {
|
||||
"dir": "src/modules",
|
||||
"lastModified": 1771066302,
|
||||
"lastModified": 1781800860,
|
||||
"owner": "cachix",
|
||||
"repo": "devenv",
|
||||
"rev": "1b355dec9bddbaddbe4966d6fc30d7aa3af8575b",
|
||||
"rev": "d59d872d80876d9eeb3e214d3b088bc4a14a9c4f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -22,10 +22,10 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1771052630,
|
||||
"lastModified": 1781779700,
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "d0555da98576b8611c25df0c208e51e9a182d95f",
|
||||
"rev": "ad30e585c7a2917325943c2b19511f5a249eff53",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -58,10 +58,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1770726378,
|
||||
"lastModified": 1781733627,
|
||||
"owner": "cachix",
|
||||
"repo": "git-hooks.nix",
|
||||
"rev": "5eaaedde414f6eb1aea8b8525c466dc37bba95ae",
|
||||
"rev": "3bbec39bc90eadfa031e6f3b77272f3f60803e39",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -92,10 +92,10 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1771008912,
|
||||
"lastModified": 1781577229,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "a82ccc39b39b621151d6732718e3e250109076fa",
|
||||
"rev": "567a49d1913ce81ac6e9582e3553dd90a955875f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -107,10 +107,10 @@
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1770843696,
|
||||
"lastModified": 1781607440,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "2343bbb58f99267223bc2aac4fc9ea301a155a16",
|
||||
"rev": "3e41b24abd260e8f71dbe2f5737d24122f972158",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -135,10 +135,10 @@
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1771007332,
|
||||
"lastModified": 1781714865,
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "bbc84d335fbbd9b3099d3e40c7469ee57dbd1873",
|
||||
"rev": "abb1301c3c14a40645bb2588b1cc858fe374b527",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -155,10 +155,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1771038269,
|
||||
"lastModified": 1781850613,
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "d7a86c8a4df49002446737603a3e0d7ef91a9637",
|
||||
"rev": "4baecb43a008cd004e5220a777e1724bd8d43e43",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
+32
-6
@@ -4,28 +4,50 @@
|
||||
config,
|
||||
inputs,
|
||||
...
|
||||
}:
|
||||
let
|
||||
ortVersion = lib.removeSuffix "\n" (builtins.readFile "${toString ./.}/ort-version");
|
||||
}: let
|
||||
ortVersion = "1.23.2";
|
||||
_ortVersionCheck =
|
||||
if pkgs.onnxruntime.version == ortVersion
|
||||
then null
|
||||
else
|
||||
throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ort-version (${ortVersion})";
|
||||
else throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ortVersion in flake.nix (${ortVersion})";
|
||||
in {
|
||||
devenv.warnOnNewVersion = false;
|
||||
|
||||
cachix.enable = false;
|
||||
|
||||
git-hooks.install.enable = true;
|
||||
git-hooks.hooks = {
|
||||
rustfmt.enable = true;
|
||||
clippy = {
|
||||
enable = true;
|
||||
settings.allFeatures = true;
|
||||
};
|
||||
};
|
||||
|
||||
# Use pinned Rust toolchain from languages.rust for git-hooks wrappers
|
||||
# (git-hooks.nix defaults to nixpkgs's cargo/clippy/rustfmt, ignoring the pin)
|
||||
git-hooks.tools.cargo = lib.mkDefault config.languages.rust.toolchain.cargo;
|
||||
git-hooks.tools.clippy = lib.mkDefault config.languages.rust.toolchain.clippy;
|
||||
git-hooks.tools.rustfmt = lib.mkDefault config.languages.rust.toolchain.rustfmt;
|
||||
|
||||
packages = [
|
||||
pkgs.openssl
|
||||
pkgs.nodejs
|
||||
pkgs.watchman
|
||||
pkgs.vscode-langservers-extracted
|
||||
pkgs.cargo-dist
|
||||
pkgs.cargo-xwin
|
||||
pkgs.clang
|
||||
pkgs.onnxruntime
|
||||
pkgs.cargo-watch
|
||||
pkgs.tailwindcss_4
|
||||
pkgs.python3
|
||||
pkgs.fontconfig
|
||||
pkgs.fontconfig.dev
|
||||
pkgs.libGL
|
||||
pkgs.libGLU
|
||||
pkgs.libclang
|
||||
pkgs.wayland
|
||||
pkgs.libxkbcommon
|
||||
];
|
||||
|
||||
languages.rust = {
|
||||
@@ -38,6 +60,10 @@ in {
|
||||
};
|
||||
|
||||
env = {
|
||||
# tikv-jemalloc-sys configure flags: -O0 + -Werror triggers glibc _FORTIFY_SOURCE warning
|
||||
NIX_CFLAGS_COMPILE = "-Wno-error=cpp";
|
||||
LIBCLANG_PATH = "${pkgs.libclang.lib}/lib";
|
||||
LD_LIBRARY_PATH = "${pkgs.wayland}/lib:${pkgs.libxkbcommon}/lib:${pkgs.pipewire}/lib:${pkgs.libglvnd}/lib";
|
||||
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
|
||||
S3_ENDPOINT = "http://127.0.0.1:19000";
|
||||
S3_BUCKET = "minne-tests";
|
||||
|
||||
@@ -9,3 +9,6 @@ inputs:
|
||||
nixpkgs:
|
||||
follows: nixpkgs
|
||||
allowUnfree: true
|
||||
nixpkgs:
|
||||
permittedInsecurePackages:
|
||||
- "minio-2025-10-15T17-29-55Z"
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
[workspace]
|
||||
members = ["cargo:."]
|
||||
|
||||
# Config for 'dist'
|
||||
[dist]
|
||||
# The preferred dist version to use in CI (Cargo.toml SemVer syntax)
|
||||
cargo-dist-version = "0.30.3"
|
||||
# CI backends to support
|
||||
ci = "github"
|
||||
# Extra static files to include in each App (path relative to this Cargo.toml's dir)
|
||||
include = ["lib"]
|
||||
# The installers to generate for each app
|
||||
installers = []
|
||||
# Target platforms to build apps for (Rust target-triple syntax)
|
||||
targets = ["aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"]
|
||||
# Skip checking whether the specified configuration files are up to date
|
||||
allow-dirty = ["ci"]
|
||||
|
||||
[dist.github-custom-runners]
|
||||
aarch64-apple-darwin = "macos-latest"
|
||||
x86_64-apple-darwin = "macos-15-intel"
|
||||
x86_64-unknown-linux-gnu = "ubuntu-22.04"
|
||||
x86_64-unknown-linux-musl = "ubuntu-22.04"
|
||||
x86_64-pc-windows-msvc = "windows-latest"
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
services:
|
||||
minne:
|
||||
build: .
|
||||
image: ghcr.io/perstarkse/minne:latest
|
||||
container_name: minne_app
|
||||
ports:
|
||||
- "3000:3000"
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
| Frontend | HTML + HTMX + minimal JS |
|
||||
| Database | SurrealDB (graph, document, vector) |
|
||||
| AI | OpenAI-compatible API |
|
||||
| Web Processing | Headless Chromium |
|
||||
| Web Processing | Servo engine (servo-fetch) + PDFium |
|
||||
|
||||
## Crate Structure
|
||||
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
# CI/CD Roadmap: Nix-First Release Builds
|
||||
|
||||
This document tracks the migration from cargo-dist raw `cargo build --release` on bare GitHub runners to Nix-built release artifacts for all platforms. The goal is a single build system (the flake) shared by CI, Docker, and release binaries.
|
||||
|
||||
**Status:** Phase 3–4 complete locally — Nix builds all release targets including Windows cross (`nix build .#minne-release-windows` verified on x86_64-linux). cargo-dist removed from workflow and devenv. GHA tag-push validation pending.
|
||||
|
||||
**Decision (2026-06-23):** Drop `x86_64-apple-darwin` (Intel macOS). Ship `aarch64-apple-darwin` only; Intel Mac users can run via Rosetta 2.
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
Nix is now the sole compiler for all release binaries. Per-platform `minne-release` flake outputs produce archives compatible with GitHub Releases layout (binaries + `lib/libonnxruntime.*` + docs). The release workflow uses matrix jobs running `nix build` with `cache-nix-action` on every job. cargo-dist has been removed; releases use `gh release create` with CHANGELOG-driven notes.
|
||||
|
||||
This fixes the mozangle/clang failure at the root: the flake already wires `libclang`, `bindgenHook`, `llvm`, `python3`, `fontconfig`, and `MOZJS_ARCHIVE` — cargo-dist on bare Ubuntu cannot see any of that without duplicating it in apt/workflow steps.
|
||||
|
||||
---
|
||||
|
||||
## Current State
|
||||
|
||||
### What works
|
||||
|
||||
- [x] CI (`nix flake check`) — format, clippy, tests, ort-version gate via Crane `buildDepsOnly`
|
||||
- [x] Release Docker job — `nix build .#dockerImage`, push to GHCR with dynamic tag from `docker load`
|
||||
- [x] Release plan job — `nix flake check`, ORT version from flake (no cargo-dist)
|
||||
- [x] Harmonized native deps in `flake.nix` for CI/Docker (openssl, libglvnd, onnxruntime, fontconfig, bindgen, mozjs)
|
||||
|
||||
### What is broken or painful
|
||||
|
||||
- [x] ~~cargo-dist Linux build fails without apt/mozjs workarounds~~ — resolved: Nix builds all platforms
|
||||
- [x] ~~Two build systems: Nix for CI/Docker, cargo + apt/homebrew for dist binaries~~ — resolved: Nix-only release
|
||||
- [x] ~~Four independent release compiles with no shared Nix store across jobs~~ — resolved: `cache-nix-action` on all release jobs
|
||||
- [x] ~~`[dist.dependencies.apt]` duplicates flake.nix logic~~ — resolved: `dist-workspace.toml` deleted
|
||||
- [ ] Release compile time on GHA not yet measured post-migration (expected ~10–25 min warm vs ~50–110 min cold)
|
||||
- [ ] GHA tag-push validation pending for macOS and Windows archives
|
||||
|
||||
---
|
||||
|
||||
## Target Architecture
|
||||
|
||||
| Layer | Owner | Notes |
|
||||
|-------|-------|-------|
|
||||
| Compile binaries | **Nix** (`minne-pkg` / cross derivations) | Crane + `commonArgs`, per-platform mozjs |
|
||||
| Bundle ORT + runtime libs | **Nix** (`minne-release`) | Match `include = ["lib"]` layout |
|
||||
| Create archives | **Nix** or thin shell | `.tar.xz` (Unix), `.zip` (Windows) |
|
||||
| Publish GitHub Release | **`gh release create`** | CHANGELOG body |
|
||||
| Docker image | **Nix** (unchanged) | Shares `minne-pkg` derivation with Linux release |
|
||||
| cargo-dist | **Removed** | Replaced by Nix jobs + `gh release` |
|
||||
|
||||
### Release targets (end state)
|
||||
|
||||
| Target | Builder | Nix output |
|
||||
|--------|---------|------------|
|
||||
| `x86_64-unknown-linux-gnu` | `ubuntu-22.04` native | `.#minne-release` |
|
||||
| `aarch64-apple-darwin` | `macos-latest` native | `.#minne-release` |
|
||||
| ~~`x86_64-apple-darwin`~~ | **Dropped** | — |
|
||||
| `x86_64-pc-windows-msvc` | `ubuntu-22.04` cross | `.#minne-release-windows` |
|
||||
|
||||
---
|
||||
|
||||
## Per-Platform Build Matrix
|
||||
|
||||
| Target | Feasibility | Nix command | Artifact layout | Blockers |
|
||||
|--------|-------------|-------------|-----------------|----------|
|
||||
| `x86_64-unknown-linux-gnu` | Ready with modest flake changes | `nix build .#minne-release --system x86_64-linux` | `main-{ver}-x86_64-unknown-linux-gnu.tar.xz` → `main/`, `server/`, `worker/`, `lib/libonnxruntime.so`, README, LICENSE, CHANGELOG | glibc 2.40 (nixpkgs-unstable) vs Ubuntu 22.04 glibc 2.35; portable runtime bundling needed |
|
||||
| `aarch64-apple-darwin` | Feasible | `nix build .#minne-release --system aarch64-darwin` | `main-{ver}-aarch64-apple-darwin.tar.xz` + `lib/libonnxruntime.dylib` | Per-system mozjs URL; Darwin `postInstall` assumes Linux today |
|
||||
| `x86_64-pc-windows-msvc` | Feasible with new cross flake | `nix build .#minne-release-windows` (x86_64-linux host) | `main-{ver}-x86_64-pc-windows-msvc.zip` + `lib/onnxruntime.dll` | Crane + cargo-xwin cross setup; no native Nix-on-Windows for v1 |
|
||||
|
||||
### mozjs prebuilt availability (mozjs-sys-v140.10.1-0)
|
||||
|
||||
Confirmed for all release triples:
|
||||
|
||||
- `libmozjs-x86_64-unknown-linux-gnu.tar.gz`
|
||||
- `libmozjs-aarch64-apple-darwin.tar.gz`
|
||||
- `libmozjs-x86_64-pc-windows-msvc.tar.gz`
|
||||
|
||||
---
|
||||
|
||||
## Caching Strategy
|
||||
|
||||
| Layer | Invalidated by | Shared across |
|
||||
|-------|----------------|---------------|
|
||||
| Nix store (system deps) | `flake.lock`, `*.nix`, `Cargo.lock` | plan, CI, Docker, all release jobs (per OS) |
|
||||
| `cargoArtifacts` (`buildDepsOnly`) | `Cargo.lock` dep changes only | minne-pkg, clippy, test, dockerImage, release |
|
||||
| `minne-pkg` (source) | Application source changes | dockerImage, release |
|
||||
| cargo-dist `target/` | Version bumps (weak) | Removed — Nix store replaces it |
|
||||
|
||||
### Expected release times
|
||||
|
||||
| Scenario | Current (cargo-dist) | After migration (Nix) |
|
||||
|----------|---------------------|-------------------------|
|
||||
| Cold release (version bump, no cache) | ~50–110 min × 4 jobs | ~45–90 min × 3 jobs (no Intel Mac) |
|
||||
| Warm release (source-only, cache hit) | Still ~full rebuild | ~10–25 min incremental per OS |
|
||||
| No-op re-release | Full rebuild | ~2–5 min if derivations unchanged |
|
||||
| Docker job (cached) | ~5–15 min | Unchanged; shares `minne-pkg` with Linux release |
|
||||
|
||||
`buildDepsOnly` survives version bumps (version is in flake `minneVersion`, not `Cargo.lock`) — major win over cargo-dist.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1 — Linux via Nix (highest pain, highest value)
|
||||
|
||||
- [x] Add per-system `mozjsArchive` helper with hash map (at minimum fix structure for all platforms)
|
||||
- [x] Add `nix/minne-release.nix` — bundle ORT + portable runtime libs + docs into archive
|
||||
- [x] Linux portable runtime lib bundling + `patchelf --set-rpath '$ORIGIN/lib'`
|
||||
- [x] Replace Linux `build-local-artifacts` steps with `nix build .#minne-release`
|
||||
- [x] Add `cache-nix-action` to Linux release build job
|
||||
- [x] Validate glibc portability (test binary on Ubuntu 22.04)
|
||||
- [x] Remove Linux apt/mozjs/ORT curl workarounds from `.github/workflows/release.yml`
|
||||
- [x] Archive naming matches prior releases (`main-{triple}.tar.xz`, no version in filename)
|
||||
|
||||
### Phase 2 — macOS aarch64
|
||||
|
||||
- [x] Platform-conditional `postInstall` in `flake.nix` (Darwin vs Linux wrapping)
|
||||
- [x] Add `nix/minne-release-darwin.nix` — ORT + runtime dylibs + docs archive
|
||||
- [x] macOS `build-local-artifacts` uses `nix build .#minne-release` on `macos-latest`
|
||||
- [x] `cache-nix-action` on macOS release build job
|
||||
- [x] Drop `x86_64-apple-darwin` target (was in `dist-workspace.toml`, now deleted)
|
||||
- [ ] Test archive on clean macOS VM / GHA release run
|
||||
- [x] Update `docs/installation.md` to note aarch64-only macOS binary (Rosetta 2 for Intel Macs)
|
||||
|
||||
### Phase 3 — Windows cross from Linux
|
||||
|
||||
- [x] Add `minne-release-windows` cross derivation (Crane + cargo-xwin)
|
||||
- [x] Add `nix/clang-cl-msvc-link-wrapper.sh` for mozangle DLL links under clang-cl
|
||||
- [x] Windows GHA job on `ubuntu-22.04` (cross-build, not Nix-on-Windows)
|
||||
- [x] Bundle `onnxruntime.dll` in release zip (match cargo-dist flat layout)
|
||||
- [x] Fenix `rust-std` for `x86_64-pc-windows-msvc` via `fenix.combine`
|
||||
- [x] Local cross-build verified: `nix build .#minne-release-windows` on x86_64-linux
|
||||
- [ ] Test archive on Windows VM / GHA release run
|
||||
|
||||
### Phase 4 — Cleanup
|
||||
|
||||
- [x] Remove cargo-dist compile steps from release workflow
|
||||
- [x] Delete `dist-workspace.toml`
|
||||
- [x] Simplify CI to `nix flake check` only (drop `cargo-dist plan`)
|
||||
- [x] Replace `host`/cargo-dist with `gh release create` + CHANGELOG
|
||||
- [x] Remove `pkgs.cargo-dist` from `devenv.nix`
|
||||
- [x] Update `AGENTS.md` release checklist
|
||||
- [ ] Update README release badges/docs if workflow structure changes
|
||||
|
||||
---
|
||||
|
||||
## Flake Changes (outline)
|
||||
|
||||
New/modified outputs:
|
||||
|
||||
```nix
|
||||
# Per-system mozjs (replace hardcoded Linux x86_64)
|
||||
mozjsTarget = { "x86_64-linux" = "x86_64-unknown-linux-gnu"; ... }.${system};
|
||||
mozjsArchive = pkgs.fetchurl { url = ".../libmozjs-${mozjsTarget}.tar.gz"; hash = mozjsHashes.${system}; };
|
||||
|
||||
# Platform-conditional postInstall (Linux LD_LIBRARY_PATH vs Darwin)
|
||||
|
||||
# NEW: release archive derivation
|
||||
packages.minne-release = callPackage ./nix/minne-release.nix { inherit minne-pkg minneVersion ortVersion; };
|
||||
|
||||
# NEW: Windows cross (x86_64-linux host only)
|
||||
packages.minne-release-windows = ...;
|
||||
```
|
||||
|
||||
New file: `nix/minne-release.nix` — copies stripped binaries, stages `lib/libonnxruntime.{so,dylib}`, optional runtime `.so` copies, includes README/LICENSE/CHANGELOG, builds `.tar.xz` / `.zip`.
|
||||
|
||||
Optional: `devShells.dist` for local release-build debugging.
|
||||
|
||||
---
|
||||
|
||||
## Workflow Changes (outline)
|
||||
|
||||
Target `release.yml` structure:
|
||||
|
||||
```
|
||||
plan:
|
||||
- nix flake check, nix eval ortVersion
|
||||
- output tag from github.ref (no hardcoded versions)
|
||||
|
||||
build-nix-artifacts: # replaces build-local-artifacts
|
||||
matrix: linux | macos-aarch64 | windows-cross
|
||||
- determinate-nix + cache-nix-action on ALL jobs
|
||||
- nix build .#${attr} --system ${system} -L
|
||||
- upload: main-*-{triple}.tar.xz / .zip
|
||||
|
||||
build_and_push_docker_image: # unchanged
|
||||
|
||||
release: # replaces build-global-artifacts + host
|
||||
- download artifacts
|
||||
- gh release create with CHANGELOG body
|
||||
```
|
||||
|
||||
Artifact naming: match current convention for backwards compatibility — `main-{version}-{triple}.tar.xz` (Unix) / `.zip` (Windows).
|
||||
|
||||
---
|
||||
|
||||
## cargo-dist Fate
|
||||
|
||||
**Status:** Removed (Option B implemented in Phase 4).
|
||||
|
||||
| Option | Verdict |
|
||||
|--------|---------|
|
||||
| A) Nix builds → cargo-dist packages only | No clean skip-compile mode; high friction |
|
||||
| **B) Replace with custom Nix jobs + `gh release`** | **Implemented** |
|
||||
| C) `build-local-artifacts = false` + custom jobs | Experimental; superseded by Option B |
|
||||
|
||||
---
|
||||
|
||||
## Task Checklist (with complexity)
|
||||
|
||||
| # | Task | Size | Phase | Done |
|
||||
|---|------|------|-------|------|
|
||||
| 1 | `mozjsArchive` per `system` with hash map | S | 1 | [x] |
|
||||
| 2 | Platform-conditional `postInstall` in flake | S | 1–2 | [x] |
|
||||
| 3 | `nix/minne-release.nix` archive bundler | M | 1 | [x] |
|
||||
| 4 | Linux portable runtime lib bundling + patchelf | M | 1 | [x] |
|
||||
| 5 | Replace Linux `build-local-artifacts` with Nix job | S | 1 | [x] |
|
||||
| 6 | Add `cache-nix-action` to all release build jobs | S | 1–3 | [x] |
|
||||
| 7 | glibc portability test + fix | M | 1 | [x] |
|
||||
| 8 | Darwin release bundle + macOS GHA job | M | 2 | [x] |
|
||||
| 9 | Drop `x86_64-apple-darwin` from targets | S | 2 | [x] |
|
||||
| 10 | Windows cross flake (`minne-release-windows`) | L | 3 | [x] |
|
||||
| 11 | Windows GHA job | S | 3 | [x] |
|
||||
| 12 | Replace `host`/cargo-dist with `gh release` | S | 4 | [x] |
|
||||
| 13 | Remove apt deps, ORT curl, cargo-dist install | S | 4 | [x] |
|
||||
| 14 | Update docs/AGENTS release checklist | S | 4 | [x] |
|
||||
|
||||
S = hours–1 day, M = 2–4 days, L = 1–2 weeks
|
||||
|
||||
---
|
||||
|
||||
## Risks & Blockers
|
||||
|
||||
| Risk | Severity | Mitigation | Resolved |
|
||||
|------|----------|------------|----------|
|
||||
| glibc compatibility (nixpkgs 2.40 vs Ubuntu 22.04 2.35) | High | Bundle runtime libs in `lib/` + `LD_LIBRARY_PATH` wrappers; bundled glibc interpreter | [x] |
|
||||
| mozjs per-platform hashes drift on `Cargo.lock` bump | Medium | Centralize in `mozjsHashes` attrset; document bump procedure | [ ] |
|
||||
| Darwin `postInstall` assumes Linux (`LD_LIBRARY_PATH`, `libglvnd`) | Medium | Platform-conditional wrapping in flake | [x] |
|
||||
| Windows cross complexity (Crane + cargo-xwin) | Medium–High | cargo-xwin env + clang-cl wrapper for mozangle; Dbghelp.lib case symlink | [x] |
|
||||
| Nix on macOS GHA speed | Medium | cache-nix-action; larger runner if needed | [ ] |
|
||||
| Codesigning / notarization (macOS) | Low | Not required for CLI today; document `xattr` workaround; revisit if needed | [ ] |
|
||||
| musl target (`x86_64-unknown-linux-musl`) | N/A | mozjs/servo stack is glibc-oriented; stay on `*-linux-gnu` unless explicitly requested | [ ] |
|
||||
| ORT version drift | Low | Existing `ortVersion` gate in flake + devenv | [x] |
|
||||
|
||||
---
|
||||
|
||||
## Open Questions
|
||||
|
||||
1. **glibc portability strategy** — Bundle runtime libs in `lib/` (preferred for portability) vs pin `nixpkgs` to an older release channel for release builds vs document minimum distro? Need a test matrix: Ubuntu 22.04, Debian 12, Fedora current.
|
||||
|
||||
2. **Archive format** — Confirmed: `.tar.xz` (Unix), `.zip` (Windows); naming `main-{triple}.*` (no version in filename).
|
||||
|
||||
3. **Binary scope** — Release all three binaries (`main`, `server`, `worker`) in one archive per platform (unchanged from prior cargo-dist behavior).
|
||||
|
||||
4. **PR artifact builds** — Not implemented; cargo-dist `pr-run-mode` was disabled. Revisit if PR smoke-test artifacts are wanted.
|
||||
|
||||
5. **Cachix** — Deferred; `cache-nix-action` on all release jobs is sufficient for now.
|
||||
|
||||
6. **Windows cross approach** — Resolved: Crane + offline xwin MSVC cache + fenix `rust-std` + clang-cl/lld-link shims (`nix build .#minne-release-windows` verified locally).
|
||||
|
||||
7. **Version source of truth** — Release workflow reads version from flake (`minneVersion`).
|
||||
|
||||
8. **cargo-dist removal timing** — Resolved: removed in Phase 4.
|
||||
|
||||
9. **Intel Mac deprecation communication** — Done: `docs/installation.md` notes aarch64-only + Rosetta 2.
|
||||
|
||||
---
|
||||
|
||||
## Success Criteria
|
||||
|
||||
After implementation:
|
||||
|
||||
- [x] Release workflow no longer runs raw `cargo build --release` on bare GitHub runners
|
||||
- [x] Native deps (clang, mozjs, onnxruntime, etc.) come from flake/Nix, not apt
|
||||
- [x] Linux, macOS (aarch64), and Windows release binaries are produced via Nix
|
||||
- [x] Docker and release binaries share maximum Nix store cache (`cache-nix-action` on all jobs)
|
||||
- [x] No hardcoded version strings in `release.yml`
|
||||
- [ ] Warm release compile time materially improved (~10–25 min/platform vs ~50–110 min today) — pending GHA measurement
|
||||
- [ ] macOS and Windows archives validated on clean VM / GHA tag-push release run
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- [Crane cross-windows example](https://crane.dev/examples/cross-windows.html)
|
||||
- [Crane discussion: MSVC / cargo-xwin](https://github.com/ipetkov/crane/discussions/555)
|
||||
- [cargo-dist CI customization](https://axodotdev.github.io/cargo-dist/book/ci/customizing.html)
|
||||
- [servo/mozjs releases](https://github.com/servo/mozjs/releases)
|
||||
- Project files: `flake.nix`, `.github/workflows/release.yml`, `devenv.nix`, `nix/minne-release*.nix`
|
||||
+3
-1
@@ -10,7 +10,7 @@
|
||||
|
||||
Minne automatically processes saved content:
|
||||
|
||||
1. **Web scraping** extracts readable text from URLs (via headless Chrome)
|
||||
1. **Web scraping** extracts readable text from URLs (via embedded Servo engine)
|
||||
2. **Text analysis** identifies key concepts and relationships
|
||||
3. **Graph creation** builds connections between related content
|
||||
4. **Embedding generation** enables semantic search
|
||||
@@ -43,6 +43,7 @@ Optional **reranking** can rescore fused chunk lists with a cross-encoder model;
|
||||
When enabled, retrieval results are rescored with a cross-encoder model for improved relevance. Powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs).
|
||||
|
||||
**Trade-offs:**
|
||||
|
||||
- Downloads ~1.1 GB of model data
|
||||
- Adds latency per query
|
||||
- Potentially improves answer quality, see [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/)
|
||||
@@ -52,6 +53,7 @@ Enable via `RERANKING_ENABLED=true`. See [Configuration](./configuration.md).
|
||||
## Multi-Format Ingestion
|
||||
|
||||
Supported content types:
|
||||
|
||||
- Plain text and notes
|
||||
- URLs (web pages)
|
||||
- PDF documents
|
||||
|
||||
+10
-5
@@ -12,13 +12,13 @@ cd minne
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
The included `docker-compose.yml` handles SurrealDB and Chromium automatically.
|
||||
The included `docker-compose.yml` handles SurrealDB automatically.
|
||||
|
||||
**Required:** Set your `OPENAI_API_KEY` in `docker-compose.yml` before starting.
|
||||
|
||||
## Nix
|
||||
|
||||
Run Minne directly with Nix (includes Chromium):
|
||||
Run Minne directly with Nix:
|
||||
|
||||
```bash
|
||||
nix run 'github:perstarkse/minne#main'
|
||||
@@ -28,11 +28,15 @@ Configure via environment variables or a `config.yaml` file. See [Configuration]
|
||||
|
||||
## Pre-built Binaries
|
||||
|
||||
Download binaries for Windows, macOS, and Linux from [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||
Download binaries for Windows, macOS (Apple Silicon), and Linux from [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||
|
||||
**macOS:** Release builds target `aarch64-apple-darwin` (Apple Silicon). Intel Macs can run the binary via [Rosetta 2](https://support.apple.com/en-us/102527).
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- SurrealDB instance (local or remote)
|
||||
- Chromium (for web scraping)
|
||||
- Linux: `libEGL` + `libfontconfig` for servo-fetch (bundled in release archives)
|
||||
- macOS: system frameworks; ONNX Runtime is bundled in the archive `lib/` directory
|
||||
|
||||
## Build from Source
|
||||
|
||||
@@ -45,9 +49,10 @@ cargo build --release --bin main
|
||||
The binary will be at `target/release/main`.
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- Rust toolchain
|
||||
- SurrealDB accessible at configured address
|
||||
- Chromium in PATH
|
||||
- `libEGL` + `libfontconfig` for servo-fetch (web scraping) — bundled in Nix and Docker images
|
||||
|
||||
## Process Modes
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "evaluations"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
@@ -30,8 +30,6 @@ serde_json = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
once_cell = "1.19"
|
||||
serde_yaml = "0.9"
|
||||
criterion = "0.5"
|
||||
state-machines = { workspace = true }
|
||||
clap = { version = "4.4", features = ["derive", "env"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
+71
-181
@@ -1,212 +1,102 @@
|
||||
# Evaluations
|
||||
|
||||
The `evaluations` crate provides a retrieval evaluation framework for benchmarking Minne's information retrieval pipeline against standard datasets.
|
||||
The `evaluations` crate benchmarks Minne's retrieval pipeline against standard datasets.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Run SQuAD v2.0 evaluation (vector-only, recommended)
|
||||
cargo run --package evaluations -- --ingest-chunks-only
|
||||
# One-time prep (convert, slice ledger, corpus cache, DB seed)
|
||||
cargo eval --warm --dataset beir --slice beir-mix-600
|
||||
|
||||
# Run a specific dataset
|
||||
cargo run --package evaluations -- --dataset fiqa --ingest-chunks-only
|
||||
# Check readiness
|
||||
cargo eval --status --dataset beir --slice beir-mix-600
|
||||
|
||||
# Convert dataset only (no evaluation)
|
||||
cargo run --package evaluations -- --convert-only
|
||||
# Run benchmark (steady state after warm)
|
||||
cargo eval --dataset beir --slice beir-mix-600 --require-ready
|
||||
```
|
||||
|
||||
Default dataset is `beir`. When `--slice` is omitted, the first catalog slice for the dataset is applied automatically (e.g. `beir-mix-600`).
|
||||
|
||||
Chunk-only ingestion is the default. Pass `--include-entities` to opt into entity extraction during ingestion (requires `OPENAI_API_KEY`).
|
||||
|
||||
### Custom slice sizes
|
||||
|
||||
`--slice` is a ledger id, not only a catalog name. You can use any id; `--limit` controls how many questions the ledger contains:
|
||||
|
||||
```bash
|
||||
# 200-case BEIR mix (default --limit is 200)
|
||||
cargo eval --warm --dataset beir --slice beir-mix-200
|
||||
cargo eval --dataset beir --slice beir-mix-200 --require-ready
|
||||
```
|
||||
|
||||
The catalog slice `beir-mix-600` in `manifest.yaml` is a preset with `limit: 600` and `negative_multiplier: 9.0`.
|
||||
|
||||
### BEIR mix layout
|
||||
|
||||
`beir` is a **virtual mix** across eight subset datasets (FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR). There is no monolithic `beir-minne/` store.
|
||||
|
||||
1. Build an in-memory qrels-world mix from raw subset data
|
||||
2. Resolve the slice ledger (`cache/slices/beir/<slice-id>.json`)
|
||||
3. Materialize only ledger paragraph ids into per-subset stores (`fever-minne/`, `fiqa-minne/`, …)
|
||||
4. Ingest the slice corpus and seed SurrealDB
|
||||
|
||||
Conversion is **qrels-closed**: only documents that appear in qrels are exported, not the full BEIR corpus.
|
||||
|
||||
Chunk-only mode may evaluate fewer cases than the slice ledger size when some questions are impossible or lack verifiable answer chunks.
|
||||
|
||||
Reports include a **Retrieved Context Volume** section: total characters and estimated tokens across all chunks returned per query (`~chars/4`, comparable across `--chunk-result-cap` sweeps). Use this to compare the cost of raising `--chunk-result-cap`.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. SurrealDB
|
||||
|
||||
Start a SurrealDB instance before running evaluations:
|
||||
### SurrealDB
|
||||
|
||||
```bash
|
||||
docker-compose up -d surrealdb
|
||||
```
|
||||
|
||||
Or using the default endpoint configuration:
|
||||
### Raw datasets
|
||||
|
||||
```bash
|
||||
surreal start --user root_user --pass root_password
|
||||
```
|
||||
Place raw datasets under `evaluations/data/raw/`. See [manifest.yaml](./manifest.yaml) for paths.
|
||||
|
||||
### 2. Download Raw Datasets
|
||||
BEIR subsets live in sibling directories (`data/raw/fever`, `data/raw/fiqa`, …). The `data/raw/beir` entry is a virtual catalog placeholder; warm uses the subset paths.
|
||||
|
||||
Raw datasets must be downloaded manually and placed in `evaluations/data/raw/`. See [Dataset Sources](#dataset-sources) below for links and formats.
|
||||
|
||||
## Directory Structure
|
||||
## Directory structure
|
||||
|
||||
```
|
||||
evaluations/
|
||||
├── data/
|
||||
│ ├── raw/ # Downloaded raw datasets (manual)
|
||||
│ │ ├── squad/ # SQuAD v2.0
|
||||
│ │ ├── nq-dev/ # Natural Questions
|
||||
│ │ ├── fiqa/ # BEIR: FiQA-2018
|
||||
│ │ ├── fever/ # BEIR: FEVER
|
||||
│ │ ├── hotpotqa/ # BEIR: HotpotQA
|
||||
│ │ └── ... # Other BEIR subsets
|
||||
│ └── converted/ # Auto-generated (Minne JSON format)
|
||||
├── cache/ # Ingestion and embedding caches
|
||||
├── reports/ # Evaluation output (JSON + Markdown)
|
||||
├── manifest.yaml # Dataset and slice definitions
|
||||
└── src/ # Evaluation source code
|
||||
│ ├── raw/ # Downloaded datasets (manual)
|
||||
│ │ ├── fever/ # BEIR subset raw dirs (corpus.jsonl, queries.jsonl, qrels/)
|
||||
│ │ ├── fiqa/
|
||||
│ │ └── …
|
||||
│ └── converted/ # Sharded stores (auto-generated)
|
||||
│ ├── fever-minne/ # per-BEIR-subset stores
|
||||
│ ├── fiqa-minne/
|
||||
│ └── … # BEIR mix loads from subset stores (no monolithic beir-minne/)
|
||||
├── cache/
|
||||
│ ├── slices/ # Slice ledgers
|
||||
│ └── ingested/ # Corpus ingestion caches (manifest includes namespace seed)
|
||||
├── reports/ # JSON + Markdown output from benchmark runs
|
||||
├── manifest.yaml
|
||||
└── src/
|
||||
```
|
||||
|
||||
## Dataset Sources
|
||||
**After upgrading:** delete old monolithic `*-minne.json` files, any legacy `beir-minne/` merged store, `cache/snapshots/` directories, and stale `reports/history/` artifacts, then re-run `--warm`.
|
||||
|
||||
### SQuAD v2.0
|
||||
|
||||
Download and place at `data/raw/squad/dev-v2.0.json`:
|
||||
|
||||
```bash
|
||||
mkdir -p evaluations/data/raw/squad
|
||||
curl -L https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json \
|
||||
-o evaluations/data/raw/squad/dev-v2.0.json
|
||||
```
|
||||
|
||||
### Natural Questions (NQ)
|
||||
|
||||
Download and place at `data/raw/nq-dev/dev-all.jsonl`:
|
||||
|
||||
```bash
|
||||
mkdir -p evaluations/data/raw/nq-dev
|
||||
# Download from Google's Natural Questions page or HuggingFace
|
||||
# File: dev-all.jsonl (simplified JSONL format)
|
||||
```
|
||||
|
||||
Source: [Google Natural Questions](https://ai.google.com/research/NaturalQuestions)
|
||||
|
||||
### BEIR Datasets
|
||||
|
||||
All BEIR datasets follow the same format structure:
|
||||
|
||||
```
|
||||
data/raw/<dataset>/
|
||||
├── corpus.jsonl # Document corpus
|
||||
├── queries.jsonl # Query set
|
||||
└── qrels/
|
||||
└── test.tsv # Relevance judgments (or dev.tsv)
|
||||
```
|
||||
|
||||
Download datasets from the [BEIR Benchmark repository](https://github.com/beir-cellar/beir). Each dataset zip extracts to the required directory structure.
|
||||
|
||||
| Dataset | Directory |
|
||||
|------------|---------------|
|
||||
| FEVER | `fever/` |
|
||||
| FiQA-2018 | `fiqa/` |
|
||||
| HotpotQA | `hotpotqa/` |
|
||||
| NFCorpus | `nfcorpus/` |
|
||||
| Quora | `quora/` |
|
||||
| TREC-COVID | `trec-covid/` |
|
||||
| SciFact | `scifact/` |
|
||||
| NQ (BEIR) | `nq/` |
|
||||
|
||||
Example download:
|
||||
|
||||
```bash
|
||||
cd evaluations/data/raw
|
||||
curl -L https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip -o fiqa.zip
|
||||
unzip fiqa.zip && rm fiqa.zip
|
||||
```
|
||||
|
||||
## Dataset Conversion
|
||||
|
||||
Raw datasets are automatically converted to Minne's internal JSON format on first run. To force reconversion:
|
||||
|
||||
```bash
|
||||
cargo run --package evaluations -- --force-convert
|
||||
```
|
||||
|
||||
Converted files are saved to `data/converted/` and cached for subsequent runs.
|
||||
|
||||
## CLI Reference
|
||||
|
||||
### Common Options
|
||||
## Common flags
|
||||
|
||||
| Flag | Description | Default |
|
||||
|------|-------------|---------|
|
||||
| `--dataset <NAME>` | Dataset to evaluate | `squad-v2` |
|
||||
| `--limit <N>` | Max questions to evaluate (0 = all) | `200` |
|
||||
| `--k <N>` | Precision@k cutoff | `5` |
|
||||
| `--slice <ID>` | Use a predefined slice from manifest | — |
|
||||
| `--rerank` | Enable FastEmbed reranking stage | disabled |
|
||||
| `--embedding-backend <BE>` | `fastembed` or `hashed` | `fastembed` |
|
||||
| `--ingest-chunks-only` | Skip entity extraction, ingest only text chunks | disabled |
|
||||
| `--dataset` | Dataset to evaluate | `beir` |
|
||||
| `--slice` | Slice ledger id (catalog or custom) | first catalog slice |
|
||||
| `--limit` | Max questions in the slice ledger | `200` |
|
||||
| `--warm` | Prepare without running queries | — |
|
||||
| `--status` | Print readiness | — |
|
||||
| `--require-ready` | Fail if not warmed | — |
|
||||
| `--include-entities` | Entity extraction during ingestion | off |
|
||||
| `--force-convert` | Rebuild converted store | — |
|
||||
| `--chunk-result-cap` | Max chunks returned per query (raise with `--k`) | `5` |
|
||||
| `--perf-log-console` | Print per-stage timings after a run | off |
|
||||
| `--label` | Label stored in JSON/Markdown reports | — |
|
||||
|
||||
> [!TIP]
|
||||
> Use `--ingest-chunks-only` when evaluating vector-only retrieval strategies. This skips the LLM-based entity extraction and graph generation, significantly speeding up ingestion while focusing on pure chunk-based vector search.
|
||||
|
||||
### Available Datasets
|
||||
|
||||
```
|
||||
squad-v2, natural-questions, beir, fever, fiqa, hotpotqa,
|
||||
nfcorpus, quora, trec-covid, scifact, nq-beir
|
||||
```
|
||||
|
||||
### Database Configuration
|
||||
|
||||
| Flag | Environment | Default |
|
||||
|------|-------------|---------|
|
||||
| `--db-endpoint` | `EVAL_DB_ENDPOINT` | `ws://127.0.0.1:8000` |
|
||||
| `--db-username` | `EVAL_DB_USERNAME` | `root_user` |
|
||||
| `--db-password` | `EVAL_DB_PASSWORD` | `root_password` |
|
||||
| `--db-namespace` | `EVAL_DB_NAMESPACE` | auto-generated |
|
||||
| `--db-database` | `EVAL_DB_DATABASE` | auto-generated |
|
||||
|
||||
### Example Runs
|
||||
|
||||
```bash
|
||||
# Vector-only evaluation (recommended for benchmarking)
|
||||
cargo run --package evaluations -- \
|
||||
--dataset fiqa \
|
||||
--ingest-chunks-only \
|
||||
--limit 200
|
||||
|
||||
# Full FiQA evaluation with reranking
|
||||
cargo run --package evaluations -- \
|
||||
--dataset fiqa \
|
||||
--ingest-chunks-only \
|
||||
--limit 500 \
|
||||
--rerank \
|
||||
--k 10
|
||||
|
||||
# Use a predefined slice for reproducibility
|
||||
cargo run --package evaluations -- --slice fiqa-test-200 --ingest-chunks-only
|
||||
|
||||
# Run the mixed BEIR benchmark
|
||||
cargo run --package evaluations -- --dataset beir --slice beir-mix-600 --ingest-chunks-only
|
||||
```
|
||||
|
||||
## Slices
|
||||
|
||||
Slices are predefined, reproducible subsets defined in `manifest.yaml`. Each slice specifies:
|
||||
|
||||
- **limit**: Number of questions
|
||||
- **corpus_limit**: Maximum corpus size
|
||||
- **seed**: Fixed RNG seed for reproducibility
|
||||
|
||||
View available slices in [manifest.yaml](./manifest.yaml).
|
||||
|
||||
## Reports
|
||||
|
||||
Evaluations generate reports in `reports/`:
|
||||
|
||||
- **JSON**: Full structured results (`*-report.json`)
|
||||
- **Markdown**: Human-readable summary with sample mismatches (`*-report.md`)
|
||||
- **History**: Timestamped run history (`history/`)
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
```bash
|
||||
# Log per-stage performance timings
|
||||
cargo run --package evaluations -- --perf-log-console
|
||||
|
||||
# Save telemetry to file
|
||||
cargo run --package evaluations -- --perf-log-json ./perf.json
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
See [../LICENSE](../LICENSE).
|
||||
See [REFACTOR.md](./REFACTOR.md) for architecture notes.
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
# Evaluations crate refactor plan
|
||||
|
||||
This document records the architecture review and the simplification work applied to the
|
||||
`evaluations` crate. **No backwards compatibility** is maintained for converted JSON layouts,
|
||||
legacy report history, or old cache artifact formats.
|
||||
|
||||
## Goals
|
||||
|
||||
- Smaller, linear pipeline (no state machine ceremony)
|
||||
- Sharded converted store for **all** datasets (memory-efficient partial loading)
|
||||
- Slice-first loading when a catalog slice is selected
|
||||
- In-memory SurrealDB for ingestion (no ephemeral server namespaces)
|
||||
- Single DB lifecycle module (`db/`)
|
||||
- CLI helpers under `cli/`
|
||||
|
||||
## Primary workflow
|
||||
|
||||
```bash
|
||||
# One-time prep (converts raw data if needed, builds slice ledger, corpus cache, DB seed)
|
||||
cargo eval --warm --dataset beir --slice beir-mix-600
|
||||
|
||||
# Check readiness
|
||||
cargo eval --status --dataset beir --slice beir-mix-600
|
||||
|
||||
# Steady-state benchmark
|
||||
cargo eval --dataset beir --slice beir-mix-600 --require-ready
|
||||
```
|
||||
|
||||
Default dataset is `beir`. Chunk-only ingestion is the default; pass `--include-entities` to
|
||||
opt into entity extraction (requires `OPENAI_API_KEY`). Slice tuning such as
|
||||
`negative_multiplier` lives in `manifest.yaml` (e.g. `beir-mix-600` uses `9.0`).
|
||||
|
||||
## Cache layers (after refactor)
|
||||
|
||||
| Layer | Location | Purpose |
|
||||
|-------|----------|---------|
|
||||
| Converted store | `data/converted/<name>/` | Sharded paragraphs + question catalog |
|
||||
| Slice ledger | `cache/slices/<dataset>/<slice-id>.json` | Deterministic questions + paragraph set |
|
||||
| Corpus cache | `cache/ingested/<dataset>/<slice-id>/` | Ingestion paragraph shards, manifest, and namespace reuse seed |
|
||||
|
||||
Namespace reuse state lives in the corpus manifest (`metadata.namespace_seed`), not a separate
|
||||
`snapshots/` tree. After upgrading, delete old `*-minne.json` monolithic files, any
|
||||
`cache/snapshots/` directories, and re-run `--warm`.
|
||||
|
||||
## Phases applied
|
||||
|
||||
### Phase 0 — dead code
|
||||
|
||||
- Removed unused `criterion` dependency
|
||||
- Removed unused `EmbeddingCache`
|
||||
- Updated README for current CLI
|
||||
|
||||
### Phase 1 — structure
|
||||
|
||||
- Flattened pipeline to linear `async fn` stages
|
||||
- Removed `eval.rs` hub; imports go to owning modules
|
||||
- Merged `namespace.rs`, `db_helpers.rs` → `db/`; dropped standalone `snapshot.rs`
|
||||
- Moved `status.rs` → `cli/status.rs`
|
||||
- Fixed catalog slice bootstrap (build ledger when explicit slice manifest is missing)
|
||||
|
||||
### Phase 2 — no legacy paths
|
||||
|
||||
- All datasets use sharded converted store only
|
||||
- Removed legacy JSON layout and migration
|
||||
- Removed legacy report history format
|
||||
- Auto-apply first catalog slice when `--slice` omitted
|
||||
- Namespace seed folded into corpus manifest (removed `cache/snapshots/`)
|
||||
|
||||
### Phase 3 — performance
|
||||
|
||||
- Ingestion always uses in-memory SurrealDB
|
||||
- Slice-first partial load when ledger is complete
|
||||
- Default catalog slice for dataset when `--slice` not passed
|
||||
- Split `slice/` into `mod.rs`, `build.rs`, and `beir.rs`
|
||||
|
||||
### Phase 4 — BEIR mix slice-first
|
||||
|
||||
- `beir` is a virtual mix: slice ledger references prefixed ids (`fever-…`, `fiqa-…`, …)
|
||||
- Conversion is **qrels-closed** per subset (only documents appearing in qrels, not full corpus)
|
||||
- Slice ledger is resolved for the requested `--slice` (catalog preset or custom id + `--limit`)
|
||||
- Only ledger paragraph ids are materialized into per-subset stores (`fever-minne/`, `fiqa-minne/`, …)
|
||||
- No monolithic `beir-minne/` merged store
|
||||
- Raw BEIR data lives in per-subset dirs under `data/raw/`; `data/raw/beir` is a catalog placeholder
|
||||
|
||||
## Do not re-introduce
|
||||
|
||||
- Monolithic `*-minne.json` converted files
|
||||
- Monolithic `beir-minne/` merged converted store (use per-subset stores + virtual mix loader)
|
||||
- `state-machines` pipeline for this linear flow
|
||||
- `eval.rs` re-export hub
|
||||
- Legacy history migration in reports
|
||||
- Ephemeral `ingest_eval_*` namespaces on the shared SurrealDB server
|
||||
- Separate `cache/snapshots/` namespace state files
|
||||
|
||||
## Open follow-ups
|
||||
|
||||
- Generate `DatasetKind` from `manifest.yaml` at build time
|
||||
- Split `report.rs` when touching reporting again
|
||||
@@ -1,4 +1,4 @@
|
||||
default_dataset: squad-v2
|
||||
default_dataset: beir
|
||||
datasets:
|
||||
- id: squad-v2
|
||||
label: "SQuAD v2.0"
|
||||
@@ -45,6 +45,7 @@ datasets:
|
||||
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR"
|
||||
limit: 600
|
||||
corpus_limit: 6000
|
||||
negative_multiplier: 9.0
|
||||
seed: 0x5eed2025
|
||||
- id: fever
|
||||
label: "FEVER (BEIR)"
|
||||
|
||||
+98
-51
@@ -3,7 +3,7 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use clap::{Args, Parser, ValueEnum};
|
||||
|
||||
use crate::datasets::DatasetKind;
|
||||
@@ -137,9 +137,9 @@ pub struct IngestConfig {
|
||||
#[arg(long, default_value_t = 50)]
|
||||
pub ingest_chunk_overlap_tokens: usize,
|
||||
|
||||
/// Run ingestion in chunk-only mode (skip analyzer/graph generation)
|
||||
/// Include entity extraction and graph generation during ingestion (uses LLM tokens)
|
||||
#[arg(long)]
|
||||
pub ingest_chunks_only: bool,
|
||||
pub include_entities: bool,
|
||||
|
||||
/// Number of paragraphs to ingest concurrently
|
||||
#[arg(long, default_value_t = 10)]
|
||||
@@ -159,6 +159,7 @@ pub struct IngestConfig {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Args)]
|
||||
#[allow(clippy::struct_field_names)]
|
||||
pub struct DatabaseArgs {
|
||||
/// `SurrealDB` server endpoint
|
||||
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
|
||||
@@ -179,10 +180,6 @@ pub struct DatabaseArgs {
|
||||
/// Override the database used on the `SurrealDB` server
|
||||
#[arg(long, env = "EVAL_DB_DATABASE")]
|
||||
pub db_database: Option<String>,
|
||||
|
||||
/// Path to inspect DB state
|
||||
#[arg(long)]
|
||||
pub inspect_db_state: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@@ -233,10 +230,6 @@ pub struct Config {
|
||||
#[arg(long, default_value_t = 5)]
|
||||
pub sample: usize,
|
||||
|
||||
/// Disable context cropping when converting datasets (ingest entire documents)
|
||||
#[arg(long)]
|
||||
pub full_context: bool,
|
||||
|
||||
#[command(flatten)]
|
||||
pub retrieval: RetrievalSettings,
|
||||
|
||||
@@ -322,6 +315,18 @@ pub struct Config {
|
||||
#[command(flatten)]
|
||||
pub database: DatabaseArgs,
|
||||
|
||||
/// Require warmed corpus/namespace before running queries
|
||||
#[arg(long)]
|
||||
pub require_ready: bool,
|
||||
|
||||
/// Prepare converted data, slice, corpus, and namespace without running queries
|
||||
#[arg(long, conflicts_with = "status")]
|
||||
pub warm: bool,
|
||||
|
||||
/// Print readiness of converted data, slice, corpus, and namespace
|
||||
#[arg(long, conflicts_with = "warm")]
|
||||
pub status: bool,
|
||||
|
||||
// Computed fields (not arguments)
|
||||
#[arg(skip)]
|
||||
pub raw_dataset_path: PathBuf,
|
||||
@@ -334,11 +339,6 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
#[allow(clippy::unused_self)]
|
||||
pub fn context_token_limit(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn finalize(&mut self) -> Result<()> {
|
||||
// Handle dataset paths
|
||||
@@ -367,9 +367,7 @@ impl Config {
|
||||
// Handle retrieval settings
|
||||
self.retrieval.require_verified_chunks = !self.llm_mode;
|
||||
|
||||
if self.dataset == DatasetKind::Beir {
|
||||
self.negative_multiplier = 9.0;
|
||||
}
|
||||
self.apply_catalog_slice_defaults()?;
|
||||
|
||||
// Validations
|
||||
if self.ingest.ingest_chunk_min_tokens == 0
|
||||
@@ -396,26 +394,26 @@ impl Config {
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(k) = self.retrieval.chunk_rrf_k {
|
||||
if k <= 0.0 || !k.is_finite() {
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-k must be a positive, finite number (got {k})"
|
||||
));
|
||||
}
|
||||
if let Some(k) = self.retrieval.chunk_rrf_k
|
||||
&& (k <= 0.0 || !k.is_finite())
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-k must be a positive, finite number (got {k})"
|
||||
));
|
||||
}
|
||||
if let Some(weight) = self.retrieval.chunk_rrf_vector_weight {
|
||||
if weight < 0.0 || !weight.is_finite() {
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-vector-weight must be a non-negative, finite number (got {weight})"
|
||||
));
|
||||
}
|
||||
if let Some(weight) = self.retrieval.chunk_rrf_vector_weight
|
||||
&& (weight < 0.0 || !weight.is_finite())
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-vector-weight must be a non-negative, finite number (got {weight})"
|
||||
));
|
||||
}
|
||||
if let Some(weight) = self.retrieval.chunk_rrf_fts_weight {
|
||||
if weight < 0.0 || !weight.is_finite() {
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-fts-weight must be a non-negative, finite number (got {weight})"
|
||||
));
|
||||
}
|
||||
if let Some(weight) = self.retrieval.chunk_rrf_fts_weight
|
||||
&& (weight < 0.0 || !weight.is_finite())
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-fts-weight must be a non-negative, finite number (got {weight})"
|
||||
));
|
||||
}
|
||||
|
||||
if self.concurrency == 0 {
|
||||
@@ -428,16 +426,16 @@ impl Config {
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(query_model) = &self.query_model {
|
||||
if query_model.trim().is_empty() {
|
||||
return Err(anyhow!("--query-model requires a non-empty model name"));
|
||||
}
|
||||
if let Some(query_model) = &self.query_model
|
||||
&& query_model.trim().is_empty()
|
||||
{
|
||||
return Err(anyhow!("--query-model requires a non-empty model name"));
|
||||
}
|
||||
|
||||
if let Some(grow) = self.slice_grow {
|
||||
if grow == 0 {
|
||||
return Err(anyhow!("--slice-grow must be greater than zero"));
|
||||
}
|
||||
if let Some(grow) = self.slice_grow
|
||||
&& grow == 0
|
||||
{
|
||||
return Err(anyhow!("--slice-grow must be greater than zero"));
|
||||
}
|
||||
|
||||
if self.negative_multiplier <= 0.0 || !self.negative_multiplier.is_finite() {
|
||||
@@ -467,16 +465,65 @@ impl Config {
|
||||
}
|
||||
|
||||
// Handle perf log dir env var fallback
|
||||
if self.perf_log_dir.is_none() {
|
||||
if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") {
|
||||
if !dir.trim().is_empty() {
|
||||
self.perf_log_dir = Some(PathBuf::from(dir));
|
||||
}
|
||||
}
|
||||
if self.perf_log_dir.is_none()
|
||||
&& let Ok(dir) = env::var("EVAL_PERF_LOG_DIR")
|
||||
&& !dir.trim().is_empty()
|
||||
{
|
||||
self.perf_log_dir = Some(PathBuf::from(dir));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_catalog_slice_defaults(&mut self) -> Result<()> {
|
||||
let catalog = crate::datasets::catalog()?;
|
||||
let entry = catalog.dataset(self.dataset.id())?;
|
||||
|
||||
if self.slice.is_none()
|
||||
&& let Some(default_slice) = entry.slices.first()
|
||||
{
|
||||
self.slice = Some(default_slice.id.clone());
|
||||
}
|
||||
|
||||
let Some(slice_id) = self.slice.as_deref() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let Ok((_, slice)) = catalog.slice(slice_id) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if slice.dataset_id != self.dataset.id() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(limit) = slice.limit
|
||||
&& self.limit_arg == 200
|
||||
{
|
||||
self.limit_arg = limit;
|
||||
self.limit = Some(limit);
|
||||
}
|
||||
if self.corpus_limit.is_none() {
|
||||
self.corpus_limit = slice.corpus_limit;
|
||||
}
|
||||
if let Some(seed) = slice.seed {
|
||||
self.slice_seed = seed;
|
||||
}
|
||||
if let Some(include_unanswerable) = slice.include_unanswerable {
|
||||
self.llm_mode = include_unanswerable;
|
||||
self.retrieval.require_verified_chunks = !include_unanswerable;
|
||||
}
|
||||
if let Some(multiplier) = slice.negative_multiplier
|
||||
&& negative_multiplier_is_default(self.negative_multiplier)
|
||||
{
|
||||
self.negative_multiplier = multiplier;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn negative_multiplier_is_default(value: f32) -> bool {
|
||||
(value - crate::slice::DEFAULT_NEGATIVE_MULTIPLIER).abs() < f32::EPSILON
|
||||
}
|
||||
|
||||
pub struct ParsedArgs {
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct EmbeddingCacheData {
|
||||
entities: HashMap<String, Vec<f32>>,
|
||||
chunks: HashMap<String, Vec<f32>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingCache {
|
||||
path: Arc<Path>,
|
||||
data: Arc<Mutex<EmbeddingCacheData>>,
|
||||
dirty: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl EmbeddingCache {
|
||||
pub async fn load(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref().to_path_buf();
|
||||
let data = if path.exists() {
|
||||
let raw = tokio::fs::read(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading embedding cache {}", path.display()))?;
|
||||
serde_json::from_slice(&raw)
|
||||
.with_context(|| format!("parsing embedding cache {}", path.display()))?
|
||||
} else {
|
||||
EmbeddingCacheData::default()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
path: Arc::from(path.as_path()),
|
||||
data: Arc::new(Mutex::new(data)),
|
||||
dirty: Arc::new(AtomicBool::new(false)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_entity(&self, id: &str) -> Option<Vec<f32>> {
|
||||
let guard = self.data.lock().await;
|
||||
guard.entities.get(id).cloned()
|
||||
}
|
||||
|
||||
pub async fn insert_entity(&self, id: String, embedding: Vec<f32>) {
|
||||
let mut guard = self.data.lock().await;
|
||||
guard.entities.insert(id, embedding);
|
||||
self.dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn get_chunk(&self, id: &str) -> Option<Vec<f32>> {
|
||||
let guard = self.data.lock().await;
|
||||
guard.chunks.get(id).cloned()
|
||||
}
|
||||
|
||||
pub async fn insert_chunk(&self, id: String, embedding: Vec<f32>) {
|
||||
let mut guard = self.data.lock().await;
|
||||
guard.chunks.insert(id, embedding);
|
||||
self.dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> Result<()> {
|
||||
if !self.dirty.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let guard = self.data.lock().await;
|
||||
let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?;
|
||||
if let Some(parent) = self.path.parent() {
|
||||
tokio::fs::create_dir_all(parent)
|
||||
.await
|
||||
.with_context(|| format!("creating cache directory {}", parent.display()))?;
|
||||
}
|
||||
tokio::fs::write(&*self.path, body)
|
||||
.await
|
||||
.with_context(|| format!("writing embedding cache {}", self.path.display()))?;
|
||||
self.dirty.store(false, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -156,6 +156,7 @@ mod tests {
|
||||
chunk_min_tokens: 1,
|
||||
chunk_max_tokens: 10,
|
||||
chunk_only: false,
|
||||
namespace_seed: None,
|
||||
},
|
||||
paragraphs,
|
||||
questions,
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
pub mod status;
|
||||
|
||||
pub use status::{collect_status, ensure_query_ready, print_status, warm};
|
||||
@@ -0,0 +1,311 @@
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::{
|
||||
args::Config,
|
||||
corpus::{self, CorpusCacheConfig},
|
||||
datasets::{
|
||||
ConvertedLayout, DatasetKind, beir_subset_store_summary, beir_subset_stores_ready,
|
||||
content_checksum_for_layout, detect_layout, mix_content_checksum, store_dir_for,
|
||||
},
|
||||
db::{connect_eval_db, default_database, default_namespace, namespace_has_corpus},
|
||||
slice::{self, ledger_target},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct EvalStatus {
|
||||
pub dataset: String,
|
||||
pub slice: Option<String>,
|
||||
pub converted: ConvertedStatus,
|
||||
pub slice_ledger: SliceLedgerStatus,
|
||||
pub corpus_cache: CorpusCacheStatus,
|
||||
pub namespace: NamespaceStatus,
|
||||
pub query_ready: bool,
|
||||
pub notes: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ConvertedStatus {
|
||||
pub layout: String,
|
||||
pub path: String,
|
||||
pub ready: bool,
|
||||
pub partial_load_eligible: bool,
|
||||
pub checksum: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct SliceLedgerStatus {
|
||||
pub ready: bool,
|
||||
pub path: Option<String>,
|
||||
pub cases: Option<usize>,
|
||||
pub positives: Option<usize>,
|
||||
pub negatives: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct CorpusCacheStatus {
|
||||
pub ready: bool,
|
||||
pub path: Option<String>,
|
||||
pub manifest_present: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct NamespaceStatus {
|
||||
pub namespace: String,
|
||||
pub database: String,
|
||||
pub seeded: bool,
|
||||
pub namespace_seed_recorded: bool,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn collect_status(config: &Config) -> Result<EvalStatus> {
|
||||
let mut notes = Vec::new();
|
||||
let is_beir_mix = config.dataset == DatasetKind::Beir;
|
||||
let converted_path = &config.converted_dataset_path;
|
||||
let layout = if is_beir_mix {
|
||||
ConvertedLayout::Missing
|
||||
} else {
|
||||
detect_layout(converted_path)
|
||||
};
|
||||
let layout_label = if is_beir_mix {
|
||||
"beir-mix-subset-stores"
|
||||
} else {
|
||||
match layout {
|
||||
ConvertedLayout::ShardedStore => "sharded-store",
|
||||
ConvertedLayout::Missing => "missing",
|
||||
}
|
||||
};
|
||||
|
||||
let store_dir = store_dir_for(converted_path);
|
||||
let display_path = if is_beir_mix {
|
||||
beir_subset_store_summary()?
|
||||
.into_iter()
|
||||
.map(|(subset, paragraphs, questions)| {
|
||||
format!("{subset}-minne ({paragraphs} paragraphs, {questions} questions)")
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("; ")
|
||||
} else {
|
||||
store_dir.display().to_string()
|
||||
};
|
||||
|
||||
let manifest_path = slice::cached_manifest_path(config);
|
||||
let slice_config = slice::slice_config_with_limit(config, ledger_target(config));
|
||||
let slice_manifest = manifest_path
|
||||
.as_ref()
|
||||
.and_then(|path| slice::read_manifest_if_exists(path).ok().flatten());
|
||||
|
||||
let slice_ledger = SliceLedgerStatus {
|
||||
ready: slice_manifest
|
||||
.as_ref()
|
||||
.is_some_and(|manifest| slice::manifest_is_complete(manifest, &slice_config)),
|
||||
path: manifest_path
|
||||
.as_ref()
|
||||
.map(|path| path.display().to_string()),
|
||||
cases: slice_manifest.as_ref().map(|manifest| manifest.case_count),
|
||||
positives: slice_manifest
|
||||
.as_ref()
|
||||
.map(|manifest| manifest.positive_paragraphs),
|
||||
negatives: slice_manifest
|
||||
.as_ref()
|
||||
.map(|manifest| manifest.negative_paragraphs),
|
||||
};
|
||||
|
||||
let beir_paragraph_ids = slice_manifest.as_ref().map(|manifest| {
|
||||
manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|entry| entry.id.clone())
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
});
|
||||
|
||||
let converted_ready = if is_beir_mix {
|
||||
slice_ledger.ready
|
||||
&& beir_paragraph_ids
|
||||
.as_ref()
|
||||
.is_some_and(|ids| beir_subset_stores_ready(ids).unwrap_or(false))
|
||||
} else {
|
||||
layout == ConvertedLayout::ShardedStore
|
||||
};
|
||||
|
||||
let checksum = if is_beir_mix {
|
||||
beir_paragraph_ids
|
||||
.as_ref()
|
||||
.and_then(|ids| mix_content_checksum(ids).ok())
|
||||
} else if layout == ConvertedLayout::ShardedStore {
|
||||
content_checksum_for_layout(converted_path).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let partial_load_eligible = slice_ledger.ready && config.slice.is_some();
|
||||
|
||||
let corpus_cache = if let Some(manifest) = slice_manifest.as_ref() {
|
||||
let cache_settings = CorpusCacheConfig::from(config);
|
||||
let base_dir = corpus::cached_corpus_dir(
|
||||
&cache_settings,
|
||||
config.dataset.id(),
|
||||
manifest.slice_id.as_str(),
|
||||
);
|
||||
let manifest_present = corpus::load_cached_manifest(&base_dir)?.is_some();
|
||||
CorpusCacheStatus {
|
||||
ready: manifest_present,
|
||||
path: Some(base_dir.display().to_string()),
|
||||
manifest_present,
|
||||
}
|
||||
} else {
|
||||
CorpusCacheStatus {
|
||||
ready: false,
|
||||
path: None,
|
||||
manifest_present: false,
|
||||
}
|
||||
};
|
||||
|
||||
let namespace = config.database.db_namespace.clone().unwrap_or_else(|| {
|
||||
default_namespace(config.dataset.id(), config.limit, config.slice.as_deref())
|
||||
});
|
||||
let database = config
|
||||
.database
|
||||
.db_database
|
||||
.clone()
|
||||
.unwrap_or_else(default_database);
|
||||
|
||||
let namespace_seed = corpus_cache.path.as_ref().and_then(|path| {
|
||||
corpus::load_cached_manifest(Path::new(path))
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|manifest| manifest.metadata.namespace_seed)
|
||||
});
|
||||
|
||||
let (seeded, namespace_seed_recorded) =
|
||||
match connect_eval_db(config, &namespace, &database).await {
|
||||
Ok(db) => {
|
||||
let has_corpus = namespace_has_corpus(&db).await.unwrap_or(false);
|
||||
(has_corpus, namespace_seed.is_some())
|
||||
}
|
||||
Err(err) => {
|
||||
notes.push(format!("SurrealDB unavailable: {err}"));
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
let query_ready = converted_ready
|
||||
&& slice_ledger.ready
|
||||
&& corpus_cache.ready
|
||||
&& seeded
|
||||
&& namespace_seed_recorded;
|
||||
|
||||
if !query_ready {
|
||||
notes.push("Run `cargo eval --warm --slice <id>` to prepare corpus and namespace.".into());
|
||||
}
|
||||
|
||||
Ok(EvalStatus {
|
||||
dataset: config.dataset.id().to_string(),
|
||||
slice: config.slice.clone(),
|
||||
converted: ConvertedStatus {
|
||||
layout: layout_label.to_string(),
|
||||
path: display_path,
|
||||
ready: converted_ready,
|
||||
partial_load_eligible,
|
||||
checksum,
|
||||
},
|
||||
slice_ledger,
|
||||
corpus_cache,
|
||||
namespace: NamespaceStatus {
|
||||
namespace,
|
||||
database,
|
||||
seeded,
|
||||
namespace_seed_recorded,
|
||||
},
|
||||
query_ready,
|
||||
notes,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn print_status(status: &EvalStatus) {
|
||||
println!("Evaluation status for dataset `{}`", status.dataset);
|
||||
if let Some(slice) = &status.slice {
|
||||
println!("Slice: {slice}");
|
||||
}
|
||||
println!(
|
||||
"Converted: {} ({})",
|
||||
if status.converted.ready {
|
||||
"ready"
|
||||
} else {
|
||||
"missing"
|
||||
},
|
||||
status.converted.layout
|
||||
);
|
||||
println!("Converted path: {}", status.converted.path);
|
||||
if status.converted.partial_load_eligible {
|
||||
println!("Slice-first loading: eligible");
|
||||
}
|
||||
println!(
|
||||
"Slice ledger: {}",
|
||||
if status.slice_ledger.ready {
|
||||
format!(
|
||||
"ready ({} cases, {} positives, {} negatives)",
|
||||
status.slice_ledger.cases.unwrap_or(0),
|
||||
status.slice_ledger.positives.unwrap_or(0),
|
||||
status.slice_ledger.negatives.unwrap_or(0)
|
||||
)
|
||||
} else {
|
||||
"missing or incomplete".to_string()
|
||||
}
|
||||
);
|
||||
if let Some(path) = &status.slice_ledger.path {
|
||||
println!("Slice ledger path: {path}");
|
||||
}
|
||||
println!(
|
||||
"Corpus cache: {}",
|
||||
if status.corpus_cache.ready {
|
||||
"ready"
|
||||
} else {
|
||||
"missing"
|
||||
}
|
||||
);
|
||||
if let Some(path) = &status.corpus_cache.path {
|
||||
println!("Corpus cache path: {path}");
|
||||
}
|
||||
println!(
|
||||
"Namespace `{}` / `{}`: seeded={}, namespace_seed_recorded={}",
|
||||
status.namespace.namespace,
|
||||
status.namespace.database,
|
||||
status.namespace.seeded,
|
||||
status.namespace.namespace_seed_recorded
|
||||
);
|
||||
println!(
|
||||
"Query-ready: {}",
|
||||
if status.query_ready { "yes" } else { "no" }
|
||||
);
|
||||
for note in &status.notes {
|
||||
println!("Note: {note}");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn warm(config: &Config) -> Result<()> {
|
||||
let loaded =
|
||||
crate::datasets::prepare_dataset(config.dataset, config).context("preparing dataset")?;
|
||||
crate::pipeline::warm_evaluation(&loaded.dataset, config, &loaded.content_checksum)
|
||||
.await
|
||||
.context("warming evaluation corpus and namespace")?;
|
||||
let status = collect_status(config).await?;
|
||||
print_status(&status);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn ensure_query_ready(config: &Config) -> Result<()> {
|
||||
let status = collect_status(config).await?;
|
||||
if status.query_ready {
|
||||
return Ok(());
|
||||
}
|
||||
print_status(&status);
|
||||
anyhow::bail!(
|
||||
"evaluation is not query-ready; run `cargo eval --warm --slice {}` first",
|
||||
config.slice.as_deref().unwrap_or("<slice-id>")
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
#![allow(clippy::arithmetic_side_effects)]
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
|
||||
use crate::types::EvaluationCandidate;
|
||||
|
||||
const TOKENIZER_LABEL: &str = "estimated (~chars/4; ingestion uses bert-base-cased)";
|
||||
|
||||
#[allow(clippy::struct_field_names)]
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct RetrievedContextStats {
|
||||
pub chunk_count: usize,
|
||||
pub char_count: usize,
|
||||
pub token_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct RetrievalContextStats {
|
||||
pub tokenizer: String,
|
||||
pub queries: usize,
|
||||
pub total_chunks: usize,
|
||||
pub total_chars: usize,
|
||||
pub total_tokens: usize,
|
||||
pub avg_chunks_per_query: f64,
|
||||
pub avg_chars_per_query: f64,
|
||||
pub avg_tokens_per_query: f64,
|
||||
pub p50_tokens_per_query: usize,
|
||||
pub p95_tokens_per_query: usize,
|
||||
pub max_tokens_per_query: usize,
|
||||
}
|
||||
|
||||
pub fn stats_for_candidates(candidates: &[EvaluationCandidate]) -> RetrievedContextStats {
|
||||
let mut seen_chunk_ids = std::collections::HashSet::new();
|
||||
let mut stats = RetrievedContextStats::default();
|
||||
|
||||
for candidate in candidates {
|
||||
for chunk in &candidate.chunks {
|
||||
let chunk_id = chunk.chunk.id().to_string();
|
||||
if !seen_chunk_ids.insert(chunk_id) {
|
||||
continue;
|
||||
}
|
||||
let text = chunk.chunk.chunk.as_str();
|
||||
stats.chunk_count += 1;
|
||||
stats.char_count += text.chars().count();
|
||||
stats.token_count += estimate_ingestion_tokens(text);
|
||||
}
|
||||
}
|
||||
|
||||
stats
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
pub fn aggregate_context_stats(per_query: &[RetrievedContextStats]) -> RetrievalContextStats {
|
||||
let queries = per_query.len();
|
||||
if queries == 0 {
|
||||
return RetrievalContextStats {
|
||||
tokenizer: TOKENIZER_LABEL.to_string(),
|
||||
queries: 0,
|
||||
total_chunks: 0,
|
||||
total_chars: 0,
|
||||
total_tokens: 0,
|
||||
avg_chunks_per_query: 0.0,
|
||||
avg_chars_per_query: 0.0,
|
||||
avg_tokens_per_query: 0.0,
|
||||
p50_tokens_per_query: 0,
|
||||
p95_tokens_per_query: 0,
|
||||
max_tokens_per_query: 0,
|
||||
};
|
||||
}
|
||||
|
||||
let total_chunks: usize = per_query.iter().map(|stats| stats.chunk_count).sum();
|
||||
let total_chars: usize = per_query.iter().map(|stats| stats.char_count).sum();
|
||||
let total_tokens: usize = per_query.iter().map(|stats| stats.token_count).sum();
|
||||
let mut tokens_per_query: Vec<usize> =
|
||||
per_query.iter().map(|stats| stats.token_count).collect();
|
||||
tokens_per_query.sort_unstable();
|
||||
let max_tokens_per_query = *tokens_per_query.last().unwrap_or(&0);
|
||||
|
||||
let total_chunks_f = total_chunks as f64;
|
||||
let total_chars_f = total_chars as f64;
|
||||
let total_tokens_f = total_tokens as f64;
|
||||
let queries_f = queries as f64;
|
||||
let avg_chunks_per_query = total_chunks_f / queries_f;
|
||||
let avg_chars_per_query = total_chars_f / queries_f;
|
||||
let avg_tokens_per_query = total_tokens_f / queries_f;
|
||||
|
||||
RetrievalContextStats {
|
||||
tokenizer: TOKENIZER_LABEL.to_string(),
|
||||
queries,
|
||||
total_chunks,
|
||||
total_chars,
|
||||
total_tokens,
|
||||
avg_chunks_per_query,
|
||||
avg_chars_per_query,
|
||||
avg_tokens_per_query,
|
||||
p50_tokens_per_query: percentile_usize(&tokens_per_query, 0.50),
|
||||
p95_tokens_per_query: percentile_usize(&tokens_per_query, 0.95),
|
||||
max_tokens_per_query,
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_ingestion_tokens(text: &str) -> usize {
|
||||
let chars = text.chars().count();
|
||||
if chars == 0 {
|
||||
return 0;
|
||||
}
|
||||
chars.div_ceil(4)
|
||||
}
|
||||
|
||||
#[allow(
|
||||
clippy::cast_precision_loss,
|
||||
clippy::cast_sign_loss,
|
||||
clippy::cast_possible_truncation,
|
||||
clippy::indexing_slicing,
|
||||
clippy::arithmetic_side_effects
|
||||
)]
|
||||
fn percentile_usize(sorted: &[usize], fraction: f64) -> usize {
|
||||
if sorted.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let clamped = fraction.clamp(0.0, 1.0);
|
||||
let index = ((sorted.len() - 1) as f64 * clamped).round() as usize;
|
||||
sorted[index.min(sorted.len() - 1)]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use common::storage::types::text_chunk::TextChunk;
|
||||
use retrieval_pipeline::RetrievedChunk;
|
||||
|
||||
#[test]
|
||||
fn deduplicates_chunks_when_counting_context() {
|
||||
let shared = Arc::new(TextChunk::new(
|
||||
"src".into(),
|
||||
"hello world".into(),
|
||||
"user".into(),
|
||||
));
|
||||
let candidates = vec![
|
||||
EvaluationCandidate {
|
||||
entity_id: "a".into(),
|
||||
source_id: "src".into(),
|
||||
entity_name: "A".into(),
|
||||
entity_description: None,
|
||||
entity_category: None,
|
||||
score: 1.0,
|
||||
chunks: vec![RetrievedChunk {
|
||||
chunk: Arc::clone(&shared),
|
||||
score: 1.0,
|
||||
}],
|
||||
},
|
||||
EvaluationCandidate {
|
||||
entity_id: "b".into(),
|
||||
source_id: "src".into(),
|
||||
entity_name: "B".into(),
|
||||
entity_description: None,
|
||||
entity_category: None,
|
||||
score: 0.9,
|
||||
chunks: vec![RetrievedChunk {
|
||||
chunk: shared,
|
||||
score: 0.9,
|
||||
}],
|
||||
},
|
||||
];
|
||||
let stats = stats_for_candidates(&candidates);
|
||||
assert_eq!(stats.chunk_count, 1);
|
||||
assert_eq!(stats.char_count, "hello world".chars().count());
|
||||
assert_eq!(stats.token_count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aggregates_per_query_token_totals() {
|
||||
let per_query = vec![
|
||||
RetrievedContextStats {
|
||||
chunk_count: 2,
|
||||
char_count: 100,
|
||||
token_count: 40,
|
||||
},
|
||||
RetrievedContextStats {
|
||||
chunk_count: 5,
|
||||
char_count: 250,
|
||||
token_count: 100,
|
||||
},
|
||||
];
|
||||
let aggregate = aggregate_context_stats(&per_query);
|
||||
assert_eq!(aggregate.queries, 2);
|
||||
assert_eq!(aggregate.total_chunks, 7);
|
||||
assert_eq!(aggregate.total_tokens, 140);
|
||||
assert_eq!(aggregate.max_tokens_per_query, 100);
|
||||
assert!((aggregate.avg_tokens_per_query - 70.0).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
@@ -11,32 +11,14 @@ pub struct CorpusCacheConfig {
|
||||
pub ingestion_max_retries: usize,
|
||||
}
|
||||
|
||||
impl CorpusCacheConfig {
|
||||
pub fn new(
|
||||
ingestion_cache_dir: impl Into<PathBuf>,
|
||||
force_refresh: bool,
|
||||
refresh_embeddings_only: bool,
|
||||
ingestion_batch_size: usize,
|
||||
ingestion_max_retries: usize,
|
||||
) -> Self {
|
||||
impl From<&Config> for CorpusCacheConfig {
|
||||
fn from(config: &Config) -> Self {
|
||||
Self {
|
||||
ingestion_cache_dir: ingestion_cache_dir.into(),
|
||||
force_refresh,
|
||||
refresh_embeddings_only,
|
||||
ingestion_batch_size,
|
||||
ingestion_max_retries,
|
||||
ingestion_cache_dir: config.ingest.ingestion_cache_dir.clone(),
|
||||
force_refresh: config.force_convert || config.ingest.slice_reset_ingestion,
|
||||
refresh_embeddings_only: config.ingest.refresh_embeddings_only,
|
||||
ingestion_batch_size: config.ingest.ingestion_batch_size,
|
||||
ingestion_max_retries: config.ingest.ingestion_max_retries,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Config> for CorpusCacheConfig {
|
||||
fn from(config: &Config) -> Self {
|
||||
CorpusCacheConfig::new(
|
||||
config.ingest.ingestion_cache_dir.clone(),
|
||||
config.force_convert || config.ingest.slice_reset_ingestion,
|
||||
config.ingest.refresh_embeddings_only,
|
||||
config.ingest.ingestion_batch_size,
|
||||
config.ingest.ingestion_max_retries,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,12 @@ pub(crate) mod store;
|
||||
pub use config::CorpusCacheConfig;
|
||||
pub use orchestrator::{
|
||||
cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus,
|
||||
load_cached_manifest,
|
||||
load_cached_manifest, persist_corpus_manifest,
|
||||
};
|
||||
pub use store::{
|
||||
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
|
||||
CorpusQuestion, ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||
CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion, MANIFEST_VERSION,
|
||||
NamespaceSeedRecord, ParagraphShard, ParagraphShardStore, seed_manifest_into_db,
|
||||
window_manifest,
|
||||
};
|
||||
|
||||
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
|
||||
@@ -20,6 +21,6 @@ pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline
|
||||
chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
||||
..Default::default()
|
||||
},
|
||||
chunk_only: config.ingest.ingest_chunks_only,
|
||||
chunk_only: !config.ingest.include_entities,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,16 +6,14 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use async_openai::Client;
|
||||
use chrono::Utc;
|
||||
#[cfg(not(test))]
|
||||
use common::utils::config::get_config;
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
store::{DynStorage, StorageManager},
|
||||
types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject},
|
||||
types::{StoredObject, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask},
|
||||
},
|
||||
utils::config::{AppConfig, StorageKind},
|
||||
};
|
||||
@@ -33,7 +31,7 @@ use crate::{
|
||||
|
||||
use crate::corpus::{
|
||||
CorpusCacheConfig, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
|
||||
ParagraphShard, ParagraphShardStore, MANIFEST_VERSION,
|
||||
MANIFEST_VERSION, ParagraphShard, ParagraphShardStore,
|
||||
};
|
||||
|
||||
const INGESTION_SPEC_VERSION: u32 = 2;
|
||||
@@ -125,10 +123,14 @@ pub async fn ensure_corpus(
|
||||
openai: Arc<OpenAIClient>,
|
||||
user_id: &str,
|
||||
converted_path: &Path,
|
||||
precomputed_checksum: Option<&str>,
|
||||
ingestion_config: IngestionConfig,
|
||||
) -> Result<CorpusHandle> {
|
||||
let checksum = compute_file_checksum(converted_path)
|
||||
.with_context(|| format!("computing checksum for {}", converted_path.display()))?;
|
||||
let checksum = match precomputed_checksum {
|
||||
Some(value) => value.to_string(),
|
||||
None => crate::datasets::content_checksum_for_layout(converted_path)
|
||||
.with_context(|| format!("computing checksum for {}", converted_path.display()))?,
|
||||
};
|
||||
let ingestion_fingerprint =
|
||||
build_ingestion_fingerprint(dataset, slice, &checksum, &ingestion_config);
|
||||
|
||||
@@ -381,6 +383,7 @@ pub async fn ensure_corpus(
|
||||
chunk_min_tokens: ingestion_config.tuning.chunk_min_tokens,
|
||||
chunk_max_tokens: ingestion_config.tuning.chunk_max_tokens,
|
||||
chunk_only: ingestion_config.chunk_only,
|
||||
namespace_seed: None,
|
||||
},
|
||||
paragraphs: corpus_paragraphs,
|
||||
questions: corpus_questions,
|
||||
@@ -415,7 +418,7 @@ pub async fn ensure_corpus(
|
||||
negative_ingested: stats.negative_ingested,
|
||||
};
|
||||
|
||||
persist_manifest(&handle).context("persisting corpus manifest")?;
|
||||
persist_corpus_manifest(&handle).context("persisting corpus manifest")?;
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
@@ -501,7 +504,6 @@ async fn ingest_paragraph_batch(
|
||||
Ok(shards)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||
let db = SurrealDbClient::memory(namespace, "corpus")
|
||||
.await
|
||||
@@ -509,21 +511,6 @@ async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||
Ok(Arc::new(db))
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||
let config = get_config().context("loading app config for ingestion database")?;
|
||||
let db = SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
namespace,
|
||||
"corpus",
|
||||
)
|
||||
.await
|
||||
.context("creating surrealdb database for ingestion")?;
|
||||
Ok(Arc::new(db))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn ingest_single_paragraph(
|
||||
pipeline: Arc<IngestionPipeline>,
|
||||
@@ -631,8 +618,12 @@ pub fn compute_ingestion_fingerprint(
|
||||
slice: &ResolvedSlice<'_>,
|
||||
converted_path: &Path,
|
||||
ingestion_config: &IngestionConfig,
|
||||
precomputed_checksum: Option<&str>,
|
||||
) -> Result<String> {
|
||||
let checksum = compute_file_checksum(converted_path)?;
|
||||
let checksum = match precomputed_checksum {
|
||||
Some(value) => value.to_string(),
|
||||
None => crate::datasets::content_checksum_for_layout(converted_path)?,
|
||||
};
|
||||
Ok(build_ingestion_fingerprint(
|
||||
dataset,
|
||||
slice,
|
||||
@@ -641,7 +632,7 @@ pub fn compute_ingestion_fingerprint(
|
||||
))
|
||||
}
|
||||
|
||||
pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
|
||||
pub fn load_cached_manifest(base_dir: &std::path::Path) -> Result<Option<CorpusManifest>> {
|
||||
let path = base_dir.join("manifest.json");
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
@@ -656,7 +647,7 @@ pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
|
||||
Ok(Some(manifest))
|
||||
}
|
||||
|
||||
fn persist_manifest(handle: &CorpusHandle) -> Result<()> {
|
||||
pub fn persist_corpus_manifest(handle: &CorpusHandle) -> Result<()> {
|
||||
let path = handle.path.join("manifest.json");
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
@@ -685,24 +676,6 @@ pub fn corpus_handle_from_manifest(manifest: CorpusManifest, base_dir: PathBuf)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
fn compute_file_checksum(path: &Path) -> Result<String> {
|
||||
let mut file = fs::File::open(path)
|
||||
.with_context(|| format!("opening file {} for checksum", path.display()))?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buffer = [0u8; 8192];
|
||||
loop {
|
||||
let read = file
|
||||
.read(&mut buffer)
|
||||
.with_context(|| format!("reading {} for checksum", path.display()))?;
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buffer[..read]);
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -728,11 +701,7 @@ mod tests {
|
||||
|
||||
ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata: crate::datasets::DatasetMetadata::for_kind(
|
||||
DatasetKind::default(),
|
||||
false,
|
||||
None,
|
||||
),
|
||||
metadata: crate::datasets::DatasetMetadata::for_kind(DatasetKind::default(), false),
|
||||
source: "src".to_string(),
|
||||
paragraphs: vec![paragraph],
|
||||
}
|
||||
|
||||
+38
-288
@@ -5,35 +5,24 @@ use std::{
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use common::storage::types::StoredObject;
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
knowledge_relationship::{KnowledgeRelationship, RelationshipMetadata},
|
||||
text_chunk::TextChunk,
|
||||
text_chunk_embedding::TextChunkEmbedding,
|
||||
StoredObject, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
||||
text_content::TextContent,
|
||||
},
|
||||
};
|
||||
use ingestion_pipeline::{IngestionTuning, PipelineArtifacts, persist_artifacts};
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use surrealdb::sql::Thing;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
||||
|
||||
pub const MANIFEST_VERSION: u32 = 3;
|
||||
pub const PARAGRAPH_SHARD_VERSION: u32 = 3;
|
||||
const MANIFEST_BATCH_SIZE: usize = 100;
|
||||
const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches
|
||||
const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively
|
||||
const MAX_BATCHES_PER_REQUEST: usize = 24;
|
||||
const REQUEST_MAX_BYTES: usize = 800_000; // total payload cap per Surreal query request
|
||||
|
||||
fn current_manifest_version() -> u32 {
|
||||
MANIFEST_VERSION
|
||||
}
|
||||
@@ -51,7 +40,7 @@ fn default_chunk_max_tokens() -> usize {
|
||||
}
|
||||
|
||||
fn default_chunk_only() -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
|
||||
// Reuse the pipeline's canonical embedded-artifact types so the on-disk corpus
|
||||
@@ -131,6 +120,14 @@ pub struct CorpusManifest {
|
||||
pub questions: Vec<CorpusQuestion>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct NamespaceSeedRecord {
|
||||
pub namespace: String,
|
||||
pub database: String,
|
||||
pub slice_case_count: usize,
|
||||
pub seeded_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CorpusMetadata {
|
||||
pub dataset_id: String,
|
||||
@@ -153,6 +150,8 @@ pub struct CorpusMetadata {
|
||||
pub chunk_max_tokens: usize,
|
||||
#[serde(default = "default_chunk_only")]
|
||||
pub chunk_only: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub namespace_seed: Option<NamespaceSeedRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
@@ -251,130 +250,6 @@ pub fn window_manifest(
|
||||
Ok(narrowed)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct RelationInsert {
|
||||
#[serde(rename = "in")]
|
||||
pub in_: Thing,
|
||||
#[serde(rename = "out")]
|
||||
pub out: Thing,
|
||||
pub id: String,
|
||||
pub metadata: RelationshipMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SizedBatch<T> {
|
||||
approx_bytes: usize,
|
||||
items: Vec<T>,
|
||||
}
|
||||
|
||||
struct ManifestBatches {
|
||||
text_contents: Vec<SizedBatch<TextContent>>,
|
||||
entities: Vec<SizedBatch<KnowledgeEntity>>,
|
||||
entity_embeddings: Vec<SizedBatch<KnowledgeEntityEmbedding>>,
|
||||
relationships: Vec<SizedBatch<RelationInsert>>,
|
||||
chunks: Vec<SizedBatch<TextChunk>>,
|
||||
chunk_embeddings: Vec<SizedBatch<TextChunkEmbedding>>,
|
||||
}
|
||||
|
||||
fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches> {
|
||||
let mut text_contents = Vec::new();
|
||||
let mut entities = Vec::new();
|
||||
let mut entity_embeddings = Vec::new();
|
||||
let mut relationships = Vec::new();
|
||||
let mut chunks = Vec::new();
|
||||
let mut chunk_embeddings = Vec::new();
|
||||
|
||||
let mut seen_text_content = HashSet::new();
|
||||
let mut seen_entities = HashSet::new();
|
||||
let mut seen_relationships = HashSet::new();
|
||||
let mut seen_chunks = HashSet::new();
|
||||
|
||||
for paragraph in &manifest.paragraphs {
|
||||
if seen_text_content.insert(paragraph.text_content.id.clone()) {
|
||||
text_contents.push(paragraph.text_content.clone());
|
||||
}
|
||||
|
||||
for embedded_entity in ¶graph.entities {
|
||||
if seen_entities.insert(embedded_entity.entity.id.clone()) {
|
||||
let entity = embedded_entity.entity.clone();
|
||||
entities.push(entity.clone());
|
||||
entity_embeddings.push(KnowledgeEntityEmbedding::new(
|
||||
&entity.id,
|
||||
entity.source_id.clone(),
|
||||
embedded_entity.embedding.clone(),
|
||||
entity.user_id.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
for relationship in ¶graph.relationships {
|
||||
if seen_relationships.insert(relationship.id.clone()) {
|
||||
let table = KnowledgeEntity::table_name();
|
||||
let in_id = relationship
|
||||
.in_
|
||||
.strip_prefix(&format!("{table}:"))
|
||||
.unwrap_or(&relationship.in_);
|
||||
let out_id = relationship
|
||||
.out
|
||||
.strip_prefix(&format!("{table}:"))
|
||||
.unwrap_or(&relationship.out);
|
||||
let in_thing = Thing::from((table, in_id));
|
||||
let out_thing = Thing::from((table, out_id));
|
||||
relationships.push(RelationInsert {
|
||||
in_: in_thing,
|
||||
out: out_thing,
|
||||
id: relationship.id.clone(),
|
||||
metadata: relationship.metadata.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for embedded_chunk in ¶graph.chunks {
|
||||
if seen_chunks.insert(embedded_chunk.chunk.id.clone()) {
|
||||
let chunk = embedded_chunk.chunk.clone();
|
||||
chunks.push(chunk.clone());
|
||||
chunk_embeddings.push(TextChunkEmbedding::new(
|
||||
&chunk.id,
|
||||
chunk.source_id.clone(),
|
||||
embedded_chunk.embedding.clone(),
|
||||
chunk.user_id.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ManifestBatches {
|
||||
text_contents: chunk_items(
|
||||
&text_contents,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
TEXT_CONTENT_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking text_content payloads")?,
|
||||
entities: chunk_items(&entities, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
||||
.context("chunking knowledge_entity payloads")?,
|
||||
entity_embeddings: chunk_items(
|
||||
&entity_embeddings,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking knowledge_entity_embedding payloads")?,
|
||||
relationships: chunk_items(
|
||||
&relationships,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking relationship payloads")?,
|
||||
chunks: chunk_items(&chunks, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
||||
.context("chunking text_chunk payloads")?,
|
||||
chunk_embeddings: chunk_items(
|
||||
&chunk_embeddings,
|
||||
MANIFEST_BATCH_SIZE,
|
||||
MANIFEST_MAX_BYTES_PER_BATCH,
|
||||
)
|
||||
.context("chunking text_chunk_embedding payloads")?,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ParagraphShard {
|
||||
#[serde(default = "current_paragraph_shard_version")]
|
||||
@@ -430,7 +305,7 @@ impl ParagraphShardStore {
|
||||
Ok(file) => file,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
||||
Err(err) => {
|
||||
return Err(err).with_context(|| format!("opening shard {}", path.display()))
|
||||
return Err(err).with_context(|| format!("opening shard {}", path.display()));
|
||||
}
|
||||
};
|
||||
let reader = BufReader::new(file);
|
||||
@@ -599,157 +474,28 @@ fn normalize_answer_text(text: &str) -> String {
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
||||
fn chunk_items<T: Clone + Serialize>(
|
||||
items: &[T],
|
||||
max_items: usize,
|
||||
max_bytes: usize,
|
||||
) -> Result<Vec<SizedBatch<T>>> {
|
||||
if items.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut batches = Vec::new();
|
||||
let mut current = Vec::new();
|
||||
let mut current_bytes = 0usize;
|
||||
|
||||
for item in items {
|
||||
let size = serde_json::to_vec(item)
|
||||
.map(|buf| buf.len())
|
||||
.context("serialising batch item for sizing")?;
|
||||
|
||||
let would_overflow_items = !current.is_empty() && current.len() >= max_items;
|
||||
let would_overflow_bytes = !current.is_empty() && current_bytes + size > max_bytes;
|
||||
|
||||
if would_overflow_items || would_overflow_bytes {
|
||||
batches.push(SizedBatch {
|
||||
approx_bytes: current_bytes.max(1),
|
||||
items: std::mem::take(&mut current),
|
||||
});
|
||||
current_bytes = 0;
|
||||
}
|
||||
|
||||
current_bytes += size;
|
||||
current.push(item.clone());
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
batches.push(SizedBatch {
|
||||
approx_bytes: current_bytes.max(1),
|
||||
items: current,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(batches)
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
||||
async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
|
||||
db: &SurrealDbClient,
|
||||
statement: impl AsRef<str>,
|
||||
prefix: &str,
|
||||
batches: &[SizedBatch<T>],
|
||||
) -> Result<()> {
|
||||
if batches.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut start = 0;
|
||||
while start < batches.len() {
|
||||
let mut group_bytes = 0usize;
|
||||
let mut group_end = start;
|
||||
let mut group_count = 0usize;
|
||||
|
||||
while group_end < batches.len() {
|
||||
let batch_bytes = batches[group_end].approx_bytes.max(1);
|
||||
if group_count > 0
|
||||
&& (group_bytes + batch_bytes > REQUEST_MAX_BYTES
|
||||
|| group_count >= MAX_BATCHES_PER_REQUEST)
|
||||
{
|
||||
break;
|
||||
}
|
||||
group_bytes += batch_bytes;
|
||||
group_end += 1;
|
||||
group_count += 1;
|
||||
}
|
||||
|
||||
let slice = &batches[start..group_end];
|
||||
let mut query = db.client.query("BEGIN TRANSACTION;");
|
||||
for (bind_index, batch) in slice.iter().enumerate() {
|
||||
let name = format!("{prefix}{bind_index}");
|
||||
query = query
|
||||
.query(format!("{} ${};", statement.as_ref(), name))
|
||||
.bind((name, batch.items.clone()));
|
||||
}
|
||||
let response = query
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.await
|
||||
.context("executing batched insert transaction")?;
|
||||
if let Err(err) = response.check() {
|
||||
return Err(anyhow!(
|
||||
"batched insert failed for statement '{}': {err:?}",
|
||||
statement.as_ref()
|
||||
));
|
||||
}
|
||||
|
||||
start = group_end;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
||||
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
|
||||
let tuning = IngestionTuning::default();
|
||||
let embedding_dimensions = manifest.metadata.embedding_dimension;
|
||||
let mut seen_text_content = HashSet::new();
|
||||
|
||||
let result = async {
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextContent::table_name()),
|
||||
"tc",
|
||||
&batches.text_contents,
|
||||
)
|
||||
.await?;
|
||||
for paragraph in &manifest.paragraphs {
|
||||
if !seen_text_content.insert(paragraph.text_content.id.clone()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", KnowledgeEntity::table_name()),
|
||||
"ke",
|
||||
&batches.entities,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextChunk::table_name()),
|
||||
"ch",
|
||||
&batches.chunks,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
"INSERT RELATION INTO relates_to",
|
||||
"rel",
|
||||
&batches.relationships,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", KnowledgeEntityEmbedding::table_name()),
|
||||
"kee",
|
||||
&batches.entity_embeddings,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_batched_inserts(
|
||||
db,
|
||||
format!("INSERT INTO {}", TextChunkEmbedding::table_name()),
|
||||
"tce",
|
||||
&batches.chunk_embeddings,
|
||||
)
|
||||
.await?;
|
||||
let artifacts = PipelineArtifacts {
|
||||
text_content: paragraph.text_content.clone(),
|
||||
entities: paragraph.entities.clone(),
|
||||
relationships: paragraph.relationships.clone(),
|
||||
chunks: paragraph.chunks.clone(),
|
||||
};
|
||||
|
||||
persist_artifacts(db, &tuning, embedding_dimensions, artifacts)
|
||||
.await
|
||||
.map_err(|err| anyhow!("persist manifest paragraph: {err}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
@@ -778,7 +524,10 @@ pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManife
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
text_chunk::TextChunk,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
@@ -888,6 +637,7 @@ mod tests {
|
||||
chunk_min_tokens: 1,
|
||||
chunk_max_tokens: 10,
|
||||
chunk_only: false,
|
||||
namespace_seed: None,
|
||||
},
|
||||
paragraphs: vec![paragraph_one, paragraph_two],
|
||||
questions: vec![question],
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
fs::File,
|
||||
io::{BufRead, BufReader},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use serde::Deserialize;
|
||||
use tracing::warn;
|
||||
|
||||
@@ -47,20 +47,71 @@ struct QrelEntry {
|
||||
score: i32,
|
||||
}
|
||||
|
||||
/// Convert only documents that appear in qrels (the BEIR evaluation closed world).
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
|
||||
pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<ConvertedParagraph>> {
|
||||
convert_beir_documents(raw_dir, dataset, None)
|
||||
}
|
||||
|
||||
/// Convert a subset of qrels-world documents. `doc_ids` use corpus ids (unprefixed).
|
||||
#[allow(
|
||||
clippy::too_many_lines,
|
||||
clippy::arithmetic_side_effects,
|
||||
clippy::indexing_slicing
|
||||
)]
|
||||
pub fn convert_beir_documents(
|
||||
raw_dir: &Path,
|
||||
dataset: DatasetKind,
|
||||
doc_ids: Option<&HashSet<String>>,
|
||||
) -> Result<Vec<ConvertedParagraph>> {
|
||||
let corpus_path = raw_dir.join("corpus.jsonl");
|
||||
let queries_path = raw_dir.join("queries.jsonl");
|
||||
let qrels_path = resolve_qrels_path(raw_dir)?;
|
||||
|
||||
let corpus = load_corpus(&corpus_path)?;
|
||||
let queries = load_queries(&queries_path)?;
|
||||
let qrels = load_qrels(&qrels_path)?;
|
||||
|
||||
let mut paragraphs = Vec::with_capacity(corpus.len());
|
||||
let mut qrels_doc_ids = HashSet::new();
|
||||
for entries in qrels.values() {
|
||||
for entry in entries {
|
||||
qrels_doc_ids.insert(entry.doc_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let target_doc_ids: HashSet<String> = match doc_ids {
|
||||
Some(ids) => ids
|
||||
.iter()
|
||||
.filter(|id| qrels_doc_ids.contains(*id))
|
||||
.cloned()
|
||||
.collect(),
|
||||
None => qrels_doc_ids.clone(),
|
||||
};
|
||||
|
||||
if target_doc_ids.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no qrels documents to convert for {} at {}",
|
||||
dataset.id(),
|
||||
raw_dir.display()
|
||||
));
|
||||
}
|
||||
|
||||
let corpus = load_corpus_filtered(&corpus_path, &target_doc_ids)?;
|
||||
|
||||
let mut doc_ids_sorted: Vec<String> = target_doc_ids.into_iter().collect();
|
||||
doc_ids_sorted.sort();
|
||||
|
||||
let mut paragraphs = Vec::with_capacity(doc_ids_sorted.len());
|
||||
let mut paragraph_index = HashMap::new();
|
||||
|
||||
for (doc_id, entry) in &corpus {
|
||||
for doc_id in &doc_ids_sorted {
|
||||
let Some(entry) = corpus.get(doc_id) else {
|
||||
warn!(
|
||||
doc_id = %doc_id,
|
||||
dataset = %dataset.id(),
|
||||
"Skipping qrels document missing from corpus"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let paragraph_id = format!("{}-{doc_id}", dataset.source_prefix());
|
||||
let paragraph = ConvertedParagraph {
|
||||
id: paragraph_id.clone(),
|
||||
@@ -87,6 +138,12 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some(filter) = doc_ids
|
||||
&& !filter.contains(&best.doc_id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(¶graph_slot) = paragraph_index.get(&best.doc_id) else {
|
||||
missing_docs += 1;
|
||||
warn!(
|
||||
@@ -106,7 +163,6 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let answers = vec![snippet];
|
||||
|
||||
let question_id = format!("{}-{query_id}", dataset.source_prefix());
|
||||
paragraphs[paragraph_slot]
|
||||
@@ -114,7 +170,7 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
||||
.push(ConvertedQuestion {
|
||||
id: question_id,
|
||||
question: query.text.clone(),
|
||||
answers,
|
||||
answers: vec![snippet],
|
||||
is_impossible: false,
|
||||
});
|
||||
}
|
||||
@@ -122,13 +178,21 @@ pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<Converte
|
||||
if missing_queries + missing_docs + skipped_answers > 0 {
|
||||
warn!(
|
||||
missing_queries,
|
||||
missing_docs, skipped_answers, "Skipped some BEIR qrels entries during conversion"
|
||||
missing_docs,
|
||||
skipped_answers,
|
||||
dataset = %dataset.id(),
|
||||
"Skipped some BEIR qrels entries during conversion"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(paragraphs)
|
||||
}
|
||||
|
||||
pub fn corpus_doc_id(paragraph_id: &str, dataset: DatasetKind) -> Option<String> {
|
||||
let prefix = format!("{}-", dataset.source_prefix());
|
||||
paragraph_id.strip_prefix(&prefix).map(str::to_string)
|
||||
}
|
||||
|
||||
fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
|
||||
let qrels_dir = raw_dir.join("qrels");
|
||||
let candidates = ["test.tsv", "dev.tsv", "train.tsv"];
|
||||
@@ -148,7 +212,10 @@ fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects)]
|
||||
fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
|
||||
fn load_corpus_filtered(
|
||||
path: &Path,
|
||||
doc_ids: &HashSet<String>,
|
||||
) -> Result<BTreeMap<String, BeirParagraph>> {
|
||||
let file =
|
||||
File::open(path).with_context(|| format!("opening BEIR corpus at {}", path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
@@ -167,6 +234,9 @@ fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
if !doc_ids.contains(&corpus_row.id) {
|
||||
continue;
|
||||
}
|
||||
let title = corpus_row.title.unwrap_or_else(|| corpus_row.id.clone());
|
||||
let text = corpus_row.text.unwrap_or_default();
|
||||
let context = build_context(&title, &text);
|
||||
@@ -296,10 +366,8 @@ mod tests {
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||
fn converts_basic_beir_layout() {
|
||||
let dir = tempdir().unwrap();
|
||||
#[allow(clippy::unwrap_used)]
|
||||
fn write_fixture(dir: &tempfile::TempDir) {
|
||||
let corpus = r#"
|
||||
{"_id":"d1","title":"Doc 1","text":"Doc one has some text for testing."}
|
||||
{"_id":"d2","title":"Doc 2","text":"Second document content."}
|
||||
@@ -313,24 +381,34 @@ mod tests {
|
||||
fs::write(dir.path().join("queries.jsonl"), queries.trim()).unwrap();
|
||||
fs::create_dir_all(dir.path().join("qrels")).unwrap();
|
||||
fs::write(dir.path().join("qrels/test.tsv"), qrels).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||
fn converts_qrels_world_only() {
|
||||
let dir = tempdir().unwrap();
|
||||
write_fixture(&dir);
|
||||
|
||||
let paragraphs = convert_beir(dir.path(), DatasetKind::Fever).unwrap();
|
||||
|
||||
assert_eq!(paragraphs.len(), 2);
|
||||
let doc_one = paragraphs
|
||||
.iter()
|
||||
.find(|p| p.id == "fever-d1")
|
||||
.expect("missing paragraph for d1");
|
||||
assert_eq!(paragraphs.len(), 1);
|
||||
let doc_one = ¶graphs[0];
|
||||
assert_eq!(doc_one.id, "fever-d1");
|
||||
assert_eq!(doc_one.questions.len(), 1);
|
||||
let question = &doc_one.questions[0];
|
||||
assert_eq!(question.id, "fever-q1");
|
||||
assert!(!question.answers.is_empty());
|
||||
assert!(doc_one.context.contains(&question.answers[0]));
|
||||
assert_eq!(doc_one.questions[0].id, "fever-q1");
|
||||
}
|
||||
|
||||
let doc_two = paragraphs
|
||||
.iter()
|
||||
.find(|p| p.id == "fever-d2")
|
||||
.expect("missing paragraph for d2");
|
||||
assert!(doc_two.questions.is_empty());
|
||||
#[test]
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing)]
|
||||
fn converts_filtered_doc_ids() {
|
||||
let dir = tempdir().unwrap();
|
||||
write_fixture(&dir);
|
||||
|
||||
let mut ids = HashSet::new();
|
||||
ids.insert("d1".to_string());
|
||||
let paragraphs =
|
||||
convert_beir_documents(dir.path(), DatasetKind::Fever, Some(&ids)).unwrap();
|
||||
assert_eq!(paragraphs.len(), 1);
|
||||
assert_eq!(paragraphs[0].id, "fever-d1");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,257 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::info;
|
||||
|
||||
use super::{
|
||||
BEIR_DATASETS, ConvertedDataset, DatasetKind, DatasetMetadata, beir,
|
||||
checksum::hash_file,
|
||||
store::{
|
||||
self, build_dataset_from_catalog, paragraph_path, read_meta, store_dir_for,
|
||||
upsert_sharded_paragraphs, write_sharded,
|
||||
},
|
||||
};
|
||||
use crate::{args::Config, slice};
|
||||
|
||||
pub fn subset_for_paragraph_id(paragraph_id: &str) -> Option<DatasetKind> {
|
||||
let mut kinds: Vec<DatasetKind> = BEIR_DATASETS.to_vec();
|
||||
kinds.sort_by_key(|kind| std::cmp::Reverse(kind.source_prefix().len()));
|
||||
for kind in kinds {
|
||||
let prefix = format!("{}-", kind.source_prefix());
|
||||
if paragraph_id.starts_with(&prefix) {
|
||||
return Some(kind);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn build_beir_mix_qrels_dataset(include_unanswerable: bool) -> Result<ConvertedDataset> {
|
||||
if include_unanswerable {
|
||||
tracing::warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
|
||||
}
|
||||
|
||||
let mut paragraphs = Vec::new();
|
||||
for subset in BEIR_DATASETS {
|
||||
let entry = super::dataset_entry_for_kind(subset)?;
|
||||
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
|
||||
paragraphs.extend(subset_paragraphs);
|
||||
}
|
||||
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: super::base_timestamp(),
|
||||
metadata: DatasetMetadata::for_kind(DatasetKind::Beir, include_unanswerable),
|
||||
source: "beir-mix".to_string(),
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn prepare_beir_mix(config: &Config) -> Result<super::loader::LoadedDataset> {
|
||||
let virtual_ds = build_beir_mix_qrels_dataset(config.llm_mode)?;
|
||||
let slice_config = slice::slice_config_with_limit(config, slice::ledger_target(config));
|
||||
let resolved = slice::resolve_slice(&virtual_ds, &slice_config)
|
||||
.context("resolving BEIR mix slice ledger (check --slice and --limit match your intent)")?;
|
||||
|
||||
let unique: HashSet<String> = resolved
|
||||
.manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|entry| entry.id.clone())
|
||||
.collect();
|
||||
|
||||
materialize_subset_stores(&unique, config.force_convert)?;
|
||||
|
||||
let dataset = load_beir_mix_from_subsets(&unique)?;
|
||||
let checksum = mix_content_checksum(&unique)?;
|
||||
|
||||
info!(
|
||||
slice = resolved.manifest.slice_id.as_str(),
|
||||
paragraphs = unique.len(),
|
||||
checksum = %checksum,
|
||||
"Prepared BEIR mix from per-subset converted stores"
|
||||
);
|
||||
|
||||
Ok(super::loader::LoadedDataset {
|
||||
dataset,
|
||||
content_checksum: checksum,
|
||||
partial: true,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn materialize_subset_stores(paragraph_ids: &HashSet<String>, force: bool) -> Result<()> {
|
||||
let mut by_subset: HashMap<DatasetKind, Vec<String>> = HashMap::new();
|
||||
for paragraph_id in paragraph_ids {
|
||||
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||
})?;
|
||||
by_subset
|
||||
.entry(kind)
|
||||
.or_default()
|
||||
.push(paragraph_id.clone());
|
||||
}
|
||||
|
||||
for (kind, ids) in by_subset {
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
let existing = if store_dir.join("meta.json").is_file() {
|
||||
store::load_paragraph_ids_set(&store_dir)?
|
||||
} else {
|
||||
HashSet::new()
|
||||
};
|
||||
|
||||
let missing: Vec<String> = if force {
|
||||
ids
|
||||
} else {
|
||||
ids.into_iter()
|
||||
.filter(|paragraph_id| !existing.contains(paragraph_id))
|
||||
.collect()
|
||||
};
|
||||
|
||||
if missing.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let corpus_ids: HashSet<String> = missing
|
||||
.iter()
|
||||
.filter_map(|paragraph_id| beir::corpus_doc_id(paragraph_id, kind))
|
||||
.collect();
|
||||
let paragraphs = beir::convert_beir_documents(&entry.raw_path, kind, Some(&corpus_ids))?;
|
||||
|
||||
if store_dir.join("meta.json").is_file() {
|
||||
upsert_sharded_paragraphs(&store_dir, ¶graphs)?;
|
||||
} else {
|
||||
let question_count = paragraphs
|
||||
.iter()
|
||||
.map(|paragraph| paragraph.questions.len())
|
||||
.sum::<usize>();
|
||||
let dataset = ConvertedDataset {
|
||||
generated_at: super::base_timestamp(),
|
||||
metadata: DatasetMetadata::for_kind(kind, false),
|
||||
source: entry.raw_path.display().to_string(),
|
||||
paragraphs,
|
||||
};
|
||||
write_sharded(&dataset, &store_dir)?;
|
||||
info!(
|
||||
subset = kind.id(),
|
||||
store = %store_dir.display(),
|
||||
paragraphs = dataset.paragraphs.len(),
|
||||
questions = question_count,
|
||||
"Created subset converted store for BEIR mix"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load_beir_mix_from_subsets(paragraph_ids: &HashSet<String>) -> Result<ConvertedDataset> {
|
||||
let mut by_subset: HashMap<DatasetKind, HashSet<String>> = HashMap::new();
|
||||
for paragraph_id in paragraph_ids {
|
||||
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||
})?;
|
||||
by_subset
|
||||
.entry(kind)
|
||||
.or_default()
|
||||
.insert(paragraph_id.clone());
|
||||
}
|
||||
|
||||
let mut paragraphs = Vec::with_capacity(paragraph_ids.len());
|
||||
for (kind, subset_ids) in by_subset {
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
let partial = build_dataset_from_catalog(&store_dir, &subset_ids)?;
|
||||
paragraphs.extend(partial.paragraphs);
|
||||
}
|
||||
|
||||
paragraphs.sort_by(|left, right| left.id.cmp(&right.id));
|
||||
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: super::base_timestamp(),
|
||||
metadata: DatasetMetadata::for_kind(DatasetKind::Beir, false),
|
||||
source: "beir-mix".to_string(),
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn mix_content_checksum(paragraph_ids: &HashSet<String>) -> Result<String> {
|
||||
let mut ids: Vec<String> = paragraph_ids.iter().cloned().collect();
|
||||
ids.sort();
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
for paragraph_id in ids {
|
||||
let kind = subset_for_paragraph_id(¶graph_id)
|
||||
.ok_or_else(|| anyhow!("unknown BEIR subset for paragraph '{paragraph_id}'"))?;
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
let path = paragraph_path(&store_dir, ¶graph_id);
|
||||
if !path.is_file() {
|
||||
return Err(anyhow!(
|
||||
"missing converted paragraph {} at {}",
|
||||
paragraph_id,
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
hasher.update(paragraph_id.as_bytes());
|
||||
hasher.update([0]);
|
||||
hasher.update(hash_file(&path)?.as_bytes());
|
||||
}
|
||||
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
pub fn beir_subset_stores_ready(paragraph_ids: &HashSet<String>) -> Result<bool> {
|
||||
for paragraph_id in paragraph_ids {
|
||||
let kind = subset_for_paragraph_id(paragraph_id).with_context(|| {
|
||||
format!("routing BEIR mix paragraph id '{paragraph_id}' to subset store")
|
||||
})?;
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
if !store_dir.join("meta.json").is_file() {
|
||||
return Ok(false);
|
||||
}
|
||||
if !paragraph_path(&store_dir, paragraph_id).is_file() {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub fn beir_subset_store_summary() -> Result<Vec<(String, usize, usize)>> {
|
||||
let mut summary = Vec::new();
|
||||
for kind in BEIR_DATASETS {
|
||||
let entry = super::dataset_entry_for_kind(kind)?;
|
||||
let store_dir = store_dir_for(&entry.converted_path);
|
||||
if store_dir.join("meta.json").is_file() {
|
||||
let meta = read_meta(&store_dir)?;
|
||||
summary.push((
|
||||
kind.id().to_string(),
|
||||
meta.paragraph_count,
|
||||
meta.question_count,
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(summary)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn routes_prefixed_paragraph_ids() {
|
||||
assert_eq!(
|
||||
subset_for_paragraph_id("fever-doc-1"),
|
||||
Some(DatasetKind::Fever)
|
||||
);
|
||||
assert_eq!(
|
||||
subset_for_paragraph_id("nq-beir-doc-1"),
|
||||
Some(DatasetKind::NqBeir)
|
||||
);
|
||||
assert_eq!(
|
||||
subset_for_paragraph_id("trec-covid-doc-1"),
|
||||
Some(DatasetKind::TrecCovid)
|
||||
);
|
||||
assert!(subset_for_paragraph_id("unknown-doc").is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::Read,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
const SIDECAR_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChecksumSidecar {
|
||||
pub version: u32,
|
||||
pub sha256: String,
|
||||
pub size_bytes: u64,
|
||||
#[serde(default)]
|
||||
pub modified_unix_secs: u64,
|
||||
}
|
||||
|
||||
impl ChecksumSidecar {
|
||||
#[cfg(test)]
|
||||
pub fn sidecar_path(content_path: &Path) -> PathBuf {
|
||||
content_path.with_extension("sha256")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn is_valid_for(&self, content_path: &Path) -> bool {
|
||||
if self.version != SIDECAR_VERSION {
|
||||
return false;
|
||||
}
|
||||
let Ok(metadata) = fs::metadata(content_path) else {
|
||||
return false;
|
||||
};
|
||||
if metadata.len() != self.size_bytes {
|
||||
return false;
|
||||
}
|
||||
if self.modified_unix_secs != 0 {
|
||||
let Ok(modified) = metadata.modified() else {
|
||||
return true;
|
||||
};
|
||||
let Ok(secs) = modified.duration_since(std::time::UNIX_EPOCH) else {
|
||||
return true;
|
||||
};
|
||||
if secs.as_secs() != self.modified_unix_secs {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
pub fn hash_file(path: &Path) -> Result<String> {
|
||||
let mut file = File::open(path)
|
||||
.with_context(|| format!("opening file {} for checksum", path.display()))?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buffer = vec![0u8; 65_536];
|
||||
loop {
|
||||
let read = file
|
||||
.read(&mut buffer)
|
||||
.with_context(|| format!("reading {} for checksum", path.display()))?;
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buffer[..read]);
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
pub fn read_sidecar(path: &Path) -> Result<Option<ChecksumSidecar>> {
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let raw = fs::read_to_string(path)
|
||||
.with_context(|| format!("reading checksum sidecar {}", path.display()))?;
|
||||
let sidecar: ChecksumSidecar = serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing checksum sidecar {}", path.display()))?;
|
||||
Ok(Some(sidecar))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn write_sidecar(content_path: &Path, sha256: &str) -> Result<()> {
|
||||
let metadata = fs::metadata(content_path)
|
||||
.with_context(|| format!("reading metadata for {}", content_path.display()))?;
|
||||
let modified_unix_secs = metadata
|
||||
.modified()
|
||||
.ok()
|
||||
.and_then(|time| time.duration_since(std::time::UNIX_EPOCH).ok())
|
||||
.map_or(0, |duration| duration.as_secs());
|
||||
let sidecar = ChecksumSidecar {
|
||||
version: SIDECAR_VERSION,
|
||||
sha256: sha256.to_string(),
|
||||
size_bytes: metadata.len(),
|
||||
modified_unix_secs,
|
||||
};
|
||||
let path = ChecksumSidecar::sidecar_path(content_path);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating checksum sidecar directory {}", parent.display()))?;
|
||||
}
|
||||
let blob = serde_json::to_vec_pretty(&sidecar).context("serialising checksum sidecar")?;
|
||||
fs::write(&path, blob)
|
||||
.with_context(|| format!("writing checksum sidecar {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn content_checksum(content_path: &Path) -> Result<String> {
|
||||
let sidecar_path = ChecksumSidecar::sidecar_path(content_path);
|
||||
if let Some(sidecar) = read_sidecar(&sidecar_path)?
|
||||
&& sidecar.is_valid_for(content_path)
|
||||
{
|
||||
return Ok(sidecar.sha256);
|
||||
}
|
||||
let sha256 = hash_file(content_path)?;
|
||||
write_sidecar(content_path, &sha256)?;
|
||||
Ok(sha256)
|
||||
}
|
||||
|
||||
pub fn store_aggregate_checksum(store_dir: &Path) -> Result<String> {
|
||||
let marker = store_dir.join("checksum.sha256");
|
||||
let meta = store_dir.join("meta.json");
|
||||
if marker.is_file()
|
||||
&& meta.is_file()
|
||||
&& let (Ok(marker_meta), Ok(meta_meta)) = (marker.metadata(), meta.metadata())
|
||||
&& marker_meta
|
||||
.modified()
|
||||
.ok()
|
||||
.zip(meta_meta.modified().ok())
|
||||
.is_some_and(|(marker_modified, meta_modified)| marker_modified >= meta_modified)
|
||||
&& let Some(sidecar) = read_sidecar(&marker)?
|
||||
{
|
||||
return Ok(sidecar.sha256);
|
||||
}
|
||||
|
||||
let mut entries = Vec::new();
|
||||
collect_store_files(store_dir, store_dir, &mut entries)?;
|
||||
entries.sort();
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
for relative in &entries {
|
||||
let path = store_dir.join(relative);
|
||||
if path == marker {
|
||||
continue;
|
||||
}
|
||||
hasher.update(relative.as_bytes());
|
||||
hasher.update([0]);
|
||||
let file_hash = hash_file(&path)?;
|
||||
hasher.update(file_hash.as_bytes());
|
||||
}
|
||||
let digest = format!("{:x}", hasher.finalize());
|
||||
|
||||
let sidecar = ChecksumSidecar {
|
||||
version: SIDECAR_VERSION,
|
||||
sha256: digest.clone(),
|
||||
size_bytes: entries.len() as u64,
|
||||
modified_unix_secs: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map_or(0, |duration| duration.as_secs()),
|
||||
};
|
||||
if let Some(parent) = marker.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(&marker, serde_json::to_vec_pretty(&sidecar)?)?;
|
||||
Ok(digest)
|
||||
}
|
||||
|
||||
fn collect_store_files(base: &Path, current: &Path, entries: &mut Vec<String>) -> Result<()> {
|
||||
for entry in fs::read_dir(current)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path
|
||||
.file_name()
|
||||
.is_some_and(|name| name == "checksum.sha256")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if path.is_dir() {
|
||||
collect_store_files(base, &path, entries)?;
|
||||
} else if path.is_file() {
|
||||
let relative = path
|
||||
.strip_prefix(base)
|
||||
.unwrap_or(&path)
|
||||
.to_string_lossy()
|
||||
.replace('\\', "/");
|
||||
entries.push(relative);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn sidecar_round_trip() -> Result<()> {
|
||||
let dir = tempdir()?;
|
||||
let file = dir.path().join("sample.json");
|
||||
fs::write(&file, br#"{"hello":"world"}"#)?;
|
||||
|
||||
let first = content_checksum(&file)?;
|
||||
let second = content_checksum(&file)?;
|
||||
assert_eq!(first, second);
|
||||
|
||||
fs::write(&file, br#"{"hello":"world!"}"#)?;
|
||||
let third = content_checksum(&file)?;
|
||||
assert_ne!(first, third);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tracing::info;
|
||||
|
||||
use super::{
|
||||
ConvertedDataset, DatasetKind, catalog,
|
||||
store::{
|
||||
self, ConvertedLayout, build_dataset_from_catalog, detect_layout, read_meta, store_dir_for,
|
||||
write_sharded,
|
||||
},
|
||||
};
|
||||
use crate::{
|
||||
args::Config,
|
||||
slice::{self, SliceConfig},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoadedDataset {
|
||||
pub dataset: ConvertedDataset,
|
||||
pub content_checksum: String,
|
||||
pub partial: bool,
|
||||
}
|
||||
|
||||
pub fn prepare_dataset(dataset_kind: DatasetKind, config: &Config) -> Result<LoadedDataset> {
|
||||
if dataset_kind == DatasetKind::Beir {
|
||||
return super::beir_mix::prepare_beir_mix(config);
|
||||
}
|
||||
|
||||
let converted_path = &config.converted_dataset_path;
|
||||
let layout = detect_layout(converted_path);
|
||||
let store_dir = store_dir_for(converted_path);
|
||||
|
||||
if layout == ConvertedLayout::Missing || config.force_convert {
|
||||
return convert_and_load(dataset_kind, config);
|
||||
}
|
||||
|
||||
load_from_store(dataset_kind, config, &store_dir, true)
|
||||
}
|
||||
|
||||
fn convert_and_load(dataset_kind: DatasetKind, config: &Config) -> Result<LoadedDataset> {
|
||||
let dataset = super::convert(
|
||||
config.raw_dataset_path.as_path(),
|
||||
dataset_kind,
|
||||
config.llm_mode,
|
||||
)
|
||||
.with_context(|| format!("converting {} dataset", dataset_kind.label()))?;
|
||||
|
||||
let store_dir = store_dir_for(&config.converted_dataset_path);
|
||||
write_sharded(&dataset, &store_dir)?;
|
||||
prebuild_catalog_slices(&dataset, config)?;
|
||||
let checksum = crate::datasets::store_aggregate_checksum(&store_dir)?;
|
||||
|
||||
Ok(LoadedDataset {
|
||||
dataset,
|
||||
content_checksum: checksum,
|
||||
partial: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_from_store(
|
||||
dataset_kind: DatasetKind,
|
||||
config: &Config,
|
||||
store_dir: &std::path::Path,
|
||||
allow_partial: bool,
|
||||
) -> Result<LoadedDataset> {
|
||||
let checksum = crate::datasets::store_aggregate_checksum(store_dir)?;
|
||||
let meta = read_meta(store_dir)?;
|
||||
validate_metadata_fields(&meta.metadata, dataset_kind, config)?;
|
||||
|
||||
if allow_partial && let Some(paragraph_ids) = slice_paragraph_ids_for_fast_path(config)? {
|
||||
let unique: HashSet<String> = paragraph_ids.into_iter().collect();
|
||||
info!(
|
||||
paragraphs = unique.len(),
|
||||
store = %store_dir.display(),
|
||||
"Loading slice-addressed paragraphs from sharded converted store"
|
||||
);
|
||||
let dataset = build_dataset_from_catalog(store_dir, &unique)?;
|
||||
return Ok(LoadedDataset {
|
||||
dataset,
|
||||
content_checksum: checksum,
|
||||
partial: true,
|
||||
});
|
||||
}
|
||||
|
||||
info!(
|
||||
store = %store_dir.display(),
|
||||
paragraphs = meta.paragraph_count,
|
||||
"Loading full sharded converted store"
|
||||
);
|
||||
let dataset = store::load_sharded_full(store_dir)?;
|
||||
Ok(LoadedDataset {
|
||||
dataset,
|
||||
content_checksum: checksum,
|
||||
partial: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn slice_paragraph_ids_for_fast_path(config: &Config) -> Result<Option<Vec<String>>> {
|
||||
let Some(manifest_path) = slice::cached_manifest_path(config) else {
|
||||
return Ok(None);
|
||||
};
|
||||
let Some(manifest) = slice::read_manifest_if_exists(&manifest_path)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let slice_config = slice::slice_config_with_limit(config, slice::ledger_target(config));
|
||||
if !slice::manifest_is_complete(&manifest, &slice_config) {
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(
|
||||
manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|entry| entry.id.clone())
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
fn validate_metadata_fields(
|
||||
metadata: &super::DatasetMetadata,
|
||||
dataset_kind: DatasetKind,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
if metadata.id != dataset_kind.id() {
|
||||
anyhow::bail!(
|
||||
"converted dataset targets '{}', expected '{}'",
|
||||
metadata.id,
|
||||
dataset_kind.id()
|
||||
);
|
||||
}
|
||||
if metadata.include_unanswerable != config.llm_mode {
|
||||
anyhow::bail!(
|
||||
"converted dataset include_unanswerable mismatch (expected {}, found {})",
|
||||
config.llm_mode,
|
||||
metadata.include_unanswerable
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn prebuild_catalog_slices(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||
let catalog = catalog()?;
|
||||
let entry = catalog.dataset(dataset.metadata.id.as_str())?;
|
||||
if entry.slices.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!(
|
||||
dataset = dataset.metadata.id.as_str(),
|
||||
slices = entry.slices.len(),
|
||||
"Prebuilding catalog slice ledgers"
|
||||
);
|
||||
|
||||
for slice_entry in &entry.slices {
|
||||
let slice_config = slice_config_for_catalog_entry(config, slice_entry);
|
||||
match slice::resolve_slice(dataset, &slice_config) {
|
||||
Ok(resolved) => info!(
|
||||
slice = resolved.manifest.slice_id.as_str(),
|
||||
cases = resolved.manifest.case_count,
|
||||
positives = resolved.manifest.positive_paragraphs,
|
||||
negatives = resolved.manifest.negative_paragraphs,
|
||||
"Prebuilt catalog slice ledger"
|
||||
),
|
||||
Err(err) => tracing::warn!(
|
||||
slice = slice_entry.id.as_str(),
|
||||
error = %err,
|
||||
"Failed to prebuild catalog slice ledger"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_config_for_catalog_entry<'a>(
|
||||
config: &'a Config,
|
||||
slice_entry: &'a super::SliceEntry,
|
||||
) -> SliceConfig<'a> {
|
||||
SliceConfig {
|
||||
cache_dir: config.cache_dir.as_path(),
|
||||
force_convert: config.force_convert,
|
||||
explicit_slice: Some(slice_entry.id.as_str()),
|
||||
limit: slice_entry.limit,
|
||||
corpus_limit: slice_entry.corpus_limit,
|
||||
slice_seed: slice_entry.seed.unwrap_or(config.slice_seed),
|
||||
llm_mode: slice_entry.include_unanswerable.unwrap_or(config.llm_mode),
|
||||
negative_multiplier: slice_entry
|
||||
.negative_multiplier
|
||||
.unwrap_or(config.negative_multiplier),
|
||||
require_verified_chunks: config.retrieval.require_verified_chunks,
|
||||
}
|
||||
}
|
||||
+40
-145
@@ -1,6 +1,10 @@
|
||||
mod beir;
|
||||
mod beir_mix;
|
||||
mod checksum;
|
||||
mod loader;
|
||||
mod nq;
|
||||
mod squad;
|
||||
mod store;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
@@ -9,7 +13,7 @@ use std::{
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use clap::ValueEnum;
|
||||
use once_cell::sync::OnceCell;
|
||||
@@ -20,38 +24,31 @@ const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml"
|
||||
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct DatasetCatalog {
|
||||
datasets: BTreeMap<String, DatasetEntry>,
|
||||
slices: HashMap<String, SliceLocation>,
|
||||
default_dataset: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct DatasetEntry {
|
||||
pub metadata: DatasetMetadata,
|
||||
pub raw_path: PathBuf,
|
||||
pub converted_path: PathBuf,
|
||||
pub include_unanswerable: bool,
|
||||
pub slices: Vec<SliceEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SliceEntry {
|
||||
pub id: String,
|
||||
pub dataset_id: String,
|
||||
pub label: String,
|
||||
pub description: Option<String>,
|
||||
pub limit: Option<usize>,
|
||||
pub corpus_limit: Option<usize>,
|
||||
pub include_unanswerable: Option<bool>,
|
||||
pub seed: Option<u64>,
|
||||
pub negative_multiplier: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
struct SliceLocation {
|
||||
dataset_id: String,
|
||||
slice_index: usize,
|
||||
@@ -59,7 +56,6 @@ struct SliceLocation {
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ManifestFile {
|
||||
default_dataset: Option<String>,
|
||||
datasets: Vec<ManifestDataset>,
|
||||
}
|
||||
|
||||
@@ -81,6 +77,7 @@ struct ManifestDataset {
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ManifestSlice {
|
||||
id: String,
|
||||
label: String,
|
||||
@@ -94,6 +91,8 @@ struct ManifestSlice {
|
||||
include_unanswerable: Option<bool>,
|
||||
#[serde(default)]
|
||||
seed: Option<u64>,
|
||||
#[serde(default)]
|
||||
negative_multiplier: Option<f32>,
|
||||
}
|
||||
|
||||
impl DatasetCatalog {
|
||||
@@ -111,18 +110,19 @@ impl DatasetCatalog {
|
||||
let raw_path = resolve_path(root, &dataset.raw);
|
||||
let converted_path = resolve_path(root, &dataset.converted);
|
||||
|
||||
if !raw_path.exists() {
|
||||
if !raw_path.exists() && dataset.id != "beir" {
|
||||
bail!(
|
||||
"dataset '{}' raw file missing at {}",
|
||||
dataset.id,
|
||||
raw_path.display()
|
||||
);
|
||||
}
|
||||
if !converted_path.exists() {
|
||||
let store_dir = store::store_dir_for(&converted_path);
|
||||
if !converted_path.exists() && !store_dir.join("meta.json").is_file() {
|
||||
warn!(
|
||||
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
|
||||
"dataset '{}' converted store missing at {}; the next conversion run will regenerate it",
|
||||
dataset.id,
|
||||
converted_path.display()
|
||||
store_dir.display()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -139,7 +139,6 @@ impl DatasetCatalog {
|
||||
.clone()
|
||||
.unwrap_or_else(|| dataset.id.clone()),
|
||||
include_unanswerable: dataset.include_unanswerable,
|
||||
context_token_limit: None,
|
||||
};
|
||||
|
||||
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
|
||||
@@ -154,12 +153,11 @@ impl DatasetCatalog {
|
||||
entry_slices.push(SliceEntry {
|
||||
id: manifest_slice.id.clone(),
|
||||
dataset_id: dataset.id.clone(),
|
||||
label: manifest_slice.label,
|
||||
description: manifest_slice.description,
|
||||
limit: manifest_slice.limit,
|
||||
corpus_limit: manifest_slice.corpus_limit,
|
||||
include_unanswerable: manifest_slice.include_unanswerable,
|
||||
seed: manifest_slice.seed,
|
||||
negative_multiplier: manifest_slice.negative_multiplier,
|
||||
});
|
||||
slices.insert(
|
||||
manifest_slice.id,
|
||||
@@ -176,22 +174,16 @@ impl DatasetCatalog {
|
||||
metadata,
|
||||
raw_path,
|
||||
converted_path,
|
||||
include_unanswerable: dataset.include_unanswerable,
|
||||
slices: entry_slices,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let default_dataset = manifest
|
||||
.default_dataset
|
||||
.or_else(|| datasets.keys().next().cloned())
|
||||
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
|
||||
if datasets.is_empty() {
|
||||
bail!("dataset manifest does not include any datasets");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
datasets,
|
||||
slices,
|
||||
default_dataset,
|
||||
})
|
||||
Ok(Self { datasets, slices })
|
||||
}
|
||||
|
||||
pub fn global() -> Result<&'static Self> {
|
||||
@@ -204,12 +196,6 @@ impl DatasetCatalog {
|
||||
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
|
||||
self.dataset(&self.default_dataset)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
|
||||
let location = self
|
||||
.slices
|
||||
@@ -236,20 +222,27 @@ fn resolve_path(root: &Path, value: &str) -> PathBuf {
|
||||
}
|
||||
}
|
||||
|
||||
pub use beir_mix::{beir_subset_store_summary, beir_subset_stores_ready, mix_content_checksum};
|
||||
pub use checksum::store_aggregate_checksum;
|
||||
pub use loader::{prebuild_catalog_slices, prepare_dataset};
|
||||
pub use store::{
|
||||
ConvertedLayout, content_checksum_for_layout, detect_layout, store_dir_for, write_sharded,
|
||||
};
|
||||
|
||||
pub fn catalog() -> Result<&'static DatasetCatalog> {
|
||||
DatasetCatalog::global()
|
||||
}
|
||||
|
||||
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||
pub(crate) fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
|
||||
let catalog = catalog()?;
|
||||
catalog.dataset(kind.id())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, ValueEnum, Default)]
|
||||
pub enum DatasetKind {
|
||||
#[default]
|
||||
SquadV2,
|
||||
NaturalQuestions,
|
||||
#[default]
|
||||
Beir,
|
||||
#[value(name = "fever")]
|
||||
Fever,
|
||||
@@ -390,7 +383,9 @@ impl FromStr for DatasetKind {
|
||||
"scifact" => Ok(Self::Scifact),
|
||||
"nq-beir" | "natural-questions-beir" => Ok(Self::NqBeir),
|
||||
other => {
|
||||
anyhow::bail!("unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid, scifact, nq-beir.")
|
||||
anyhow::bail!(
|
||||
"unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid, scifact, nq-beir."
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -416,16 +411,10 @@ pub struct DatasetMetadata {
|
||||
pub source_prefix: String,
|
||||
#[serde(default)]
|
||||
pub include_unanswerable: bool,
|
||||
#[serde(default)]
|
||||
pub context_token_limit: Option<usize>,
|
||||
}
|
||||
|
||||
impl DatasetMetadata {
|
||||
pub fn for_kind(
|
||||
kind: DatasetKind,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Self {
|
||||
pub fn for_kind(kind: DatasetKind, include_unanswerable: bool) -> Self {
|
||||
if let Ok(entry) = dataset_entry_for_kind(kind) {
|
||||
return Self {
|
||||
id: entry.metadata.id.clone(),
|
||||
@@ -434,7 +423,6 @@ impl DatasetMetadata {
|
||||
entity_suffix: entry.metadata.entity_suffix.clone(),
|
||||
source_prefix: entry.metadata.source_prefix.clone(),
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -445,13 +433,12 @@ impl DatasetMetadata {
|
||||
entity_suffix: kind.entity_suffix().to_string(),
|
||||
source_prefix: kind.source_prefix().to_string(),
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_metadata() -> DatasetMetadata {
|
||||
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
|
||||
DatasetMetadata::for_kind(DatasetKind::default(), false)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -483,14 +470,15 @@ pub fn convert(
|
||||
raw_path: &Path,
|
||||
dataset: DatasetKind,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
let paragraphs = match dataset {
|
||||
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
|
||||
DatasetKind::NaturalQuestions => {
|
||||
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
|
||||
DatasetKind::NaturalQuestions => nq::convert_nq(raw_path, include_unanswerable)?,
|
||||
DatasetKind::Beir => {
|
||||
bail!(
|
||||
"BEIR mix is prepared via slice-first subset stores; use prepare_beir_mix instead of convert"
|
||||
);
|
||||
}
|
||||
DatasetKind::Beir => convert_beir_mix(include_unanswerable, context_token_limit)?,
|
||||
DatasetKind::Fever
|
||||
| DatasetKind::Fiqa
|
||||
| DatasetKind::HotpotQa
|
||||
@@ -501,11 +489,6 @@ pub fn convert(
|
||||
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
|
||||
};
|
||||
|
||||
let metadata_limit = match dataset {
|
||||
DatasetKind::NaturalQuestions => None,
|
||||
_ => context_token_limit,
|
||||
};
|
||||
|
||||
let generated_at = match dataset {
|
||||
DatasetKind::Beir
|
||||
| DatasetKind::Fever
|
||||
@@ -526,100 +509,12 @@ pub fn convert(
|
||||
|
||||
Ok(ConvertedDataset {
|
||||
generated_at,
|
||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
|
||||
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable),
|
||||
source: source_label,
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_beir_mix(
|
||||
include_unanswerable: bool,
|
||||
_context_token_limit: Option<usize>,
|
||||
) -> Result<Vec<ConvertedParagraph>> {
|
||||
if include_unanswerable {
|
||||
warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
|
||||
}
|
||||
|
||||
let mut paragraphs = Vec::new();
|
||||
for subset in BEIR_DATASETS {
|
||||
let entry = dataset_entry_for_kind(subset)?;
|
||||
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
|
||||
paragraphs.extend(subset_paragraphs);
|
||||
}
|
||||
|
||||
Ok(paragraphs)
|
||||
}
|
||||
|
||||
fn ensure_parent(path: &Path) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating parent directory for {}", path.display()))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
|
||||
ensure_parent(converted_path)?;
|
||||
let json =
|
||||
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
|
||||
fs::write(converted_path, json)
|
||||
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
|
||||
}
|
||||
|
||||
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
|
||||
let raw = fs::read_to_string(converted_path)
|
||||
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
|
||||
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
|
||||
if dataset.metadata.id.trim().is_empty() {
|
||||
dataset.metadata = default_metadata();
|
||||
}
|
||||
if dataset.source.is_empty() {
|
||||
dataset.source = converted_path.display().to_string();
|
||||
}
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
pub fn ensure_converted(
|
||||
dataset_kind: DatasetKind,
|
||||
raw_path: &Path,
|
||||
converted_path: &Path,
|
||||
force: bool,
|
||||
include_unanswerable: bool,
|
||||
context_token_limit: Option<usize>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
if force || !converted_path.exists() {
|
||||
let dataset = convert(
|
||||
raw_path,
|
||||
dataset_kind,
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
)?;
|
||||
write_converted(&dataset, converted_path)?;
|
||||
return Ok(dataset);
|
||||
}
|
||||
|
||||
match read_converted(converted_path) {
|
||||
Ok(dataset)
|
||||
if dataset.metadata.id == dataset_kind.id()
|
||||
&& dataset.metadata.include_unanswerable == include_unanswerable
|
||||
&& dataset.metadata.context_token_limit == context_token_limit =>
|
||||
{
|
||||
Ok(dataset)
|
||||
}
|
||||
_ => {
|
||||
let dataset = convert(
|
||||
raw_path,
|
||||
dataset_kind,
|
||||
include_unanswerable,
|
||||
context_token_limit,
|
||||
)?;
|
||||
write_converted(&dataset, converted_path)?;
|
||||
Ok(dataset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base_timestamp() -> DateTime<Utc> {
|
||||
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
|
||||
}
|
||||
|
||||
@@ -16,11 +16,7 @@ use super::{ConvertedParagraph, ConvertedQuestion};
|
||||
clippy::arithmetic_side_effects,
|
||||
clippy::cast_sign_loss
|
||||
)]
|
||||
pub fn convert_nq(
|
||||
raw_path: &Path,
|
||||
include_unanswerable: bool,
|
||||
_context_token_limit: Option<usize>,
|
||||
) -> Result<Vec<ConvertedParagraph>> {
|
||||
pub fn convert_nq(raw_path: &Path, include_unanswerable: bool) -> Result<Vec<ConvertedParagraph>> {
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NqExample {
|
||||
|
||||
@@ -0,0 +1,412 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fs::{self, File, OpenOptions},
|
||||
io::{BufRead, BufReader, Write},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
use super::{
|
||||
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetMetadata,
|
||||
checksum::store_aggregate_checksum,
|
||||
};
|
||||
use crate::slice;
|
||||
|
||||
pub const SHARDED_STORE_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ShardedMeta {
|
||||
pub version: u32,
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub metadata: DatasetMetadata,
|
||||
pub source: String,
|
||||
pub paragraph_count: usize,
|
||||
pub question_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct QuestionRecord {
|
||||
paragraph_id: String,
|
||||
#[serde(flatten)]
|
||||
question: ConvertedQuestion,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuestionCatalog {
|
||||
pub entries: Vec<QuestionRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ConvertedLayout {
|
||||
ShardedStore,
|
||||
Missing,
|
||||
}
|
||||
|
||||
pub fn store_dir_for(converted_path: &Path) -> PathBuf {
|
||||
converted_path
|
||||
.parent()
|
||||
.unwrap_or_else(|| Path::new("."))
|
||||
.join(converted_path.file_stem().map_or_else(
|
||||
|| "dataset".to_string(),
|
||||
|stem| stem.to_string_lossy().into(),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn detect_layout(converted_path: &Path) -> ConvertedLayout {
|
||||
let store_dir = store_dir_for(converted_path);
|
||||
if store_dir.join("meta.json").is_file() {
|
||||
ConvertedLayout::ShardedStore
|
||||
} else {
|
||||
ConvertedLayout::Missing
|
||||
}
|
||||
}
|
||||
|
||||
fn paragraph_file_name(paragraph_id: &str) -> String {
|
||||
format!("{}.json", slice::paragraph_storage_key(paragraph_id))
|
||||
}
|
||||
|
||||
pub fn paragraph_path(store_dir: &Path, paragraph_id: &str) -> PathBuf {
|
||||
store_dir
|
||||
.join("paragraphs")
|
||||
.join(paragraph_file_name(paragraph_id))
|
||||
}
|
||||
|
||||
pub fn write_sharded(dataset: &ConvertedDataset, store_dir: &Path) -> Result<String> {
|
||||
if store_dir.exists() {
|
||||
fs::remove_dir_all(store_dir)
|
||||
.with_context(|| format!("clearing sharded store {}", store_dir.display()))?;
|
||||
}
|
||||
fs::create_dir_all(store_dir.join("paragraphs"))
|
||||
.with_context(|| format!("creating sharded store {}", store_dir.display()))?;
|
||||
|
||||
let question_count = dataset
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|paragraph| paragraph.questions.len())
|
||||
.sum::<usize>();
|
||||
|
||||
let meta = ShardedMeta {
|
||||
version: SHARDED_STORE_VERSION,
|
||||
generated_at: dataset.generated_at,
|
||||
metadata: dataset.metadata.clone(),
|
||||
source: dataset.source.clone(),
|
||||
paragraph_count: dataset.paragraphs.len(),
|
||||
question_count,
|
||||
};
|
||||
let meta_path = store_dir.join("meta.json");
|
||||
fs::write(
|
||||
&meta_path,
|
||||
serde_json::to_vec_pretty(&meta).context("serialising sharded store metadata")?,
|
||||
)
|
||||
.with_context(|| format!("writing sharded metadata {}", meta_path.display()))?;
|
||||
|
||||
let mut questions_file = File::create(store_dir.join("questions.jsonl"))
|
||||
.context("creating questions.jsonl for sharded store")?;
|
||||
let mut paragraph_ids_file = File::create(store_dir.join("paragraph_ids.jsonl"))
|
||||
.context("creating paragraph_ids.jsonl for sharded store")?;
|
||||
|
||||
for paragraph in &dataset.paragraphs {
|
||||
writeln!(paragraph_ids_file, "{}", paragraph.id)
|
||||
.context("writing paragraph id to paragraph_ids.jsonl")?;
|
||||
for question in ¶graph.questions {
|
||||
let record = QuestionRecord {
|
||||
paragraph_id: paragraph.id.clone(),
|
||||
question: question.clone(),
|
||||
};
|
||||
serde_json::to_writer(&mut questions_file, &record)
|
||||
.context("writing question record to questions.jsonl")?;
|
||||
questions_file.write_all(b"\n")?;
|
||||
}
|
||||
|
||||
let path = paragraph_path(store_dir, ¶graph.id);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(
|
||||
&path,
|
||||
serde_json::to_vec(paragraph).context("serialising sharded paragraph")?,
|
||||
)
|
||||
.with_context(|| format!("writing sharded paragraph {}", path.display()))?;
|
||||
}
|
||||
|
||||
let digest = store_aggregate_checksum(store_dir)?;
|
||||
info!(
|
||||
store = %store_dir.display(),
|
||||
paragraphs = dataset.paragraphs.len(),
|
||||
questions = question_count,
|
||||
checksum = %digest,
|
||||
"Wrote sharded converted dataset"
|
||||
);
|
||||
Ok(digest)
|
||||
}
|
||||
|
||||
pub fn read_meta(store_dir: &Path) -> Result<ShardedMeta> {
|
||||
let path = store_dir.join("meta.json");
|
||||
let raw = fs::read_to_string(&path)
|
||||
.with_context(|| format!("reading sharded metadata {}", path.display()))?;
|
||||
serde_json::from_str(&raw)
|
||||
.with_context(|| format!("parsing sharded metadata {}", path.display()))
|
||||
}
|
||||
|
||||
pub fn content_checksum_for_layout(converted_path: &Path) -> Result<String> {
|
||||
match detect_layout(converted_path) {
|
||||
ConvertedLayout::ShardedStore => {
|
||||
crate::datasets::store_aggregate_checksum(&store_dir_for(converted_path))
|
||||
}
|
||||
ConvertedLayout::Missing => Err(anyhow!(
|
||||
"converted dataset missing at {}",
|
||||
converted_path.display()
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn load_paragraph(store_dir: &Path, paragraph_id: &str) -> Result<ConvertedParagraph> {
|
||||
let path = paragraph_path(store_dir, paragraph_id);
|
||||
let raw =
|
||||
fs::read(&path).with_context(|| format!("reading sharded paragraph {}", path.display()))?;
|
||||
serde_json::from_slice(&raw)
|
||||
.with_context(|| format!("parsing sharded paragraph {}", path.display()))
|
||||
}
|
||||
|
||||
fn load_paragraphs(store_dir: &Path, paragraph_ids: &[String]) -> Result<Vec<ConvertedParagraph>> {
|
||||
paragraph_ids
|
||||
.iter()
|
||||
.map(|paragraph_id| load_paragraph(store_dir, paragraph_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn load_sharded_partial(
|
||||
store_dir: &Path,
|
||||
paragraph_ids: &[String],
|
||||
) -> Result<ConvertedDataset> {
|
||||
let meta = read_meta(store_dir)?;
|
||||
let mut paragraphs = load_paragraphs(store_dir, paragraph_ids)?;
|
||||
paragraphs.sort_by(|left, right| left.id.cmp(&right.id));
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: meta.generated_at,
|
||||
metadata: meta.metadata,
|
||||
source: meta.source,
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_sharded_full(store_dir: &Path) -> Result<ConvertedDataset> {
|
||||
let meta = read_meta(store_dir)?;
|
||||
let ids = load_paragraph_ids(store_dir)?;
|
||||
let paragraphs = load_paragraphs(store_dir, &ids)?;
|
||||
Ok(ConvertedDataset {
|
||||
generated_at: meta.generated_at,
|
||||
metadata: meta.metadata,
|
||||
source: meta.source,
|
||||
paragraphs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_paragraph_ids_set(store_dir: &Path) -> Result<HashSet<String>> {
|
||||
Ok(load_paragraph_ids(store_dir)?.into_iter().collect())
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects)]
|
||||
pub fn upsert_sharded_paragraphs(
|
||||
store_dir: &Path,
|
||||
paragraphs: &[ConvertedParagraph],
|
||||
) -> Result<()> {
|
||||
if paragraphs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if !store_dir.join("meta.json").is_file() {
|
||||
return Err(anyhow!(
|
||||
"cannot upsert into missing sharded store at {}",
|
||||
store_dir.display()
|
||||
));
|
||||
}
|
||||
|
||||
fs::create_dir_all(store_dir.join("paragraphs"))
|
||||
.with_context(|| format!("creating paragraphs directory in {}", store_dir.display()))?;
|
||||
|
||||
let existing = load_paragraph_ids_set(store_dir)?;
|
||||
let questions_path = store_dir.join("questions.jsonl");
|
||||
let mut questions_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&questions_path)
|
||||
.with_context(|| format!("opening question catalog {}", questions_path.display()))?;
|
||||
|
||||
let mut ids_file = None;
|
||||
let mut new_paragraphs = 0usize;
|
||||
let mut new_questions = 0usize;
|
||||
|
||||
for paragraph in paragraphs {
|
||||
let is_new = !existing.contains(¶graph.id);
|
||||
let path = paragraph_path(store_dir, ¶graph.id);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(
|
||||
&path,
|
||||
serde_json::to_vec(paragraph).context("serialising sharded paragraph")?,
|
||||
)
|
||||
.with_context(|| format!("writing sharded paragraph {}", path.display()))?;
|
||||
|
||||
if is_new {
|
||||
if ids_file.is_none() {
|
||||
ids_file = Some(
|
||||
OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(store_dir.join("paragraph_ids.jsonl"))
|
||||
.context("opening paragraph_ids.jsonl for append")?,
|
||||
);
|
||||
}
|
||||
if let Some(file) = ids_file.as_mut() {
|
||||
writeln!(file, "{}", paragraph.id).context("appending paragraph id")?;
|
||||
}
|
||||
new_paragraphs += 1;
|
||||
|
||||
for question in ¶graph.questions {
|
||||
let record = QuestionRecord {
|
||||
paragraph_id: paragraph.id.clone(),
|
||||
question: question.clone(),
|
||||
};
|
||||
serde_json::to_writer(&mut questions_file, &record)
|
||||
.context("writing question record to questions.jsonl")?;
|
||||
questions_file.write_all(b"\n")?;
|
||||
new_questions += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if new_paragraphs > 0 || new_questions > 0 {
|
||||
let meta = read_meta(store_dir)?;
|
||||
let updated = ShardedMeta {
|
||||
paragraph_count: meta.paragraph_count + new_paragraphs,
|
||||
question_count: meta.question_count + new_questions,
|
||||
..meta
|
||||
};
|
||||
fs::write(
|
||||
store_dir.join("meta.json"),
|
||||
serde_json::to_vec_pretty(&updated).context("serialising updated sharded metadata")?,
|
||||
)?;
|
||||
store_aggregate_checksum(store_dir)?;
|
||||
info!(
|
||||
store = %store_dir.display(),
|
||||
new_paragraphs,
|
||||
new_questions,
|
||||
"Upserted paragraphs into sharded converted store"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load_paragraph_ids(store_dir: &Path) -> Result<Vec<String>> {
|
||||
let path = store_dir.join("paragraph_ids.jsonl");
|
||||
let file = File::open(&path)
|
||||
.with_context(|| format!("opening paragraph id index {}", path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
reader
|
||||
.lines()
|
||||
.map(|line| {
|
||||
line.context("reading paragraph id index line")
|
||||
.and_then(|value| {
|
||||
let trimmed = value.trim();
|
||||
if trimmed.is_empty() {
|
||||
Err(anyhow!("empty paragraph id in index"))
|
||||
} else {
|
||||
Ok(trimmed.to_string())
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn load_question_catalog(store_dir: &Path) -> Result<QuestionCatalog> {
|
||||
let path = store_dir.join("questions.jsonl");
|
||||
let file = File::open(&path)
|
||||
.with_context(|| format!("opening question catalog {}", path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut entries = Vec::new();
|
||||
for line in reader.lines() {
|
||||
let line = line.context("reading question catalog line")?;
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let record: QuestionRecord =
|
||||
serde_json::from_str(&line).context("parsing question catalog record")?;
|
||||
entries.push(record);
|
||||
}
|
||||
Ok(QuestionCatalog { entries })
|
||||
}
|
||||
|
||||
pub fn build_dataset_from_catalog(
|
||||
store_dir: &Path,
|
||||
paragraph_ids: &HashSet<String>,
|
||||
) -> Result<ConvertedDataset> {
|
||||
let catalog = load_question_catalog(store_dir)?;
|
||||
let mut questions_by_paragraph: HashMap<String, Vec<ConvertedQuestion>> = HashMap::new();
|
||||
for entry in catalog.entries {
|
||||
if paragraph_ids.contains(&entry.paragraph_id) {
|
||||
questions_by_paragraph
|
||||
.entry(entry.paragraph_id.clone())
|
||||
.or_default()
|
||||
.push(entry.question);
|
||||
}
|
||||
}
|
||||
|
||||
let mut dataset = load_sharded_partial(
|
||||
store_dir,
|
||||
¶graph_ids.iter().cloned().collect::<Vec<_>>(),
|
||||
)?;
|
||||
for paragraph in &mut dataset.paragraphs {
|
||||
if let Some(questions) = questions_by_paragraph.remove(¶graph.id) {
|
||||
paragraph.questions = questions;
|
||||
} else {
|
||||
paragraph.questions.clear();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::datasets::{DatasetKind, DatasetMetadata};
|
||||
|
||||
fn sample_dataset() -> ConvertedDataset {
|
||||
ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata: DatasetMetadata::for_kind(DatasetKind::SquadV2, false),
|
||||
source: "test".to_string(),
|
||||
paragraphs: vec![ConvertedParagraph {
|
||||
id: "p1".to_string(),
|
||||
title: "Title".to_string(),
|
||||
context: "Body".to_string(),
|
||||
questions: vec![ConvertedQuestion {
|
||||
id: "q1".to_string(),
|
||||
question: "Question?".to_string(),
|
||||
answers: vec!["Answer".to_string()],
|
||||
is_impossible: false,
|
||||
}],
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
fn sharded_round_trip() -> Result<()> {
|
||||
let dir = tempfile::tempdir()?;
|
||||
let store_dir = dir.path().join("sample");
|
||||
let dataset = sample_dataset();
|
||||
write_sharded(&dataset, &store_dir)?;
|
||||
|
||||
let loaded = load_sharded_full(&store_dir)?;
|
||||
assert_eq!(loaded.paragraphs.len(), 1);
|
||||
assert_eq!(loaded.paragraphs[0].questions[0].id, "q1");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,22 @@
|
||||
//! Database namespace management utilities.
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use chrono::Utc;
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::user::{Theme, User},
|
||||
types::StoredObject,
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::StoredObject,
|
||||
types::user::{Theme, User},
|
||||
},
|
||||
utils::embedding::EmbeddingProvider,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
args::Config,
|
||||
corpus::{self, CorpusHandle, CorpusManifest, NamespaceSeedRecord},
|
||||
datasets,
|
||||
snapshot::{self, DbSnapshotState},
|
||||
};
|
||||
|
||||
/// Connect to the evaluation database with fallback auth strategies.
|
||||
pub(crate) async fn connect_eval_db(
|
||||
config: &Config,
|
||||
namespace: &str,
|
||||
@@ -73,7 +73,6 @@ pub(crate) async fn connect_eval_db(
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the namespace contains any corpus data.
|
||||
pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
|
||||
#[derive(Deserialize)]
|
||||
struct CountRow {
|
||||
@@ -89,41 +88,51 @@ pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
|
||||
Ok(rows.first().map_or(0, |row| row.count) > 0)
|
||||
}
|
||||
|
||||
/// Determine if we can reuse an existing namespace based on cached state.
|
||||
fn manifest_matches_runtime(
|
||||
manifest: &CorpusManifest,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
ingestion_fingerprint: &str,
|
||||
) -> bool {
|
||||
let metadata = &manifest.metadata;
|
||||
metadata.ingestion_fingerprint == ingestion_fingerprint
|
||||
&& metadata.embedding_backend == embedding_provider.backend_label()
|
||||
&& metadata.embedding_model == embedding_provider.model_code()
|
||||
&& metadata.embedding_dimension == embedding_provider.dimension()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn can_reuse_namespace(
|
||||
db: &SurrealDbClient,
|
||||
descriptor: &snapshot::Descriptor,
|
||||
manifest: &CorpusManifest,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
dataset_id: &str,
|
||||
slice_id: &str,
|
||||
ingestion_fingerprint: &str,
|
||||
slice_case_count: usize,
|
||||
) -> Result<bool> {
|
||||
let Some(state) = descriptor.load_db_state().await? else {
|
||||
info!("No namespace state recorded; reseeding corpus from cached shards");
|
||||
if !manifest_matches_runtime(manifest, embedding_provider, ingestion_fingerprint) {
|
||||
info!("Corpus manifest metadata mismatch; rebuilding namespace from cached shards");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let Some(seed) = manifest.metadata.namespace_seed.as_ref() else {
|
||||
info!("No namespace seed recorded in corpus manifest; reseeding");
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
if state.slice_case_count != slice_case_count {
|
||||
if seed.slice_case_count != slice_case_count {
|
||||
info!(
|
||||
requested_cases = slice_case_count,
|
||||
stored_cases = state.slice_case_count,
|
||||
"Skipping live namespace reuse; cached state does not match requested window"
|
||||
stored_cases = seed.slice_case_count,
|
||||
"Skipping namespace reuse; case window mismatch"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if state.dataset_id != dataset_id
|
||||
|| state.slice_id != slice_id
|
||||
|| state.ingestion_fingerprint != ingestion_fingerprint
|
||||
|| state.namespace.as_deref() != Some(namespace)
|
||||
|| state.database.as_deref() != Some(database)
|
||||
{
|
||||
if seed.namespace != namespace || seed.database != database {
|
||||
info!(
|
||||
namespace,
|
||||
database, "Cached namespace metadata mismatch; rebuilding corpus from ingestion cache"
|
||||
database, "Corpus manifest namespace metadata mismatch; reseeding"
|
||||
);
|
||||
return Ok(false);
|
||||
}
|
||||
@@ -140,28 +149,20 @@ pub(crate) async fn can_reuse_namespace(
|
||||
}
|
||||
}
|
||||
|
||||
/// Record the current namespace state to allow future reuse checks.
|
||||
pub(crate) async fn record_namespace_state(
|
||||
descriptor: &snapshot::Descriptor,
|
||||
dataset_id: &str,
|
||||
slice_id: &str,
|
||||
ingestion_fingerprint: &str,
|
||||
pub(crate) async fn record_namespace_seed(
|
||||
handle: &mut CorpusHandle,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
slice_case_count: usize,
|
||||
) {
|
||||
let state = DbSnapshotState {
|
||||
dataset_id: dataset_id.to_string(),
|
||||
slice_id: slice_id.to_string(),
|
||||
ingestion_fingerprint: ingestion_fingerprint.to_string(),
|
||||
snapshot_hash: descriptor.metadata_hash().to_string(),
|
||||
updated_at: Utc::now(),
|
||||
namespace: Some(namespace.to_string()),
|
||||
database: Some(database.to_string()),
|
||||
handle.manifest.metadata.namespace_seed = Some(NamespaceSeedRecord {
|
||||
namespace: namespace.to_string(),
|
||||
database: database.to_string(),
|
||||
slice_case_count,
|
||||
};
|
||||
if let Err(err) = descriptor.store_db_state(&state).await {
|
||||
warn!(error = %err, "Failed to record namespace state");
|
||||
seeded_at: Utc::now(),
|
||||
});
|
||||
if let Err(err) = corpus::persist_corpus_manifest(handle) {
|
||||
warn!(error = %err, "Failed to record namespace seed in corpus manifest");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,8 +186,17 @@ fn sanitize_identifier(input: &str) -> String {
|
||||
cleaned
|
||||
}
|
||||
|
||||
/// Generate a default namespace name based on dataset and limit.
|
||||
pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> String {
|
||||
pub(crate) fn default_namespace(
|
||||
dataset_id: &str,
|
||||
limit: Option<usize>,
|
||||
slice_id: Option<&str>,
|
||||
) -> String {
|
||||
if let Some(slice_id) = slice_id {
|
||||
let sanitized = sanitize_identifier(slice_id);
|
||||
if !sanitized.is_empty() {
|
||||
return format!("eval_{sanitized}");
|
||||
}
|
||||
}
|
||||
let dataset_component = sanitize_identifier(dataset_id);
|
||||
let limit_component = match limit {
|
||||
Some(value) if value > 0 => format!("limit{value}"),
|
||||
@@ -195,12 +205,10 @@ pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> Strin
|
||||
format!("eval_{dataset_component}_{limit_component}")
|
||||
}
|
||||
|
||||
/// Generate the default database name for evaluations.
|
||||
pub(crate) fn default_database() -> String {
|
||||
"retrieval_eval".to_string()
|
||||
}
|
||||
|
||||
/// Ensure the evaluation user exists in the database.
|
||||
pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
||||
let timestamp = datasets::base_timestamp();
|
||||
let user = User {
|
||||
@@ -225,3 +233,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
|
||||
.context("storing evaluation user")?;
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub(crate) fn sanitize_model_code(code: &str) -> String {
|
||||
sanitize_identifier(code)
|
||||
}
|
||||
@@ -2,13 +2,6 @@ use anyhow::{Context, Result};
|
||||
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime};
|
||||
use tracing::info;
|
||||
|
||||
// Helper functions for index management during namespace reseed
|
||||
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
|
||||
let _ = db;
|
||||
info!("Removing ALL indexes before namespace reseed (no-op placeholder)");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
|
||||
ensure_runtime(db, dimension)
|
||||
@@ -34,14 +27,39 @@ pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &s
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// // Test helper to force index dimension change
|
||||
// #[allow(dead_code)]
|
||||
// pub async fn change_embedding_length_in_hnsw_indexes(
|
||||
// db: &SurrealDbClient,
|
||||
// dimension: usize,
|
||||
// ) -> Result<()> {
|
||||
// recreate_indexes(db, dimension).await
|
||||
// }
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
|
||||
|
||||
info!("Warming HNSW caches with sample queries");
|
||||
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
r#"SELECT chunk_id
|
||||
FROM text_chunk_embedding
|
||||
WHERE embedding <|1,1|> $embedding
|
||||
LIMIT 5"#,
|
||||
)
|
||||
.bind(("embedding", dummy_embedding.clone()))
|
||||
.await
|
||||
.context("warming text chunk HNSW cache")?;
|
||||
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
r#"SELECT entity_id
|
||||
FROM knowledge_entity_embedding
|
||||
WHERE embedding <|1,1|> $embedding
|
||||
LIMIT 5"#,
|
||||
)
|
||||
.bind(("embedding", dummy_embedding))
|
||||
.await
|
||||
.context("warming knowledge entity HNSW cache")?;
|
||||
|
||||
info!("HNSW cache warming completed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -0,0 +1,9 @@
|
||||
mod connect;
|
||||
mod lifecycle;
|
||||
|
||||
pub(crate) use connect::{
|
||||
can_reuse_namespace, connect_eval_db, default_database, default_namespace, ensure_eval_user,
|
||||
namespace_has_corpus, record_namespace_seed, sanitize_model_code,
|
||||
};
|
||||
pub(crate) use lifecycle::warm_hnsw_cache;
|
||||
pub use lifecycle::{recreate_indexes, reset_namespace};
|
||||
@@ -1,128 +0,0 @@
|
||||
//! Evaluation utilities module - re-exports from focused submodules.
|
||||
|
||||
// Re-export types from the root types module
|
||||
pub use crate::types::*;
|
||||
|
||||
// Re-export from focused modules at crate root (crate-internal only)
|
||||
pub(crate) use crate::cases::{cases_from_manifest, SeededCase};
|
||||
pub(crate) use crate::namespace::{
|
||||
can_reuse_namespace, connect_eval_db, default_database, default_namespace, ensure_eval_user,
|
||||
record_namespace_state,
|
||||
};
|
||||
pub(crate) use crate::settings::{enforce_system_settings, load_or_init_system_settings};
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use common::storage::db::SurrealDbClient;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
args::{self, Config},
|
||||
datasets::ConvertedDataset,
|
||||
slice::{self},
|
||||
};
|
||||
|
||||
/// Grow the slice ledger to contain the target number of cases.
|
||||
pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||
let ledger_limit = ledger_target(config);
|
||||
let slice_settings = slice::slice_config_with_limit(config, ledger_limit);
|
||||
let slice =
|
||||
slice::resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
|
||||
info!(
|
||||
slice = slice.manifest.slice_id.as_str(),
|
||||
cases = slice.manifest.case_count,
|
||||
positives = slice.manifest.positive_paragraphs,
|
||||
negatives = slice.manifest.negative_paragraphs,
|
||||
total_paragraphs = slice.manifest.total_paragraphs,
|
||||
"Slice ledger ready"
|
||||
);
|
||||
println!(
|
||||
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
|
||||
slice.manifest.slice_id,
|
||||
slice.manifest.case_count,
|
||||
slice.manifest.positive_paragraphs,
|
||||
slice.manifest.negative_paragraphs
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn ledger_target(config: &Config) -> Option<usize> {
|
||||
match (config.slice_grow, config.limit) {
|
||||
(Some(grow), Some(limit)) => Some(limit.max(grow)),
|
||||
(Some(grow), None) => Some(grow),
|
||||
(None, limit) => limit,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
|
||||
args::ensure_parent(path)?;
|
||||
let mut file = tokio::fs::File::create(path)
|
||||
.await
|
||||
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
|
||||
for case in cases {
|
||||
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
|
||||
file.write_all(&line).await?;
|
||||
file.write_all(b"\n").await?;
|
||||
}
|
||||
file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
|
||||
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
|
||||
|
||||
info!("Warming HNSW caches with sample queries");
|
||||
|
||||
// Warm up chunk embedding index - just query the embedding table to load HNSW index
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
r#"SELECT chunk_id
|
||||
FROM text_chunk_embedding
|
||||
WHERE embedding <|1,1|> $embedding
|
||||
LIMIT 5"#,
|
||||
)
|
||||
.bind(("embedding", dummy_embedding.clone()))
|
||||
.await
|
||||
.context("warming text chunk HNSW cache")?;
|
||||
|
||||
// Warm up entity embedding index
|
||||
let _ = db
|
||||
.client
|
||||
.query(
|
||||
r#"SELECT entity_id
|
||||
FROM knowledge_entity_embedding
|
||||
WHERE embedding <|1,1|> $embedding
|
||||
LIMIT 5"#,
|
||||
)
|
||||
.bind(("embedding", dummy_embedding))
|
||||
.await
|
||||
.context("warming knowledge entity HNSW cache")?;
|
||||
|
||||
info!("HNSW cache warming completed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
|
||||
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
|
||||
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
|
||||
}
|
||||
|
||||
pub(crate) fn sanitize_model_code(code: &str) -> String {
|
||||
code.chars()
|
||||
.map(|ch| {
|
||||
if ch.is_ascii_alphanumeric() {
|
||||
ch.to_ascii_lowercase()
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Re-export run_evaluation from the pipeline module at crate root
|
||||
pub use crate::pipeline::run_evaluation;
|
||||
@@ -1,13 +1,9 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use std::{collections::HashMap, fs, path::Path};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
|
||||
|
||||
use crate::{args::Config, corpus, eval::connect_eval_db, snapshot::DbSnapshotState};
|
||||
use crate::{args::Config, corpus, db::connect_eval_db};
|
||||
|
||||
pub async fn inspect_question(config: &Config) -> Result<()> {
|
||||
let question_id = config
|
||||
@@ -64,39 +60,26 @@ pub async fn inspect_question(config: &Config) -> Result<()> {
|
||||
);
|
||||
}
|
||||
|
||||
let db_state_path = config
|
||||
.database
|
||||
.inspect_db_state
|
||||
.clone()
|
||||
.unwrap_or_else(|| default_state_path(config, &manifest));
|
||||
if let Some(state) = load_db_state(&db_state_path)? {
|
||||
if let (Some(ns), Some(db_name)) = (state.namespace.as_deref(), state.database.as_deref()) {
|
||||
match connect_eval_db(config, ns, db_name).await {
|
||||
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
|
||||
MissingChunks::None => println!(
|
||||
"All matching_chunk_ids exist in namespace '{ns}', database '{db_name}'"
|
||||
),
|
||||
MissingChunks::Missing(list) => println!(
|
||||
"Missing chunks in namespace '{ns}', database '{db_name}': {list:?}"
|
||||
),
|
||||
},
|
||||
Err(err) => {
|
||||
println!(
|
||||
"Failed to connect to SurrealDB namespace '{ns}' / database '{db_name}': {err}"
|
||||
);
|
||||
if let Some(seed) = manifest.metadata.namespace_seed.as_ref() {
|
||||
let ns = seed.namespace.as_str();
|
||||
let db_name = seed.database.as_str();
|
||||
match connect_eval_db(config, ns, db_name).await {
|
||||
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
|
||||
MissingChunks::None => println!(
|
||||
"All matching_chunk_ids exist in namespace '{ns}', database '{db_name}'"
|
||||
),
|
||||
MissingChunks::Missing(list) => {
|
||||
println!("Missing chunks in namespace '{ns}', database '{db_name}': {list:?}");
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
println!(
|
||||
"Failed to connect to SurrealDB namespace '{ns}' / database '{db_name}': {err}"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"State file {} is missing namespace/database fields; skipping live DB validation",
|
||||
db_state_path.display()
|
||||
);
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"State file {} not found; skipping live DB validation",
|
||||
db_state_path.display()
|
||||
);
|
||||
println!("Corpus manifest has no namespace seed; skipping live DB validation");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -137,25 +120,6 @@ fn build_chunk_lookup(manifest: &corpus::CorpusManifest) -> HashMap<String, Chun
|
||||
lookup
|
||||
}
|
||||
|
||||
fn default_state_path(config: &Config, manifest: &corpus::CorpusManifest) -> PathBuf {
|
||||
config
|
||||
.cache_dir
|
||||
.join("snapshots")
|
||||
.join(&manifest.metadata.dataset_id)
|
||||
.join(&manifest.metadata.slice_id)
|
||||
.join("db/state.json")
|
||||
}
|
||||
|
||||
fn load_db_state(path: &Path) -> Result<Option<DbSnapshotState>> {
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let bytes = fs::read(path).with_context(|| format!("reading db state {}", path.display()))?;
|
||||
let state = serde_json::from_slice(&bytes)
|
||||
.with_context(|| format!("parsing db state {}", path.display()))?;
|
||||
Ok(Some(state))
|
||||
}
|
||||
|
||||
enum MissingChunks {
|
||||
None,
|
||||
Missing(Vec<String>),
|
||||
|
||||
+64
-62
@@ -1,55 +1,51 @@
|
||||
mod args;
|
||||
mod cache;
|
||||
mod cases;
|
||||
mod cli;
|
||||
mod context_stats;
|
||||
mod corpus;
|
||||
mod datasets;
|
||||
mod db_helpers;
|
||||
mod eval;
|
||||
mod db;
|
||||
mod inspection;
|
||||
mod namespace;
|
||||
mod openai;
|
||||
mod perf;
|
||||
mod pipeline;
|
||||
mod report;
|
||||
mod settings;
|
||||
mod slice;
|
||||
mod snapshot;
|
||||
mod types;
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::runtime::Builder;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
use tracing_subscriber::{EnvFilter, fmt};
|
||||
|
||||
/// Configure `SurrealDB` environment variables for optimal performance
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::unwrap_used)]
|
||||
fn configure_surrealdb_performance(cpu_count: usize) {
|
||||
// Set environment variables only if they're not already set
|
||||
let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 2).to_string());
|
||||
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
|
||||
|
||||
let max_order_queue = std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 4).to_string());
|
||||
std::env::set_var(
|
||||
"SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE",
|
||||
max_order_queue,
|
||||
);
|
||||
|
||||
let websocket_concurrent = std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
|
||||
.unwrap_or_else(|_| cpu_count.to_string());
|
||||
std::env::set_var(
|
||||
"SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS",
|
||||
websocket_concurrent,
|
||||
);
|
||||
|
||||
let websocket_buffer = std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 8).to_string());
|
||||
std::env::set_var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE", websocket_buffer);
|
||||
|
||||
let transaction_cache = std::env::var("SURREAL_TRANSACTION_CACHE_SIZE")
|
||||
.unwrap_or_else(|_| (cpu_count * 16).to_string());
|
||||
std::env::set_var("SURREAL_TRANSACTION_CACHE_SIZE", transaction_cache);
|
||||
// SAFETY: single-threaded setup before SurrealDB clients are created.
|
||||
unsafe {
|
||||
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
|
||||
std::env::set_var(
|
||||
"SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE",
|
||||
max_order_queue,
|
||||
);
|
||||
std::env::set_var(
|
||||
"SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS",
|
||||
websocket_concurrent,
|
||||
);
|
||||
std::env::set_var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE", websocket_buffer);
|
||||
std::env::set_var("SURREAL_TRANSACTION_CACHE_SIZE", transaction_cache);
|
||||
}
|
||||
|
||||
info!(
|
||||
indexing_batch_size = %std::env::var("SURREAL_INDEXING_BATCH_SIZE").unwrap(),
|
||||
@@ -62,12 +58,11 @@ fn configure_surrealdb_performance(cpu_count: usize) {
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Create an explicit multi-threaded runtime with optimized configuration
|
||||
let runtime = Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.worker_threads(std::thread::available_parallelism()?.get())
|
||||
.max_blocking_threads(std::thread::available_parallelism()?.get())
|
||||
.thread_stack_size(10 * 1024 * 1024) // 10MiB stack size
|
||||
.thread_stack_size(10 * 1024 * 1024)
|
||||
.thread_name("eval-retrieval-worker")
|
||||
.build()
|
||||
.context("failed to create tokio runtime")?;
|
||||
@@ -77,7 +72,6 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn async_main() -> anyhow::Result<()> {
|
||||
// Log runtime configuration
|
||||
let cpu_count = std::thread::available_parallelism()?.get();
|
||||
info!(
|
||||
cpu_cores = cpu_count,
|
||||
@@ -87,7 +81,6 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
"Started multi-threaded tokio runtime"
|
||||
);
|
||||
|
||||
// Configure SurrealDB environment variables for better performance
|
||||
configure_surrealdb_performance(cpu_count);
|
||||
|
||||
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
|
||||
@@ -97,13 +90,22 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
|
||||
let parsed = args::parse()?;
|
||||
|
||||
// Clap handles help automatically, so we don't need to check for it manually
|
||||
|
||||
if parsed.config.inspect_question.is_some() {
|
||||
inspection::inspect_question(&parsed.config).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if parsed.config.status {
|
||||
let status = cli::collect_status(&parsed.config).await?;
|
||||
cli::print_status(&status);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if parsed.config.warm {
|
||||
cli::warm(&parsed.config).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let dataset_kind = parsed.config.dataset;
|
||||
|
||||
if parsed.config.convert_only {
|
||||
@@ -115,7 +117,6 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
parsed.config.raw_dataset_path.as_path(),
|
||||
dataset_kind,
|
||||
parsed.config.llm_mode,
|
||||
parsed.config.context_token_limit(),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
@@ -124,56 +125,52 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
parsed.config.raw_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
crate::datasets::write_converted(&dataset, parsed.config.converted_dataset_path.as_path())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"writing converted dataset to {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
println!(
|
||||
"Converted dataset written to {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
);
|
||||
let store_dir = datasets::store_dir_for(&parsed.config.converted_dataset_path);
|
||||
datasets::write_sharded(&dataset, &store_dir)?;
|
||||
datasets::prebuild_catalog_slices(&dataset, &parsed.config)?;
|
||||
println!("Converted dataset written under {}", store_dir.display());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if parsed.config.require_ready {
|
||||
cli::ensure_query_ready(&parsed.config).await?;
|
||||
}
|
||||
|
||||
info!(dataset = dataset_kind.id(), "Preparing converted dataset");
|
||||
let dataset = crate::datasets::ensure_converted(
|
||||
dataset_kind,
|
||||
parsed.config.raw_dataset_path.as_path(),
|
||||
parsed.config.converted_dataset_path.as_path(),
|
||||
parsed.config.force_convert,
|
||||
parsed.config.llm_mode,
|
||||
parsed.config.context_token_limit(),
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"preparing converted dataset at {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
let loaded =
|
||||
crate::datasets::prepare_dataset(dataset_kind, &parsed.config).with_context(|| {
|
||||
format!(
|
||||
"preparing converted dataset at {}",
|
||||
parsed.config.converted_dataset_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
info!(
|
||||
questions = dataset
|
||||
questions = loaded
|
||||
.dataset
|
||||
.paragraphs
|
||||
.iter()
|
||||
.map(|p| p.questions.len())
|
||||
.sum::<usize>(),
|
||||
paragraphs = dataset.paragraphs.len(),
|
||||
dataset = dataset.metadata.id.as_str(),
|
||||
paragraphs = loaded.dataset.paragraphs.len(),
|
||||
partial = loaded.partial,
|
||||
dataset = loaded.dataset.metadata.id.as_str(),
|
||||
"Dataset ready"
|
||||
);
|
||||
|
||||
if parsed.config.slice_grow.is_some() {
|
||||
eval::grow_slice(&dataset, &parsed.config).context("growing slice ledger")?;
|
||||
slice::grow_slice(&loaded.dataset, &parsed.config).context("growing slice ledger")?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Running retrieval evaluation");
|
||||
let summary = eval::run_evaluation(&dataset, &parsed.config)
|
||||
.await
|
||||
.context("running retrieval evaluation")?;
|
||||
let summary = pipeline::run_evaluation(
|
||||
&loaded.dataset,
|
||||
&parsed.config,
|
||||
Some(loaded.content_checksum.as_str()),
|
||||
)
|
||||
.await
|
||||
.context("running retrieval evaluation")?;
|
||||
|
||||
let report = report::write_reports(
|
||||
&summary,
|
||||
@@ -226,12 +223,17 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
|
||||
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) | Retrieved context: {chunks} chunks, {tokens} tokens ({tokenizer}, avg {avg_tokens:.0}/query, p95 {p95}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
|
||||
summary.dataset_label,
|
||||
k = summary.k,
|
||||
precision = summary.precision,
|
||||
correct = summary.correct,
|
||||
retrieval_total = summary.retrieval_cases,
|
||||
chunks = summary.retrieved_context.total_chunks,
|
||||
tokens = summary.retrieved_context.total_tokens,
|
||||
tokenizer = summary.retrieved_context.tokenizer,
|
||||
avg_tokens = summary.retrieved_context.avg_tokens_per_query,
|
||||
p95 = summary.retrieved_context.p95_tokens_per_query,
|
||||
json = report.paths.json.display(),
|
||||
md = report.paths.markdown.display(),
|
||||
history = report.history_path.display(),
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use async_openai::{Client, config::OpenAIConfig};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
|
||||
pub fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
|
||||
pub fn ingestion_openai_client(
|
||||
include_entities: bool,
|
||||
) -> Result<(Arc<Client<OpenAIConfig>>, Option<String>)> {
|
||||
if include_entities {
|
||||
let (client, base_url) = build_client_from_env().context(
|
||||
"OPENAI_API_KEY must be set when --include-entities is enabled (entity extraction uses OpenAI)",
|
||||
)?;
|
||||
Ok((Arc::new(client), Some(base_url)))
|
||||
} else {
|
||||
Ok((Arc::new(Client::with_config(OpenAIConfig::default())), None))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.context("OPENAI_API_KEY must be set to run retrieval evaluations")?;
|
||||
let base_url =
|
||||
|
||||
+11
-7
@@ -7,8 +7,8 @@ use anyhow::{Context, Result};
|
||||
|
||||
use crate::{
|
||||
args,
|
||||
eval::EvaluationSummary,
|
||||
report::{self, EvaluationReport},
|
||||
types::EvaluationSummary,
|
||||
};
|
||||
|
||||
pub fn mirror_perf_outputs(
|
||||
@@ -91,23 +91,26 @@ fn format_duration(value: Option<u128>) -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::eval::{EvaluationStageTimings, PerformanceTimings};
|
||||
use crate::types::{
|
||||
EvaluationStageTimings, LatencyStats, PerformanceTimings, StageLatency,
|
||||
StageLatencyBreakdown,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn sample_latency() -> crate::eval::LatencyStats {
|
||||
crate::eval::LatencyStats {
|
||||
fn sample_latency() -> LatencyStats {
|
||||
LatencyStats {
|
||||
avg: 10.0,
|
||||
p50: 8,
|
||||
p95: 15,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
|
||||
crate::eval::StageLatencyBreakdown {
|
||||
fn sample_stage_latency() -> StageLatencyBreakdown {
|
||||
StageLatencyBreakdown {
|
||||
stages: ["embed", "search", "rerank", "resolve_entities", "assemble"]
|
||||
.into_iter()
|
||||
.map(|stage| crate::eval::StageLatency {
|
||||
.map(|stage| StageLatency {
|
||||
stage: stage.to_string(),
|
||||
stats: sample_latency(),
|
||||
})
|
||||
@@ -206,6 +209,7 @@ mod tests {
|
||||
chunk_vector_take: 20,
|
||||
chunk_fts_take: 20,
|
||||
max_chunks_per_entity: 4,
|
||||
retrieved_context: crate::context_stats::aggregate_context_stats(&[]),
|
||||
cases: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_openai::Client;
|
||||
use common::{
|
||||
storage::{
|
||||
@@ -20,11 +20,11 @@ use retrieval_pipeline::{
|
||||
|
||||
use crate::{
|
||||
args::Config,
|
||||
cache::EmbeddingCache,
|
||||
cases::SeededCase,
|
||||
corpus,
|
||||
datasets::ConvertedDataset,
|
||||
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
|
||||
slice, snapshot,
|
||||
slice,
|
||||
types::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary},
|
||||
};
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
@@ -41,12 +41,10 @@ pub(super) struct EvaluationContext<'a> {
|
||||
pub namespace: String,
|
||||
pub database: String,
|
||||
pub db: Option<SurrealDbClient>,
|
||||
pub descriptor: Option<snapshot::Descriptor>,
|
||||
pub settings: Option<SystemSettings>,
|
||||
pub settings_missing: bool,
|
||||
pub must_reapply_settings: bool,
|
||||
pub embedding_provider: Option<EmbeddingProvider>,
|
||||
pub embedding_cache: Option<EmbeddingCache>,
|
||||
pub openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||
pub openai_base_url: Option<String>,
|
||||
pub expected_fingerprint: Option<String>,
|
||||
@@ -67,13 +65,19 @@ pub(super) struct EvaluationContext<'a> {
|
||||
pub summary: Option<EvaluationSummary>,
|
||||
pub diagnostics_path: Option<PathBuf>,
|
||||
pub diagnostics_enabled: bool,
|
||||
pub content_checksum: Option<String>,
|
||||
}
|
||||
|
||||
impl<'a> EvaluationContext<'a> {
|
||||
pub fn new(dataset: &'a ConvertedDataset, config: &'a Config) -> Self {
|
||||
pub fn new(
|
||||
dataset: &'a ConvertedDataset,
|
||||
config: &'a Config,
|
||||
content_checksum: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
dataset,
|
||||
config,
|
||||
content_checksum,
|
||||
stage_timings: EvaluationStageTimings::default(),
|
||||
ledger_limit: None,
|
||||
slice_settings: None,
|
||||
@@ -84,12 +88,10 @@ impl<'a> EvaluationContext<'a> {
|
||||
namespace: String::new(),
|
||||
database: String::new(),
|
||||
db: None,
|
||||
descriptor: None,
|
||||
settings: None,
|
||||
settings_missing: false,
|
||||
must_reapply_settings: false,
|
||||
embedding_provider: None,
|
||||
embedding_cache: None,
|
||||
openai_client: None,
|
||||
openai_base_url: None,
|
||||
expected_fingerprint: None,
|
||||
@@ -133,12 +135,6 @@ impl<'a> EvaluationContext<'a> {
|
||||
.ok_or_else(|| anyhow!("database connection missing"))
|
||||
}
|
||||
|
||||
pub fn descriptor(&self) -> Result<&snapshot::Descriptor> {
|
||||
self.descriptor
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("snapshot descriptor unavailable"))
|
||||
}
|
||||
|
||||
pub fn embedding_provider(&self) -> Result<&EmbeddingProvider> {
|
||||
self.embedding_provider
|
||||
.as_ref()
|
||||
@@ -159,6 +155,10 @@ impl<'a> EvaluationContext<'a> {
|
||||
.ok_or_else(|| anyhow!("corpus handle missing"))
|
||||
}
|
||||
|
||||
pub fn content_checksum(&self) -> Option<&str> {
|
||||
self.content_checksum.as_deref()
|
||||
}
|
||||
|
||||
pub fn evaluation_user(&self) -> Result<&User> {
|
||||
self.eval_user
|
||||
.as_ref()
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use crate::{args, types::CaseDiagnostics};
|
||||
|
||||
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
|
||||
args::ensure_parent(path)?;
|
||||
let mut file = tokio::fs::File::create(path)
|
||||
.await
|
||||
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
|
||||
for case in cases {
|
||||
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
|
||||
file.write_all(&line).await?;
|
||||
file.write_all(b"\n").await?;
|
||||
}
|
||||
file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
mod context;
|
||||
mod diagnostics;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
@@ -8,20 +8,40 @@ use crate::{args::Config, datasets::ConvertedDataset, types::EvaluationSummary};
|
||||
|
||||
use context::EvaluationContext;
|
||||
|
||||
async fn run_through_namespace<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
config: &'a Config,
|
||||
content_checksum: Option<String>,
|
||||
) -> Result<EvaluationContext<'a>> {
|
||||
let mut ctx = EvaluationContext::new(dataset, config, content_checksum);
|
||||
stages::prepare_slice(&mut ctx).await?;
|
||||
stages::prepare_db(&mut ctx).await?;
|
||||
stages::prepare_corpus(&mut ctx).await?;
|
||||
stages::prepare_namespace(&mut ctx).await?;
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
pub async fn warm_evaluation(
|
||||
dataset: &ConvertedDataset,
|
||||
config: &Config,
|
||||
content_checksum: &str,
|
||||
) -> Result<()> {
|
||||
let _ctx = run_through_namespace(dataset, config, Some(content_checksum.to_string())).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_evaluation(
|
||||
dataset: &ConvertedDataset,
|
||||
config: &Config,
|
||||
content_checksum: Option<&str>,
|
||||
) -> Result<EvaluationSummary> {
|
||||
let mut ctx = EvaluationContext::new(dataset, config);
|
||||
let machine = state::ready();
|
||||
|
||||
let machine = stages::prepare_slice(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_db(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_corpus(machine, &mut ctx).await?;
|
||||
let machine = stages::prepare_namespace(machine, &mut ctx).await?;
|
||||
let machine = stages::run_queries(machine, &mut ctx).await?;
|
||||
let machine = stages::summarize(machine, &mut ctx).await?;
|
||||
let _ = stages::finalize(machine, &mut ctx).await?;
|
||||
|
||||
let mut ctx = EvaluationContext::new(dataset, config, content_checksum.map(str::to_string));
|
||||
stages::prepare_slice(&mut ctx).await?;
|
||||
stages::prepare_db(&mut ctx).await?;
|
||||
stages::prepare_corpus(&mut ctx).await?;
|
||||
stages::prepare_namespace(&mut ctx).await?;
|
||||
stages::run_queries(&mut ctx).await?;
|
||||
stages::summarize(&mut ctx).await?;
|
||||
stages::finalize(&mut ctx).await?;
|
||||
ctx.into_summary()
|
||||
}
|
||||
|
||||
@@ -3,18 +3,12 @@ use std::time::Instant;
|
||||
use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::eval::write_chunk_diagnostics;
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{Completed, EvaluationMachine, Summarized},
|
||||
diagnostics::write_chunk_diagnostics,
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
|
||||
pub(crate) async fn finalize(
|
||||
machine: EvaluationMachine<(), Summarized>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<Completed> {
|
||||
pub(crate) async fn finalize(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::Finalize;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -22,19 +16,12 @@ pub(crate) async fn finalize(
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
if let Some(cache) = ctx.embedding_cache.as_ref() {
|
||||
cache
|
||||
.persist()
|
||||
if let Some(path) = ctx.diagnostics_path.as_ref()
|
||||
&& ctx.diagnostics_enabled
|
||||
{
|
||||
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
|
||||
.await
|
||||
.context("persisting embedding cache")?;
|
||||
}
|
||||
|
||||
if let Some(path) = ctx.diagnostics_path.as_ref() {
|
||||
if ctx.diagnostics_enabled {
|
||||
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
|
||||
.await
|
||||
.with_context(|| format!("writing chunk diagnostics to {}", path.display()))?;
|
||||
}
|
||||
.with_context(|| format!("writing chunk diagnostics to {}", path.display()))?;
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -53,7 +40,5 @@ pub(crate) async fn finalize(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.finalize()
|
||||
.map_err(|(_, guard)| map_guard_error("finalize", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -13,14 +13,3 @@ pub(crate) use prepare_namespace::prepare_namespace;
|
||||
pub(crate) use prepare_slice::prepare_slice;
|
||||
pub(crate) use run_queries::run_queries;
|
||||
pub(crate) use summarize::summarize;
|
||||
|
||||
use anyhow::Result;
|
||||
use state_machines::core::GuardError;
|
||||
|
||||
use super::state::EvaluationMachine;
|
||||
|
||||
fn map_guard_error(event: &str, guard: &GuardError) -> anyhow::Error {
|
||||
anyhow::anyhow!("invalid evaluation pipeline transition during {event}: {guard:?}")
|
||||
}
|
||||
|
||||
type StageResult<S> = Result<EvaluationMachine<(), S>>;
|
||||
|
||||
@@ -3,19 +3,12 @@ use std::time::Instant;
|
||||
use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{corpus, eval::can_reuse_namespace, slice, snapshot};
|
||||
use crate::{corpus, db::can_reuse_namespace, slice};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{CorpusReady, DbReady, EvaluationMachine},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn prepare_corpus(
|
||||
machine: EvaluationMachine<(), DbReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<CorpusReady> {
|
||||
pub(crate) async fn prepare_corpus(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::PrepareCorpus;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -31,13 +24,13 @@ pub(crate) async fn prepare_corpus(
|
||||
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
|
||||
.context("selecting slice window for corpus preparation")?;
|
||||
|
||||
let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider()?);
|
||||
let ingestion_config = corpus::make_ingestion_config(config);
|
||||
let expected_fingerprint = corpus::compute_ingestion_fingerprint(
|
||||
ctx.dataset(),
|
||||
slice,
|
||||
config.converted_dataset_path.as_path(),
|
||||
&ingestion_config,
|
||||
ctx.content_checksum(),
|
||||
)?;
|
||||
let base_dir = corpus::cached_corpus_dir(
|
||||
&cache_settings,
|
||||
@@ -47,47 +40,38 @@ pub(crate) async fn prepare_corpus(
|
||||
|
||||
if !config.reseed_slice {
|
||||
let requested_cases = window.cases.len();
|
||||
if can_reuse_namespace(
|
||||
ctx.db()?,
|
||||
&descriptor,
|
||||
&ctx.namespace,
|
||||
&ctx.database,
|
||||
ctx.dataset().metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?
|
||||
if let Some(manifest) = corpus::load_cached_manifest(&base_dir)?
|
||||
&& can_reuse_namespace(
|
||||
ctx.db()?,
|
||||
&manifest,
|
||||
&embedding_provider,
|
||||
&ctx.namespace,
|
||||
&ctx.database,
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
if let Some(manifest) = corpus::load_cached_manifest(&base_dir)? {
|
||||
info!(
|
||||
cache = %base_dir.display(),
|
||||
namespace = ctx.namespace.as_str(),
|
||||
database = ctx.database.as_str(),
|
||||
"Namespace already seeded; reusing cached corpus manifest"
|
||||
);
|
||||
let corpus_handle = corpus::corpus_handle_from_manifest(manifest, base_dir);
|
||||
ctx.corpus_handle = Some(corpus_handle);
|
||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||
ctx.ingestion_duration_ms = 0;
|
||||
ctx.descriptor = Some(descriptor);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
return machine
|
||||
.prepare_corpus()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_corpus", &guard));
|
||||
}
|
||||
info!(
|
||||
cache = %base_dir.display(),
|
||||
"Namespace reusable but cached manifest missing; regenerating corpus"
|
||||
namespace = ctx.namespace.as_str(),
|
||||
database = ctx.database.as_str(),
|
||||
"Namespace already seeded; reusing cached corpus manifest"
|
||||
);
|
||||
let corpus_handle = corpus::corpus_handle_from_manifest(manifest, base_dir);
|
||||
ctx.corpus_handle = Some(corpus_handle);
|
||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||
ctx.ingestion_duration_ms = 0;
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +87,7 @@ pub(crate) async fn prepare_corpus(
|
||||
openai_client,
|
||||
&eval_user_id,
|
||||
config.converted_dataset_path.as_path(),
|
||||
ctx.content_checksum(),
|
||||
ingestion_config.clone(),
|
||||
)
|
||||
.await
|
||||
@@ -126,7 +111,6 @@ pub(crate) async fn prepare_corpus(
|
||||
ctx.corpus_handle = Some(corpus_handle);
|
||||
ctx.expected_fingerprint = Some(expected_fingerprint);
|
||||
ctx.ingestion_duration_ms = ingestion_duration_ms;
|
||||
ctx.descriptor = Some(descriptor);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
@@ -136,7 +120,5 @@ pub(crate) async fn prepare_corpus(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_corpus()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_corpus", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
use std::{sync::Arc, time::Instant};
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{Context, anyhow};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
args::EmbeddingBackend,
|
||||
cache::EmbeddingCache,
|
||||
eval::{
|
||||
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
|
||||
},
|
||||
db::{connect_eval_db, sanitize_model_code},
|
||||
openai,
|
||||
settings::{enforce_system_settings, load_or_init_system_settings},
|
||||
};
|
||||
use common::utils::embedding::{default_embedding_pool_size, EmbeddingProvider};
|
||||
use common::utils::embedding::{EmbeddingProvider, default_embedding_pool_size};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{DbReady, EvaluationMachine, SlicePrepared},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
pub(crate) async fn prepare_db(
|
||||
machine: EvaluationMachine<(), SlicePrepared>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<DbReady> {
|
||||
pub(crate) async fn prepare_db(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::PrepareDb;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -36,19 +27,18 @@ pub(crate) async fn prepare_db(
|
||||
|
||||
let db = connect_eval_db(config, &namespace, &database).await?;
|
||||
|
||||
let (raw_openai_client, openai_base_url) =
|
||||
openai::build_client_from_env().context("building OpenAI client")?;
|
||||
let openai_client = Arc::new(raw_openai_client);
|
||||
let (openai_client, openai_base_url) =
|
||||
openai::ingestion_openai_client(config.ingest.include_entities)
|
||||
.context("building OpenAI client for ingestion")?;
|
||||
|
||||
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
|
||||
let embedding_provider = match config.embedding_backend {
|
||||
crate::args::EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed(
|
||||
EmbeddingBackend::FastEmbed => EmbeddingProvider::new_fastembed(
|
||||
config.embedding_model.clone(),
|
||||
default_embedding_pool_size(),
|
||||
)
|
||||
.await
|
||||
.context("creating FastEmbed provider")?,
|
||||
crate::args::EmbeddingBackend::Hashed => {
|
||||
EmbeddingBackend::Hashed => {
|
||||
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
|
||||
}
|
||||
};
|
||||
@@ -68,30 +58,25 @@ pub(crate) async fn prepare_db(
|
||||
dimension = provider_dimension,
|
||||
"Embedding provider initialised"
|
||||
);
|
||||
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
|
||||
if let Some(base_url) = &openai_base_url {
|
||||
info!(openai_base_url = %base_url, "OpenAI client configured for entity ingestion");
|
||||
}
|
||||
|
||||
let (mut settings, settings_missing) =
|
||||
load_or_init_system_settings(&db, provider_dimension).await?;
|
||||
|
||||
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
|
||||
if let Some(model_code) = embedding_provider.model_code() {
|
||||
let sanitized = sanitize_model_code(&model_code);
|
||||
let path = config.cache_dir.join(format!("{sanitized}.json"));
|
||||
if config.force_convert && path.exists() {
|
||||
tokio::fs::remove_file(&path)
|
||||
.await
|
||||
.with_context(|| format!("removing stale cache {}", path.display()))
|
||||
.ok();
|
||||
}
|
||||
let cache = EmbeddingCache::load(&path).await?;
|
||||
info!(path = %path.display(), "Embedding cache ready");
|
||||
Some(cache)
|
||||
} else {
|
||||
None
|
||||
if config.embedding_backend == EmbeddingBackend::FastEmbed
|
||||
&& let Some(model_code) = embedding_provider.model_code()
|
||||
{
|
||||
let sanitized = sanitize_model_code(&model_code);
|
||||
let path = config.cache_dir.join(format!("{sanitized}.json"));
|
||||
if config.force_convert && path.exists() {
|
||||
tokio::fs::remove_file(&path)
|
||||
.await
|
||||
.with_context(|| format!("removing stale cache {}", path.display()))
|
||||
.ok();
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
let must_reapply_settings = settings_missing;
|
||||
let defer_initial_enforce = settings_missing && !config.reseed_slice;
|
||||
@@ -104,9 +89,8 @@ pub(crate) async fn prepare_db(
|
||||
ctx.must_reapply_settings = must_reapply_settings;
|
||||
ctx.settings = Some(settings);
|
||||
ctx.embedding_provider = Some(embedding_provider);
|
||||
ctx.embedding_cache = embedding_cache;
|
||||
ctx.openai_client = Some(openai_client);
|
||||
ctx.openai_base_url = Some(openai_base_url);
|
||||
ctx.openai_base_url = openai_base_url;
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
ctx.record_stage_duration(stage, elapsed);
|
||||
@@ -116,7 +100,5 @@ pub(crate) async fn prepare_db(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_db()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_db", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{Context, anyhow};
|
||||
use common::storage::types::system_settings::SystemSettings;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
cases::cases_from_manifest,
|
||||
corpus,
|
||||
db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace},
|
||||
eval::{
|
||||
can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user,
|
||||
record_namespace_state, warm_hnsw_cache,
|
||||
db::{
|
||||
can_reuse_namespace, ensure_eval_user, record_namespace_seed, recreate_indexes,
|
||||
reset_namespace, warm_hnsw_cache,
|
||||
},
|
||||
settings::enforce_system_settings,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{CorpusReady, EvaluationMachine, NamespaceReady},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn prepare_namespace(
|
||||
machine: EvaluationMachine<(), CorpusReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<NamespaceReady> {
|
||||
pub(crate) async fn prepare_namespace(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::PrepareNamespace;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -32,7 +26,6 @@ pub(crate) async fn prepare_namespace(
|
||||
let started = Instant::now();
|
||||
|
||||
let config = ctx.config();
|
||||
let dataset = ctx.dataset();
|
||||
let expected_fingerprint = ctx
|
||||
.expected_fingerprint
|
||||
.as_deref()
|
||||
@@ -60,20 +53,16 @@ pub(crate) async fn prepare_namespace(
|
||||
|
||||
let mut namespace_reused = false;
|
||||
if !config.reseed_slice {
|
||||
namespace_reused = {
|
||||
let slice = ctx.slice()?;
|
||||
can_reuse_namespace(
|
||||
ctx.db()?,
|
||||
ctx.descriptor()?,
|
||||
&namespace,
|
||||
&database,
|
||||
dataset.metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
namespace_reused = can_reuse_namespace(
|
||||
ctx.db()?,
|
||||
base_manifest,
|
||||
&embedding_provider,
|
||||
&namespace,
|
||||
&database,
|
||||
expected_fingerprint.as_str(),
|
||||
requested_cases,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let mut namespace_seed_ms = None;
|
||||
@@ -114,34 +103,20 @@ pub(crate) async fn prepare_namespace(
|
||||
"Seeding ingestion corpus into SurrealDB"
|
||||
);
|
||||
}
|
||||
let indexes_disabled = remove_all_indexes(ctx.db()?).await.is_ok();
|
||||
|
||||
let seed_start = Instant::now();
|
||||
corpus::seed_manifest_into_db(ctx.db()?, &manifest_for_seed)
|
||||
.await
|
||||
.context("seeding ingestion corpus from manifest")?;
|
||||
namespace_seed_ms = Some(seed_start.elapsed().as_millis());
|
||||
|
||||
// Recreate indexes AFTER data is loaded (correct bulk loading pattern)
|
||||
if indexes_disabled {
|
||||
info!("Recreating indexes after seeding data");
|
||||
recreate_indexes(ctx.db()?, embedding_provider.dimension())
|
||||
.await
|
||||
.context("recreating indexes with correct dimension")?;
|
||||
warm_hnsw_cache(ctx.db()?, embedding_provider.dimension()).await?;
|
||||
}
|
||||
{
|
||||
let slice = ctx.slice()?;
|
||||
record_namespace_state(
|
||||
ctx.descriptor()?,
|
||||
dataset.metadata.id.as_str(),
|
||||
slice.manifest.slice_id.as_str(),
|
||||
expected_fingerprint.as_str(),
|
||||
&namespace,
|
||||
&database,
|
||||
requested_cases,
|
||||
)
|
||||
.await;
|
||||
info!("Recreating indexes after seeding data");
|
||||
recreate_indexes(ctx.db()?, embedding_provider.dimension())
|
||||
.await
|
||||
.context("recreating indexes with correct dimension")?;
|
||||
warm_hnsw_cache(ctx.db()?, embedding_provider.dimension()).await?;
|
||||
|
||||
if let Some(handle) = ctx.corpus_handle.as_mut() {
|
||||
record_namespace_seed(handle, &namespace, &database, requested_cases).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +173,5 @@ pub(crate) async fn prepare_namespace(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_namespace()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_namespace", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,20 +4,13 @@ use anyhow::Context;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
eval::{default_database, default_namespace, ledger_target},
|
||||
db::{default_database, default_namespace},
|
||||
slice,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, Ready, SlicePrepared},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
pub(crate) async fn prepare_slice(
|
||||
machine: EvaluationMachine<(), Ready>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<SlicePrepared> {
|
||||
pub(crate) async fn prepare_slice(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::PrepareSlice;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -25,7 +18,7 @@ pub(crate) async fn prepare_slice(
|
||||
);
|
||||
let started = Instant::now();
|
||||
|
||||
let ledger_limit = ledger_target(ctx.config());
|
||||
let ledger_limit = slice::ledger_target(ctx.config());
|
||||
let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit);
|
||||
let resolved_slice =
|
||||
slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?;
|
||||
@@ -49,7 +42,11 @@ pub(crate) async fn prepare_slice(
|
||||
.db_namespace
|
||||
.clone()
|
||||
.unwrap_or_else(|| {
|
||||
default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit)
|
||||
default_namespace(
|
||||
ctx.dataset().metadata.id.as_str(),
|
||||
ctx.config().limit,
|
||||
ctx.config().slice.as_deref(),
|
||||
)
|
||||
});
|
||||
ctx.database = ctx
|
||||
.config()
|
||||
@@ -66,7 +63,5 @@ pub(crate) async fn prepare_slice(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.prepare_slice()
|
||||
.map_err(|(_, guard)| map_guard_error("prepare_slice", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
use std::{collections::HashSet, sync::Arc, time::Instant};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{Context, anyhow};
|
||||
use common::storage::types::StoredObject;
|
||||
use futures::stream::{self, StreamExt};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::eval::{
|
||||
adapt_retrieval_output, build_case_diagnostics, text_contains_answer, CaseDiagnostics,
|
||||
CaseSummary, RetrievedSummary,
|
||||
use crate::{
|
||||
cases::SeededCase,
|
||||
context_stats,
|
||||
types::{
|
||||
CaseDiagnostics, CaseSummary, RetrievedSummary, adapt_retrieval_output,
|
||||
build_case_diagnostics, text_contains_answer,
|
||||
},
|
||||
};
|
||||
use retrieval_pipeline::{
|
||||
pipeline::{self, RetrievalConfig, StageTimings},
|
||||
@@ -15,17 +19,10 @@ use retrieval_pipeline::{
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, NamespaceReady, QueriesFinished},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
||||
pub(crate) async fn run_queries(
|
||||
machine: EvaluationMachine<(), NamespaceReady>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<QueriesFinished> {
|
||||
pub(crate) async fn run_queries(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::RunQueries;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -153,7 +150,7 @@ pub(crate) async fn run_queries(
|
||||
.await
|
||||
.context("acquiring query semaphore permit")?;
|
||||
|
||||
let crate::eval::SeededCase {
|
||||
let SeededCase {
|
||||
question_id,
|
||||
question,
|
||||
expected_source,
|
||||
@@ -197,6 +194,7 @@ pub(crate) async fn run_queries(
|
||||
let query_latency = query_start.elapsed().as_millis();
|
||||
|
||||
let candidates = adapt_retrieval_output(result_output);
|
||||
let retrieved_context = context_stats::stats_for_candidates(&candidates);
|
||||
let mut retrieved = Vec::new();
|
||||
let mut match_rank = None;
|
||||
let answers_lower: Vec<String> =
|
||||
@@ -288,6 +286,7 @@ pub(crate) async fn run_queries(
|
||||
reciprocal_rank: Some(reciprocal_rank),
|
||||
ndcg: Some(ndcg),
|
||||
latency_ms: query_latency,
|
||||
retrieved_context,
|
||||
retrieved,
|
||||
};
|
||||
|
||||
@@ -353,9 +352,7 @@ pub(crate) async fn run_queries(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.run_queries()
|
||||
.map_err(|(_, guard)| map_guard_error("run_queries", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::arithmetic_side_effects, clippy::cast_precision_loss)]
|
||||
@@ -394,9 +391,5 @@ fn calculate_ndcg(retrieved: &[RetrievedSummary], k: usize) -> f64 {
|
||||
idcg += rel / (f64::from(i) + 2.0).log2();
|
||||
}
|
||||
|
||||
if idcg == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
dcg / idcg
|
||||
}
|
||||
if idcg == 0.0 { 0.0 } else { dcg / idcg }
|
||||
}
|
||||
|
||||
@@ -3,25 +3,19 @@ use std::time::Instant;
|
||||
use chrono::Utc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::eval::{
|
||||
build_stage_latency_breakdown, compute_latency_stats, EvaluationSummary, PerformanceTimings,
|
||||
use crate::types::{
|
||||
EvaluationSummary, PerformanceTimings, RetrievedContextStats, build_stage_latency_breakdown,
|
||||
compute_latency_stats,
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
context::{EvalStage, EvaluationContext},
|
||||
state::{EvaluationMachine, QueriesFinished, Summarized},
|
||||
};
|
||||
use super::{map_guard_error, StageResult};
|
||||
use super::super::context::{EvalStage, EvaluationContext};
|
||||
|
||||
#[allow(
|
||||
clippy::too_many_lines,
|
||||
clippy::arithmetic_side_effects,
|
||||
clippy::cast_precision_loss
|
||||
)]
|
||||
pub(crate) async fn summarize(
|
||||
machine: EvaluationMachine<(), QueriesFinished>,
|
||||
ctx: &mut EvaluationContext<'_>,
|
||||
) -> StageResult<Summarized> {
|
||||
pub(crate) async fn summarize(ctx: &mut EvaluationContext<'_>) -> anyhow::Result<()> {
|
||||
let stage = EvalStage::Summarize;
|
||||
info!(
|
||||
evaluation_stage = stage.label(),
|
||||
@@ -123,6 +117,12 @@ pub(crate) async fn summarize(
|
||||
sum_ndcg / (retrieval_cases as f64)
|
||||
};
|
||||
|
||||
let per_query_context: Vec<RetrievedContextStats> = summaries
|
||||
.iter()
|
||||
.map(|summary| summary.retrieved_context)
|
||||
.collect();
|
||||
let retrieved_context = crate::context_stats::aggregate_context_stats(&per_query_context);
|
||||
|
||||
let active_tuning = ctx
|
||||
.retrieval_config
|
||||
.as_ref()
|
||||
@@ -133,7 +133,7 @@ pub(crate) async fn summarize(
|
||||
openai_base_url: ctx
|
||||
.openai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "<unknown>".to_string()),
|
||||
.unwrap_or_else(|| "n/a (chunk-only ingestion)".to_string()),
|
||||
ingestion_ms: ctx.ingestion_duration_ms,
|
||||
namespace_seed_ms: ctx.namespace_seed_ms,
|
||||
evaluation_stage_ms: ctx.stage_timings.clone(),
|
||||
@@ -217,11 +217,12 @@ pub(crate) async fn summarize(
|
||||
chunk_rrf_use_fts: active_tuning.flags.chunk_rrf_use_fts.as_bool(),
|
||||
ingest_chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
|
||||
ingest_chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
|
||||
ingest_chunks_only: config.ingest.ingest_chunks_only,
|
||||
ingest_chunks_only: !config.ingest.include_entities,
|
||||
ingest_chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
|
||||
chunk_vector_take: active_tuning.chunk_vector_take,
|
||||
chunk_fts_take: active_tuning.chunk_fts_take,
|
||||
max_chunks_per_entity: active_tuning.max_chunks_per_entity,
|
||||
retrieved_context,
|
||||
cases: summaries,
|
||||
});
|
||||
|
||||
@@ -233,7 +234,5 @@ pub(crate) async fn summarize(
|
||||
"completed evaluation stage"
|
||||
);
|
||||
|
||||
machine
|
||||
.summarize()
|
||||
.map_err(|(_, guard)| map_guard_error("summarize", &guard))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
use state_machines::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: EvaluationMachine,
|
||||
state: EvaluationState,
|
||||
initial: Ready,
|
||||
states: [Ready, SlicePrepared, DbReady, CorpusReady, NamespaceReady, QueriesFinished, Summarized, Completed, Failed],
|
||||
events {
|
||||
prepare_slice { transition: { from: Ready, to: SlicePrepared } }
|
||||
prepare_db { transition: { from: SlicePrepared, to: DbReady } }
|
||||
prepare_corpus { transition: { from: DbReady, to: CorpusReady } }
|
||||
prepare_namespace { transition: { from: CorpusReady, to: NamespaceReady } }
|
||||
run_queries { transition: { from: NamespaceReady, to: QueriesFinished } }
|
||||
summarize { transition: { from: QueriesFinished, to: Summarized } }
|
||||
finalize { transition: { from: Summarized, to: Completed } }
|
||||
abort {
|
||||
transition: { from: Ready, to: Failed }
|
||||
transition: { from: SlicePrepared, to: Failed }
|
||||
transition: { from: DbReady, to: Failed }
|
||||
transition: { from: CorpusReady, to: Failed }
|
||||
transition: { from: NamespaceReady, to: Failed }
|
||||
transition: { from: QueriesFinished, to: Failed }
|
||||
transition: { from: Summarized, to: Failed }
|
||||
transition: { from: Completed, to: Failed }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ready() -> EvaluationMachine<(), Ready> {
|
||||
EvaluationMachine::new(())
|
||||
}
|
||||
+83
-218
@@ -7,12 +7,10 @@ use std::{
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::eval::{
|
||||
format_timestamp, CaseSummary, EvaluationStageTimings, EvaluationSummary, LatencyStats,
|
||||
StageLatencyBreakdown,
|
||||
use crate::types::{
|
||||
CaseSummary, EvaluationStageTimings, EvaluationSummary, LatencyStats, RetrievalContextStats,
|
||||
StageLatencyBreakdown, format_timestamp,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReportPaths {
|
||||
@@ -108,6 +106,7 @@ pub struct RetrievalSection {
|
||||
pub ingest_chunk_max_tokens: usize,
|
||||
pub ingest_chunk_overlap_tokens: usize,
|
||||
pub ingest_chunks_only: bool,
|
||||
pub retrieved_context: RetrievalContextStats,
|
||||
}
|
||||
|
||||
const fn default_chunk_rrf_k() -> f32 {
|
||||
@@ -242,6 +241,7 @@ impl EvaluationReport {
|
||||
ingest_chunk_max_tokens: summary.ingest_chunk_max_tokens,
|
||||
ingest_chunk_overlap_tokens: summary.ingest_chunk_overlap_tokens,
|
||||
ingest_chunks_only: summary.ingest_chunks_only,
|
||||
retrieved_context: summary.retrieved_context.clone(),
|
||||
};
|
||||
|
||||
let llm = if summary.llm_cases > 0 {
|
||||
@@ -345,7 +345,7 @@ impl LlmCaseEntry {
|
||||
}
|
||||
|
||||
impl RetrievedSnippet {
|
||||
fn from_summary(entry: &crate::eval::RetrievedSummary) -> Self {
|
||||
fn from_summary(entry: &crate::types::RetrievedSummary) -> Self {
|
||||
Self {
|
||||
rank: entry.rank,
|
||||
source_id: entry.source_id.clone(),
|
||||
@@ -558,6 +558,65 @@ fn render_markdown(report: &EvaluationReport) -> String {
|
||||
} else {
|
||||
md.push_str("| Rerank | disabled |\\n");
|
||||
}
|
||||
write!(
|
||||
md,
|
||||
"| Chunk result cap | {} |\\n",
|
||||
report.retrieval.chunk_result_cap
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
md.push_str("\\n## Retrieved Context Volume\\n\\n");
|
||||
md.push_str("| Metric | Value |\\n| --- | --- |\\n");
|
||||
write!(
|
||||
md,
|
||||
"| Tokenizer | {} |\\n",
|
||||
report.retrieval.retrieved_context.tokenizer
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Queries measured | {} |\\n",
|
||||
report.retrieval.retrieved_context.queries
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Total chunks returned | {} |\\n",
|
||||
report.retrieval.retrieved_context.total_chunks
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Total characters | {} |\\n",
|
||||
report.retrieval.retrieved_context.total_chars
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Total tokens | {} |\\n",
|
||||
report.retrieval.retrieved_context.total_tokens
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Avg chunks / query | {:.1} |\\n",
|
||||
report.retrieval.retrieved_context.avg_chunks_per_query
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| Avg tokens / query | {:.1} |\\n",
|
||||
report.retrieval.retrieved_context.avg_tokens_per_query
|
||||
)
|
||||
.unwrap();
|
||||
write!(
|
||||
md,
|
||||
"| P50 / P95 / max tokens / query | {} / {} / {} |\\n",
|
||||
report.retrieval.retrieved_context.p50_tokens_per_query,
|
||||
report.retrieval.retrieved_context.p95_tokens_per_query,
|
||||
report.retrieval.retrieved_context.max_tokens_per_query
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
if let Some(llm) = &report.llm {
|
||||
md.push_str("\\n## LLM Mode Metrics\\n\\n");
|
||||
@@ -745,11 +804,7 @@ fn prettify_stage(label: &str) -> String {
|
||||
}
|
||||
|
||||
fn bool_badge(value: bool) -> &'static str {
|
||||
if value {
|
||||
"✅"
|
||||
} else {
|
||||
"⚪"
|
||||
}
|
||||
if value { "✅" } else { "⚪" }
|
||||
}
|
||||
|
||||
fn render_retrieved(entries: &[RetrievedSnippet]) -> String {
|
||||
@@ -797,182 +852,6 @@ pub fn dataset_report_dir(report_dir: &Path, dataset_id: &str) -> PathBuf {
|
||||
report_dir.join(sanitize_component(dataset_id))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LegacyHistoryEntry {
|
||||
generated_at: String,
|
||||
run_label: Option<String>,
|
||||
dataset_id: String,
|
||||
dataset_label: String,
|
||||
slice_id: String,
|
||||
slice_seed: u64,
|
||||
slice_window_offset: usize,
|
||||
slice_window_length: usize,
|
||||
slice_cases: usize,
|
||||
slice_total_cases: usize,
|
||||
k: usize,
|
||||
limit: Option<usize>,
|
||||
precision: f64,
|
||||
precision_at_1: f64,
|
||||
precision_at_2: f64,
|
||||
precision_at_3: f64,
|
||||
#[serde(default)]
|
||||
mrr: f64,
|
||||
#[serde(default)]
|
||||
average_ndcg: f64,
|
||||
#[serde(default)]
|
||||
retrieval_cases: usize,
|
||||
#[serde(default)]
|
||||
retrieval_precision: f64,
|
||||
#[serde(default)]
|
||||
llm_cases: usize,
|
||||
#[serde(default)]
|
||||
llm_precision: f64,
|
||||
duration_ms: u128,
|
||||
latency_ms: LatencyStats,
|
||||
embedding_backend: String,
|
||||
embedding_model: Option<String>,
|
||||
ingestion_reused: bool,
|
||||
ingestion_embeddings_reused: bool,
|
||||
rerank_enabled: bool,
|
||||
rerank_keep_top: usize,
|
||||
rerank_pool_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
chunk_result_cap: Option<usize>,
|
||||
#[serde(default)]
|
||||
ingest_chunk_min_tokens: Option<usize>,
|
||||
#[serde(default)]
|
||||
ingest_chunk_max_tokens: Option<usize>,
|
||||
#[serde(default)]
|
||||
ingest_chunk_overlap_tokens: Option<usize>,
|
||||
#[serde(default)]
|
||||
ingest_chunks_only: Option<bool>,
|
||||
#[serde(default)]
|
||||
delta: Option<LegacyHistoryDelta>,
|
||||
openai_base_url: String,
|
||||
ingestion_ms: u128,
|
||||
#[serde(default)]
|
||||
namespace_seed_ms: Option<u128>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LegacyHistoryDelta {
|
||||
precision: f64,
|
||||
precision_at_1: f64,
|
||||
latency_avg_ms: f64,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn convert_legacy_entry(entry: LegacyHistoryEntry) -> EvaluationReport {
|
||||
let overview = OverviewSection {
|
||||
generated_at: entry.generated_at,
|
||||
run_label: entry.run_label,
|
||||
total_cases: entry.slice_cases,
|
||||
filtered_questions: 0,
|
||||
};
|
||||
|
||||
let dataset = DatasetSection {
|
||||
id: entry.dataset_id,
|
||||
label: entry.dataset_label,
|
||||
source: String::new(),
|
||||
includes_unanswerable: entry.llm_cases > 0,
|
||||
require_verified_chunks: true,
|
||||
embedding_backend: entry.embedding_backend,
|
||||
embedding_model: entry.embedding_model,
|
||||
embedding_dimension: 0,
|
||||
};
|
||||
|
||||
let slice = SliceSection {
|
||||
id: entry.slice_id,
|
||||
seed: entry.slice_seed,
|
||||
window_offset: entry.slice_window_offset,
|
||||
window_length: entry.slice_window_length,
|
||||
slice_cases: entry.slice_cases,
|
||||
ledger_total_cases: entry.slice_total_cases,
|
||||
positives: 0,
|
||||
negatives: 0,
|
||||
total_paragraphs: 0,
|
||||
negative_multiplier: 0.0,
|
||||
};
|
||||
|
||||
let retrieval_cases = if entry.retrieval_cases > 0 {
|
||||
entry.retrieval_cases
|
||||
} else {
|
||||
entry.slice_cases.saturating_sub(entry.llm_cases)
|
||||
};
|
||||
let retrieval_precision = if entry.retrieval_precision > 0.0 {
|
||||
entry.retrieval_precision
|
||||
} else {
|
||||
entry.precision
|
||||
};
|
||||
|
||||
let retrieval = RetrievalSection {
|
||||
k: entry.k,
|
||||
cases: retrieval_cases,
|
||||
correct: 0,
|
||||
precision: retrieval_precision,
|
||||
precision_at_1: entry.precision_at_1,
|
||||
precision_at_2: entry.precision_at_2,
|
||||
precision_at_3: entry.precision_at_3,
|
||||
mrr: entry.mrr,
|
||||
average_ndcg: entry.average_ndcg,
|
||||
latency: entry.latency_ms,
|
||||
concurrency: 0,
|
||||
resolve_entities: false,
|
||||
rerank_enabled: entry.rerank_enabled,
|
||||
rerank_pool_size: entry.rerank_pool_size,
|
||||
rerank_keep_top: entry.rerank_keep_top,
|
||||
chunk_result_cap: entry.chunk_result_cap.unwrap_or(5),
|
||||
chunk_rrf_k: default_chunk_rrf_k(),
|
||||
chunk_rrf_vector_weight: default_chunk_rrf_weight(),
|
||||
chunk_rrf_fts_weight: default_chunk_rrf_weight(),
|
||||
chunk_rrf_use_vector: default_chunk_rrf_use(),
|
||||
chunk_rrf_use_fts: default_chunk_rrf_use(),
|
||||
chunk_vector_take: 0,
|
||||
chunk_fts_take: 0,
|
||||
ingest_chunk_min_tokens: entry.ingest_chunk_min_tokens.unwrap_or(256),
|
||||
ingest_chunk_max_tokens: entry.ingest_chunk_max_tokens.unwrap_or(512),
|
||||
ingest_chunk_overlap_tokens: entry.ingest_chunk_overlap_tokens.unwrap_or(50),
|
||||
ingest_chunks_only: entry.ingest_chunks_only.unwrap_or(false),
|
||||
};
|
||||
|
||||
let llm = if entry.llm_cases > 0 {
|
||||
Some(LlmSection {
|
||||
cases: entry.llm_cases,
|
||||
answered: 0,
|
||||
precision: entry.llm_precision,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let performance = PerformanceSection {
|
||||
openai_base_url: entry.openai_base_url,
|
||||
ingestion_ms: entry.ingestion_ms,
|
||||
namespace_seed_ms: entry.namespace_seed_ms,
|
||||
evaluation_stages_ms: EvaluationStageTimings::default(),
|
||||
stage_latency: StageLatencyBreakdown::default(),
|
||||
namespace_reused: false,
|
||||
ingestion_reused: entry.ingestion_reused,
|
||||
embeddings_reused: entry.ingestion_embeddings_reused,
|
||||
ingestion_cache_path: String::new(),
|
||||
corpus_paragraphs: 0,
|
||||
positive_paragraphs_reused: 0,
|
||||
negative_paragraphs_reused: 0,
|
||||
};
|
||||
|
||||
EvaluationReport {
|
||||
overview,
|
||||
dataset,
|
||||
slice,
|
||||
retrieval,
|
||||
llm,
|
||||
performance,
|
||||
misses: Vec::new(),
|
||||
llm_cases: Vec::new(),
|
||||
detailed_report: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_history(path: &Path) -> Result<Vec<EvaluationReport>> {
|
||||
if !path.exists() {
|
||||
return Ok(Vec::new());
|
||||
@@ -981,34 +860,12 @@ fn load_history(path: &Path) -> Result<Vec<EvaluationReport>> {
|
||||
let contents =
|
||||
fs::read(path).with_context(|| format!("reading evaluation log {}", path.display()))?;
|
||||
|
||||
if let Ok(entries) = serde_json::from_slice::<Vec<EvaluationReport>>(&contents) {
|
||||
return Ok(entries);
|
||||
}
|
||||
|
||||
match serde_json::from_slice::<Vec<LegacyHistoryEntry>>(&contents) {
|
||||
Ok(entries) => Ok(entries.into_iter().map(convert_legacy_entry).collect()),
|
||||
Err(err) => {
|
||||
let timestamp = Utc::now().format("%Y%m%dT%H%M%S");
|
||||
let backup_path = path
|
||||
.parent()
|
||||
.unwrap_or_else(|| Path::new("."))
|
||||
.join(format!("evaluations.json.corrupted.{timestamp}"));
|
||||
warn!(
|
||||
path = %path.display(),
|
||||
backup = %backup_path.display(),
|
||||
error = %err,
|
||||
"Evaluation history file is corrupted; backing up and starting fresh"
|
||||
);
|
||||
if let Err(e) = fs::rename(path, &backup_path) {
|
||||
warn!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"Failed to backup corrupted evaluation history"
|
||||
);
|
||||
}
|
||||
Ok(Vec::new())
|
||||
}
|
||||
}
|
||||
serde_json::from_slice(&contents).with_context(|| {
|
||||
format!(
|
||||
"parsing evaluation history at {}; delete the file and re-run if upgrading from an older format",
|
||||
path.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBuf> {
|
||||
@@ -1024,9 +881,9 @@ fn record_history(report: &EvaluationReport, report_dir: &Path) -> Result<PathBu
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::eval::{
|
||||
EvaluationStageTimings, PerformanceTimings, RetrievedSummary, StageLatency,
|
||||
StageLatencyBreakdown,
|
||||
use crate::types::{
|
||||
EvaluationStageTimings, PerformanceTimings, RetrievedContextStats, RetrievedSummary,
|
||||
StageLatency, StageLatencyBreakdown,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use tempfile::tempdir;
|
||||
@@ -1101,6 +958,7 @@ mod tests {
|
||||
has_verified_chunks: !is_impossible,
|
||||
match_rank: if matched { Some(1) } else { None },
|
||||
latency_ms: 42,
|
||||
retrieved_context: RetrievedContextStats::default(),
|
||||
retrieved: vec![RetrievedSummary {
|
||||
rank: 1,
|
||||
entity_id: "entity1".into(),
|
||||
@@ -1199,6 +1057,13 @@ mod tests {
|
||||
chunk_vector_take: 50,
|
||||
chunk_fts_take: 50,
|
||||
max_chunks_per_entity: 4,
|
||||
retrieved_context: crate::context_stats::aggregate_context_stats(&[
|
||||
RetrievedContextStats {
|
||||
chunk_count: 1,
|
||||
char_count: 10,
|
||||
token_count: 3,
|
||||
},
|
||||
]),
|
||||
cases,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,15 +24,15 @@ pub(crate) async fn enforce_system_settings(
|
||||
updated_settings.embedding_dimensions = provider_dimension as u32;
|
||||
needs_settings_update = true;
|
||||
}
|
||||
if let Some(query_override) = config.query_model.as_deref() {
|
||||
if settings.query_model != query_override {
|
||||
info!(
|
||||
model = query_override,
|
||||
"Overriding system query model for this run"
|
||||
);
|
||||
updated_settings.query_model = query_override.to_string();
|
||||
needs_settings_update = true;
|
||||
}
|
||||
if let Some(query_override) = config.query_model.as_deref()
|
||||
&& settings.query_model != query_override
|
||||
{
|
||||
info!(
|
||||
model = query_override,
|
||||
"Overriding system query model for this run"
|
||||
);
|
||||
updated_settings.query_model = query_override.to_string();
|
||||
needs_settings_update = true;
|
||||
}
|
||||
if needs_settings_update {
|
||||
settings = SystemSettings::update(db, updated_settings)
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::datasets::{BEIR_DATASETS, ConvertedDataset};
|
||||
|
||||
use super::build::{BuildParams, mix_seed};
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
||||
pub(super) fn ordered_question_refs_beir(
|
||||
dataset: &ConvertedDataset,
|
||||
params: &BuildParams,
|
||||
target_cases: usize,
|
||||
) -> Result<Vec<(usize, usize)>> {
|
||||
let prefixes: Vec<&str> = BEIR_DATASETS
|
||||
.iter()
|
||||
.map(|kind| kind.source_prefix())
|
||||
.collect();
|
||||
|
||||
let mut grouped: HashMap<&str, Vec<(usize, usize)>> = HashMap::new();
|
||||
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||
let include = if params.include_impossible {
|
||||
true
|
||||
} else {
|
||||
!question.is_impossible && !question.answers.is_empty()
|
||||
};
|
||||
if !include {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(prefix) = question_prefix(&question.id) else {
|
||||
warn!(
|
||||
question_id = %question.id,
|
||||
"Skipping BEIR question without expected prefix"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
if !prefixes.contains(&prefix) {
|
||||
warn!(
|
||||
question_id = %question.id,
|
||||
prefix = %prefix,
|
||||
"Skipping BEIR question with unknown subset prefix"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
grouped.entry(prefix).or_default().push((p_idx, q_idx));
|
||||
}
|
||||
}
|
||||
|
||||
if grouped.values().all(std::vec::Vec::is_empty) {
|
||||
return Err(anyhow!(
|
||||
"no eligible BEIR questions found; cannot build slice"
|
||||
));
|
||||
}
|
||||
|
||||
for prefix in &prefixes {
|
||||
if let Some(entries) = grouped.get_mut(prefix) {
|
||||
let seed = mix_seed(
|
||||
&format!("{}::{prefix}", dataset.metadata.id),
|
||||
params.base_seed,
|
||||
);
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
entries.shuffle(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
let dataset_count = prefixes.len().max(1);
|
||||
let base_quota = target_cases / dataset_count;
|
||||
let mut remainder = target_cases % dataset_count;
|
||||
|
||||
let mut quotas: HashMap<&str, usize> = HashMap::new();
|
||||
for prefix in &prefixes {
|
||||
let mut quota = base_quota;
|
||||
if remainder > 0 {
|
||||
quota += 1;
|
||||
remainder -= 1;
|
||||
}
|
||||
quotas.insert(*prefix, quota);
|
||||
}
|
||||
|
||||
let mut take_counts: HashMap<&str, usize> = HashMap::new();
|
||||
let mut spare_slots: HashMap<&str, usize> = HashMap::new();
|
||||
let mut shortfall = 0usize;
|
||||
|
||||
for prefix in &prefixes {
|
||||
let available = grouped.get(prefix).map_or(0, std::vec::Vec::len);
|
||||
let quota = *quotas.get(prefix).unwrap_or(&0);
|
||||
let take = quota.min(available);
|
||||
let missing = quota.saturating_sub(take);
|
||||
shortfall += missing;
|
||||
take_counts.insert(*prefix, take);
|
||||
spare_slots.insert(*prefix, available.saturating_sub(take));
|
||||
}
|
||||
|
||||
while shortfall > 0 {
|
||||
let mut allocated = false;
|
||||
for prefix in &prefixes {
|
||||
if shortfall == 0 {
|
||||
break;
|
||||
}
|
||||
let spare = spare_slots.get(prefix).copied().unwrap_or(0);
|
||||
if spare == 0 {
|
||||
continue;
|
||||
}
|
||||
if let Some(count) = take_counts.get_mut(prefix) {
|
||||
*count += 1;
|
||||
}
|
||||
spare_slots.insert(*prefix, spare - 1);
|
||||
shortfall = shortfall.saturating_sub(1);
|
||||
allocated = true;
|
||||
}
|
||||
if !allocated {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut queues: Vec<VecDeque<(usize, usize)>> = Vec::new();
|
||||
let mut total_selected = 0usize;
|
||||
for prefix in &prefixes {
|
||||
let take = *take_counts.get(prefix).unwrap_or(&0);
|
||||
let mut deque = VecDeque::new();
|
||||
if let Some(entries) = grouped.get(prefix) {
|
||||
for item in entries.iter().take(take) {
|
||||
deque.push_back(*item);
|
||||
total_selected += 1;
|
||||
}
|
||||
}
|
||||
queues.push(deque);
|
||||
}
|
||||
|
||||
if total_selected < target_cases {
|
||||
warn!(
|
||||
requested = target_cases,
|
||||
available = total_selected,
|
||||
"BEIR mix requested more questions than available after balancing; continuing with capped set"
|
||||
);
|
||||
}
|
||||
|
||||
let mut output = Vec::with_capacity(total_selected);
|
||||
loop {
|
||||
let mut progressed = false;
|
||||
for queue in &mut queues {
|
||||
if let Some(item) = queue.pop_front() {
|
||||
output.push(item);
|
||||
progressed = true;
|
||||
}
|
||||
}
|
||||
if !progressed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no eligible BEIR questions found; cannot build slice"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub(super) fn question_prefix(question_id: &str) -> Option<&'static str> {
|
||||
for prefix in BEIR_DATASETS.iter().map(|kind| kind.source_prefix()) {
|
||||
if let Some(rest) = question_id.strip_prefix(prefix)
|
||||
&& rest.starts_with('-')
|
||||
{
|
||||
return Some(prefix);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct BuildParams {
|
||||
pub include_impossible: bool,
|
||||
pub base_seed: u64,
|
||||
pub rng_seed: u64,
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
pub(super) fn mix_seed(dataset_id: &str, seed: u64) -> u64 {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(dataset_id.as_bytes());
|
||||
hasher.update(seed.to_le_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let mut bytes = [0u8; 8];
|
||||
bytes.copy_from_slice(&digest[..8]);
|
||||
u64::from_le_bytes(bytes)
|
||||
}
|
||||
@@ -1,21 +1,27 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet, VecDeque},
|
||||
collections::{HashMap, HashSet},
|
||||
fmt::Write,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
|
||||
use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::datasets::{
|
||||
ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind, BEIR_DATASETS,
|
||||
use crate::{
|
||||
args::Config,
|
||||
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind},
|
||||
};
|
||||
|
||||
mod beir;
|
||||
mod build;
|
||||
|
||||
use build::{BuildParams, mix_seed};
|
||||
|
||||
const SLICE_VERSION: u32 = 2;
|
||||
pub const DEFAULT_NEGATIVE_MULTIPLIER: f32 = 4.0;
|
||||
|
||||
@@ -80,8 +86,12 @@ pub enum SliceParagraphKind {
|
||||
Negative,
|
||||
}
|
||||
|
||||
pub fn paragraph_storage_key(paragraph_id: &str) -> String {
|
||||
sanitize_identifier(paragraph_id)
|
||||
}
|
||||
|
||||
pub(crate) fn default_shard_path(paragraph_id: &str) -> String {
|
||||
let sanitized = sanitize_identifier(paragraph_id);
|
||||
let sanitized = paragraph_storage_key(paragraph_id);
|
||||
format!("paragraphs/{sanitized}.json")
|
||||
}
|
||||
|
||||
@@ -210,13 +220,6 @@ struct SliceKey<'a> {
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BuildParams {
|
||||
include_impossible: bool,
|
||||
base_seed: u64,
|
||||
rng_seed: u64,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn resolve_slice<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
@@ -225,15 +228,28 @@ pub fn resolve_slice<'a>(
|
||||
let index = DatasetIndex::build(dataset);
|
||||
|
||||
if let Some(slice_arg) = config.explicit_slice {
|
||||
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
|
||||
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
|
||||
let path = explicit_slice_path(dataset, config, slice_arg);
|
||||
if path.exists() {
|
||||
let (path, manifest) = load_explicit_slice(dataset, &index, config, slice_arg)?;
|
||||
let resolved = manifest_to_resolved(dataset, &index, manifest, path)?;
|
||||
info!(
|
||||
slice = %resolved.manifest.slice_id,
|
||||
path = %resolved.path.display(),
|
||||
cases = resolved.manifest.case_count,
|
||||
positives = resolved.manifest.positive_paragraphs,
|
||||
negatives = resolved.manifest.negative_paragraphs,
|
||||
"Using explicitly selected slice"
|
||||
);
|
||||
return Ok(resolved);
|
||||
}
|
||||
let resolved = materialize_slice_ledger(dataset, config, &index, slice_arg, path)?;
|
||||
info!(
|
||||
slice = %resolved.manifest.slice_id,
|
||||
path = %resolved.path.display(),
|
||||
cases = resolved.manifest.case_count,
|
||||
positives = resolved.manifest.positive_paragraphs,
|
||||
negatives = resolved.manifest.negative_paragraphs,
|
||||
"Using explicitly selected slice"
|
||||
"Built catalog slice ledger"
|
||||
);
|
||||
return Ok(resolved);
|
||||
}
|
||||
@@ -256,6 +272,82 @@ pub fn resolve_slice<'a>(
|
||||
.join("slices")
|
||||
.join(dataset.metadata.id.as_str());
|
||||
let path = base.join(format!("{slice_id}.json"));
|
||||
materialize_slice_ledger(dataset, config, &index, &slice_id, path)
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||
pub fn select_window<'a>(
|
||||
resolved: &'a ResolvedSlice<'a>,
|
||||
offset: usize,
|
||||
limit: Option<usize>,
|
||||
) -> Result<SliceWindow<'a>> {
|
||||
let total = resolved.manifest.case_count;
|
||||
if total == 0 {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' contains no cases",
|
||||
resolved.manifest.slice_id
|
||||
));
|
||||
}
|
||||
if offset >= total {
|
||||
return Err(anyhow!(
|
||||
"slice offset {offset} exceeds available cases ({total})",
|
||||
));
|
||||
}
|
||||
let available = total - offset;
|
||||
let requested = limit.unwrap_or(available).max(1);
|
||||
let length = requested.min(available);
|
||||
let cases = resolved.cases[offset..offset + length].to_vec();
|
||||
let mut seen = HashSet::new();
|
||||
let mut positive_ids = Vec::new();
|
||||
for case in &cases {
|
||||
if seen.insert(case.paragraph.id.as_str()) {
|
||||
positive_ids.push(case.paragraph.id.clone());
|
||||
}
|
||||
}
|
||||
Ok(SliceWindow {
|
||||
offset,
|
||||
length,
|
||||
total_cases: total,
|
||||
cases,
|
||||
positive_paragraph_ids: positive_ids,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
|
||||
select_window(resolved, 0, None)
|
||||
}
|
||||
|
||||
fn explicit_slice_path(
|
||||
dataset: &ConvertedDataset,
|
||||
config: &SliceConfig<'_>,
|
||||
slice_arg: &str,
|
||||
) -> PathBuf {
|
||||
let explicit_path = Path::new(slice_arg);
|
||||
if explicit_path.exists() {
|
||||
explicit_path.to_path_buf()
|
||||
} else {
|
||||
config
|
||||
.cache_dir
|
||||
.join("slices")
|
||||
.join(dataset.metadata.id.as_str())
|
||||
.join(format!("{slice_arg}.json"))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn materialize_slice_ledger<'a>(
|
||||
dataset: &'a ConvertedDataset,
|
||||
config: &SliceConfig<'_>,
|
||||
index: &DatasetIndex,
|
||||
slice_id: &str,
|
||||
path: PathBuf,
|
||||
) -> Result<ResolvedSlice<'a>> {
|
||||
let requested_corpus = config
|
||||
.corpus_limit
|
||||
.unwrap_or(dataset.paragraphs.len())
|
||||
.min(dataset.paragraphs.len())
|
||||
.max(1);
|
||||
|
||||
let total_questions = dataset
|
||||
.paragraphs
|
||||
@@ -339,7 +431,7 @@ pub fn resolve_slice<'a>(
|
||||
let mut manifest = manifest.unwrap_or_else(|| {
|
||||
empty_manifest(
|
||||
dataset,
|
||||
slice_id.clone(),
|
||||
slice_id.to_string(),
|
||||
¶ms,
|
||||
requested_corpus,
|
||||
config.negative_multiplier,
|
||||
@@ -396,52 +488,7 @@ pub fn resolve_slice<'a>(
|
||||
);
|
||||
}
|
||||
|
||||
let resolved = manifest_to_resolved(dataset, &index, manifest.clone(), path)?;
|
||||
|
||||
Ok(resolved)
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||
pub fn select_window<'a>(
|
||||
resolved: &'a ResolvedSlice<'a>,
|
||||
offset: usize,
|
||||
limit: Option<usize>,
|
||||
) -> Result<SliceWindow<'a>> {
|
||||
let total = resolved.manifest.case_count;
|
||||
if total == 0 {
|
||||
return Err(anyhow!(
|
||||
"slice '{}' contains no cases",
|
||||
resolved.manifest.slice_id
|
||||
));
|
||||
}
|
||||
if offset >= total {
|
||||
return Err(anyhow!(
|
||||
"slice offset {offset} exceeds available cases ({total})",
|
||||
));
|
||||
}
|
||||
let available = total - offset;
|
||||
let requested = limit.unwrap_or(available).max(1);
|
||||
let length = requested.min(available);
|
||||
let cases = resolved.cases[offset..offset + length].to_vec();
|
||||
let mut seen = HashSet::new();
|
||||
let mut positive_ids = Vec::new();
|
||||
for case in &cases {
|
||||
if seen.insert(case.paragraph.id.as_str()) {
|
||||
positive_ids.push(case.paragraph.id.clone());
|
||||
}
|
||||
}
|
||||
Ok(SliceWindow {
|
||||
offset,
|
||||
length,
|
||||
total_cases: total,
|
||||
cases,
|
||||
positive_paragraph_ids: positive_ids,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn full_window<'a>(resolved: &'a ResolvedSlice<'a>) -> Result<SliceWindow<'a>> {
|
||||
select_window(resolved, 0, None)
|
||||
manifest_to_resolved(dataset, index, manifest, path)
|
||||
}
|
||||
|
||||
fn load_explicit_slice(
|
||||
@@ -450,16 +497,7 @@ fn load_explicit_slice(
|
||||
config: &SliceConfig<'_>,
|
||||
slice_arg: &str,
|
||||
) -> Result<(PathBuf, SliceManifest)> {
|
||||
let explicit_path = Path::new(slice_arg);
|
||||
let candidate_path = if explicit_path.exists() {
|
||||
explicit_path.to_path_buf()
|
||||
} else {
|
||||
config
|
||||
.cache_dir
|
||||
.join("slices")
|
||||
.join(dataset.metadata.id.as_str())
|
||||
.join(format!("{slice_arg}.json"))
|
||||
};
|
||||
let candidate_path = explicit_slice_path(dataset, config, slice_arg);
|
||||
|
||||
let manifest = read_manifest(&candidate_path)
|
||||
.with_context(|| format!("reading slice manifest at {}", candidate_path.display()))?;
|
||||
@@ -613,7 +651,7 @@ fn ordered_question_refs(
|
||||
target_cases: usize,
|
||||
) -> Result<Vec<(usize, usize)>> {
|
||||
if dataset.metadata.id == DatasetKind::Beir.id() {
|
||||
return ordered_question_refs_beir(dataset, params, target_cases);
|
||||
return beir::ordered_question_refs_beir(dataset, params, target_cases);
|
||||
}
|
||||
|
||||
let mut question_refs = Vec::new();
|
||||
@@ -642,171 +680,6 @@ fn ordered_question_refs(
|
||||
Ok(question_refs)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::arithmetic_side_effects)]
|
||||
fn ordered_question_refs_beir(
|
||||
dataset: &ConvertedDataset,
|
||||
params: &BuildParams,
|
||||
target_cases: usize,
|
||||
) -> Result<Vec<(usize, usize)>> {
|
||||
let prefixes: Vec<&str> = BEIR_DATASETS
|
||||
.iter()
|
||||
.map(|kind| kind.source_prefix())
|
||||
.collect();
|
||||
|
||||
let mut grouped: HashMap<&str, Vec<(usize, usize)>> = HashMap::new();
|
||||
for (p_idx, paragraph) in dataset.paragraphs.iter().enumerate() {
|
||||
for (q_idx, question) in paragraph.questions.iter().enumerate() {
|
||||
let include = if params.include_impossible {
|
||||
true
|
||||
} else {
|
||||
!question.is_impossible && !question.answers.is_empty()
|
||||
};
|
||||
if !include {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(prefix) = question_prefix(&question.id) else {
|
||||
warn!(
|
||||
question_id = %question.id,
|
||||
"Skipping BEIR question without expected prefix"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
if !prefixes.contains(&prefix) {
|
||||
warn!(
|
||||
question_id = %question.id,
|
||||
prefix = %prefix,
|
||||
"Skipping BEIR question with unknown subset prefix"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
grouped.entry(prefix).or_default().push((p_idx, q_idx));
|
||||
}
|
||||
}
|
||||
|
||||
if grouped.values().all(std::vec::Vec::is_empty) {
|
||||
return Err(anyhow!(
|
||||
"no eligible BEIR questions found; cannot build slice"
|
||||
));
|
||||
}
|
||||
|
||||
for prefix in &prefixes {
|
||||
if let Some(entries) = grouped.get_mut(prefix) {
|
||||
let seed = mix_seed(
|
||||
&format!("{}::{prefix}", dataset.metadata.id),
|
||||
params.base_seed,
|
||||
);
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
entries.shuffle(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
let dataset_count = prefixes.len().max(1);
|
||||
let base_quota = target_cases / dataset_count;
|
||||
let mut remainder = target_cases % dataset_count;
|
||||
|
||||
let mut quotas: HashMap<&str, usize> = HashMap::new();
|
||||
for prefix in &prefixes {
|
||||
let mut quota = base_quota;
|
||||
if remainder > 0 {
|
||||
quota += 1;
|
||||
remainder -= 1;
|
||||
}
|
||||
quotas.insert(*prefix, quota);
|
||||
}
|
||||
|
||||
let mut take_counts: HashMap<&str, usize> = HashMap::new();
|
||||
let mut spare_slots: HashMap<&str, usize> = HashMap::new();
|
||||
let mut shortfall = 0usize;
|
||||
|
||||
for prefix in &prefixes {
|
||||
let available = grouped.get(prefix).map_or(0, std::vec::Vec::len);
|
||||
let quota = *quotas.get(prefix).unwrap_or(&0);
|
||||
let take = quota.min(available);
|
||||
let missing = quota.saturating_sub(take);
|
||||
shortfall += missing;
|
||||
take_counts.insert(*prefix, take);
|
||||
spare_slots.insert(*prefix, available.saturating_sub(take));
|
||||
}
|
||||
|
||||
while shortfall > 0 {
|
||||
let mut allocated = false;
|
||||
for prefix in &prefixes {
|
||||
if shortfall == 0 {
|
||||
break;
|
||||
}
|
||||
let spare = spare_slots.get(prefix).copied().unwrap_or(0);
|
||||
if spare == 0 {
|
||||
continue;
|
||||
}
|
||||
if let Some(count) = take_counts.get_mut(prefix) {
|
||||
*count += 1;
|
||||
}
|
||||
spare_slots.insert(*prefix, spare - 1);
|
||||
shortfall = shortfall.saturating_sub(1);
|
||||
allocated = true;
|
||||
}
|
||||
if !allocated {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut queues: Vec<VecDeque<(usize, usize)>> = Vec::new();
|
||||
let mut total_selected = 0usize;
|
||||
for prefix in &prefixes {
|
||||
let take = *take_counts.get(prefix).unwrap_or(&0);
|
||||
let mut deque = VecDeque::new();
|
||||
if let Some(entries) = grouped.get(prefix) {
|
||||
for item in entries.iter().take(take) {
|
||||
deque.push_back(*item);
|
||||
total_selected += 1;
|
||||
}
|
||||
}
|
||||
queues.push(deque);
|
||||
}
|
||||
|
||||
if total_selected < target_cases {
|
||||
warn!(
|
||||
requested = target_cases,
|
||||
available = total_selected,
|
||||
"BEIR mix requested more questions than available after balancing; continuing with capped set"
|
||||
);
|
||||
}
|
||||
|
||||
let mut output = Vec::with_capacity(total_selected);
|
||||
loop {
|
||||
let mut progressed = false;
|
||||
for queue in &mut queues {
|
||||
if let Some(item) = queue.pop_front() {
|
||||
output.push(item);
|
||||
progressed = true;
|
||||
}
|
||||
}
|
||||
if !progressed {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"no eligible BEIR questions found; cannot build slice"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn question_prefix(question_id: &str) -> Option<&'static str> {
|
||||
for prefix in BEIR_DATASETS.iter().map(|kind| kind.source_prefix()) {
|
||||
if let Some(rest) = question_id.strip_prefix(prefix) {
|
||||
if rest.starts_with('-') {
|
||||
return Some(prefix);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
fn ensure_negative_pool(
|
||||
dataset: &ConvertedDataset,
|
||||
@@ -1028,15 +901,47 @@ fn compute_slice_id(key: &SliceKey<'_>) -> Result<String> {
|
||||
}))
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
fn mix_seed(dataset_id: &str, seed: u64) -> u64 {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(dataset_id.as_bytes());
|
||||
hasher.update(seed.to_le_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let mut bytes = [0u8; 8];
|
||||
bytes.copy_from_slice(&digest[..8]);
|
||||
u64::from_le_bytes(bytes)
|
||||
pub fn read_manifest_if_exists(path: &Path) -> Result<Option<SliceManifest>> {
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
read_manifest(path).map(Some)
|
||||
}
|
||||
|
||||
pub fn cached_manifest_path(config: &crate::args::Config) -> Option<PathBuf> {
|
||||
let slice_arg = config.slice.as_deref()?;
|
||||
let explicit_path = Path::new(slice_arg);
|
||||
if explicit_path.exists() {
|
||||
return Some(explicit_path.to_path_buf());
|
||||
}
|
||||
Some(
|
||||
config
|
||||
.cache_dir
|
||||
.join("slices")
|
||||
.join(config.dataset.id())
|
||||
.join(format!("{slice_arg}.json")),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn manifest_is_complete(manifest: &SliceManifest, config: &SliceConfig<'_>) -> bool {
|
||||
let requested_limit = config.limit.unwrap_or(manifest.case_count.max(1)).max(1);
|
||||
if manifest.case_count < requested_limit {
|
||||
return false;
|
||||
}
|
||||
|
||||
let requested_corpus = config
|
||||
.corpus_limit
|
||||
.unwrap_or(manifest.total_paragraphs.max(1))
|
||||
.max(1);
|
||||
let desired_negatives = desired_negative_target(
|
||||
manifest.positive_paragraphs,
|
||||
requested_corpus,
|
||||
manifest
|
||||
.total_paragraphs
|
||||
.max(manifest.positive_paragraphs.max(1)),
|
||||
config.negative_multiplier,
|
||||
);
|
||||
manifest.negative_paragraphs >= desired_negatives
|
||||
}
|
||||
|
||||
fn read_manifest(path: &Path) -> Result<SliceManifest> {
|
||||
@@ -1057,14 +962,37 @@ fn write_manifest(path: &Path, manifest: &SliceManifest) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
use crate::args::Config;
|
||||
|
||||
impl<'a> From<&'a Config> for SliceConfig<'a> {
|
||||
fn from(config: &'a Config) -> Self {
|
||||
slice_config_with_limit(config, None)
|
||||
pub fn ledger_target(config: &Config) -> Option<usize> {
|
||||
match (config.slice_grow, config.limit) {
|
||||
(Some(grow), Some(limit)) => Some(limit.max(grow)),
|
||||
(Some(grow), None) => Some(grow),
|
||||
(None, limit) => limit,
|
||||
}
|
||||
}
|
||||
|
||||
/// Grow the slice ledger to contain the target number of cases.
|
||||
pub fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
|
||||
let ledger_limit = ledger_target(config);
|
||||
let slice_settings = slice_config_with_limit(config, ledger_limit);
|
||||
let slice = resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
|
||||
info!(
|
||||
slice = slice.manifest.slice_id.as_str(),
|
||||
cases = slice.manifest.case_count,
|
||||
positives = slice.manifest.positive_paragraphs,
|
||||
negatives = slice.manifest.negative_paragraphs,
|
||||
total_paragraphs = slice.manifest.total_paragraphs,
|
||||
"Slice ledger ready"
|
||||
);
|
||||
println!(
|
||||
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
|
||||
slice.manifest.slice_id,
|
||||
slice.manifest.case_count,
|
||||
slice.manifest.positive_paragraphs,
|
||||
slice.manifest.negative_paragraphs
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn slice_config_with_limit(config: &Config, limit_override: Option<usize>) -> SliceConfig<'_> {
|
||||
SliceConfig {
|
||||
cache_dir: config.cache_dir.as_path(),
|
||||
@@ -1088,7 +1016,7 @@ mod tests {
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn sample_dataset() -> ConvertedDataset {
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false, None);
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::SquadV2, false);
|
||||
ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata,
|
||||
@@ -1188,11 +1116,13 @@ mod tests {
|
||||
assert_eq!(window.cases.len(), 1);
|
||||
let positive_ids: Vec<&str> = window.positive_ids().collect();
|
||||
assert_eq!(positive_ids.len(), 1);
|
||||
assert!(resolved
|
||||
.manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.any(|entry| entry.id == positive_ids[0]));
|
||||
assert!(
|
||||
resolved
|
||||
.manifest
|
||||
.paragraphs
|
||||
.iter()
|
||||
.any(|entry| entry.id == positive_ids[0])
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1226,7 +1156,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false, None);
|
||||
let metadata = DatasetMetadata::for_kind(DatasetKind::Beir, false);
|
||||
let dataset = ConvertedDataset {
|
||||
generated_at: Utc::now(),
|
||||
metadata,
|
||||
@@ -1240,11 +1170,11 @@ mod tests {
|
||||
rng_seed: 0xBB,
|
||||
};
|
||||
|
||||
let refs = ordered_question_refs_beir(&dataset, ¶ms, 8)?;
|
||||
let refs = beir::ordered_question_refs_beir(&dataset, ¶ms, 8)?;
|
||||
let mut per_prefix: HashMap<String, usize> = HashMap::new();
|
||||
for (p_idx, q_idx) in refs {
|
||||
let question = &dataset.paragraphs[p_idx].questions[q_idx];
|
||||
let prefix = question_prefix(&question.id).unwrap_or("unknown");
|
||||
let prefix = beir::question_prefix(&question.id).unwrap_or("unknown");
|
||||
*per_prefix.entry(prefix.to_string()).or_default() += 1;
|
||||
}
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::fs;
|
||||
|
||||
use crate::{args::Config, slice};
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SnapshotMetadata {
|
||||
pub dataset_id: String,
|
||||
pub slice_id: String,
|
||||
pub embedding_backend: String,
|
||||
pub embedding_model: Option<String>,
|
||||
pub embedding_dimension: usize,
|
||||
pub rerank_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbSnapshotState {
|
||||
pub dataset_id: String,
|
||||
pub slice_id: String,
|
||||
pub ingestion_fingerprint: String,
|
||||
pub snapshot_hash: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub namespace: Option<String>,
|
||||
#[serde(default)]
|
||||
pub database: Option<String>,
|
||||
#[serde(default)]
|
||||
pub slice_case_count: usize,
|
||||
}
|
||||
|
||||
pub struct Descriptor {
|
||||
#[allow(dead_code)]
|
||||
metadata: SnapshotMetadata,
|
||||
dir: PathBuf,
|
||||
metadata_hash: String,
|
||||
}
|
||||
|
||||
impl Descriptor {
|
||||
pub fn new(
|
||||
config: &Config,
|
||||
slice: &slice::ResolvedSlice<'_>,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
) -> Self {
|
||||
let metadata = SnapshotMetadata {
|
||||
dataset_id: slice.manifest.dataset_id.clone(),
|
||||
slice_id: slice.manifest.slice_id.clone(),
|
||||
embedding_backend: embedding_provider.backend_label().to_string(),
|
||||
embedding_model: embedding_provider.model_code(),
|
||||
embedding_dimension: embedding_provider.dimension(),
|
||||
rerank_enabled: config.retrieval.rerank,
|
||||
};
|
||||
|
||||
let dir = config
|
||||
.cache_dir
|
||||
.join("snapshots")
|
||||
.join(&metadata.dataset_id)
|
||||
.join(&metadata.slice_id);
|
||||
let metadata_hash = compute_hash(&metadata);
|
||||
|
||||
Self {
|
||||
metadata,
|
||||
dir,
|
||||
metadata_hash,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn metadata_hash(&self) -> &str {
|
||||
&self.metadata_hash
|
||||
}
|
||||
|
||||
pub async fn load_db_state(&self) -> Result<Option<DbSnapshotState>> {
|
||||
let path = self.db_state_path();
|
||||
if !path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let bytes = fs::read(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading namespace state {}", path.display()))?;
|
||||
let state = serde_json::from_slice(&bytes)
|
||||
.with_context(|| format!("deserialising namespace state {}", path.display()))?;
|
||||
Ok(Some(state))
|
||||
}
|
||||
|
||||
pub async fn store_db_state(&self, state: &DbSnapshotState) -> Result<()> {
|
||||
let path = self.db_state_path();
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).await.with_context(|| {
|
||||
format!("creating namespace state directory {}", parent.display())
|
||||
})?;
|
||||
}
|
||||
let blob =
|
||||
serde_json::to_vec_pretty(state).context("serialising namespace state payload")?;
|
||||
fs::write(&path, blob)
|
||||
.await
|
||||
.with_context(|| format!("writing namespace state {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn db_dir(&self) -> PathBuf {
|
||||
self.dir.join("db")
|
||||
}
|
||||
|
||||
fn db_state_path(&self) -> PathBuf {
|
||||
self.db_dir().join("state.json")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn from_parts(metadata: SnapshotMetadata, dir: PathBuf) -> Self {
|
||||
let metadata_hash = compute_hash(&metadata);
|
||||
Self {
|
||||
metadata,
|
||||
dir,
|
||||
metadata_hash,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
fn compute_hash(metadata: &SnapshotMetadata) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(
|
||||
serde_json::to_vec(metadata).expect("snapshot metadata serialisation should succeed"),
|
||||
);
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::unwrap_used, clippy::expect_used)]
|
||||
async fn state_round_trip() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let metadata = SnapshotMetadata {
|
||||
dataset_id: "dataset".into(),
|
||||
slice_id: "slice".into(),
|
||||
embedding_backend: "hashed".into(),
|
||||
embedding_model: None,
|
||||
embedding_dimension: 128,
|
||||
rerank_enabled: true,
|
||||
};
|
||||
let descriptor = Descriptor::from_parts(
|
||||
metadata,
|
||||
temp_dir
|
||||
.path()
|
||||
.join("snapshots")
|
||||
.join("dataset")
|
||||
.join("slice"),
|
||||
);
|
||||
|
||||
let state = DbSnapshotState {
|
||||
dataset_id: "dataset".into(),
|
||||
slice_id: "slice".into(),
|
||||
ingestion_fingerprint: "fingerprint".into(),
|
||||
snapshot_hash: descriptor.metadata_hash().to_string(),
|
||||
updated_at: Utc::now(),
|
||||
namespace: Some("ns".into()),
|
||||
database: Some("db".into()),
|
||||
slice_case_count: 42,
|
||||
};
|
||||
descriptor.store_db_state(&state).await.unwrap();
|
||||
|
||||
let loaded = descriptor.load_db_state().await.unwrap().unwrap();
|
||||
assert_eq!(loaded.dataset_id, state.dataset_id);
|
||||
assert_eq!(loaded.slice_id, state.slice_id);
|
||||
assert_eq!(loaded.ingestion_fingerprint, state.ingestion_fingerprint);
|
||||
assert_eq!(loaded.snapshot_hash, state.snapshot_hash);
|
||||
assert_eq!(loaded.namespace, state.namespace);
|
||||
assert_eq!(loaded.database, state.database);
|
||||
assert_eq!(loaded.slice_case_count, state.slice_case_count);
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user