mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-09 16:02:45 +02:00
Compare commits
364 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8fe4ac9fec | |||
| db43be1606 | |||
| 8e8370b080 | |||
| 84695fa0cc | |||
| 654add98bc | |||
| 244ec0ea25 | |||
| d8416ac711 | |||
| f9f48d1046 | |||
| 30b8a65377 | |||
| 04faa38ee6 | |||
| cdc62dda30 | |||
| ab8ff8b07a | |||
| 79ea007b0a | |||
| a5bc72aedf | |||
| 2e2ea0c4ff | |||
| a090a8c76e | |||
| a8d10f265c | |||
| 0cb1abc6db | |||
| d1a6d9abdf | |||
| d3fa3be3e5 | |||
| a2c9bb848d | |||
| dd881efbf9 | |||
| 2939e4c2a4 | |||
| 1039ec32a4 | |||
| cb906c5b53 | |||
| 08b1612fcb | |||
| 67004c9646 | |||
| 030f0fc17d | |||
| 226b2db43a | |||
| 6f88d87e74 | |||
| bd519ab269 | |||
| f535df7e61 | |||
| 6b7befbd04 | |||
| 0eda65b07e | |||
| 04ee225732 | |||
| 13b7ad6f3a | |||
| 112a6965a4 | |||
| 911e830be5 | |||
| 3196e65172 | |||
| 380c900c86 | |||
| a99e5ada8b | |||
| b0deabaf3f | |||
| a8f0d9fa88 | |||
| 56a1dfddb8 | |||
| 863b921fb4 | |||
| f13791cfcf | |||
| 75c200b2ba | |||
| 1b7c24747a | |||
| 241ad9a089 | |||
| 72578296db | |||
| a0e9387c76 | |||
| 798b1468b6 | |||
| 3b805778b4 | |||
| 07b3e1a0e8 | |||
| 83d39afad4 | |||
| 21e4ab1f42 | |||
| 3c97d8ead5 | |||
| ab68bccb80 | |||
| 99b88c3063 | |||
| 44e5d8a2fc | |||
| 7332347f1a | |||
| 199186e5a3 | |||
| 64728468cd | |||
| c3a7e8dc59 | |||
| 35ff4e1464 | |||
| 2964f1a5a5 | |||
| cb7f625b81 | |||
| dc40cf7663 | |||
| aa0b1462a1 | |||
| 41fc7bb99c | |||
| 61d8d7abe7 | |||
| b7344644dc | |||
| 3742598a6d | |||
| c6a6080e1c | |||
| 1159712724 | |||
| e5e1414f54 | |||
| fcc49b1954 | |||
| 022f4d8575 | |||
| 945a2b7f37 | |||
| ff4ea55cd5 | |||
| c4c76efe92 | |||
| c0fcad5952 | |||
| b0ed69330d | |||
| 5cb15dab45 | |||
| 7403195df5 | |||
| 9faef31387 | |||
| 110f7b8a8f | |||
| f343005af8 | |||
| e1d98b0c35 | |||
| c12d00edaa | |||
| 903585bfef | |||
| f592eb7200 | |||
| c2839f8db3 | |||
| 9a7c57cb19 | |||
| fe5143cd7f | |||
| 3f774302c7 | |||
| 6ea51095e8 | |||
| 62d909bb7e | |||
| 69954cf78e | |||
| 153efd1a98 | |||
| fdf29bb735 | |||
| e150b476c3 | |||
| 2e076c8236 | |||
| a0632c9768 | |||
| 33300d3193 | |||
| b8272d519d | |||
| ec16f2100c | |||
| 43263fa77e | |||
| f1548d18db | |||
| f2bafe0205 | |||
| 9a23c1ea1b | |||
| f567b7198b | |||
| 32141fce6e | |||
| 477f26174c | |||
| 37584ed9fd | |||
| a363c6cc05 | |||
| 811aaec554 | |||
| 7aa7e71287 | |||
| d2772bd09c | |||
| 81825e2525 | |||
| 8289194b6c | |||
| 46c3090ebb | |||
| 5baab2f2a1 | |||
| cf8f6c6d40 | |||
| 6066b1ccaa | |||
| 30b9e7673c | |||
| 724c1f79cc | |||
| aba3c37fc8 | |||
| f5a57f90a0 | |||
| 45cd472d8b | |||
| b93e7b5299 | |||
| bc7891a3e7 | |||
| d173064ffc | |||
| d504903db3 | |||
| 850878d5c3 | |||
| 6ed49f7155 | |||
| 463c562ba7 | |||
| c6a4f13261 | |||
| 0213d0b4c4 | |||
| 13985de6ad | |||
| 1624c08397 | |||
| 336c7d33f1 | |||
| 204d0c5497 | |||
| 959b519c7d | |||
| 24ed70db9c | |||
| 16910acbf5 | |||
| 322d1ec318 | |||
| 6ad625befc | |||
| c2fbdecce0 | |||
| 9fceadb41f | |||
| 8d1be03078 | |||
| 0169c9b525 | |||
| 3f032c6930 | |||
| bcdd3628ef | |||
| 1230ea6126 | |||
| 02198dc21a | |||
| 776a454a88 | |||
| ce006f6ecc | |||
| 3fdee5a3a3 | |||
| 0c09b13a38 | |||
| 6e4aa12fb2 | |||
| 5af217170c | |||
| d4f097ab98 | |||
| a0b80a911e | |||
| 3d420582dc | |||
| d10157ac26 | |||
| e571e6ec76 | |||
| 233df1b79a | |||
| 435547de66 | |||
| 804461ac01 | |||
| c3c0f69bba | |||
| c000ac07b3 | |||
| 5bc48fb30b | |||
| 20fc43638b | |||
| 89635cfd41 | |||
| 42798788a5 | |||
| 56a2e1d801 | |||
| c556d48151 | |||
| a40ed0fe94 | |||
| d01d8b6bd7 | |||
| b956bdbe2b | |||
| 12cbf66d21 | |||
| 42e63600a1 | |||
| b4cf020f36 | |||
| ba0db2649c | |||
| c52c4d18a1 | |||
| bc5c4b32b6 | |||
| bbc1cbc302 | |||
| 60a0d621e1 | |||
| 1a641db503 | |||
| ef1478547e | |||
| 4ab5d3b551 | |||
| de59b864e8 | |||
| 847571729b | |||
| 037bc52a64 | |||
| 8127a9adeb | |||
| efc2e481af | |||
| eed720a548 | |||
| 0938c37d4b | |||
| 1f760248c4 | |||
| d0ab590d1d | |||
| 813e454f89 | |||
| a3c7d1dad0 | |||
| ecee3415f6 | |||
| 30de324879 | |||
| e7c666c054 | |||
| e8e7f610de | |||
| 79680f3245 | |||
| f92e97e541 | |||
| aa8574f4ce | |||
| 85b6291a46 | |||
| 6dcae4471b | |||
| 711e54aea1 | |||
| fd19195148 | |||
| 33380b3bcf | |||
| bdbb23d839 | |||
| a9ad904063 | |||
| 37b7a752b3 | |||
| fe9505e39d | |||
| 5d86be457c | |||
| bc8f050fcc | |||
| fd769018ce | |||
| 0fe253a127 | |||
| ae63416943 | |||
| d04f3faba5 | |||
| 9a3a8f3a1e | |||
| 078b4040bb | |||
| 03c3347652 | |||
| 703480e8b3 | |||
| 16e0611a88 | |||
| 5a1095f538 | |||
| 1532ada09f | |||
| c953dffe56 | |||
| e8daf026ca | |||
| 296dde9ee7 | |||
| e1cf17ac86 | |||
| f545ff8ecb | |||
| 4f2a08b610 | |||
| 730df164cd | |||
| 0f3f6db437 | |||
| 44ee9b6ce9 | |||
| e58ead5cd7 | |||
| 2e70bd0636 | |||
| 972919a515 | |||
| 2d96f06478 | |||
| 81a4a68d70 | |||
| 9b95321869 | |||
| 424ec81a44 | |||
| 9245ff949d | |||
| 8a7ecd635f | |||
| d77f07c626 | |||
| 2b1348aae2 | |||
| 1066f27162 | |||
| ad9215e251 | |||
| 58a9ed0b4f | |||
| b2eb4858ad | |||
| 71063a18af | |||
| df3176842f | |||
| a7eb5105e5 | |||
| 4366c88b42 | |||
| 367f976173 | |||
| ab23909e0f | |||
| 9976fef5a3 | |||
| a3bb73646c | |||
| 981da5a649 | |||
| 69596e8c5d | |||
| 0110ffa5bf | |||
| bdf7a12e9c | |||
| 53820e8b5f | |||
| a3984ab348 | |||
| 61e81ca894 | |||
| fdfe4a05b8 | |||
| 7d75a6c3c0 | |||
| ec3a1463a7 | |||
| 563032febb | |||
| 84e5cdd587 | |||
| ff1c8836ee | |||
| 0a5e37130d | |||
| fa6283485c | |||
| 291c473d00 | |||
| ae4781363f | |||
| 96f2e765f6 | |||
| 766653030d | |||
| 15edc8aa6b | |||
| 4803632f0a | |||
| 6ad4071d63 | |||
| b61dae1ea8 | |||
| 620530bce1 | |||
| 50994c7df9 | |||
| f9a382db6d | |||
| 58ddc1f824 | |||
| 1fad4a5b1a | |||
| 5d1e48d493 | |||
| 624f5d5cbc | |||
| 5a5a91b5cd | |||
| 7c8f658e2a | |||
| 58d8bb6c0c | |||
| b7c228e340 | |||
| bb57844a68 | |||
| 2bc46e5dfe | |||
| 86e939f564 | |||
| d23b5d4021 | |||
| 6efa8cd4ee | |||
| 8936daa53f | |||
| 4dbd517bf6 | |||
| d731e69bf9 | |||
| 6f077e0190 | |||
| 52a27d3983 | |||
| 3e9463e5da | |||
| 3acd320f34 | |||
| b5ad101d3f | |||
| 5f64fbd3fb | |||
| cd7901eefe | |||
| 683e2769f8 | |||
| 8df215d6c3 | |||
| ca598d52c3 | |||
| 200707af90 | |||
| fc20526789 | |||
| 838aa4c1ec | |||
| 2f20d5986d | |||
| c3ccb8c034 | |||
| 8ba853a329 | |||
| 0b874b427a | |||
| d81825c786 | |||
| e3fd7ce20f | |||
| 87fc392eb6 | |||
| dfa597758f | |||
| 0f491f88d3 | |||
| 3f06bf969a | |||
| 6d8cd05c1a | |||
| 40cf2e3b5b | |||
| 2acd00ff8e | |||
| 8f9692b113 | |||
| f860e8ff86 | |||
| 06b70d7d4e | |||
| ceb0147213 | |||
| 6af463ed44 | |||
| 37dd4c7337 | |||
| 3e673e823b | |||
| ce75c4656f | |||
| 83d8452316 | |||
| 13b8fce803 | |||
| 07355ec41e | |||
| 3d53c8b3aa | |||
| b5f9317634 | |||
| 43e5d4d629 | |||
| 7d13c06a51 | |||
| 8e7e762130 | |||
| a348950234 | |||
| c1ad819164 | |||
| 27cf09b81e | |||
| deac2ba0a3 | |||
| 18c84f5c88 | |||
| 1bd2eee8e1 | |||
| 7fbfa0fdbc | |||
| a9aee21db3 | |||
| 9be73c30d4 | |||
| d45ec5b503 | |||
| 2e74468392 | |||
| b0ca889459 | |||
| 30de906ee6 | |||
| b6e06779f9 | |||
| 6ba76f230a | |||
| bb3ed4b034 |
@@ -0,0 +1,2 @@
|
||||
[alias]
|
||||
eval = "run -p evaluations --"
|
||||
@@ -0,0 +1,49 @@
|
||||
- name: Prepare lib dir
|
||||
run: mkdir -p lib
|
||||
|
||||
# Linux
|
||||
- name: Fetch ONNX Runtime (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
set -euo pipefail
|
||||
ARCH="$(uname -m)"
|
||||
case "$ARCH" in
|
||||
x86_64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-x64-${ORT_VER}.tgz" ;;
|
||||
aarch64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-aarch64-${ORT_VER}.tgz" ;;
|
||||
*) echo "Unsupported arch $ARCH"; exit 1 ;;
|
||||
esac
|
||||
curl -fsSL -o ort.tgz "$URL"
|
||||
tar -xzf ort.tgz
|
||||
cp -v onnxruntime-*/lib/libonnxruntime.so* lib/
|
||||
|
||||
# macOS
|
||||
- name: Fetch ONNX Runtime (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
set -euo pipefail
|
||||
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
||||
tar -xzf ort.tgz
|
||||
# copy the main dylib; rename to stable name if needed
|
||||
cp -v onnxruntime-*/lib/libonnxruntime*.dylib lib/
|
||||
# optional: ensure a stable name
|
||||
if [ ! -f lib/libonnxruntime.dylib ]; then
|
||||
cp -v lib/libonnxruntime*.dylib lib/libonnxruntime.dylib
|
||||
fi
|
||||
|
||||
# Windows
|
||||
- name: Fetch ONNX Runtime (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
||||
Invoke-WebRequest $url -OutFile ort.zip
|
||||
Expand-Archive ort.zip -DestinationPath ort
|
||||
$dll = Get-ChildItem -Recurse -Path ort -Filter onnxruntime.dll | Select-Object -First 1
|
||||
Copy-Item $dll.FullName lib\onnxruntime.dll
|
||||
|
||||
+121
-121
@@ -1,44 +1,8 @@
|
||||
# This file was autogenerated by dist: https://opensource.axo.dev/cargo-dist/
|
||||
#
|
||||
# Copyright 2022-2024, axodotdev
|
||||
# SPDX-License-Identifier: MIT or Apache-2.0
|
||||
#
|
||||
# CI that:
|
||||
#
|
||||
# * checks for a Git Tag that looks like a release
|
||||
# * builds artifacts with dist (archives, installers, hashes)
|
||||
# * uploads those artifacts to temporary workflow zip
|
||||
# * on success, uploads the artifacts to a GitHub Release
|
||||
#
|
||||
# Note that the GitHub Release will be created with a generated
|
||||
# title/body based on your changelogs.
|
||||
|
||||
name: Release
|
||||
permissions:
|
||||
"contents": "write"
|
||||
"packages": "write"
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
# This task will run whenever you push a git tag that looks like a version
|
||||
# like "1.0.0", "v0.1.0-prerelease.1", "my-app/0.1.0", "releases/v1.0.0", etc.
|
||||
# Various formats will be parsed into a VERSION and an optional PACKAGE_NAME, where
|
||||
# PACKAGE_NAME must be the name of a Cargo package in your workspace, and VERSION
|
||||
# must be a Cargo-style SemVer Version (must have at least major.minor.patch).
|
||||
#
|
||||
# If PACKAGE_NAME is specified, then the announcement will be for that
|
||||
# package (erroring out if it doesn't have the given version or isn't dist-able).
|
||||
#
|
||||
# If PACKAGE_NAME isn't specified, then the announcement will be for all
|
||||
# (dist-able) packages in the workspace with that version (this mode is
|
||||
# intended for workspaces with only one dist-able package, or with all dist-able
|
||||
# packages versioned/released in lockstep).
|
||||
#
|
||||
# If you push multiple tags at once, separate instances of this workflow will
|
||||
# spin up, creating an independent announcement for each one. However, GitHub
|
||||
# will hard limit this to 3 tags per commit, as it will assume more tags is a
|
||||
# mistake.
|
||||
#
|
||||
# If there's a prerelease-style suffix to the version, then the release(s)
|
||||
# will be marked as a prerelease.
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
@@ -46,9 +10,8 @@ on:
|
||||
- '**[0-9]+.[0-9]+.[0-9]+*'
|
||||
|
||||
jobs:
|
||||
# Run 'dist plan' (or host) to determine what tasks we need to do
|
||||
plan:
|
||||
runs-on: "ubuntu-22.04"
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
val: ${{ steps.plan.outputs.manifest }}
|
||||
tag: ${{ !github.event.pull_request && github.ref_name || '' }}
|
||||
@@ -60,52 +23,36 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install dist
|
||||
# we specify bash to get pipefail; it guards against the `curl` command
|
||||
# failing. otherwise `sh` won't catch that `curl` returned non-0
|
||||
shell: bash
|
||||
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.28.0/cargo-dist-installer.sh | sh"
|
||||
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.30.0/cargo-dist-installer.sh | sh"
|
||||
|
||||
- name: Cache dist
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/dist
|
||||
# sure would be cool if github gave us proper conditionals...
|
||||
# so here's a doubly-nested ternary-via-truthiness to try to provide the best possible
|
||||
# functionality based on whether this is a pull_request, and whether it's from a fork.
|
||||
# (PRs run on the *source* but secrets are usually on the *target* -- that's *good*
|
||||
# but also really annoying to build CI around when it needs secrets to work right.)
|
||||
|
||||
- id: plan
|
||||
run: |
|
||||
dist ${{ (!github.event.pull_request && format('host --steps=create --tag={0}', github.ref_name)) || 'plan' }} --output-format=json > plan-dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
cat plan-dist-manifest.json
|
||||
echo "manifest=$(jq -c "." plan-dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
- name: "Upload dist-manifest.json"
|
||||
echo "manifest=$(jq -c . plan-dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Upload dist-manifest.json
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-plan-dist-manifest
|
||||
path: plan-dist-manifest.json
|
||||
|
||||
# Build and packages all the platform-specific things
|
||||
build-local-artifacts:
|
||||
name: build-local-artifacts (${{ join(matrix.targets, ', ') }})
|
||||
# Let the initial task tell us to not run (currently very blunt)
|
||||
needs:
|
||||
- plan
|
||||
needs: [plan]
|
||||
if: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix.include != null && (needs.plan.outputs.publishing == 'true' || fromJson(needs.plan.outputs.val).ci.github.pr_run_mode == 'upload') }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
# Target platforms/runners are computed by dist in create-release.
|
||||
# Each member of the matrix has the following arguments:
|
||||
#
|
||||
# - runner: the github runner
|
||||
# - dist-args: cli flags to pass to dist
|
||||
# - install-dist: expression to run to install dist on the runner
|
||||
#
|
||||
# Typically there will be:
|
||||
# - 1 "global" task that builds universal installers
|
||||
# - N "local" tasks that build each platform's binaries and platform-specific installers
|
||||
matrix: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix }}
|
||||
runs-on: ${{ matrix.runner }}
|
||||
container: ${{ matrix.container && matrix.container.image || null }}
|
||||
@@ -114,11 +61,12 @@ jobs:
|
||||
BUILD_MANIFEST_NAME: target/distrib/${{ join(matrix.targets, '-') }}-dist-manifest.json
|
||||
steps:
|
||||
- name: enable windows longpaths
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
run: git config --global core.longpaths true
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install Rust non-interactively if not already installed
|
||||
if: ${{ matrix.container }}
|
||||
run: |
|
||||
@@ -126,37 +74,103 @@ jobs:
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install dist
|
||||
run: ${{ matrix.install_dist.run }}
|
||||
# Get the dist-manifest
|
||||
|
||||
- name: Fetch local artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
# ===== BEGIN: Injected ORT staging for cargo-dist bundling =====
|
||||
- run: echo "=== BUILD-SETUP START ==="
|
||||
|
||||
# Unix shells
|
||||
- name: Prepare lib dir (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p lib
|
||||
rm -f lib/*
|
||||
|
||||
# Windows PowerShell
|
||||
- name: Prepare lib dir (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
New-Item -ItemType Directory -Force -Path lib | Out-Null
|
||||
# remove contents if any
|
||||
Get-ChildItem -Path lib -Force | Remove-Item -Force -Recurse -ErrorAction SilentlyContinue
|
||||
|
||||
- name: Fetch ONNX Runtime (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
set -euo pipefail
|
||||
ARCH="$(uname -m)"
|
||||
case "$ARCH" in
|
||||
x86_64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-x64-${ORT_VER}.tgz" ;;
|
||||
aarch64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-aarch64-${ORT_VER}.tgz" ;;
|
||||
*) echo "Unsupported arch $ARCH"; exit 1 ;;
|
||||
esac
|
||||
curl -fsSL -o ort.tgz "$URL"
|
||||
tar -xzf ort.tgz
|
||||
cp -v onnxruntime-*/lib/libonnxruntime.so* lib/
|
||||
# normalize to stable name if needed
|
||||
[ -f lib/libonnxruntime.so ] || cp -v lib/libonnxruntime.so.* lib/libonnxruntime.so
|
||||
|
||||
- name: Fetch ONNX Runtime (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
set -euo pipefail
|
||||
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
||||
tar -xzf ort.tgz
|
||||
cp -v onnxruntime-*/lib/libonnxruntime*.dylib lib/
|
||||
[ -f lib/libonnxruntime.dylib ] || cp -v lib/libonnxruntime*.dylib lib/libonnxruntime.dylib
|
||||
|
||||
- name: Fetch ONNX Runtime (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
||||
Invoke-WebRequest $url -OutFile ort.zip
|
||||
Expand-Archive ort.zip -DestinationPath ort
|
||||
$dll = Get-ChildItem -Recurse -Path ort -Filter onnxruntime.dll | Select-Object -First 1
|
||||
Copy-Item $dll.FullName lib\onnxruntime.dll
|
||||
|
||||
- run: |
|
||||
echo "=== BUILD-SETUP END ==="
|
||||
echo "lib/ contents:"
|
||||
ls -l lib || dir lib
|
||||
# ===== END: Injected ORT staging =====
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
${{ matrix.packages_install }}
|
||||
|
||||
- name: Build artifacts
|
||||
run: |
|
||||
# Actually do builds and make zips and whatnot
|
||||
dist build ${{ needs.plan.outputs.tag-flag }} --print=linkage --output-format=json ${{ matrix.dist_args }} > dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
|
||||
- id: cargo-dist
|
||||
name: Post-build
|
||||
# We force bash here just because github makes it really hard to get values up
|
||||
# to "real" actions without writing to env-vars, and writing to env-vars has
|
||||
# inconsistent syntax between shell and powershell.
|
||||
shell: bash
|
||||
run: |
|
||||
# Parse out what we just built and upload it to scratch storage
|
||||
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
||||
dist print-upload-files-from-manifest --manifest dist-manifest.json >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
|
||||
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
||||
- name: "Upload artifacts"
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-build-local-${{ join(matrix.targets, '_') }}
|
||||
@@ -167,16 +181,16 @@ jobs:
|
||||
build_and_push_docker_image:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
needs: [plan]
|
||||
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
||||
needs: [plan]
|
||||
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
||||
permissions:
|
||||
contents: read # Permission to checkout the repository
|
||||
packages: write # Permission to push Docker image to GHCR
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive # Matches your other checkout steps
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -185,33 +199,28 @@ jobs:
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }} # User triggering the workflow
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract Docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/${{ github.repository }}
|
||||
# This action automatically uses the Git tag as the Docker image tag.
|
||||
# For example, a Git tag 'v1.2.3' will result in Docker tag 'ghcr.io/owner/repo:v1.2.3'.
|
||||
images: ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha # Enable Docker layer caching from GitHub Actions cache
|
||||
cache-to: type=gha,mode=max # Enable Docker layer caching to GitHub Actions cache
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# Build and package all the platform-agnostic(ish) things
|
||||
build-global-artifacts:
|
||||
needs:
|
||||
- plan
|
||||
- build-local-artifacts
|
||||
runs-on: "ubuntu-22.04"
|
||||
needs: [plan, build-local-artifacts]
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
BUILD_MANIFEST_NAME: target/distrib/global-dist-manifest.json
|
||||
@@ -219,92 +228,90 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install cached dist
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/
|
||||
- run: chmod +x ~/.cargo/bin/dist
|
||||
# Get all the local artifacts for the global tasks to use (for e.g. checksums)
|
||||
|
||||
- name: Fetch local artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
- id: cargo-dist
|
||||
shell: bash
|
||||
run: |
|
||||
dist build ${{ needs.plan.outputs.tag-flag }} --output-format=json "--artifacts=global" > dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
|
||||
# Parse out what we just built and upload it to scratch storage
|
||||
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
||||
jq --raw-output ".upload_files[]" dist-manifest.json >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
|
||||
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
||||
- name: "Upload artifacts"
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-build-global
|
||||
path: |
|
||||
${{ steps.cargo-dist.outputs.paths }}
|
||||
${{ env.BUILD_MANIFEST_NAME }}
|
||||
# Determines if we should publish/announce
|
||||
|
||||
host:
|
||||
needs:
|
||||
- plan
|
||||
- build-local-artifacts
|
||||
- build-global-artifacts
|
||||
# Only run if we're "publishing", and only if local and global didn't fail (skipped is fine)
|
||||
needs: [plan, build-local-artifacts, build-global-artifacts]
|
||||
if: ${{ always() && needs.plan.outputs.publishing == 'true' && (needs.build-global-artifacts.result == 'skipped' || needs.build-global-artifacts.result == 'success') && (needs.build-local-artifacts.result == 'skipped' || needs.build-local-artifacts.result == 'success') }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
runs-on: "ubuntu-22.04"
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
val: ${{ steps.host.outputs.manifest }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install cached dist
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/
|
||||
- run: chmod +x ~/.cargo/bin/dist
|
||||
# Fetch artifacts from scratch-storage
|
||||
|
||||
- name: Fetch artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
- id: host
|
||||
shell: bash
|
||||
run: |
|
||||
dist host ${{ needs.plan.outputs.tag-flag }} --steps=upload --steps=release --output-format=json > dist-manifest.json
|
||||
echo "artifacts uploaded and released successfully"
|
||||
cat dist-manifest.json
|
||||
echo "manifest=$(jq -c "." dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
- name: "Upload dist-manifest.json"
|
||||
echo "manifest=$(jq -c . dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Upload dist-manifest.json
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
# Overwrite the previous copy
|
||||
name: artifacts-dist-manifest
|
||||
path: dist-manifest.json
|
||||
# Create a GitHub Release while uploading all files to it
|
||||
- name: "Download GitHub Artifacts"
|
||||
|
||||
- name: Download GitHub Artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: artifacts
|
||||
merge-multiple: true
|
||||
|
||||
- name: Cleanup
|
||||
run: |
|
||||
# Remove the granular manifests
|
||||
rm -f artifacts/*-dist-manifest.json
|
||||
run: rm -f artifacts/*-dist-manifest.json
|
||||
|
||||
- name: Create GitHub Release
|
||||
env:
|
||||
PRERELEASE_FLAG: "${{ fromJson(steps.host.outputs.manifest).announcement_is_prerelease && '--prerelease' || '' }}"
|
||||
@@ -312,20 +319,13 @@ jobs:
|
||||
ANNOUNCEMENT_BODY: "${{ fromJson(steps.host.outputs.manifest).announcement_github_body }}"
|
||||
RELEASE_COMMIT: "${{ github.sha }}"
|
||||
run: |
|
||||
# Write and read notes from a file to avoid quoting breaking things
|
||||
echo "$ANNOUNCEMENT_BODY" > $RUNNER_TEMP/notes.txt
|
||||
|
||||
gh release create "${{ needs.plan.outputs.tag }}" --target "$RELEASE_COMMIT" $PRERELEASE_FLAG --title "$ANNOUNCEMENT_TITLE" --notes-file "$RUNNER_TEMP/notes.txt" artifacts/*
|
||||
|
||||
announce:
|
||||
needs:
|
||||
- plan
|
||||
- host
|
||||
# use "always() && ..." to allow us to wait for all publish jobs while
|
||||
# still allowing individual publish jobs to skip themselves (for prereleases).
|
||||
# "host" however must run to completion, no skipping allowed!
|
||||
needs: [plan, host]
|
||||
if: ${{ always() && needs.host.result == 'success' }}
|
||||
runs-on: "ubuntu-22.04"
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
|
||||
+7
-3
@@ -1,15 +1,18 @@
|
||||
.devenv
|
||||
.direnv
|
||||
.devenv
|
||||
database
|
||||
node_modules
|
||||
config.yaml
|
||||
|
||||
|
||||
target
|
||||
.aider
|
||||
.aider.chat.history.md .aider.input.history .aider.tags.cache.v3 .aider.tags.cache.v3/cache.db
|
||||
result
|
||||
data
|
||||
database
|
||||
|
||||
evaluations/cache/
|
||||
evaluations/reports/
|
||||
|
||||
# Devenv
|
||||
.devenv*
|
||||
devenv.local.nix
|
||||
@@ -21,3 +24,4 @@ devenv.local.nix
|
||||
.pre-commit-config.yaml
|
||||
# html-router/assets/style.css
|
||||
html-router/node_modules
|
||||
.fastembed_cache/
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
# Changelog
|
||||
## 1.0.0 (2026-01-02)
|
||||
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
|
||||
- Added a benchmarks create for evaluating the retrieval process
|
||||
- Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms.
|
||||
- Embeddings stored on own table.
|
||||
- Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details.
|
||||
|
||||
## Version 0.2.7 (2025-12-04)
|
||||
- Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features.
|
||||
- Fix: timezone aware info in scratchpad
|
||||
|
||||
## Version 0.2.6 (2025-10-29)
|
||||
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||
- Fix: default name for relationships harmonized across application
|
||||
|
||||
## Version 0.2.5 (2025-10-24)
|
||||
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
||||
- Scratchpad feature, with the feature to convert scratchpads to content.
|
||||
- Added knowledge entity search results to the global search
|
||||
- Backend fixes for improved performance when ingesting and retrieval
|
||||
|
||||
## Version 0.2.4 (2025-10-15)
|
||||
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
|
||||
- Ingestion task archive
|
||||
|
||||
## Version 0.2.3 (2025-10-12)
|
||||
- Fix changing vector dimensions on a fresh database (#3)
|
||||
|
||||
## Version 0.2.2 (2025-10-07)
|
||||
- Support for ingestion of PDF files
|
||||
- Improved ingestion speed
|
||||
- Fix deletion of items work as expected
|
||||
- Fix enabling GPT-5 use via OpenAI API
|
||||
|
||||
## Version 0.2.1 (2025-09-24)
|
||||
- Fixed API JSON responses so iOS Shortcuts integrations keep working.
|
||||
|
||||
## Version 0.2.0 (2025-09-23)
|
||||
- Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph.
|
||||
- Added pagination for entities and content plus new observability metrics on the dashboard.
|
||||
- Enabled audio ingestion and merged the new storage backend.
|
||||
- Improved performance, request filtering, and journalctl/systemd compatibility.
|
||||
|
||||
## Version 0.1.4 (2025-07-01)
|
||||
- Added image ingestion with configurable system settings and updated Docker Compose docs.
|
||||
- Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses.
|
||||
|
||||
## Version 0.1.3 (2025-06-08)
|
||||
- Added support for AI providers beyond OpenAI.
|
||||
- Made the HTTP port configurable for deployments.
|
||||
- Smoothed graph mapper failures, long content tiles, and refreshed project documentation.
|
||||
|
||||
## Version 0.1.2 (2025-05-26)
|
||||
- Introduced full-text search across indexed knowledge.
|
||||
- Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling.
|
||||
- Fixed search result links and SurrealDB vector formatting glitches.
|
||||
|
||||
## Version 0.1.1 (2025-05-13)
|
||||
- Added streaming feedback to ingestion tasks for clearer progress updates.
|
||||
- Made the data storage path configurable.
|
||||
- Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes.
|
||||
|
||||
## Version 0.1.0 (2025-05-06)
|
||||
- Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage.
|
||||
- Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts.
|
||||
- Introduced an admin console with analytics, registration and timezone controls, and job monitoring.
|
||||
- Shipped a Tailwind/daisyUI web UI with responsive layouts, modals, content viewers, and editing flows.
|
||||
- Provided readability-based content ingestion, API/HTML ingress routes, and Docker/Docker Compose tooling.
|
||||
Generated
+1673
-171
File diff suppressed because it is too large
Load Diff
+95
-13
@@ -5,30 +5,112 @@ members = [
|
||||
"api-router",
|
||||
"html-router",
|
||||
"ingestion-pipeline",
|
||||
"composite-retrieval",
|
||||
"json-stream-parser"
|
||||
"retrieval-pipeline",
|
||||
"json-stream-parser",
|
||||
"evaluations"
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
axum = { version = "0.8", features = ["multipart", "macros"] }
|
||||
serde_json = "1.0.128"
|
||||
thiserror = "1.0.63"
|
||||
anyhow = "1.0.94"
|
||||
tracing = "0.1.40"
|
||||
surrealdb = { version = "2", features = ["kv-mem"] }
|
||||
futures = "0.3.31"
|
||||
async-openai = "0.24.1"
|
||||
async-openai = "0.29.3"
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1.88"
|
||||
axum-htmx = "0.7.0"
|
||||
axum_session = "0.16"
|
||||
axum_session_auth = "0.16"
|
||||
axum_session_surreal = "0.4"
|
||||
axum_typed_multipart = "0.16"
|
||||
tempfile = "3.12.0"
|
||||
axum = { version = "0.8", features = ["multipart", "macros"] }
|
||||
chrono-tz = "0.10.1"
|
||||
chrono = { version = "0.4.39", features = ["serde"] }
|
||||
config = "0.15.4"
|
||||
dom_smoothie = "0.10.0"
|
||||
futures = "0.3.31"
|
||||
headless_chrome = "1.0.17"
|
||||
include_dir = "0.7.4"
|
||||
mime = "0.3.17"
|
||||
mime_guess = "2.0.5"
|
||||
minijinja-autoreload = "2.5.0"
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
minijinja-embed = { version = "2.8.0" }
|
||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
reqwest = {version = "0.12.12", features = ["charset", "json"]}
|
||||
serde_json = "1.0.128"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
sha2 = "0.10.8"
|
||||
surrealdb-migrations = "2.2.2"
|
||||
surrealdb = { version = "2", features = ["kv-mem"] }
|
||||
tempfile = "3.12.0"
|
||||
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
|
||||
tokenizers = { version = "0.20.4", features = ["http"] }
|
||||
unicode-normalization = "0.1.24"
|
||||
thiserror = "1.0.63"
|
||||
tokio-util = { version = "0.7.15", features = ["io"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tower-http = { version = "0.6.2", features = ["fs", "compression-full"] }
|
||||
tower-serve-static = "0.1.1"
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
url = { version = "2.5.2", features = ["serde"] }
|
||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||
tokio-retry = "0.3.0"
|
||||
base64 = "0.22.1"
|
||||
object_store = { version = "0.11.2" }
|
||||
bytes = "1.7.1"
|
||||
state-machines = "0.2.0"
|
||||
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
|
||||
|
||||
# The profile that 'dist' will build with
|
||||
[profile.dist]
|
||||
inherits = "release"
|
||||
lto = "thin"
|
||||
|
||||
[workspace.lints.clippy]
|
||||
# Performance-focused lints
|
||||
perf = { level = "warn", priority = -1 }
|
||||
vec_init_then_push = "warn"
|
||||
large_stack_frames = "warn"
|
||||
redundant_allocation = "warn"
|
||||
single_char_pattern = "warn"
|
||||
string_extend_chars = "warn"
|
||||
format_in_format_args = "warn"
|
||||
slow_vector_initialization = "warn"
|
||||
inefficient_to_string = "warn"
|
||||
implicit_clone = "warn"
|
||||
redundant_clone = "warn"
|
||||
|
||||
# Security-focused lints
|
||||
arithmetic_side_effects = "warn"
|
||||
indexing_slicing = "warn"
|
||||
unwrap_used = "warn"
|
||||
expect_used = "warn"
|
||||
panic = "warn"
|
||||
unimplemented = "warn"
|
||||
todo = "warn"
|
||||
|
||||
# Async/Network lints
|
||||
async_yields_async = "warn"
|
||||
await_holding_invalid_type = "warn"
|
||||
rc_buffer = "warn"
|
||||
|
||||
# Maintainability-focused lints
|
||||
cargo = { level = "warn", priority = -1 }
|
||||
pedantic = { level = "warn", priority = -1 }
|
||||
clone_on_ref_ptr = "warn"
|
||||
float_cmp = "warn"
|
||||
manual_string_new = "warn"
|
||||
uninlined_format_args = "warn"
|
||||
unused_self = "warn"
|
||||
must_use_candidate = "allow"
|
||||
missing_errors_doc = "allow"
|
||||
missing_panics_doc = "warn"
|
||||
module_name_repetitions = "warn"
|
||||
wildcard_dependencies = "warn"
|
||||
missing_docs_in_private_items = "warn"
|
||||
|
||||
# Allow noisy lints that don't add value for this project
|
||||
needless_raw_string_hashes = "allow"
|
||||
multiple_bound_locations = "allow"
|
||||
cargo_common_metadata = "allow"
|
||||
multiple-crate-versions = "allow"
|
||||
module_name_repetition = "allow"
|
||||
|
||||
+32
-34
@@ -1,53 +1,51 @@
|
||||
# === Builder Stage ===
|
||||
FROM clux/muslrust:1.86.0-stable as builder
|
||||
|
||||
# === Builder ===
|
||||
FROM rust:1.86-bookworm AS builder
|
||||
WORKDIR /usr/src/minne
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cache deps
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
RUN mkdir -p api-router common composite-retrieval html-router ingestion-pipeline json-stream-parser main worker
|
||||
RUN mkdir -p api-router common retrieval-pipeline html-router ingestion-pipeline json-stream-parser main worker
|
||||
COPY api-router/Cargo.toml ./api-router/
|
||||
COPY common/Cargo.toml ./common/
|
||||
COPY composite-retrieval/Cargo.toml ./composite-retrieval/
|
||||
COPY retrieval-pipeline/Cargo.toml ./retrieval-pipeline/
|
||||
COPY html-router/Cargo.toml ./html-router/
|
||||
COPY ingestion-pipeline/Cargo.toml ./ingestion-pipeline/
|
||||
COPY json-stream-parser/Cargo.toml ./json-stream-parser/
|
||||
COPY main/Cargo.toml ./main/
|
||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker || true
|
||||
|
||||
# Build with the MUSL target
|
||||
RUN cargo build --release --target x86_64-unknown-linux-musl --bin main --features ingestion-pipeline/docker || true
|
||||
|
||||
# Copy the rest of the source code
|
||||
# Build
|
||||
COPY . .
|
||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker
|
||||
|
||||
# Build the final application binary with the MUSL target
|
||||
RUN cargo build --release --target x86_64-unknown-linux-musl --bin main --features ingestion-pipeline/docker
|
||||
# === Runtime ===
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# === Runtime Stage ===
|
||||
FROM alpine:latest
|
||||
# Chromium + runtime deps + OpenMP for ORT
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
chromium libnss3 libasound2 libgbm1 libxshmfence1 \
|
||||
ca-certificates fonts-dejavu fonts-noto-color-emoji \
|
||||
libgomp1 libstdc++6 curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN apk update && apk add --no-cache \
|
||||
chromium \
|
||||
nss \
|
||||
freetype \
|
||||
harfbuzz \
|
||||
ca-certificates \
|
||||
ttf-freefont \
|
||||
font-noto-emoji \
|
||||
&& \
|
||||
rm -rf /var/cache/apk/*
|
||||
# ONNX Runtime (CPU). Change if you bump ort.
|
||||
ARG ORT_VERSION=1.22.0
|
||||
RUN mkdir -p /opt/onnxruntime && \
|
||||
curl -fsSL -o /tmp/ort.tgz \
|
||||
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
||||
tar -xzf /tmp/ort.tgz -C /opt/onnxruntime --strip-components=1 && rm /tmp/ort.tgz
|
||||
|
||||
ENV CHROME_BIN=/usr/bin/chromium-browser \
|
||||
CHROME_PATH=/usr/lib/chromium/ \
|
||||
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt
|
||||
ENV CHROME_BIN=/usr/bin/chromium \
|
||||
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \
|
||||
ORT_DYLIB_PATH=/opt/onnxruntime/lib/libonnxruntime.so
|
||||
|
||||
# Create a non-root user to run the application
|
||||
RUN adduser -D -h /home/appuser appuser
|
||||
WORKDIR /home/appuser
|
||||
# Non-root
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
WORKDIR /home/appuser
|
||||
|
||||
# Copy the compiled binary from the builder stage (note the target path)
|
||||
COPY --from=builder /usr/src/minne/target/x86_64-unknown-linux-musl/release/main /usr/local/bin/main
|
||||
|
||||
COPY --from=builder /usr/src/minne/target/release/main /usr/local/bin/main
|
||||
EXPOSE 3000
|
||||
# EXPOSE 8000-9000
|
||||
|
||||
CMD ["main"]
|
||||
|
||||
@@ -0,0 +1,661 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
@@ -0,0 +1,66 @@
|
||||
# Minne
|
||||
|
||||
**A graph-powered personal knowledge base that makes storing easy.**
|
||||
|
||||
Capture content effortlessly, let AI discover connections, and explore your knowledge visually. Self-hosted and privacy-focused.
|
||||
|
||||
[](https://github.com/perstarkse/minne/actions/workflows/release.yml)
|
||||
[](https://www.gnu.org/licenses/agpl-3.0)
|
||||
[](https://github.com/perstarkse/minne/releases/latest)
|
||||
|
||||

|
||||
|
||||
## Try It
|
||||
|
||||
**[Live Demo](https://minne-demo.stark.pub)** — Read-only demo deployment
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
git clone https://github.com/perstarkse/minne.git
|
||||
cd minne
|
||||
|
||||
# Set your OpenAI API key in docker-compose.yml, then:
|
||||
docker compose up -d
|
||||
|
||||
# Open http://localhost:3000
|
||||
```
|
||||
|
||||
Or with Nix (with environment variables set):
|
||||
|
||||
```bash
|
||||
nix run 'github:perstarkse/minne#main'
|
||||
```
|
||||
|
||||
Pre-built binaries for Windows, macOS, and Linux are available on the [Releases](https://github.com/perstarkse/minne/releases/latest) page.
|
||||
|
||||
## Features
|
||||
|
||||
- **Fast** — Rust backend with server-side rendering and HTMX for snappy interactions
|
||||
- **Search & Chat** — Search or use conversational AI to find and reason about content
|
||||
- **Knowledge Graph** — Visual exploration with automatic or manual relationship curation
|
||||
- **Hybrid Retrieval** — Vector similarity + full-text for relevant results
|
||||
- **Multi-Format** — Ingest text, URLs, PDFs, audio, and images
|
||||
- **Self-Hosted** — Your data, your server, any OpenAI-compatible API
|
||||
|
||||
## Documentation
|
||||
|
||||
| Guide | Description |
|
||||
|-------|-------------|
|
||||
| [Installation](docs/installation.md) | Docker, Nix, binaries, source builds |
|
||||
| [Configuration](docs/configuration.md) | Environment variables, config.yaml, AI setup |
|
||||
| [Features](docs/features.md) | Search, Chat, Graph, Reranking, Ingestion |
|
||||
| [Architecture](docs/architecture.md) | Tech stack, crate structure, data flow |
|
||||
| [Vision](docs/vision.md) | Philosophy, roadmap, related projects |
|
||||
|
||||
## Tech Stack
|
||||
|
||||
Rust • Axum • HTMX • SurrealDB • FastEmbed
|
||||
|
||||
## Contributing
|
||||
|
||||
Feature requests and contributions welcome. See [Vision](docs/vision.md) for roadmap.
|
||||
|
||||
## License
|
||||
|
||||
[AGPL-3.0](LICENSE)
|
||||
@@ -2,6 +2,10 @@
|
||||
name = "api-router"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{storage::db::SurrealDbClient, utils::config::AppConfig};
|
||||
use common::{
|
||||
storage::{db::SurrealDbClient, store::StorageManager},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiState {
|
||||
pub db: Arc<SurrealDbClient>,
|
||||
pub config: AppConfig,
|
||||
pub storage: StorageManager,
|
||||
}
|
||||
|
||||
impl ApiState {
|
||||
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
pub async fn new(
|
||||
config: &AppConfig,
|
||||
storage: StorageManager,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let surreal_db_client = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
@@ -23,9 +30,10 @@ impl ApiState {
|
||||
|
||||
surreal_db_client.apply_migrations().await?;
|
||||
|
||||
let app_state = ApiState {
|
||||
let app_state = Self {
|
||||
db: surreal_db_client.clone(),
|
||||
config: config.clone(),
|
||||
storage,
|
||||
};
|
||||
|
||||
Ok(app_state)
|
||||
|
||||
@@ -27,40 +27,40 @@ impl From<AppError> for ApiError {
|
||||
match err {
|
||||
AppError::Database(_) | AppError::OpenAI(_) => {
|
||||
tracing::error!("Internal error: {:?}", err);
|
||||
ApiError::InternalError("Internal server error".to_string())
|
||||
Self::InternalError("Internal server error".to_string())
|
||||
}
|
||||
AppError::NotFound(msg) => ApiError::NotFound(msg),
|
||||
AppError::Validation(msg) => ApiError::ValidationError(msg),
|
||||
AppError::Auth(msg) => ApiError::Unauthorized(msg),
|
||||
_ => ApiError::InternalError("Internal server error".to_string()),
|
||||
AppError::NotFound(msg) => Self::NotFound(msg),
|
||||
AppError::Validation(msg) => Self::ValidationError(msg),
|
||||
AppError::Auth(msg) => Self::Unauthorized(msg),
|
||||
_ => Self::InternalError("Internal server error".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_response) = match self {
|
||||
ApiError::InternalError(message) => (
|
||||
Self::InternalError(message) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::ValidationError(message) => (
|
||||
Self::ValidationError(message) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::NotFound(message) => (
|
||||
Self::NotFound(message) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::Unauthorized(message) => (
|
||||
Self::Unauthorized(message) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
|
||||
+11
-3
@@ -6,7 +6,7 @@ use axum::{
|
||||
Router,
|
||||
};
|
||||
use middleware_api_auth::api_auth;
|
||||
use routes::{categories::get_categories, ingress::ingest_data};
|
||||
use routes::{categories::get_categories, ingress::ingest_data, liveness::live, readiness::ready};
|
||||
|
||||
pub mod api_state;
|
||||
pub mod error;
|
||||
@@ -19,9 +19,17 @@ where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
ApiState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
// Public, unauthenticated endpoints (for k8s/systemd probes)
|
||||
let public = Router::new()
|
||||
.route("/ready", get(ready))
|
||||
.route("/live", get(live));
|
||||
|
||||
// Protected API endpoints (require auth)
|
||||
let protected = Router::new()
|
||||
.route("/ingress", post(ingest_data))
|
||||
.route("/categories", get(get_categories))
|
||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth))
|
||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth));
|
||||
|
||||
public.merge(protected)
|
||||
}
|
||||
|
||||
@@ -13,14 +13,12 @@ pub async fn api_auth(
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ApiError> {
|
||||
let api_key = extract_api_key(&request).ok_or(ApiError::Unauthorized(
|
||||
"You have to be authenticated".to_string(),
|
||||
))?;
|
||||
let api_key = extract_api_key(&request)
|
||||
.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
|
||||
let user = User::find_by_api_key(&api_key, &state.db).await?;
|
||||
let user = user.ok_or(ApiError::Unauthorized(
|
||||
"You have to be authenticated".to_string(),
|
||||
))?;
|
||||
let user =
|
||||
user.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
@@ -37,7 +35,7 @@ fn extract_api_key(request: &Request) -> Option<String> {
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(|s| s.trim()))
|
||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim))
|
||||
})
|
||||
.map(String::from)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension};
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
use common::{
|
||||
error::AppError,
|
||||
@@ -8,6 +8,7 @@ use common::{
|
||||
},
|
||||
};
|
||||
use futures::{future::try_join_all, TryFutureExt};
|
||||
use serde_json::json;
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::info;
|
||||
|
||||
@@ -29,9 +30,11 @@ pub async fn ingest_data(
|
||||
TypedMultipart(input): TypedMultipart<IngestParams>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
info!("Received input: {:?}", input);
|
||||
let user_id = user.id;
|
||||
|
||||
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
||||
FileInfo::new(file, &state.db, &user.id, &state.config).map_err(AppError::from)
|
||||
FileInfo::new_with_storage(file, &state.db, &user_id, &state.storage)
|
||||
.map_err(AppError::from)
|
||||
}))
|
||||
.await?;
|
||||
|
||||
@@ -40,17 +43,15 @@ pub async fn ingest_data(
|
||||
input.context,
|
||||
input.category,
|
||||
file_infos,
|
||||
user.id.as_str(),
|
||||
&user_id,
|
||||
)?;
|
||||
|
||||
let futures: Vec<_> = payloads
|
||||
.into_iter()
|
||||
.map(|object| {
|
||||
IngestionTask::create_and_add_to_db(object.clone(), user.id.clone(), &state.db)
|
||||
})
|
||||
.map(|object| IngestionTask::create_and_add_to_db(object, user_id.clone(), &state.db))
|
||||
.collect();
|
||||
|
||||
try_join_all(futures).await.map_err(AppError::from)?;
|
||||
try_join_all(futures).await?;
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
Ok((StatusCode::OK, Json(json!({ "status": "success" }))))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
use axum::{http::StatusCode, response::IntoResponse, Json};
|
||||
use serde_json::json;
|
||||
|
||||
/// Liveness probe: always returns 200 to indicate the process is running.
|
||||
pub async fn live() -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(json!({"status": "ok"})))
|
||||
}
|
||||
@@ -1,2 +1,4 @@
|
||||
pub mod categories;
|
||||
pub mod ingress;
|
||||
pub mod liveness;
|
||||
pub mod readiness;
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::api_state::ApiState;
|
||||
|
||||
/// Readiness probe: returns 200 if core dependencies are ready, else 503.
|
||||
pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
|
||||
match state.db.client.query("RETURN true").await {
|
||||
Ok(_) => (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"status": "ok",
|
||||
"checks": { "db": "ok" }
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"status": "error",
|
||||
"checks": { "db": "fail" },
|
||||
"reason": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
[core]
|
||||
path = "./"
|
||||
schema = "less"
|
||||
|
||||
[db]
|
||||
address = "ws://localhost:8000"
|
||||
username = "root_user"
|
||||
password = "root_password"
|
||||
ns = "test"
|
||||
db = "test"
|
||||
+25
-17
@@ -2,6 +2,10 @@
|
||||
name = "common"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# Workspace dependencies
|
||||
@@ -17,28 +21,32 @@ async-openai = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
dom_smoothie = { workspace = true }
|
||||
|
||||
async-trait = "0.1.88"
|
||||
axum_session = { workspace = true }
|
||||
axum_session_auth = { workspace = true }
|
||||
axum_session_surreal = { workspace = true}
|
||||
axum_typed_multipart = { workspace = true}
|
||||
chrono = { version = "0.4.39", features = ["serde"] }
|
||||
chrono-tz = "0.10.1"
|
||||
config = "0.15.4"
|
||||
mime = "0.3.17"
|
||||
mime_guess = "2.0.5"
|
||||
reqwest = {version = "0.12.12", features = ["charset", "json"]}
|
||||
sha2 = "0.10.8"
|
||||
url = { version = "2.5.2", features = ["serde"] }
|
||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||
surrealdb-migrations = "2.2.2"
|
||||
include_dir = { workspace = true }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-autoreload = { workspace = true }
|
||||
minijinja-embed = { workspace = true }
|
||||
minijinja-contrib = {workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
chrono-tz = { workspace = true }
|
||||
config = { workspace = true }
|
||||
mime = { workspace = true }
|
||||
mime_guess = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
url = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
surrealdb-migrations = { workspace = true }
|
||||
tokio-retry = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
state-machines = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
minijinja-autoreload = "2.5.0"
|
||||
minijinja-embed = { version = "2.8.0" }
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
include_dir = "0.7.4"
|
||||
|
||||
[features]
|
||||
test-utils = []
|
||||
|
||||
@@ -13,6 +13,11 @@ CREATE system_settings:current CONTENT {
|
||||
require_email_verification: false,
|
||||
query_model: "gpt-4o-mini",
|
||||
processing_model: "gpt-4o-mini",
|
||||
embedding_model: "text-embedding-3-small",
|
||||
voice_processing_model: "whisper-1",
|
||||
image_processing_model: "gpt-4o-mini",
|
||||
image_processing_prompt: "Analyze this image and respond based on its primary content:\n - If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.\n - If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.\n - For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a Text: heading.\n\n Respond directly with the analysis.",
|
||||
embedding_dimensions: 1536,
|
||||
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
|
||||
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity\"\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."
|
||||
};
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Runtime-managed: text_content FTS indexes now created at startup via the shared Surreal helper.
|
||||
-- This migration is intentionally left as a no-op to avoid heavy index builds during migration.
|
||||
+7
@@ -0,0 +1,7 @@
|
||||
DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;
|
||||
|
||||
UPDATE system_settings:current SET
|
||||
embedding_model = "text-embedding-3-small",
|
||||
embedding_dimensions = 1536
|
||||
WHERE embedding_model == NONE && embedding_dimensions == NONE;
|
||||
@@ -0,0 +1,7 @@
|
||||
DEFINE FIELD IF NOT EXISTS image_processing_model ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;
|
||||
|
||||
UPDATE system_settings:current SET
|
||||
image_processing_model = "gpt-4o-mini",
|
||||
image_processing_prompt = "Analyze this image and respond based on its primary content:\n - If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.\n - If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.\n - For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a Text: heading.\n\n Respond directly with the analysis."
|
||||
WHERE image_processing_model == NONE && image_processing_prompt == NONE;
|
||||
@@ -0,0 +1 @@
|
||||
-- No-op: legacy `job` table was superseded by `ingestion_task`; kept for migration order compatibility.
|
||||
@@ -0,0 +1,5 @@
|
||||
DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;
|
||||
|
||||
UPDATE system_settings:current SET
|
||||
voice_processing_model = "whisper-1"
|
||||
WHERE voice_processing_model == NONE;
|
||||
@@ -0,0 +1,115 @@
|
||||
-- Align timestamp fields with SurrealDB native datetime type.
|
||||
|
||||
-- User timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON user FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON user FLEXIBLE;
|
||||
|
||||
UPDATE user SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE user SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON user TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON user TYPE datetime;
|
||||
|
||||
-- Text content timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON text_content FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON text_content FLEXIBLE;
|
||||
|
||||
UPDATE text_content SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE text_content SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON text_content TYPE datetime;
|
||||
|
||||
REBUILD INDEX text_content_created_at_idx ON text_content;
|
||||
|
||||
-- Text chunk timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON text_chunk FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON text_chunk FLEXIBLE;
|
||||
|
||||
UPDATE text_chunk SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE text_chunk SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON text_chunk TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON text_chunk TYPE datetime;
|
||||
|
||||
-- Knowledge entity timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON knowledge_entity FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON knowledge_entity FLEXIBLE;
|
||||
|
||||
UPDATE knowledge_entity SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE knowledge_entity SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON knowledge_entity TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON knowledge_entity TYPE datetime;
|
||||
|
||||
REBUILD INDEX knowledge_entity_created_at_idx ON knowledge_entity;
|
||||
|
||||
-- Conversation timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON conversation FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON conversation FLEXIBLE;
|
||||
|
||||
UPDATE conversation SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE conversation SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON conversation TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON conversation TYPE datetime;
|
||||
|
||||
REBUILD INDEX conversation_created_at_idx ON conversation;
|
||||
|
||||
-- Message timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON message FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON message FLEXIBLE;
|
||||
|
||||
UPDATE message SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE message SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON message TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON message TYPE datetime;
|
||||
|
||||
REBUILD INDEX message_updated_at_idx ON message;
|
||||
|
||||
-- Ingestion task timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON ingestion_task FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON ingestion_task FLEXIBLE;
|
||||
|
||||
UPDATE ingestion_task SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE ingestion_task SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON ingestion_task TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON ingestion_task TYPE datetime;
|
||||
|
||||
REBUILD INDEX idx_ingestion_task_created ON ingestion_task;
|
||||
|
||||
-- File timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON file FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON file FLEXIBLE;
|
||||
|
||||
UPDATE file SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE file SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON file TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON file TYPE datetime;
|
||||
@@ -0,0 +1 @@
|
||||
-- Runtime-managed: FTS indexes now built at startup; migration retained as a no-op.
|
||||
@@ -0,0 +1,173 @@
|
||||
-- State machine migration for ingestion_task records
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS state ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS attempts ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS max_attempts ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS scheduled_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS locked_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS lease_duration_secs ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS worker_id ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS error_code ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS error_message ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS last_error_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS priority ON TABLE ingestion_task TYPE option<number>;
|
||||
|
||||
REMOVE FIELD status ON TABLE ingestion_task;
|
||||
DEFINE FIELD status ON TABLE ingestion_task TYPE option<object>;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS idx_ingestion_task_state_sched ON TABLE ingestion_task FIELDS state, scheduled_at;
|
||||
|
||||
LET $needs_migration = (SELECT count() AS count FROM type::table('ingestion_task') WHERE state = NONE)[0].count;
|
||||
|
||||
IF $needs_migration > 0 THEN {
|
||||
-- Created -> Pending
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Pending",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF created_at != NONE THEN created_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Created";
|
||||
|
||||
-- InProgress -> Processing
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Processing",
|
||||
attempts = IF status.attempts != NONE THEN status.attempts ELSE 1 END,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
|
||||
locked_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "InProgress";
|
||||
|
||||
-- Completed -> Succeeded
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Succeeded",
|
||||
attempts = 1,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Completed";
|
||||
|
||||
-- Error -> DeadLetter (terminal failure)
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "DeadLetter",
|
||||
attempts = 3,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = status.message,
|
||||
last_error_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Error";
|
||||
|
||||
-- Cancelled -> Cancelled
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Cancelled",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Cancelled";
|
||||
|
||||
-- Fallback for any remaining records missing state
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Pending",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE;
|
||||
} END;
|
||||
|
||||
-- Ensure defaults for newly added fields
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET max_attempts = 3
|
||||
WHERE max_attempts = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET lease_duration_secs = 300
|
||||
WHERE lease_duration_secs = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET attempts = 0
|
||||
WHERE attempts = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET priority = 0
|
||||
WHERE priority = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END
|
||||
WHERE scheduled_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET locked_at = NONE
|
||||
WHERE locked_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET worker_id = NONE
|
||||
WHERE worker_id != NONE AND worker_id = "";
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET error_code = NONE
|
||||
WHERE error_code = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET error_message = NONE
|
||||
WHERE error_message = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET last_error_at = NONE
|
||||
WHERE last_error_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET status = NONE
|
||||
WHERE status != NONE;
|
||||
@@ -0,0 +1,24 @@
|
||||
-- Add scratchpad table and schema
|
||||
|
||||
-- Define scratchpad table and schema
|
||||
DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;
|
||||
|
||||
-- Standard fields from stored_object! macro
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;
|
||||
|
||||
-- Custom fields from the Scratchpad struct
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;
|
||||
DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;
|
||||
DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;
|
||||
|
||||
-- Indexes based on query patterns
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;
|
||||
@@ -0,0 +1,18 @@
|
||||
-- Remove HNSW indexes from base tables (now created at runtime on *_embedding tables)
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
|
||||
|
||||
-- Remove FTS indexes (now created at runtime via indexes.rs)
|
||||
REMOVE INDEX IF EXISTS text_content_fts_text_idx ON text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_category_idx ON text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_context_idx ON text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_url_idx ON text_content;
|
||||
REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON text_content;
|
||||
REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
|
||||
REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
|
||||
REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
|
||||
|
||||
-- Remove legacy analyzers (recreated at runtime with updated configuration)
|
||||
REMOVE ANALYZER IF EXISTS app_default_fts_analyzer;
|
||||
REMOVE ANALYZER IF EXISTS app_en_fts_analyzer;
|
||||
@@ -0,0 +1,23 @@
|
||||
-- Move chunk/entity embeddings to dedicated tables for index efficiency.
|
||||
|
||||
-- Text chunk embeddings table
|
||||
DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
|
||||
|
||||
-- Knowledge entity embeddings table
|
||||
DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
||||
@@ -0,0 +1,23 @@
|
||||
-- Copy embeddings from base tables to dedicated tables
|
||||
-- This runs BEFORE the field removal migration
|
||||
|
||||
FOR $chunk IN (SELECT * FROM text_chunk WHERE embedding != NONE AND array::len(embedding) > 0) {
|
||||
CREATE text_chunk_embedding CONTENT {
|
||||
chunk_id: $chunk.id,
|
||||
embedding: $chunk.embedding,
|
||||
user_id: $chunk.user_id,
|
||||
source_id: $chunk.source_id,
|
||||
created_at: $chunk.created_at,
|
||||
updated_at: $chunk.updated_at
|
||||
};
|
||||
};
|
||||
|
||||
FOR $entity IN (SELECT * FROM knowledge_entity WHERE embedding != NONE AND array::len(embedding) > 0) {
|
||||
CREATE knowledge_entity_embedding CONTENT {
|
||||
entity_id: $entity.id,
|
||||
embedding: $entity.embedding,
|
||||
user_id: $entity.user_id,
|
||||
created_at: $entity.created_at,
|
||||
updated_at: $entity.updated_at
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,3 @@
|
||||
-- Drop legacy embedding fields from base tables; embeddings now live in *_embedding tables.
|
||||
REMOVE FIELD IF EXISTS embedding ON TABLE text_chunk;
|
||||
REMOVE FIELD IF EXISTS embedding ON TABLE knowledge_entity;
|
||||
@@ -0,0 +1,8 @@
|
||||
-- Add embedding_backend field to system_settings for visibility of active backend
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;
|
||||
|
||||
-- Set default to 'openai' for existing installs to preserve backward compatibility
|
||||
UPDATE system_settings:current SET
|
||||
embedding_backend = 'openai'
|
||||
WHERE embedding_backend == NONE;
|
||||
@@ -0,0 +1,97 @@
|
||||
-- Enforce SCHEMAFULL on all tables and define missing fields
|
||||
|
||||
-- 1. Define missing fields for ingestion_task (formerly job, but now ingestion_task)
|
||||
DEFINE TABLE OVERWRITE ingestion_task SCHEMAFULL;
|
||||
|
||||
-- Core Fields
|
||||
DEFINE FIELD IF NOT EXISTS id ON ingestion_task TYPE record<ingestion_task>;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime DEFAULT time::now();
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime DEFAULT time::now();
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;
|
||||
|
||||
-- State Machine Fields
|
||||
DEFINE FIELD IF NOT EXISTS state ON ingestion_task TYPE string ASSERT $value IN ['Pending', 'Reserved', 'Processing', 'Succeeded', 'Failed', 'Cancelled', 'DeadLetter'];
|
||||
DEFINE FIELD IF NOT EXISTS attempts ON ingestion_task TYPE int DEFAULT 0;
|
||||
DEFINE FIELD IF NOT EXISTS max_attempts ON ingestion_task TYPE int DEFAULT 3;
|
||||
DEFINE FIELD IF NOT EXISTS scheduled_at ON ingestion_task TYPE datetime DEFAULT time::now();
|
||||
DEFINE FIELD IF NOT EXISTS locked_at ON ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS lease_duration_secs ON ingestion_task TYPE int DEFAULT 300;
|
||||
DEFINE FIELD IF NOT EXISTS worker_id ON ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS error_code ON ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS error_message ON ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS last_error_at ON ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS priority ON ingestion_task TYPE int DEFAULT 0;
|
||||
|
||||
-- Content Payload (IngestionPayload Enum)
|
||||
DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;
|
||||
DEFINE FIELD IF NOT EXISTS content.Url ON ingestion_task TYPE option<object>;
|
||||
DEFINE FIELD IF NOT EXISTS content.Text ON ingestion_task TYPE option<object>;
|
||||
DEFINE FIELD IF NOT EXISTS content.File ON ingestion_task TYPE option<object>;
|
||||
|
||||
-- Content: Url Variant
|
||||
DEFINE FIELD IF NOT EXISTS content.Url.url ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.Url.context ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.Url.category ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.Url.user_id ON ingestion_task TYPE string;
|
||||
|
||||
-- Content: Text Variant
|
||||
DEFINE FIELD IF NOT EXISTS content.Text.text ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.Text.context ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.Text.category ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.Text.user_id ON ingestion_task TYPE string;
|
||||
|
||||
-- Content: File Variant
|
||||
DEFINE FIELD IF NOT EXISTS content.File.context ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.category ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.user_id ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info ON ingestion_task TYPE object;
|
||||
|
||||
-- Content: File.file_info (FileInfo Struct)
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info.id ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info.created_at ON ingestion_task TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info.updated_at ON ingestion_task TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info.sha256 ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info.path ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info.file_name ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info.mime_type ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content.File.file_info.user_id ON ingestion_task TYPE string;
|
||||
|
||||
-- 2. Enforce SCHEMAFULL on all other tables
|
||||
DEFINE TABLE OVERWRITE analytics SCHEMAFULL;
|
||||
DEFINE TABLE OVERWRITE conversation SCHEMAFULL;
|
||||
DEFINE TABLE OVERWRITE file SCHEMAFULL;
|
||||
DEFINE TABLE OVERWRITE knowledge_entity SCHEMAFULL;
|
||||
DEFINE TABLE OVERWRITE message SCHEMAFULL;
|
||||
DEFINE TABLE OVERWRITE relates_to SCHEMAFULL TYPE RELATION;
|
||||
DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
|
||||
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
|
||||
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
|
||||
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
|
||||
DEFINE TABLE OVERWRITE scratchpad SCHEMAFULL;
|
||||
DEFINE TABLE OVERWRITE system_settings SCHEMAFULL;
|
||||
DEFINE TABLE OVERWRITE text_chunk SCHEMAFULL;
|
||||
-- text_content must have fields defined before enforcing SCHEMAFULL
|
||||
DEFINE TABLE OVERWRITE text_content SCHEMAFULL;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
|
||||
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
|
||||
DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;
|
||||
|
||||
DEFINE TABLE OVERWRITE user SCHEMAFULL;
|
||||
@@ -1 +0,0 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -147,6 +147,7 @@\n\n DEFINE FIELD OVERWRITE script_name ON script_migration TYPE string;\n DEFINE FIELD OVERWRITE executed_at ON script_migration TYPE datetime VALUE time::now() READONLY;\n+\n # Defines the schema for the 'system_settings' table.\n\n DEFINE TABLE IF NOT EXISTS system_settings SCHEMALESS;\n","events":null}
|
||||
@@ -1 +0,0 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -198,11 +198,11 @@\n DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;\n # UrlInfo is a struct, store as object\n DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;\n-DEFINE FIELD IF NOT EXISTS instructions ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;\n DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;\n\n-# Indexes based on query patterns (get_latest_text_contents, get_text_contents_by_category)\n+# Indexes based on query patterns\n DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;\n DEFINE INDEX IF NOT EXISTS text_content_category_idx ON text_content FIELDS category;\n","events":null}
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -242,7 +242,7 @@\n\n # Defines the schema for the 'text_content' table.\n\n-DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS text_content SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n@@ -254,10 +254,24 @@\n DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;\n # UrlInfo is a struct, store as object\n DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;\n+DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;\n+\n DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;\n DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;\n\n+# FileInfo fields\n+DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;\n+\n # Indexes based on query patterns\n DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;\n","events":null}
|
||||
File diff suppressed because one or more lines are too long
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE datetime;
|
||||
|
||||
# Custom fields from the Conversation struct
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS file SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON file TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON file TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE datetime;
|
||||
|
||||
# Custom fields from the FileInfo struct
|
||||
DEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
# Defines the schema for the 'ingestion_task' table (used by IngestionTask).
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS job SCHEMALESS;
|
||||
DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON job TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON job TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime;
|
||||
|
||||
# Custom fields from the IngestionTask struct
|
||||
# IngestionPayload is complex, store as object
|
||||
DEFINE FIELD IF NOT EXISTS content ON job TYPE object;
|
||||
# IngestionTaskStatus can hold data (InProgress), store as object
|
||||
DEFINE FIELD IF NOT EXISTS status ON job TYPE object;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON job TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;
|
||||
DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;
|
||||
|
||||
# Indexes explicitly defined in build_indexes and useful for get_unfinished_tasks
|
||||
DEFINE INDEX IF NOT EXISTS idx_job_status ON job FIELDS status;
|
||||
DEFINE INDEX IF NOT EXISTS idx_job_user ON job FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS idx_job_created ON job FIELDS created_at;
|
||||
DEFINE INDEX IF NOT EXISTS idx_ingestion_task_status ON ingestion_task FIELDS status;
|
||||
DEFINE INDEX IF NOT EXISTS idx_ingestion_task_user ON ingestion_task FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS idx_ingestion_task_created ON ingestion_task FIELDS created_at;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE datetime;
|
||||
|
||||
# Custom fields from the KnowledgeEntity struct
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;
|
||||
@@ -15,16 +15,12 @@ DEFINE FIELD IF NOT EXISTS entity_type ON knowledge_entity TYPE string;
|
||||
# metadata is Option<serde_json::Value>, store as object
|
||||
DEFINE FIELD IF NOT EXISTS metadata ON knowledge_entity TYPE option<object>;
|
||||
|
||||
# Define embedding as a standard array of floats for schema definition
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity TYPE array<float>;
|
||||
# The specific vector nature is handled by the index definition below
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
|
||||
|
||||
# Indexes based on build_indexes and query patterns
|
||||
# The INDEX definition correctly specifies the vector properties
|
||||
DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||
-- Indexes based on build_indexes and query patterns
|
||||
-- HNSW index now defined on knowledge_entity_embedding table for better memory usage
|
||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
-- Defines the schema for the 'knowledge_entity_embedding' table.
|
||||
-- Separate table to optimize HNSW index creation memory usage
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
|
||||
|
||||
-- Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
||||
|
||||
-- Custom fields
|
||||
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;
|
||||
|
||||
-- Indexes
|
||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS message SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON message TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON message TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE datetime;
|
||||
|
||||
# Custom fields from the Message struct
|
||||
DEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
# Defines the schema for the 'scratchpad' table.
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;
|
||||
|
||||
# Standard fields from stored_object! macro
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;
|
||||
|
||||
# Custom fields from the Scratchpad struct
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;
|
||||
DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;
|
||||
DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;
|
||||
|
||||
# Indexes based on query patterns
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;
|
||||
@@ -7,5 +7,10 @@ DEFINE FIELD IF NOT EXISTS registrations_enabled ON system_settings TYPE bool;
|
||||
DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;
|
||||
DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS image_processing_model ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;
|
||||
DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;
|
||||
|
||||
@@ -3,21 +3,15 @@
|
||||
DEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;
|
||||
|
||||
# Custom fields from the TextChunk struct
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS chunk ON text_chunk TYPE string;
|
||||
|
||||
# Define embedding as a standard array of floats for schema definition
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk TYPE array<float>;
|
||||
# The specific vector nature is handled by the index definition below
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk TYPE string;
|
||||
|
||||
# Indexes based on build_indexes and query patterns (delete_by_source_id)
|
||||
# The INDEX definition correctly specifies the vector properties
|
||||
DEFINE INDEX IF NOT EXISTS idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
-- Defines the schema for the 'text_chunk_embedding' table.
|
||||
-- Separate table to optimize HNSW index creation memory usage
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;
|
||||
|
||||
# Custom fields
|
||||
DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;
|
||||
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
|
||||
|
||||
-- Indexes
|
||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
|
||||
@@ -1,10 +1,10 @@
|
||||
# Defines the schema for the 'text_content' table.
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;
|
||||
DEFINE TABLE IF NOT EXISTS text_content SCHEMAFULL;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;
|
||||
|
||||
# Custom fields from the TextContent struct
|
||||
DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
|
||||
@@ -12,10 +12,24 @@ DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
|
||||
# UrlInfo is a struct, store as object
|
||||
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
|
||||
DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
|
||||
|
||||
# FileInfo fields
|
||||
DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;
|
||||
|
||||
# Indexes based on query patterns
|
||||
DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS user SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON user TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON user TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE datetime;
|
||||
|
||||
# Custom fields from the User struct
|
||||
DEFINE FIELD IF NOT EXISTS email ON user TYPE string;
|
||||
|
||||
+3
-2
@@ -5,6 +5,7 @@ use tokio::task::JoinError;
|
||||
use crate::storage::types::file_info::FileError;
|
||||
|
||||
// Core internal errors
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Database error: {0}")]
|
||||
@@ -29,8 +30,8 @@ pub enum AppError {
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Reqwest error: {0}")]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
#[error("Tiktoken error: {0}")]
|
||||
Tiktoken(#[from] anyhow::Error),
|
||||
#[error("Anyhow error: {0}")]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
#[error("Ingestion Processing error: {0}")]
|
||||
Processing(String),
|
||||
#[error("DOM smoothie error: {0}")]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![allow(clippy::doc_markdown)]
|
||||
//! Shared utilities and storage helpers for the workspace crates.
|
||||
pub mod error;
|
||||
pub mod storage;
|
||||
pub mod utils;
|
||||
|
||||
+87
-10
@@ -7,17 +7,20 @@ use include_dir::{include_dir, Dir};
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
use surrealdb::{
|
||||
engine::any::{connect, Any},
|
||||
opt::auth::Root,
|
||||
opt::auth::{Namespace, Root},
|
||||
Error, Notification, Surreal,
|
||||
};
|
||||
use surrealdb_migrations::MigrationRunner;
|
||||
use tracing::debug;
|
||||
|
||||
/// Embedded SurrealDB migration directory packaged with the crate.
|
||||
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SurrealDbClient {
|
||||
pub client: Surreal<Any>,
|
||||
}
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub trait ProvidesDb {
|
||||
fn db(&self) -> &Arc<SurrealDbClient>;
|
||||
}
|
||||
@@ -47,9 +50,28 @@ impl SurrealDbClient {
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
pub async fn new_with_namespace_user(
|
||||
address: &str,
|
||||
namespace: &str,
|
||||
username: &str,
|
||||
password: &str,
|
||||
database: &str,
|
||||
) -> Result<Self, Error> {
|
||||
let db = connect(address).await?;
|
||||
db.signin(Namespace {
|
||||
namespace,
|
||||
username,
|
||||
password,
|
||||
})
|
||||
.await?;
|
||||
db.use_ns(namespace).use_db(database).await?;
|
||||
Ok(SurrealDbClient { client: db })
|
||||
}
|
||||
|
||||
pub async fn create_session_store(
|
||||
&self,
|
||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||
debug!("Creating session store");
|
||||
SessionStore::new(
|
||||
Some(self.client.clone().into()),
|
||||
SessionConfig::default()
|
||||
@@ -65,6 +87,7 @@ impl SurrealDbClient {
|
||||
/// the database and selecting the appropriate namespace and database, but before
|
||||
/// the application starts performing operations that rely on the schema.
|
||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||
debug!("Applying migrations");
|
||||
MigrationRunner::new(&self.client)
|
||||
.load_files(&MIGRATIONS_DIR)
|
||||
.up()
|
||||
@@ -74,15 +97,6 @@ impl SurrealDbClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Operation to rebuild indexes
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
|
||||
.query("REBUILD INDEX IF EXISTS idx_embeddings_entities ON knowledge_entity")
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -100,6 +114,19 @@ impl SurrealDbClient {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Operation to upsert an object in SurrealDB, replacing any existing record
|
||||
/// with the same ID. Useful for idempotent ingestion flows.
|
||||
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: StoredObject + Send + Sync + 'static,
|
||||
{
|
||||
let id = item.get_id().to_string();
|
||||
self.client
|
||||
.upsert((T::table_name(), id))
|
||||
.content(item)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject
|
||||
///
|
||||
/// # Returns
|
||||
@@ -238,6 +265,56 @@ mod tests {
|
||||
assert!(fetch_post.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_item_overwrites_existing_records() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to initialize schema");
|
||||
|
||||
let mut dummy = Dummy {
|
||||
id: "abc".to_string(),
|
||||
name: "first".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
|
||||
db.store_item(dummy.clone())
|
||||
.await
|
||||
.expect("Failed to store initial record");
|
||||
|
||||
dummy.name = "updated".to_string();
|
||||
let upserted = db
|
||||
.upsert_item(dummy.clone())
|
||||
.await
|
||||
.expect("Failed to upsert record");
|
||||
assert!(upserted.is_some());
|
||||
|
||||
let fetched: Option<Dummy> = db.get_item(&dummy.id).await.expect("fetch after upsert");
|
||||
assert_eq!(fetched.unwrap().name, "updated");
|
||||
|
||||
let new_record = Dummy {
|
||||
id: "def".to_string(),
|
||||
name: "brand-new".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
};
|
||||
db.upsert_item(new_record.clone())
|
||||
.await
|
||||
.expect("Failed to upsert new record");
|
||||
|
||||
let fetched_new: Option<Dummy> = db
|
||||
.get_item(&new_record.id)
|
||||
.await
|
||||
.expect("fetch inserted via upsert");
|
||||
assert_eq!(fetched_new, Some(new_record));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_applying_migrations() {
|
||||
let namespace = "test_ns";
|
||||
|
||||
@@ -0,0 +1,795 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use futures::future::try_join_all;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Map, Value};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
|
||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct HnswIndexSpec {
|
||||
index_name: &'static str,
|
||||
table: &'static str,
|
||||
options: &'static str,
|
||||
}
|
||||
|
||||
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
|
||||
[
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_text_chunk_embedding",
|
||||
table: "text_chunk_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
},
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_knowledge_entity_embedding",
|
||||
table: "knowledge_entity_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
|
||||
[
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_fts_idx",
|
||||
table: "text_content",
|
||||
field: "text",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_context_fts_idx",
|
||||
table: "text_content",
|
||||
field: "context",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_file_name_fts_idx",
|
||||
table: "text_content",
|
||||
field: "file_info.file_name",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_url_fts_idx",
|
||||
table: "text_content",
|
||||
field: "url_info.url",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_url_title_fts_idx",
|
||||
table: "text_content",
|
||||
field: "url_info.title",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "knowledge_entity_fts_name_idx",
|
||||
table: "knowledge_entity",
|
||||
field: "name",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "knowledge_entity_fts_description_idx",
|
||||
table: "knowledge_entity",
|
||||
field: "description",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_chunk_fts_chunk_idx",
|
||||
table: "text_chunk",
|
||||
field: "chunk",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
impl HnswIndexSpec {
|
||||
fn definition_if_not_exists(&self, dimension: usize) -> String {
|
||||
format!(
|
||||
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} \
|
||||
FIELDS embedding HNSW DIMENSION {dimension} {options};",
|
||||
index = self.index_name,
|
||||
table = self.table,
|
||||
dimension = dimension,
|
||||
options = self.options,
|
||||
)
|
||||
}
|
||||
|
||||
fn definition_overwrite(&self, dimension: usize) -> String {
|
||||
format!(
|
||||
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} \
|
||||
FIELDS embedding HNSW DIMENSION {dimension} {options};",
|
||||
index = self.index_name,
|
||||
table = self.table,
|
||||
dimension = dimension,
|
||||
options = self.options,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct FtsIndexSpec {
|
||||
index_name: &'static str,
|
||||
table: &'static str,
|
||||
field: &'static str,
|
||||
analyzer: Option<&'static str>,
|
||||
method: &'static str,
|
||||
}
|
||||
|
||||
impl FtsIndexSpec {
|
||||
fn definition(&self) -> String {
|
||||
let analyzer_clause = self
|
||||
.analyzer
|
||||
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
|
||||
.unwrap_or_default();
|
||||
|
||||
format!(
|
||||
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
|
||||
index = self.index_name,
|
||||
table = self.table,
|
||||
field = self.field,
|
||||
)
|
||||
}
|
||||
|
||||
fn overwrite_definition(&self) -> String {
|
||||
let analyzer_clause = self
|
||||
.analyzer
|
||||
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
|
||||
.unwrap_or_default();
|
||||
|
||||
format!(
|
||||
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
|
||||
index = self.index_name,
|
||||
table = self.table,
|
||||
field = self.field,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
|
||||
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
|
||||
pub async fn ensure_runtime_indexes(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
ensure_runtime_indexes_inner(db, embedding_dimension)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_indexes_inner(db)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
async fn ensure_runtime_indexes_inner(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
) -> Result<()> {
|
||||
create_fts_analyzer(db).await?;
|
||||
|
||||
for spec in fts_index_specs() {
|
||||
if index_exists(db, spec.table, spec.index_name).await? {
|
||||
continue;
|
||||
}
|
||||
// We need to create these sequentially otherwise SurrealDB errors with read/write clash
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.definition(),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||
match hnsw_index_state(db, &spec, embedding_dimension).await? {
|
||||
HnswIndexState::Missing => {
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.definition_if_not_exists(embedding_dimension),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await
|
||||
}
|
||||
HnswIndexState::Matches => {
|
||||
let status = get_index_status(db, spec.index_name, spec.table).await?;
|
||||
if status.eq_ignore_ascii_case("error") {
|
||||
warn!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
"HNSW index found in error state; triggering rebuild"
|
||||
);
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.definition_overwrite(embedding_dimension),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
HnswIndexState::Different(existing) => {
|
||||
info!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
existing_dimension = existing,
|
||||
target_dimension = embedding_dimension,
|
||||
"Overwriting HNSW index to match new embedding dimension"
|
||||
);
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.definition_overwrite(embedding_dimension),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
try_join_all(hnsw_tasks).await.map(|_| ())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -> Result<String> {
|
||||
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query(info_query)
|
||||
.await
|
||||
.context("checking index status")?;
|
||||
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
|
||||
|
||||
let info = match info {
|
||||
Some(i) => i,
|
||||
None => return Ok("unknown".to_string()),
|
||||
};
|
||||
|
||||
let building = info.get("building");
|
||||
let status = building
|
||||
.and_then(|b| b.get("status"))
|
||||
.and_then(|s| s.as_str())
|
||||
.unwrap_or("ready")
|
||||
.to_string();
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
debug!("Rebuilding indexes with concurrent definitions");
|
||||
create_fts_analyzer(db).await?;
|
||||
|
||||
for spec in fts_index_specs() {
|
||||
if !index_exists(db, spec.table, spec.index_name).await? {
|
||||
debug!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
"Skipping FTS rebuild because index is missing"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.overwrite_definition(),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||
if !index_exists(db, spec.table, spec.index_name).await? {
|
||||
debug!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
"Skipping HNSW rebuild because index is missing"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(dimension) = existing_hnsw_dimension(db, &spec).await? else {
|
||||
warn!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
"HNSW index missing dimension; skipping rebuild"
|
||||
);
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.definition_overwrite(dimension),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||
}
|
||||
|
||||
async fn existing_hnsw_dimension(
|
||||
db: &SurrealDbClient,
|
||||
spec: &HnswIndexSpec,
|
||||
) -> Result<Option<usize>> {
|
||||
let Some(indexes) = table_index_definitions(db, spec.table).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(definition) = indexes
|
||||
.get(spec.index_name)
|
||||
.and_then(|details| details.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(extract_dimension(definition).and_then(|d| usize::try_from(d).ok()))
|
||||
}
|
||||
|
||||
async fn hnsw_index_state(
|
||||
db: &SurrealDbClient,
|
||||
spec: &HnswIndexSpec,
|
||||
expected_dimension: usize,
|
||||
) -> Result<HnswIndexState> {
|
||||
match existing_hnsw_dimension(db, spec).await? {
|
||||
None => Ok(HnswIndexState::Missing),
|
||||
Some(current_dimension) if current_dimension == expected_dimension => {
|
||||
Ok(HnswIndexState::Matches)
|
||||
}
|
||||
Some(current_dimension) => Ok(HnswIndexState::Different(current_dimension as u64)),
|
||||
}
|
||||
}
|
||||
|
||||
enum HnswIndexState {
|
||||
Missing,
|
||||
Matches,
|
||||
Different(u64),
|
||||
}
|
||||
|
||||
fn extract_dimension(definition: &str) -> Option<u64> {
|
||||
definition
|
||||
.split("DIMENSION")
|
||||
.nth(1)
|
||||
.and_then(|rest| rest.split_whitespace().next())
|
||||
.and_then(|token| token.trim_end_matches(';').parse::<u64>().ok())
|
||||
}
|
||||
|
||||
async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
||||
// Prefer snowball stemming when supported; fall back to ascii-only when the filter
|
||||
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
|
||||
// an existing analyzer definition.
|
||||
let snowball_query = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);",
|
||||
analyzer = FTS_ANALYZER_NAME
|
||||
);
|
||||
|
||||
match db.client.query(snowball_query).await {
|
||||
Ok(res) => {
|
||||
if res.check().is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
warn!(
|
||||
"Snowball analyzer check failed; attempting ascii fallback definition (analyzer: {})",
|
||||
FTS_ANALYZER_NAME
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
"Snowball analyzer creation errored; attempting ascii fallback definition"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let fallback_query = format!(
|
||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii;",
|
||||
analyzer = FTS_ANALYZER_NAME
|
||||
);
|
||||
|
||||
let res = db
|
||||
.client
|
||||
.query(fallback_query)
|
||||
.await
|
||||
.context("creating fallback FTS analyzer")?;
|
||||
|
||||
if let Err(err) = res.check() {
|
||||
warn!(
|
||||
error = %err,
|
||||
"Fallback analyzer creation failed; FTS will run without snowball/ascii analyzer ({})",
|
||||
FTS_ANALYZER_NAME
|
||||
);
|
||||
return Err(err).context("failed to create fallback FTS analyzer");
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Snowball analyzer unavailable; using fallback analyzer ({}) with lowercase+ascii only",
|
||||
FTS_ANALYZER_NAME
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_index_with_polling(
|
||||
db: &SurrealDbClient,
|
||||
definition: String,
|
||||
index_name: &str,
|
||||
table: &str,
|
||||
progress_table: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let expected_total = match progress_table {
|
||||
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
|
||||
format!("counting rows in {table} for index {index_name} progress")
|
||||
})?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut attempts = 0;
|
||||
const MAX_ATTEMPTS: usize = 3;
|
||||
loop {
|
||||
attempts += 1;
|
||||
let res = db
|
||||
.client
|
||||
.query(definition.clone())
|
||||
.await
|
||||
.with_context(|| format!("creating index {index_name} on table {table}"))?;
|
||||
match res.check() {
|
||||
Ok(_) => break,
|
||||
Err(err) => {
|
||||
let msg = err.to_string();
|
||||
let conflict = msg.contains("read or write conflict");
|
||||
warn!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
error = ?err,
|
||||
attempt = attempts,
|
||||
definition = %definition,
|
||||
"Index definition failed"
|
||||
);
|
||||
if conflict && attempts < MAX_ATTEMPTS {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
continue;
|
||||
}
|
||||
return Err(err).with_context(|| {
|
||||
format!("index definition failed for {index_name} on {table}")
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
expected_rows = ?expected_total,
|
||||
"Index definition submitted; waiting for build to finish"
|
||||
);
|
||||
|
||||
poll_index_build_status(db, index_name, table, expected_total, INDEX_POLL_INTERVAL).await
|
||||
}
|
||||
|
||||
async fn poll_index_build_status(
|
||||
db: &SurrealDbClient,
|
||||
index_name: &str,
|
||||
table: &str,
|
||||
total_rows: Option<u64>,
|
||||
poll_every: Duration,
|
||||
) -> Result<()> {
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(poll_every).await;
|
||||
|
||||
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
|
||||
let mut info_res =
|
||||
db.client.query(info_query).await.with_context(|| {
|
||||
format!("checking index build status for {index_name} on {table}")
|
||||
})?;
|
||||
|
||||
let info: Option<Value> = info_res
|
||||
.take(0)
|
||||
.context("failed to deserialize INFO FOR INDEX result")?;
|
||||
|
||||
let Some(snapshot) = parse_index_build_info(info, total_rows) else {
|
||||
warn!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
"INFO FOR INDEX returned no data; assuming index definition might be missing"
|
||||
);
|
||||
break;
|
||||
};
|
||||
|
||||
match snapshot.progress_pct {
|
||||
Some(pct) => debug!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
initial = snapshot.initial,
|
||||
pending = snapshot.pending,
|
||||
updated = snapshot.updated,
|
||||
processed = snapshot.processed,
|
||||
total = snapshot.total_rows,
|
||||
progress_pct = format_args!("{pct:.1}"),
|
||||
"Index build status"
|
||||
),
|
||||
None => debug!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
initial = snapshot.initial,
|
||||
pending = snapshot.pending,
|
||||
updated = snapshot.updated,
|
||||
processed = snapshot.processed,
|
||||
"Index build status"
|
||||
),
|
||||
}
|
||||
|
||||
if snapshot.is_ready() {
|
||||
debug!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
elapsed = ?started_at.elapsed(),
|
||||
processed = snapshot.processed,
|
||||
total = snapshot.total_rows,
|
||||
"Index is ready"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
if snapshot.status.eq_ignore_ascii_case("error") {
|
||||
warn!(
|
||||
index = %index_name,
|
||||
table = %table,
|
||||
status = snapshot.status,
|
||||
"Index build reported error status; stopping polling"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct IndexBuildSnapshot {
|
||||
status: String,
|
||||
initial: u64,
|
||||
pending: u64,
|
||||
updated: u64,
|
||||
processed: u64,
|
||||
total_rows: Option<u64>,
|
||||
progress_pct: Option<f64>,
|
||||
}
|
||||
|
||||
impl IndexBuildSnapshot {
|
||||
fn is_ready(&self) -> bool {
|
||||
self.status.eq_ignore_ascii_case("ready")
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_index_build_info(
|
||||
info: Option<Value>,
|
||||
total_rows: Option<u64>,
|
||||
) -> Option<IndexBuildSnapshot> {
|
||||
let info = info?;
|
||||
let building = info.get("building");
|
||||
|
||||
let status = building
|
||||
.and_then(|b| b.get("status"))
|
||||
.and_then(|s| s.as_str())
|
||||
// If there's no `building` block at all, treat as "ready" (index not building anymore)
|
||||
.unwrap_or("ready")
|
||||
.to_string();
|
||||
|
||||
let initial = building
|
||||
.and_then(|b| b.get("initial"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let pending = building
|
||||
.and_then(|b| b.get("pending"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let updated = building
|
||||
.and_then(|b| b.get("updated"))
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
|
||||
let processed = initial.saturating_add(updated);
|
||||
|
||||
let progress_pct = total_rows.map(|total| {
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
((processed as f64 / total as f64).min(1.0)) * 100.0
|
||||
}
|
||||
});
|
||||
|
||||
Some(IndexBuildSnapshot {
|
||||
status,
|
||||
initial,
|
||||
pending,
|
||||
updated,
|
||||
processed,
|
||||
total_rows,
|
||||
progress_pct,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CountRow {
|
||||
count: u64,
|
||||
}
|
||||
|
||||
async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
|
||||
let query = format!("SELECT count() AS count FROM {table} GROUP ALL;");
|
||||
let mut response = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.with_context(|| format!("counting rows in {table}"))?;
|
||||
let rows: Vec<CountRow> = response
|
||||
.take(0)
|
||||
.context("failed to deserialize count() response")?;
|
||||
Ok(rows.first().map_or(0, |r| r.count))
|
||||
}
|
||||
|
||||
async fn table_index_definitions(
|
||||
db: &SurrealDbClient,
|
||||
table: &str,
|
||||
) -> Result<Option<Map<String, Value>>> {
|
||||
let info_query = format!("INFO FOR TABLE {table};");
|
||||
let mut response = db
|
||||
.client
|
||||
.query(info_query)
|
||||
.await
|
||||
.with_context(|| format!("fetching table info for {}", table))?;
|
||||
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
.context("failed to take table info response")?;
|
||||
|
||||
let info_json: Value =
|
||||
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
|
||||
|
||||
Ok(info_json
|
||||
.get("Object")
|
||||
.and_then(|o| o.get("indexes"))
|
||||
.and_then(|i| i.get("Object"))
|
||||
.and_then(|i| i.as_object())
|
||||
.cloned())
|
||||
}
|
||||
|
||||
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
|
||||
let Some(indexes) = table_index_definitions(db, table).await? else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
Ok(indexes.contains_key(index_name))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_reports_progress() {
|
||||
let info = json!({
|
||||
"building": {
|
||||
"initial": 56894,
|
||||
"pending": 0,
|
||||
"status": "indexing",
|
||||
"updated": 0
|
||||
}
|
||||
});
|
||||
|
||||
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot");
|
||||
assert_eq!(
|
||||
snapshot,
|
||||
IndexBuildSnapshot {
|
||||
status: "indexing".to_string(),
|
||||
initial: 56894,
|
||||
pending: 0,
|
||||
updated: 0,
|
||||
processed: 56894,
|
||||
total_rows: Some(61081),
|
||||
progress_pct: Some((56894_f64 / 61081_f64) * 100.0),
|
||||
}
|
||||
);
|
||||
assert!(!snapshot.is_ready());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_index_build_info_defaults_to_ready_when_no_building_block() {
|
||||
// Surreal returns `{}` when the index exists but isn't building.
|
||||
let info = json!({});
|
||||
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot");
|
||||
assert!(snapshot.is_ready());
|
||||
assert_eq!(snapshot.processed, 0);
|
||||
assert_eq!(snapshot.progress_pct, Some(0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_dimension_parses_value() {
|
||||
let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;";
|
||||
assert_eq!(extract_dimension(definition), Some(1536));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_runtime_indexes_is_idempotent() {
|
||||
let namespace = "indexes_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should succeed");
|
||||
|
||||
// First run creates everything
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("initial index creation");
|
||||
|
||||
// Second run should be a no-op and still succeed
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("second index creation");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_hnsw_index_overwrites_dimension() {
|
||||
let namespace = "indexes_dim";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("in-memory db");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("migrations should succeed");
|
||||
|
||||
// Create initial index with default dimension
|
||||
ensure_runtime_indexes(&db, 1536)
|
||||
.await
|
||||
.expect("initial index creation");
|
||||
|
||||
// Change dimension and ensure overwrite path is exercised
|
||||
ensure_runtime_indexes(&db, 128)
|
||||
.await
|
||||
.expect("overwritten index creation");
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,4 @@
|
||||
pub mod db;
|
||||
pub mod indexes;
|
||||
pub mod store;
|
||||
pub mod types;
|
||||
|
||||
@@ -0,0 +1,840 @@
|
||||
use std::io::ErrorKind;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result as AnyResult};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use object_store::local::LocalFileSystem;
|
||||
use object_store::memory::InMemory;
|
||||
use object_store::{path::Path as ObjPath, ObjectStore};
|
||||
|
||||
use crate::utils::config::{AppConfig, StorageKind};
|
||||
|
||||
pub type DynStore = Arc<dyn ObjectStore>;
|
||||
|
||||
/// Storage manager with persistent state and proper lifecycle management.
|
||||
#[derive(Clone)]
|
||||
pub struct StorageManager {
|
||||
// Store from objectstore wrapped as dyn
|
||||
store: DynStore,
|
||||
// Simple enum to track which kind
|
||||
backend_kind: StorageKind,
|
||||
// Where on disk
|
||||
local_base: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl StorageManager {
|
||||
/// Create a new StorageManager with the specified configuration.
|
||||
///
|
||||
/// This method validates the configuration and creates the appropriate
|
||||
/// storage backend with proper initialization.
|
||||
pub async fn new(cfg: &AppConfig) -> object_store::Result<Self> {
|
||||
let backend_kind = cfg.storage.clone();
|
||||
let (store, local_base) = create_storage_backend(cfg).await?;
|
||||
|
||||
Ok(Self {
|
||||
store,
|
||||
backend_kind,
|
||||
local_base,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a StorageManager with a custom storage backend.
|
||||
///
|
||||
/// This method is useful for testing scenarios where you want to inject
|
||||
/// a specific storage backend.
|
||||
pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self {
|
||||
Self {
|
||||
store,
|
||||
backend_kind,
|
||||
local_base: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the storage backend kind.
|
||||
pub fn backend_kind(&self) -> &StorageKind {
|
||||
&self.backend_kind
|
||||
}
|
||||
|
||||
/// Access the resolved local base directory when using the local backend.
|
||||
pub fn local_base_path(&self) -> Option<&Path> {
|
||||
self.local_base.as_deref()
|
||||
}
|
||||
|
||||
/// Resolve an object location to a filesystem path when using the local backend.
|
||||
///
|
||||
/// Returns `None` when the backend is not local or when the provided location includes
|
||||
/// unsupported components (absolute paths or parent traversals).
|
||||
pub fn resolve_local_path(&self, location: &str) -> Option<PathBuf> {
|
||||
let base = self.local_base_path()?;
|
||||
let relative = Path::new(location);
|
||||
if relative.is_absolute()
|
||||
|| relative
|
||||
.components()
|
||||
.any(|component| matches!(component, Component::ParentDir | Component::Prefix(_)))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(base.join(relative))
|
||||
}
|
||||
|
||||
/// Store bytes at the specified location.
|
||||
///
|
||||
/// This operation persists data using the underlying storage backend.
|
||||
/// For memory backends, data persists for the lifetime of the StorageManager.
|
||||
pub async fn put(&self, location: &str, data: Bytes) -> object_store::Result<()> {
|
||||
let path = ObjPath::from(location);
|
||||
let payload = object_store::PutPayload::from_bytes(data);
|
||||
self.store.put(&path, payload).await.map(|_| ())
|
||||
}
|
||||
|
||||
/// Retrieve bytes from the specified location.
|
||||
///
|
||||
/// Returns the full contents buffered in memory.
|
||||
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
|
||||
let path = ObjPath::from(location);
|
||||
let result = self.store.get(&path).await?;
|
||||
result.bytes().await
|
||||
}
|
||||
|
||||
/// Get a streaming handle for large objects.
|
||||
///
|
||||
/// Returns a fallible stream of Bytes chunks suitable for large file processing.
|
||||
pub async fn get_stream(
|
||||
&self,
|
||||
location: &str,
|
||||
) -> object_store::Result<BoxStream<'static, object_store::Result<Bytes>>> {
|
||||
let path = ObjPath::from(location);
|
||||
let result = self.store.get(&path).await?;
|
||||
Ok(result.into_stream())
|
||||
}
|
||||
|
||||
/// Delete all objects below the specified prefix.
|
||||
///
|
||||
/// For local filesystem backends, this also attempts to clean up empty directories.
|
||||
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
|
||||
let prefix_path = ObjPath::from(prefix);
|
||||
let locations = self
|
||||
.store
|
||||
.list(Some(&prefix_path))
|
||||
.map_ok(|m| m.location)
|
||||
.boxed();
|
||||
self.store
|
||||
.delete_stream(locations)
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
|
||||
// Cleanup filesystem directories only for local backend
|
||||
if matches!(self.backend_kind, StorageKind::Local) {
|
||||
self.cleanup_filesystem_directories(prefix).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all objects below the specified prefix.
|
||||
pub async fn list(
|
||||
&self,
|
||||
prefix: Option<&str>,
|
||||
) -> object_store::Result<Vec<object_store::ObjectMeta>> {
|
||||
let prefix_path = prefix.map(ObjPath::from);
|
||||
self.store.list(prefix_path.as_ref()).try_collect().await
|
||||
}
|
||||
|
||||
/// Check if an object exists at the specified location.
|
||||
pub async fn exists(&self, location: &str) -> object_store::Result<bool> {
|
||||
let path = ObjPath::from(location);
|
||||
self.store
|
||||
.head(&path)
|
||||
.await
|
||||
.map(|_| true)
|
||||
.or_else(|e| match e {
|
||||
object_store::Error::NotFound { .. } => Ok(false),
|
||||
_ => Err(e),
|
||||
})
|
||||
}
|
||||
|
||||
/// Cleanup filesystem directories for local backend.
|
||||
///
|
||||
/// This is a best-effort cleanup and ignores errors.
|
||||
async fn cleanup_filesystem_directories(&self, prefix: &str) -> object_store::Result<()> {
|
||||
if !matches!(self.backend_kind, StorageKind::Local) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(base) = &self.local_base else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let relative = Path::new(prefix);
|
||||
if relative.is_absolute()
|
||||
|| relative
|
||||
.components()
|
||||
.any(|component| matches!(component, Component::ParentDir | Component::Prefix(_)))
|
||||
{
|
||||
tracing::warn!(
|
||||
prefix = %prefix,
|
||||
"Skipping directory cleanup for unsupported prefix components"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut current = base.join(relative);
|
||||
|
||||
while current.starts_with(base) && current.as_path() != base.as_path() {
|
||||
match tokio::fs::remove_dir(¤t).await {
|
||||
Ok(()) => {}
|
||||
Err(err) => match err.kind() {
|
||||
ErrorKind::NotFound => {}
|
||||
ErrorKind::DirectoryNotEmpty => break,
|
||||
_ => tracing::debug!(
|
||||
error = %err,
|
||||
path = %current.display(),
|
||||
"Failed to remove directory during cleanup"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
if let Some(parent) = current.parent() {
|
||||
current = parent.to_path_buf();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a storage backend based on configuration.
|
||||
///
|
||||
/// This factory function handles the creation and initialization of different
|
||||
/// storage backends with proper error handling and validation.
|
||||
async fn create_storage_backend(
|
||||
cfg: &AppConfig,
|
||||
) -> object_store::Result<(DynStore, Option<PathBuf>)> {
|
||||
match cfg.storage {
|
||||
StorageKind::Local => {
|
||||
let base = resolve_base_dir(cfg);
|
||||
if !base.exists() {
|
||||
tokio::fs::create_dir_all(&base).await.map_err(|e| {
|
||||
object_store::Error::Generic {
|
||||
store: "LocalFileSystem",
|
||||
source: e.into(),
|
||||
}
|
||||
})?;
|
||||
}
|
||||
let store = LocalFileSystem::new_with_prefix(base.clone())?;
|
||||
Ok((Arc::new(store), Some(base)))
|
||||
}
|
||||
StorageKind::Memory => {
|
||||
let store = InMemory::new();
|
||||
Ok((Arc::new(store), None))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Testing utilities for storage operations.
|
||||
///
|
||||
/// This module provides specialized utilities for testing scenarios with
|
||||
/// automatic memory backend setup and proper test isolation.
|
||||
#[cfg(test)]
|
||||
pub mod testing {
|
||||
use super::*;
|
||||
use crate::utils::config::{AppConfig, PdfIngestMode};
|
||||
use uuid;
|
||||
|
||||
/// Create a test configuration with memory storage.
|
||||
///
|
||||
/// This provides a ready-to-use configuration for testing scenarios
|
||||
/// that don't require filesystem persistence.
|
||||
pub fn test_config_memory() -> AppConfig {
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
surrealdb_address: "test".into(),
|
||||
surrealdb_username: "test".into(),
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: "/tmp/unused".into(), // Ignored for memory storage
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Memory,
|
||||
pdf_ingest_mode: PdfIngestMode::LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a test configuration with local storage.
|
||||
///
|
||||
/// This provides a ready-to-use configuration for testing scenarios
|
||||
/// that require actual filesystem operations.
|
||||
pub fn test_config_local() -> AppConfig {
|
||||
let base = format!("/tmp/minne_test_storage_{}", uuid::Uuid::new_v4());
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
surrealdb_address: "test".into(),
|
||||
surrealdb_username: "test".into(),
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: base.into(),
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: PdfIngestMode::LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// A specialized StorageManager for testing scenarios.
|
||||
///
|
||||
/// This provides automatic setup for memory storage with proper isolation
|
||||
/// and cleanup capabilities for test environments.
|
||||
#[derive(Clone)]
|
||||
pub struct TestStorageManager {
|
||||
storage: StorageManager,
|
||||
_temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
|
||||
}
|
||||
|
||||
impl TestStorageManager {
|
||||
/// Create a new TestStorageManager with memory backend.
|
||||
///
|
||||
/// This is the preferred method for unit tests as it provides
|
||||
/// fast execution and complete isolation.
|
||||
pub async fn new_memory() -> object_store::Result<Self> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg).await?;
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new TestStorageManager with local filesystem backend.
|
||||
///
|
||||
/// This method creates a temporary directory that will be automatically
|
||||
/// cleaned up when the TestStorageManager is dropped.
|
||||
pub async fn new_local() -> object_store::Result<Self> {
|
||||
let cfg = test_config_local();
|
||||
let storage = StorageManager::new(&cfg).await?;
|
||||
let resolved = storage
|
||||
.local_base_path()
|
||||
.map(|path| (cfg.data_dir.clone(), path.to_path_buf()));
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: resolved,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a TestStorageManager with custom configuration.
|
||||
pub async fn with_config(cfg: &AppConfig) -> object_store::Result<Self> {
|
||||
let storage = StorageManager::new(cfg).await?;
|
||||
let temp_dir = if matches!(cfg.storage, StorageKind::Local) {
|
||||
storage
|
||||
.local_base_path()
|
||||
.map(|path| (cfg.data_dir.clone(), path.to_path_buf()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: temp_dir,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying StorageManager.
|
||||
pub fn storage(&self) -> &StorageManager {
|
||||
&self.storage
|
||||
}
|
||||
|
||||
/// Clone the underlying StorageManager.
|
||||
pub fn clone_storage(&self) -> StorageManager {
|
||||
self.storage.clone()
|
||||
}
|
||||
|
||||
/// Store test data at the specified location.
|
||||
pub async fn put(&self, location: &str, data: &[u8]) -> object_store::Result<()> {
|
||||
self.storage.put(location, Bytes::from(data.to_vec())).await
|
||||
}
|
||||
|
||||
/// Retrieve test data from the specified location.
|
||||
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
|
||||
self.storage.get(location).await
|
||||
}
|
||||
|
||||
/// Delete test data below the specified prefix.
|
||||
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
|
||||
self.storage.delete_prefix(prefix).await
|
||||
}
|
||||
|
||||
/// Check if test data exists at the specified location.
|
||||
pub async fn exists(&self, location: &str) -> object_store::Result<bool> {
|
||||
self.storage.exists(location).await
|
||||
}
|
||||
|
||||
/// List all test objects below the specified prefix.
|
||||
pub async fn list(
|
||||
&self,
|
||||
prefix: Option<&str>,
|
||||
) -> object_store::Result<Vec<object_store::ObjectMeta>> {
|
||||
self.storage.list(prefix).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestStorageManager {
|
||||
fn drop(&mut self) {
|
||||
// Clean up temporary directories for local storage
|
||||
if let Some((_, path)) = &self._temp_dir {
|
||||
if path.exists() {
|
||||
let _ = std::fs::remove_dir_all(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience macro for creating memory storage tests.
|
||||
///
|
||||
/// This macro simplifies the creation of test storage with memory backend.
|
||||
#[macro_export]
|
||||
macro_rules! test_storage_memory {
|
||||
() => {{
|
||||
async move {
|
||||
$crate::storage::store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("Failed to create test memory storage")
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/// Convenience macro for creating local storage tests.
|
||||
///
|
||||
/// This macro simplifies the creation of test storage with local filesystem backend.
|
||||
#[macro_export]
|
||||
macro_rules! test_storage_local {
|
||||
() => {{
|
||||
async move {
|
||||
$crate::storage::store::testing::TestStorageManager::new_local()
|
||||
.await
|
||||
.expect("Failed to create test local storage")
|
||||
}
|
||||
}};
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the absolute base directory used for local storage from config.
|
||||
///
|
||||
/// If `data_dir` is relative, it is resolved against the current working directory.
|
||||
pub fn resolve_base_dir(cfg: &AppConfig) -> PathBuf {
|
||||
if cfg.data_dir.starts_with('/') {
|
||||
PathBuf::from(&cfg.data_dir)
|
||||
} else {
|
||||
std::env::current_dir()
|
||||
.unwrap_or_else(|_| PathBuf::from("."))
|
||||
.join(&cfg.data_dir)
|
||||
}
|
||||
}
|
||||
|
||||
/// Split an absolute filesystem path into `(parent_dir, file_name)`.
|
||||
pub fn split_abs_path(path: &str) -> AnyResult<(PathBuf, String)> {
|
||||
let pb = PathBuf::from(path);
|
||||
let parent = pb
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow!("Path has no parent: {path}"))?
|
||||
.to_path_buf();
|
||||
let file = pb
|
||||
.file_name()
|
||||
.ok_or_else(|| anyhow!("Path has no file name: {path}"))?
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
Ok((parent, file))
|
||||
}
|
||||
|
||||
/// Split a logical object location `"a/b/c"` into `("a/b", "c")`.
|
||||
pub fn split_object_path(path: &str) -> AnyResult<(String, String)> {
|
||||
if let Some((p, f)) = path.rsplit_once('/') {
|
||||
return Ok((p.to_string(), f.to_string()));
|
||||
}
|
||||
Err(anyhow!("Object path has no separator: {path}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
|
||||
use bytes::Bytes;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_config(root: &str) -> AppConfig {
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
surrealdb_address: "test".into(),
|
||||
surrealdb_username: "test".into(),
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: root.into(),
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn test_config_memory() -> AppConfig {
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
surrealdb_address: "test".into(),
|
||||
surrealdb_username: "test".into(),
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: "/tmp/unused".into(), // Ignored for memory storage
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Memory,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_memory_basic_operations() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
assert!(storage.local_base_path().is_none());
|
||||
|
||||
let location = "test/data/file.txt";
|
||||
let data = b"test data for storage manager";
|
||||
|
||||
// Test put and get
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(location).await.expect("exists check"));
|
||||
|
||||
// Test delete
|
||||
storage.delete_prefix("test/data/").await.expect("delete");
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists check after delete"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_local_basic_operations() {
|
||||
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
|
||||
let cfg = test_config(&base);
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
let resolved_base = storage
|
||||
.local_base_path()
|
||||
.expect("resolved base dir")
|
||||
.to_path_buf();
|
||||
assert_eq!(resolved_base, PathBuf::from(&base));
|
||||
|
||||
let location = "test/data/file.txt";
|
||||
let data = b"test data for local storage";
|
||||
|
||||
// Test put and get
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
let object_dir = resolved_base.join("test/data");
|
||||
tokio::fs::metadata(&object_dir)
|
||||
.await
|
||||
.expect("object directory exists after write");
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(location).await.expect("exists check"));
|
||||
|
||||
// Test delete
|
||||
storage.delete_prefix("test/data/").await.expect("delete");
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists check after delete"));
|
||||
assert!(
|
||||
tokio::fs::metadata(&object_dir).await.is_err(),
|
||||
"object directory should be removed"
|
||||
);
|
||||
tokio::fs::metadata(&resolved_base)
|
||||
.await
|
||||
.expect("base directory remains intact");
|
||||
|
||||
// Clean up
|
||||
let _ = tokio::fs::remove_dir_all(&base).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_memory_persistence() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
|
||||
let location = "persistence/test.txt";
|
||||
let data1 = b"first data";
|
||||
let data2 = b"second data";
|
||||
|
||||
// Put first data
|
||||
storage
|
||||
.put(location, Bytes::from(data1.to_vec()))
|
||||
.await
|
||||
.expect("put first");
|
||||
|
||||
// Retrieve and verify first data
|
||||
let retrieved1 = storage.get(location).await.expect("get first");
|
||||
assert_eq!(retrieved1.as_ref(), data1);
|
||||
|
||||
// Overwrite with second data
|
||||
storage
|
||||
.put(location, Bytes::from(data2.to_vec()))
|
||||
.await
|
||||
.expect("put second");
|
||||
|
||||
// Retrieve and verify second data
|
||||
let retrieved2 = storage.get(location).await.expect("get second");
|
||||
assert_eq!(retrieved2.as_ref(), data2);
|
||||
|
||||
// Data persists across multiple operations using the same StorageManager
|
||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_list_operations() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
|
||||
// Create multiple files
|
||||
let files = vec![
|
||||
("dir1/file1.txt", b"content1"),
|
||||
("dir1/file2.txt", b"content2"),
|
||||
("dir2/file3.txt", b"content3"),
|
||||
];
|
||||
|
||||
for (location, data) in &files {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
}
|
||||
|
||||
// Test listing without prefix
|
||||
let all_files = storage.list(None).await.expect("list all");
|
||||
assert_eq!(all_files.len(), 3);
|
||||
|
||||
// Test listing with prefix
|
||||
let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1");
|
||||
assert_eq!(dir1_files.len(), 2);
|
||||
assert!(dir1_files
|
||||
.iter()
|
||||
.any(|meta| meta.location.as_ref().contains("file1.txt")));
|
||||
assert!(dir1_files
|
||||
.iter()
|
||||
.any(|meta| meta.location.as_ref().contains("file2.txt")));
|
||||
|
||||
// Test listing non-existent prefix
|
||||
let empty_files = storage
|
||||
.list(Some("nonexistent/"))
|
||||
.await
|
||||
.expect("list nonexistent");
|
||||
assert_eq!(empty_files.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_stream_operations() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
|
||||
let location = "stream/test.bin";
|
||||
let content = vec![42u8; 1024 * 64]; // 64KB of data
|
||||
|
||||
// Put large data
|
||||
storage
|
||||
.put(location, Bytes::from(content.clone()))
|
||||
.await
|
||||
.expect("put large data");
|
||||
|
||||
// Get as stream
|
||||
let mut stream = storage.get_stream(location).await.expect("get stream");
|
||||
let mut collected = Vec::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.expect("stream chunk");
|
||||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
assert_eq!(collected, content);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_with_custom_backend() {
|
||||
use object_store::memory::InMemory;
|
||||
|
||||
// Create custom memory backend
|
||||
let custom_store = InMemory::new();
|
||||
let storage = StorageManager::with_backend(Arc::new(custom_store), StorageKind::Memory);
|
||||
|
||||
let location = "custom/test.txt";
|
||||
let data = b"custom backend test";
|
||||
|
||||
// Test operations with custom backend
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
assert!(storage.exists(location).await.expect("exists"));
|
||||
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_error_handling() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
|
||||
// Test getting non-existent file
|
||||
let result = storage.get("nonexistent.txt").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test checking existence of non-existent file
|
||||
let exists = storage
|
||||
.exists("nonexistent.txt")
|
||||
.await
|
||||
.expect("exists check");
|
||||
assert!(!exists);
|
||||
|
||||
// Test listing with invalid location (should not panic)
|
||||
let _result = storage.get("").await;
|
||||
// This may or may not error depending on the backend implementation
|
||||
// The important thing is that it doesn't panic
|
||||
}
|
||||
|
||||
// TestStorageManager tests
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_memory() {
|
||||
let test_storage = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage");
|
||||
|
||||
let location = "test/storage/file.txt";
|
||||
let data = b"test data with TestStorageManager";
|
||||
|
||||
// Test put and get
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage.exists(location).await.expect("exists"));
|
||||
|
||||
// Test list
|
||||
let files = test_storage
|
||||
.list(Some("test/storage/"))
|
||||
.await
|
||||
.expect("list");
|
||||
assert_eq!(files.len(), 1);
|
||||
|
||||
// Test delete
|
||||
test_storage
|
||||
.delete_prefix("test/storage/")
|
||||
.await
|
||||
.expect("delete");
|
||||
assert!(!test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists after delete"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_local() {
|
||||
let test_storage = testing::TestStorageManager::new_local()
|
||||
.await
|
||||
.expect("create test storage");
|
||||
|
||||
let location = "test/local/file.txt";
|
||||
let data = b"test data with local TestStorageManager";
|
||||
|
||||
// Test put and get
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage.exists(location).await.expect("exists"));
|
||||
|
||||
// The storage should be automatically cleaned up when test_storage is dropped
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_isolation() {
|
||||
let storage1 = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage 1");
|
||||
let storage2 = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage 2");
|
||||
|
||||
let location = "isolation/test.txt";
|
||||
let data1 = b"storage 1 data";
|
||||
let data2 = b"storage 2 data";
|
||||
|
||||
// Put different data in each storage
|
||||
storage1.put(location, data1).await.expect("put storage 1");
|
||||
storage2.put(location, data2).await.expect("put storage 2");
|
||||
|
||||
// Verify isolation
|
||||
let retrieved1 = storage1.get(location).await.expect("get storage 1");
|
||||
let retrieved2 = storage2.get(location).await.expect("get storage 2");
|
||||
|
||||
assert_eq!(retrieved1.as_ref(), data1);
|
||||
assert_eq!(retrieved2.as_ref(), data2);
|
||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_config() {
|
||||
let cfg = testing::test_config_memory();
|
||||
let test_storage = testing::TestStorageManager::with_config(&cfg)
|
||||
.await
|
||||
.expect("create test storage with config");
|
||||
|
||||
let location = "config/test.txt";
|
||||
let data = b"test data with custom config";
|
||||
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Verify it's using memory backend
|
||||
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
|
||||
}
|
||||
}
|
||||
@@ -71,6 +71,7 @@ impl Analytics {
|
||||
// We need to use a direct query for COUNT aggregation
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CountResult {
|
||||
/// Total user count.
|
||||
count: i64,
|
||||
}
|
||||
|
||||
@@ -81,7 +82,7 @@ impl Analytics {
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(result.map(|r| r.count).unwrap_or(0))
|
||||
Ok(result.map_or(0, |r| r.count))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,10 @@ impl Conversation {
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||
.patch(PatchOp::replace("/updated_at", Utc::now()))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(Utc::now()),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,9 @@
|
||||
#![allow(
|
||||
clippy::result_large_err,
|
||||
clippy::needless_pass_by_value,
|
||||
clippy::implicit_clone,
|
||||
clippy::semicolon_if_nothing_returned
|
||||
)]
|
||||
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
@@ -38,6 +44,7 @@ impl IngestionPayload {
|
||||
/// # Returns
|
||||
/// * `Result<Vec<IngestionPayload>, AppError>` - On success, returns a vector of ingress objects
|
||||
/// (one per file/content type). On failure, returns an `AppError`.
|
||||
#[allow(clippy::similar_names)]
|
||||
pub fn create_ingestion_payload(
|
||||
content: Option<String>,
|
||||
context: String,
|
||||
|
||||
@@ -1,113 +1,538 @@
|
||||
use futures::Stream;
|
||||
use surrealdb::{opt::PatchOp, Notification};
|
||||
#![allow(
|
||||
clippy::cast_possible_wrap,
|
||||
clippy::items_after_statements,
|
||||
clippy::arithmetic_side_effects,
|
||||
clippy::cast_sign_loss,
|
||||
clippy::missing_docs_in_private_items,
|
||||
clippy::trivially_copy_pass_by_ref,
|
||||
clippy::expect_used
|
||||
)]
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::Duration as ChronoDuration;
|
||||
use state_machines::state_machine;
|
||||
use surrealdb::sql::Datetime as SurrealDatetime;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use super::ingestion_payload::IngestionPayload;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum IngestionTaskStatus {
|
||||
Created,
|
||||
InProgress {
|
||||
attempts: u32,
|
||||
last_attempt: DateTime<Utc>,
|
||||
},
|
||||
Completed,
|
||||
Error(String),
|
||||
pub const MAX_ATTEMPTS: u32 = 3;
|
||||
pub const DEFAULT_LEASE_SECS: i64 = 300;
|
||||
pub const DEFAULT_PRIORITY: i32 = 0;
|
||||
|
||||
#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
||||
pub enum TaskState {
|
||||
#[serde(rename = "Pending")]
|
||||
#[default]
|
||||
Pending,
|
||||
#[serde(rename = "Reserved")]
|
||||
Reserved,
|
||||
#[serde(rename = "Processing")]
|
||||
Processing,
|
||||
#[serde(rename = "Succeeded")]
|
||||
Succeeded,
|
||||
#[serde(rename = "Failed")]
|
||||
Failed,
|
||||
#[serde(rename = "Cancelled")]
|
||||
Cancelled,
|
||||
#[serde(rename = "DeadLetter")]
|
||||
DeadLetter,
|
||||
}
|
||||
|
||||
impl TaskState {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
TaskState::Pending => "Pending",
|
||||
TaskState::Reserved => "Reserved",
|
||||
TaskState::Processing => "Processing",
|
||||
TaskState::Succeeded => "Succeeded",
|
||||
TaskState::Failed => "Failed",
|
||||
TaskState::Cancelled => "Cancelled",
|
||||
TaskState::DeadLetter => "DeadLetter",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
TaskState::Succeeded | TaskState::Cancelled | TaskState::DeadLetter
|
||||
)
|
||||
}
|
||||
|
||||
pub fn display_label(&self) -> &'static str {
|
||||
match self {
|
||||
TaskState::Pending => "Pending",
|
||||
TaskState::Reserved => "Reserved",
|
||||
TaskState::Processing => "Processing",
|
||||
TaskState::Succeeded => "Completed",
|
||||
TaskState::Failed => "Retrying",
|
||||
TaskState::Cancelled => "Cancelled",
|
||||
TaskState::DeadLetter => "Dead Letter",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Default)]
|
||||
pub struct TaskErrorInfo {
|
||||
pub code: Option<String>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum TaskTransition {
|
||||
StartProcessing,
|
||||
Succeed,
|
||||
Fail,
|
||||
Cancel,
|
||||
DeadLetter,
|
||||
Release,
|
||||
}
|
||||
|
||||
impl TaskTransition {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
TaskTransition::StartProcessing => "start_processing",
|
||||
TaskTransition::Succeed => "succeed",
|
||||
TaskTransition::Fail => "fail",
|
||||
TaskTransition::Cancel => "cancel",
|
||||
TaskTransition::DeadLetter => "deadletter",
|
||||
TaskTransition::Release => "release",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod lifecycle {
|
||||
use super::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: TaskLifecycleMachine,
|
||||
initial: Pending,
|
||||
states: [Pending, Reserved, Processing, Succeeded, Failed, Cancelled, DeadLetter],
|
||||
events {
|
||||
reserve {
|
||||
transition: { from: Pending, to: Reserved }
|
||||
transition: { from: Failed, to: Reserved }
|
||||
}
|
||||
start_processing {
|
||||
transition: { from: Reserved, to: Processing }
|
||||
}
|
||||
succeed {
|
||||
transition: { from: Processing, to: Succeeded }
|
||||
}
|
||||
fail {
|
||||
transition: { from: Processing, to: Failed }
|
||||
}
|
||||
cancel {
|
||||
transition: { from: Pending, to: Cancelled }
|
||||
transition: { from: Reserved, to: Cancelled }
|
||||
transition: { from: Processing, to: Cancelled }
|
||||
}
|
||||
deadletter {
|
||||
transition: { from: Failed, to: DeadLetter }
|
||||
}
|
||||
release {
|
||||
transition: { from: Reserved, to: Pending }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn pending() -> TaskLifecycleMachine<(), Pending> {
|
||||
TaskLifecycleMachine::new(())
|
||||
}
|
||||
|
||||
pub(super) fn reserved() -> TaskLifecycleMachine<(), Reserved> {
|
||||
pending()
|
||||
.reserve()
|
||||
.expect("reserve transition from Pending should exist")
|
||||
}
|
||||
|
||||
pub(super) fn processing() -> TaskLifecycleMachine<(), Processing> {
|
||||
reserved()
|
||||
.start_processing()
|
||||
.expect("start_processing transition from Reserved should exist")
|
||||
}
|
||||
|
||||
pub(super) fn failed() -> TaskLifecycleMachine<(), Failed> {
|
||||
processing()
|
||||
.fail()
|
||||
.expect("fail transition from Processing should exist")
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_transition(state: &TaskState, event: TaskTransition) -> AppError {
|
||||
AppError::Validation(format!(
|
||||
"Invalid task transition: {} -> {}",
|
||||
state.as_str(),
|
||||
event.as_str()
|
||||
))
|
||||
}
|
||||
|
||||
stored_object!(IngestionTask, "ingestion_task", {
|
||||
content: IngestionPayload,
|
||||
status: IngestionTaskStatus,
|
||||
user_id: String
|
||||
state: TaskState,
|
||||
user_id: String,
|
||||
attempts: u32,
|
||||
max_attempts: u32,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime")]
|
||||
scheduled_at: chrono::DateTime<chrono::Utc>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
locked_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
lease_duration_secs: i64,
|
||||
worker_id: Option<String>,
|
||||
error_code: Option<String>,
|
||||
error_message: Option<String>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
last_error_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
priority: i32
|
||||
});
|
||||
|
||||
pub const MAX_ATTEMPTS: u32 = 3;
|
||||
|
||||
impl IngestionTask {
|
||||
pub async fn new(content: IngestionPayload, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
pub fn new(content: IngestionPayload, user_id: String) -> Self {
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
content,
|
||||
status: IngestionTaskStatus::Created,
|
||||
state: TaskState::Pending,
|
||||
user_id,
|
||||
attempts: 0,
|
||||
max_attempts: MAX_ATTEMPTS,
|
||||
scheduled_at: now,
|
||||
locked_at: None,
|
||||
lease_duration_secs: DEFAULT_LEASE_SECS,
|
||||
worker_id: None,
|
||||
error_code: None,
|
||||
error_message: None,
|
||||
last_error_at: None,
|
||||
priority: DEFAULT_PRIORITY,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new job and stores it in the database
|
||||
pub fn can_retry(&self) -> bool {
|
||||
self.attempts < self.max_attempts
|
||||
}
|
||||
|
||||
pub fn lease_duration(&self) -> Duration {
|
||||
Duration::from_secs(self.lease_duration_secs.max(0) as u64)
|
||||
}
|
||||
|
||||
pub async fn create_and_add_to_db(
|
||||
content: IngestionPayload,
|
||||
user_id: String,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
let task = Self::new(content, user_id).await;
|
||||
|
||||
let task = Self::new(content, user_id);
|
||||
db.store_item(task.clone()).await?;
|
||||
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
// Update job status
|
||||
pub async fn update_status(
|
||||
id: &str,
|
||||
status: IngestionTaskStatus,
|
||||
pub async fn claim_next_ready(
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let _job: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/status", status))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::sql::Datetime::default(),
|
||||
worker_id: &str,
|
||||
now: chrono::DateTime<chrono::Utc>,
|
||||
lease_duration: Duration,
|
||||
) -> Result<Option<IngestionTask>, AppError> {
|
||||
debug_assert!(lifecycle::pending().reserve().is_ok());
|
||||
debug_assert!(lifecycle::failed().reserve().is_ok());
|
||||
|
||||
const CLAIM_QUERY: &str = r#"
|
||||
UPDATE (
|
||||
SELECT * FROM type::table($table)
|
||||
WHERE state IN $candidate_states
|
||||
AND scheduled_at <= $now
|
||||
AND (
|
||||
attempts < max_attempts
|
||||
OR state IN $sticky_states
|
||||
)
|
||||
AND (
|
||||
locked_at = NONE
|
||||
OR time::unix($now) - time::unix(locked_at) >= lease_duration_secs
|
||||
)
|
||||
ORDER BY priority DESC, scheduled_at ASC, created_at ASC
|
||||
LIMIT 1
|
||||
)
|
||||
SET state = $reserved_state,
|
||||
attempts = if state IN $increment_states THEN
|
||||
if attempts + 1 > max_attempts THEN max_attempts ELSE attempts + 1 END
|
||||
ELSE
|
||||
attempts
|
||||
END,
|
||||
locked_at = $now,
|
||||
worker_id = $worker_id,
|
||||
lease_duration_secs = $lease_secs,
|
||||
updated_at = $now
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(CLAIM_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind((
|
||||
"candidate_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Failed.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind((
|
||||
"sticky_states",
|
||||
vec![TaskState::Reserved.as_str(), TaskState::Processing.as_str()],
|
||||
))
|
||||
.bind((
|
||||
"increment_states",
|
||||
vec![TaskState::Pending.as_str(), TaskState::Failed.as_str()],
|
||||
))
|
||||
.bind(("reserved_state", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", worker_id.to_string()))
|
||||
.bind(("lease_secs", lease_duration.as_secs() as i64))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
let task: Option<IngestionTask> = result.take(0)?;
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
/// Listen for new jobs
|
||||
pub async fn listen_for_tasks(
|
||||
pub async fn mark_processing(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const START_PROCESSING_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $processing,
|
||||
updated_at = $now,
|
||||
locked_at = $now
|
||||
WHERE state = $reserved AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(START_PROCESSING_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("reserved", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::StartProcessing))
|
||||
}
|
||||
|
||||
pub async fn mark_succeeded(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const COMPLETE_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $succeeded,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $now,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE
|
||||
WHERE state = $processing AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(COMPLETE_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("succeeded", TaskState::Succeeded.as_str()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Succeed))
|
||||
}
|
||||
|
||||
pub async fn mark_failed(
|
||||
&self,
|
||||
error: TaskErrorInfo,
|
||||
retry_delay: Duration,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<impl Stream<Item = Result<Notification<Self>, surrealdb::Error>>, surrealdb::Error>
|
||||
{
|
||||
db.listen::<Self>().await
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
let now = chrono::Utc::now();
|
||||
let retry_at = now
|
||||
+ ChronoDuration::from_std(retry_delay).unwrap_or_else(|_| ChronoDuration::seconds(30));
|
||||
|
||||
const FAIL_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $failed,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $retry_at,
|
||||
error_code = $error_code,
|
||||
error_message = $error_message,
|
||||
last_error_at = $now
|
||||
WHERE state = $processing AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(FAIL_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("failed", TaskState::Failed.as_str()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("retry_at", SurrealDatetime::from(retry_at)))
|
||||
.bind(("error_code", error.code.clone()))
|
||||
.bind(("error_message", error.message.clone()))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Fail))
|
||||
}
|
||||
|
||||
/// Get all unfinished tasks, ie newly created and in progress up two times
|
||||
pub async fn get_unfinished_tasks(db: &SurrealDbClient) -> Result<Vec<Self>, AppError> {
|
||||
let jobs: Vec<Self> = db
|
||||
pub async fn mark_dead_letter(
|
||||
&self,
|
||||
error: TaskErrorInfo,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
const DEAD_LETTER_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $dead,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $now,
|
||||
error_code = $error_code,
|
||||
error_message = $error_message,
|
||||
last_error_at = $now
|
||||
WHERE state = $failed
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(DEAD_LETTER_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("dead", TaskState::DeadLetter.as_str()))
|
||||
.bind(("failed", TaskState::Failed.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("error_code", error.code.clone()))
|
||||
.bind(("error_message", error.message.clone()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::DeadLetter))
|
||||
}
|
||||
|
||||
pub async fn mark_cancelled(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const CANCEL_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $cancelled,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE
|
||||
WHERE state IN $allow_states
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(CANCEL_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("cancelled", TaskState::Cancelled.as_str()))
|
||||
.bind((
|
||||
"allow_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Cancel))
|
||||
}
|
||||
|
||||
pub async fn release(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const RELEASE_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $pending,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE
|
||||
WHERE state = $reserved
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(RELEASE_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("pending", TaskState::Pending.as_str()))
|
||||
.bind(("reserved", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Release))
|
||||
}
|
||||
|
||||
pub async fn get_unfinished_tasks(
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<IngestionTask>, AppError> {
|
||||
let tasks: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE
|
||||
status = 'Created'
|
||||
OR (
|
||||
status.InProgress != NONE
|
||||
AND status.InProgress.attempts < $max_attempts
|
||||
)
|
||||
ORDER BY created_at ASC",
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE state IN $active_states
|
||||
ORDER BY scheduled_at ASC, created_at ASC",
|
||||
)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("max_attempts", MAX_ATTEMPTS))
|
||||
.bind((
|
||||
"active_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
TaskState::Failed.as_str(),
|
||||
],
|
||||
))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(jobs)
|
||||
Ok(tasks)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
|
||||
// Helper function to create a test ingestion payload
|
||||
fn create_test_payload(user_id: &str) -> IngestionPayload {
|
||||
fn create_payload(user_id: &str) -> IngestionPayload {
|
||||
IngestionPayload::Text {
|
||||
text: "Test content".to_string(),
|
||||
context: "Test context".to_string(),
|
||||
@@ -116,180 +541,197 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
async fn memory_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("in-memory surrealdb")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_ingestion_task() {
|
||||
async fn test_new_task_defaults() {
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
|
||||
// Verify task properties
|
||||
assert_eq!(task.user_id, user_id);
|
||||
assert_eq!(task.content, payload);
|
||||
assert!(matches!(task.status, IngestionTaskStatus::Created));
|
||||
assert!(!task.id.is_empty());
|
||||
assert_eq!(task.state, TaskState::Pending);
|
||||
assert_eq!(task.attempts, 0);
|
||||
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
|
||||
assert!(task.locked_at.is_none());
|
||||
assert!(task.worker_id.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_add_to_db() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_create_and_store_task() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
// Create and store task
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
let created =
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
.await
|
||||
.expect("store");
|
||||
|
||||
let stored: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&created.id)
|
||||
.await
|
||||
.expect("Failed to create and add task to db");
|
||||
.expect("fetch");
|
||||
|
||||
// Query to verify task was stored
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE user_id = '{}'",
|
||||
IngestionTask::table_name(),
|
||||
user_id
|
||||
);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
let tasks: Vec<IngestionTask> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify task is in the database
|
||||
assert!(!tasks.is_empty(), "Task should exist in the database");
|
||||
let stored_task = &tasks[0];
|
||||
assert_eq!(stored_task.user_id, user_id);
|
||||
assert!(matches!(stored_task.status, IngestionTaskStatus::Created));
|
||||
let stored = stored.expect("task exists");
|
||||
assert_eq!(stored.id, created.id);
|
||||
assert_eq!(stored.state, TaskState::Pending);
|
||||
assert_eq!(stored.attempts, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_status() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_claim_and_transition() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
// Create task manually
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
let task_id = task.id.clone();
|
||||
let worker_id = "worker-1";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim");
|
||||
|
||||
// Store task
|
||||
db.store_item(task).await.expect("Failed to store task");
|
||||
let claimed = claimed.expect("task claimed");
|
||||
assert_eq!(claimed.state, TaskState::Reserved);
|
||||
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
|
||||
|
||||
// Update status to InProgress
|
||||
let now = Utc::now();
|
||||
let new_status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: now,
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
assert_eq!(processing.state, TaskState::Processing);
|
||||
|
||||
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded");
|
||||
assert_eq!(succeeded.state, TaskState::Succeeded);
|
||||
assert!(succeeded.worker_id.is_none());
|
||||
assert!(succeeded.locked_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fail_and_dead_letter() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let worker_id = "worker-dead";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim")
|
||||
.expect("claimed");
|
||||
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
|
||||
let error_info = TaskErrorInfo {
|
||||
code: Some("pipeline_error".into()),
|
||||
message: "failed".into(),
|
||||
};
|
||||
|
||||
IngestionTask::update_status(&task_id, new_status.clone(), &db)
|
||||
let failed = processing
|
||||
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
|
||||
.await
|
||||
.expect("Failed to update status");
|
||||
.expect("failed update");
|
||||
assert_eq!(failed.state, TaskState::Failed);
|
||||
assert_eq!(failed.error_message.as_deref(), Some("failed"));
|
||||
assert!(failed.worker_id.is_none());
|
||||
assert!(failed.locked_at.is_none());
|
||||
assert!(failed.scheduled_at > now);
|
||||
|
||||
// Verify status updated
|
||||
let updated_task: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&task_id)
|
||||
let dead = failed
|
||||
.mark_dead_letter(error_info.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to get updated task");
|
||||
.expect("dead letter");
|
||||
assert_eq!(dead.state, TaskState::DeadLetter);
|
||||
assert_eq!(dead.error_message.as_deref(), Some("failed"));
|
||||
}
|
||||
|
||||
assert!(updated_task.is_some());
|
||||
let updated_task = updated_task.unwrap();
|
||||
#[tokio::test]
|
||||
async fn test_mark_processing_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
match updated_task.status {
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
assert_eq!(attempts, 1);
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let err = task
|
||||
.mark_processing(&db)
|
||||
.await
|
||||
.expect_err("processing should fail without reservation");
|
||||
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> start_processing"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected InProgress status"),
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_tasks() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_mark_failed_requires_processing() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
// Create tasks with different statuses
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let mut in_progress_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
in_progress_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut max_attempts_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
max_attempts_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: MAX_ATTEMPTS,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
completed_task.status = IngestionTaskStatus::Completed;
|
||||
|
||||
let mut error_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
error_task.status = IngestionTaskStatus::Error("Test error".to_string());
|
||||
|
||||
// Store all tasks
|
||||
db.store_item(created_task)
|
||||
let err = task
|
||||
.mark_failed(
|
||||
TaskErrorInfo {
|
||||
code: None,
|
||||
message: "boom".into(),
|
||||
},
|
||||
Duration::from_secs(30),
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
db.store_item(in_progress_task)
|
||||
.await
|
||||
.expect("Failed to store in-progress task");
|
||||
db.store_item(max_attempts_task)
|
||||
.await
|
||||
.expect("Failed to store max-attempts task");
|
||||
db.store_item(completed_task)
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
db.store_item(error_task)
|
||||
.await
|
||||
.expect("Failed to store error task");
|
||||
.expect_err("failing should require processing state");
|
||||
|
||||
// Get unfinished tasks
|
||||
let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db)
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> fail"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_release_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let err = task
|
||||
.release(&db)
|
||||
.await
|
||||
.expect("Failed to get unfinished tasks");
|
||||
.expect_err("release should require reserved state");
|
||||
|
||||
// Verify only Created and InProgress with attempts < MAX_ATTEMPTS are returned
|
||||
assert_eq!(unfinished_tasks.len(), 2);
|
||||
|
||||
let statuses: Vec<_> = unfinished_tasks
|
||||
.iter()
|
||||
.map(|task| match &task.status {
|
||||
IngestionTaskStatus::Created => "Created",
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
if *attempts < MAX_ATTEMPTS {
|
||||
"InProgress<MAX"
|
||||
} else {
|
||||
"InProgress>=MAX"
|
||||
}
|
||||
}
|
||||
IngestionTaskStatus::Completed => "Completed",
|
||||
IngestionTaskStatus::Error(_) => "Error",
|
||||
IngestionTaskStatus::Cancelled => "Cancelled",
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert!(statuses.contains(&"Created"));
|
||||
assert!(statuses.contains(&"InProgress<MAX"));
|
||||
assert!(!statuses.contains(&"InProgress>=MAX"));
|
||||
assert!(!statuses.contains(&"Completed"));
|
||||
assert!(!statuses.contains(&"Error"));
|
||||
assert!(!statuses.contains(&"Cancelled"));
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> release"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,27 @@
|
||||
#![allow(
|
||||
clippy::missing_docs_in_private_items,
|
||||
clippy::module_name_repetitions,
|
||||
clippy::match_same_arms,
|
||||
clippy::format_push_string,
|
||||
clippy::uninlined_format_args,
|
||||
clippy::explicit_iter_loop,
|
||||
clippy::items_after_statements,
|
||||
clippy::get_first,
|
||||
clippy::redundant_closure_for_method_calls
|
||||
)]
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
error::AppError, storage::db::SurrealDbClient, stored_object,
|
||||
error::AppError, storage::db::SurrealDbClient,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object,
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
Retry,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
@@ -33,16 +52,54 @@ impl From<String> for KnowledgeEntityType {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct KnowledgeEntitySearchResult {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
#[serde(
|
||||
serialize_with = "serialize_datetime",
|
||||
deserialize_with = "deserialize_datetime",
|
||||
default
|
||||
)]
|
||||
pub created_at: DateTime<Utc>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_datetime",
|
||||
deserialize_with = "deserialize_datetime",
|
||||
default
|
||||
)]
|
||||
pub updated_at: DateTime<Utc>,
|
||||
|
||||
pub source_id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub entity_type: KnowledgeEntityType,
|
||||
#[serde(default)]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub user_id: String,
|
||||
|
||||
pub score: f32,
|
||||
#[serde(default)]
|
||||
pub highlighted_name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub highlighted_description: Option<String>,
|
||||
}
|
||||
|
||||
stored_object!(KnowledgeEntity, "knowledge_entity", {
|
||||
source_id: String,
|
||||
name: String,
|
||||
description: String,
|
||||
entity_type: KnowledgeEntityType,
|
||||
metadata: Option<serde_json::Value>,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
/// Vector search result including hydrated entity.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
|
||||
pub struct KnowledgeEntityVectorResult {
|
||||
pub entity: KnowledgeEntity,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl KnowledgeEntity {
|
||||
pub fn new(
|
||||
source_id: String,
|
||||
@@ -50,7 +107,6 @@ impl KnowledgeEntity {
|
||||
description: String,
|
||||
entity_type: KnowledgeEntityType,
|
||||
metadata: Option<serde_json::Value>,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
@@ -63,11 +119,54 @@ impl KnowledgeEntity {
|
||||
description,
|
||||
entity_type,
|
||||
metadata,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn search(
|
||||
db: &SurrealDbClient,
|
||||
search_terms: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
||||
let sql = r#"
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
source_id,
|
||||
name,
|
||||
description,
|
||||
entity_type,
|
||||
metadata,
|
||||
user_id,
|
||||
search::highlight('<b>', '</b>', 0) AS highlighted_name,
|
||||
search::highlight('<b>', '</b>', 1) AS highlighted_description,
|
||||
(
|
||||
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END +
|
||||
IF search::score(1) != NONE THEN search::score(1) ELSE 0 END
|
||||
) AS score
|
||||
FROM knowledge_entity
|
||||
WHERE
|
||||
(
|
||||
name @0@ $terms OR
|
||||
description @1@ $terms
|
||||
)
|
||||
AND user_id = $user_id
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit;
|
||||
"#;
|
||||
|
||||
Ok(db
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("terms", search_terms.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", limit))
|
||||
.await?
|
||||
.take(0)?)
|
||||
}
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
@@ -82,6 +181,89 @@ impl KnowledgeEntity {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Atomically store a knowledge entity and its embedding.
|
||||
/// Writes the entity to `knowledge_entity` and the embedding to `knowledge_entity_embedding`.
|
||||
pub async fn store_with_embedding(
|
||||
entity: KnowledgeEntity,
|
||||
embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone());
|
||||
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
|
||||
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
entity_table = Self::table_name(),
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity.id.clone()))
|
||||
.bind(("entity", entity))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
||||
pub async fn vector_search(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntityVectorResult>, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: KnowledgeEntity,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
entity_id,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH entity_id;
|
||||
"#,
|
||||
emb_table = KnowledgeEntityEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||
|
||||
response = response.check().map_err(AppError::Database)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| KnowledgeEntityVectorResult {
|
||||
entity: r.entity_id,
|
||||
score: r.score,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
id: &str,
|
||||
name: &str,
|
||||
@@ -94,36 +276,298 @@ impl KnowledgeEntity {
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
name, description, entity_type
|
||||
);
|
||||
let embedding = generate_embedding(ai_client, &embedding_input).await?;
|
||||
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
|
||||
let user_id = Self::get_user_id_by_id(id, db_client).await?;
|
||||
let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id);
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing($table, $id)
|
||||
SET name = $name,
|
||||
description = $description,
|
||||
updated_at = $updated_at,
|
||||
entity_type = $entity_type,
|
||||
embedding = $embedding
|
||||
RETURN AFTER",
|
||||
"BEGIN TRANSACTION;
|
||||
UPDATE type::thing($table, $id)
|
||||
SET name = $name,
|
||||
description = $description,
|
||||
updated_at = $updated_at,
|
||||
entity_type = $entity_type;
|
||||
UPSERT type::thing($emb_table, $emb_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("emb_table", KnowledgeEntityEmbedding::table_name()))
|
||||
.bind(("id", id.to_string()))
|
||||
.bind(("name", name.to_string()))
|
||||
.bind(("updated_at", Utc::now()))
|
||||
.bind(("updated_at", surrealdb::Datetime::from(now)))
|
||||
.bind(("entity_type", entity_type.to_owned()))
|
||||
.bind(("embedding", embedding))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.bind(("description", description.to_string()))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
|
||||
let mut response = db_client
|
||||
.client
|
||||
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
user_id: String,
|
||||
}
|
||||
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
|
||||
rows.get(0)
|
||||
.map(|r| r.user_id.clone())
|
||||
.ok_or_else(|| AppError::InternalError("user not found for entity".to_string()))
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all knowledge entities in the database.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It follows the same
|
||||
/// pattern as the text chunk update:
|
||||
/// 1. Re-defines the vector index with the new dimensions.
|
||||
/// 2. Fetches all existing entities.
|
||||
/// 3. Sequentially regenerates the embedding for each and updates the record.
|
||||
pub async fn update_all_embeddings(
|
||||
db: &SurrealDbClient,
|
||||
openai_client: &Client<OpenAIConfig>,
|
||||
new_model: &str,
|
||||
new_dimensions: u32,
|
||||
) -> Result<(), AppError> {
|
||||
info!(
|
||||
"Starting re-embedding process for all knowledge entities. New dimensions: {}",
|
||||
new_dimensions
|
||||
);
|
||||
|
||||
// Fetch all entities first
|
||||
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
|
||||
let total_entities = all_entities.len();
|
||||
if total_entities == 0 {
|
||||
info!("No knowledge entities to update. Just updating the idx");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} entities to process.", total_entities);
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all entities...");
|
||||
for entity in all_entities.iter() {
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
entity.name, entity.description, entity.entity_type
|
||||
);
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
|
||||
let embedding = Retry::spawn(retry_strategy, || {
|
||||
crate::utils::embedding::generate_embedding_with_params(
|
||||
openai_client,
|
||||
&embedding_input,
|
||||
new_model,
|
||||
new_dimensions,
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Check embedding lengths
|
||||
if embedding.len() != new_dimensions as usize {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Perform DB updates in a single transaction
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
// Add all update statements to the embedding table
|
||||
for (id, (embedding, user_id)) in new_embeddings {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
.iter()
|
||||
.map(|f| f.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
transaction_query.push_str(&format!(
|
||||
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
|
||||
entity_id = type::thing('knowledge_entity', '{id}'), \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id
|
||||
));
|
||||
}
|
||||
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
// Execute the entire atomic operation
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
info!("Re-embedding process for knowledge entities completed successfully.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all knowledge entities using an `EmbeddingProvider`.
|
||||
///
|
||||
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
|
||||
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
|
||||
pub async fn update_all_embeddings_with_provider(
|
||||
db: &SurrealDbClient,
|
||||
provider: &crate::utils::embedding::EmbeddingProvider,
|
||||
) -> Result<(), AppError> {
|
||||
let new_dimensions = provider.dimension();
|
||||
info!(
|
||||
dimensions = new_dimensions,
|
||||
backend = provider.backend_label(),
|
||||
"Starting re-embedding process for all knowledge entities"
|
||||
);
|
||||
|
||||
// Fetch all entities first
|
||||
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
|
||||
let total_entities = all_entities.len();
|
||||
if total_entities == 0 {
|
||||
info!("No knowledge entities to update. Just updating the index.");
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions).await?;
|
||||
return Ok(());
|
||||
}
|
||||
info!(entities = total_entities, "Found entities to process");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all entities...");
|
||||
|
||||
for (i, entity) in all_entities.iter().enumerate() {
|
||||
if i > 0 && i % 100 == 0 {
|
||||
info!(
|
||||
progress = i,
|
||||
total = total_entities,
|
||||
"Re-embedding progress"
|
||||
);
|
||||
}
|
||||
|
||||
let embedding_input = format!(
|
||||
"name: {}, description: {}, type: {:?}",
|
||||
entity.name, entity.description, entity.entity_type
|
||||
);
|
||||
|
||||
let embedding = provider
|
||||
.embed(&embedding_input)
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
entity.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
|
||||
info!("Removing old index and clearing embeddings...");
|
||||
|
||||
// Explicitly remove the index first. This prevents background HNSW maintenance from crashing
|
||||
// when we delete/replace data, dealing with a known SurrealDB panic.
|
||||
db.client
|
||||
.query(format!(
|
||||
"REMOVE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {};",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
db.client
|
||||
.query(format!(
|
||||
"DELETE FROM {};",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
// Perform DB updates in a single transaction
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id)) in new_embeddings {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
.iter()
|
||||
.map(|f| f.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
transaction_query.push_str(&format!(
|
||||
"CREATE type::thing('knowledge_entity_embedding', '{id}') SET \
|
||||
entity_id = type::thing('knowledge_entity', '{id}'), \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
created_at = time::now(), \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id
|
||||
));
|
||||
}
|
||||
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
// Execute the entire atomic operation
|
||||
db.client
|
||||
.query(transaction_query)
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
info!("Re-embedding process for knowledge entities completed successfully.");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_knowledge_entity_creation() {
|
||||
@@ -133,7 +577,6 @@ mod tests {
|
||||
let description = "Test Description".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let metadata = Some(json!({"key": "value"}));
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
@@ -142,7 +585,6 @@ mod tests {
|
||||
description.clone(),
|
||||
entity_type.clone(),
|
||||
metadata.clone(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -152,7 +594,6 @@ mod tests {
|
||||
assert_eq!(entity.description, description);
|
||||
assert_eq!(entity.entity_type, entity_type);
|
||||
assert_eq!(entity.metadata, metadata);
|
||||
assert_eq!(entity.embedding, embedding);
|
||||
assert_eq!(entity.user_id, user_id);
|
||||
assert!(!entity.id.is_empty());
|
||||
}
|
||||
@@ -216,20 +657,25 @@ mod tests {
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
// Create two entities with the same source_id
|
||||
let source_id = "source123".to_string();
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -239,7 +685,6 @@ mod tests {
|
||||
"Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -251,18 +696,18 @@ mod tests {
|
||||
"Different Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
// Store the entities
|
||||
db.store_item(entity1)
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
db.store_item(entity2)
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
db.store_item(different_entity.clone())
|
||||
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store different entity");
|
||||
|
||||
@@ -311,6 +756,162 @@ mod tests {
|
||||
assert_eq!(different_remaining[0].id, different_entity.id);
|
||||
}
|
||||
|
||||
// Note: We can't easily test the patch method without mocking the OpenAI client
|
||||
// and the generate_embedding function. This would require more complex setup.
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.expect("vector search");
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let source_id = "src".to_string();
|
||||
let entity = KnowledgeEntity::new(
|
||||
source_id.clone(),
|
||||
"hello".to_string(),
|
||||
"world".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store entity with embedding");
|
||||
|
||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {}",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("query embeddings")
|
||||
.take(0)
|
||||
.expect("take embeddings");
|
||||
assert_eq!(stored_embeddings.len(), 1);
|
||||
|
||||
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
|
||||
.await
|
||||
.expect("fetch embedding");
|
||||
assert!(fetched_emb.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.expect("vector search");
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let res = &results[0];
|
||||
assert_eq!(res.entity.id, entity.id);
|
||||
assert_eq!(res.entity.source_id, source_id);
|
||||
assert_eq!(res.entity.name, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let e1 = KnowledgeEntity::new(
|
||||
"s1".to_string(),
|
||||
"entity one".to_string(),
|
||||
"desc".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
let e2 = KnowledgeEntity::new(
|
||||
"s2".to_string(),
|
||||
"entity two".to_string(),
|
||||
"desc".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store e1");
|
||||
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store e2");
|
||||
|
||||
let stored_e1: Option<KnowledgeEntity> = db.get_item(&e1.id).await.unwrap();
|
||||
let stored_e2: Option<KnowledgeEntity> = db.get_item(&e2.id).await.unwrap();
|
||||
assert!(stored_e1.is_some() && stored_e2.is_some());
|
||||
|
||||
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
||||
.client
|
||||
.query(format!(
|
||||
"SELECT * FROM {}",
|
||||
KnowledgeEntityEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.expect("query embeddings")
|
||||
.take(0)
|
||||
.expect("take embeddings");
|
||||
assert_eq!(stored_embeddings.len(), 2);
|
||||
|
||||
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
|
||||
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.expect("vector search");
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].entity.id, e2.id);
|
||||
assert_eq!(results[1].entity.id, e1.id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,387 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
||||
entity_id: RecordId,
|
||||
embedding: Vec<f32>,
|
||||
/// Denormalized user id for query scoping
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl KnowledgeEntityEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table};
|
||||
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
|
||||
COMMIT TRANSACTION;",
|
||||
table = Self::table_name(),
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::Database)?;
|
||||
res.check().map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new knowledge entity embedding
|
||||
pub fn new(entity_id: &str, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get embedding by entity ID
|
||||
pub async fn get_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
/// Get embeddings for multiple entities in batch
|
||||
pub async fn get_by_entity_ids(
|
||||
entity_ids: &[RecordId],
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<HashMap<String, Vec<f32>>, AppError> {
|
||||
if entity_ids.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let ids_list: Vec<RecordId> = entity_ids.to_vec();
|
||||
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
|
||||
Self::table_name()
|
||||
);
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("entity_ids", ids_list))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(embeddings
|
||||
.into_iter()
|
||||
.map(|e| (e.entity_id.key().to_string(), e.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Delete embedding by entity ID
|
||||
pub async fn delete_by_entity_id(
|
||||
entity_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE entity_id = $entity_id",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("entity_id", entity_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete embeddings by source_id (via joining to knowledge_entity table)
|
||||
#[allow(clippy::items_after_statements)]
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
|
||||
let mut res = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
#[allow(clippy::missing_docs_in_private_items)]
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
|
||||
|
||||
for row in ids {
|
||||
Self::delete_by_entity_id(&row.id, db).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use chrono::Utc;
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
fn build_knowledge_entity_with_id(
|
||||
key: &str,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
) -> KnowledgeEntity {
|
||||
KnowledgeEntity {
|
||||
id: key.to_owned(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
source_id: source_id.to_owned(),
|
||||
name: "Test entity".to_owned(),
|
||||
description: "Desc".to_owned(),
|
||||
entity_type: KnowledgeEntityType::Document,
|
||||
metadata: None,
|
||||
user_id: user_id.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_entity_id() {
|
||||
let db = setup_test_db().await;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-1";
|
||||
let source_id = "source-ke";
|
||||
|
||||
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding by entity_id")
|
||||
.expect("Expected embedding to exist");
|
||||
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.entity_id, entity_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_entity_id() {
|
||||
let db = setup_test_db().await;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-delete";
|
||||
let source_id = "source-del";
|
||||
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding before delete");
|
||||
assert!(existing.is_some());
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to delete by entity_id");
|
||||
|
||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding after delete");
|
||||
assert!(after.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_entity_and_embedding() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user_store";
|
||||
let source_id = "source_store";
|
||||
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
||||
assert!(stored_entity.is_some());
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||
.await
|
||||
.expect("Failed to fetch embedding");
|
||||
assert!(stored_embedding.is_some());
|
||||
let stored_embedding = stored_embedding.unwrap();
|
||||
assert_eq!(stored_embedding.user_id, user_id);
|
||||
assert_eq!(stored_embedding.entity_id, entity_rid);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
let db = setup_test_db().await;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
let user_id = "user_ke";
|
||||
let source_id = "shared-ke";
|
||||
let other_source = "other-ke";
|
||||
|
||||
let entity1 = build_knowledge_entity_with_id("entity-s1", source_id, user_id);
|
||||
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
|
||||
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||
let other_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity_other.id);
|
||||
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete by source_id");
|
||||
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none()
|
||||
);
|
||||
assert!(
|
||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none()
|
||||
);
|
||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
|
||||
.await
|
||||
.expect("failed to redefine index");
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE knowledge_entity_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_knowledge_entity_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 16"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fetch_entity_via_record_id() {
|
||||
let db = setup_test_db().await;
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("set test index dimension");
|
||||
let user_id = "user_ke";
|
||||
let entity_key = "entity-fetch";
|
||||
let source_id = "source-fetch";
|
||||
|
||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
||||
.await
|
||||
.expect("Failed to store entity with embedding");
|
||||
|
||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
entity_id: KnowledgeEntity,
|
||||
}
|
||||
|
||||
let mut res = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT entity_id FROM knowledge_entity_embedding WHERE entity_id = $id FETCH entity_id;",
|
||||
)
|
||||
.bind(("id", entity_rid.clone()))
|
||||
.await
|
||||
.expect("failed to fetch embedding with FETCH");
|
||||
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows");
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
let fetched_entity = &rows[0].entity_id;
|
||||
assert_eq!(fetched_entity.id, entity_key);
|
||||
assert_eq!(fetched_entity.name, "Test entity");
|
||||
assert_eq!(fetched_entity.user_id, user_id);
|
||||
}
|
||||
}
|
||||
@@ -41,20 +41,21 @@ impl KnowledgeRelationship {
|
||||
}
|
||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
r#"RELATE knowledge_entity:`{}`->relates_to:`{}`->knowledge_entity:`{}`
|
||||
r#"DELETE relates_to:`{rel_id}`;
|
||||
RELATE knowledge_entity:`{in_id}`->relates_to:`{rel_id}`->knowledge_entity:`{out_id}`
|
||||
SET
|
||||
metadata.user_id = '{}',
|
||||
metadata.source_id = '{}',
|
||||
metadata.relationship_type = '{}'"#,
|
||||
self.in_,
|
||||
self.id,
|
||||
self.out,
|
||||
self.metadata.user_id,
|
||||
self.metadata.source_id,
|
||||
self.metadata.relationship_type
|
||||
metadata.user_id = '{user_id}',
|
||||
metadata.source_id = '{source_id}',
|
||||
metadata.relationship_type = '{relationship_type}'"#,
|
||||
rel_id = self.id,
|
||||
in_id = self.in_,
|
||||
out_id = self.out,
|
||||
user_id = self.metadata.user_id.as_str(),
|
||||
source_id = self.metadata.source_id.as_str(),
|
||||
relationship_type = self.metadata.relationship_type.as_str()
|
||||
);
|
||||
|
||||
db_client.query(query).await?;
|
||||
db_client.query(query).await?.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -64,8 +65,7 @@ impl KnowledgeRelationship {
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'",
|
||||
source_id
|
||||
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'"
|
||||
);
|
||||
|
||||
db_client.query(query).await?;
|
||||
@@ -75,13 +75,33 @@ impl KnowledgeRelationship {
|
||||
|
||||
pub async fn delete_relationship_by_id(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!("DELETE relates_to:`{}`", id);
|
||||
let mut authorized_result = db_client
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'"
|
||||
))
|
||||
.await?;
|
||||
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
|
||||
|
||||
db_client.query(query).await?;
|
||||
if authorized.is_empty() {
|
||||
let mut exists_result = db_client
|
||||
.query(format!("SELECT * FROM relates_to:`{id}`"))
|
||||
.await?;
|
||||
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
|
||||
|
||||
Ok(())
|
||||
if existing.is_some() {
|
||||
Err(AppError::Auth(
|
||||
"Not authorized to delete relationship".into(),
|
||||
))
|
||||
} else {
|
||||
Err(AppError::NotFound(format!("Relationship {id} not found")))
|
||||
}
|
||||
} else {
|
||||
db_client.query(format!("DELETE relates_to:`{id}`")).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +115,6 @@ mod tests {
|
||||
let source_id = "source123".to_string();
|
||||
let description = format!("Description for {}", name);
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity = KnowledgeEntity::new(
|
||||
@@ -104,7 +123,6 @@ mod tests {
|
||||
description,
|
||||
entity_type,
|
||||
None,
|
||||
embedding,
|
||||
user_id,
|
||||
);
|
||||
|
||||
@@ -141,7 +159,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship() {
|
||||
async fn test_store_and_verify_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
@@ -149,6 +167,10 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
@@ -161,7 +183,7 @@ mod tests {
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id,
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
@@ -189,7 +211,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id() {
|
||||
async fn test_store_and_delete_relationship() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
@@ -209,33 +231,120 @@ mod tests {
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id,
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
|
||||
// Store the relationship
|
||||
// Store relationship
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
// Delete the relationship by ID
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &db)
|
||||
// Ensure relationship exists before deletion attempt
|
||||
let mut existing_before_delete = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
||||
user_id, source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let before_results: Vec<KnowledgeRelationship> =
|
||||
existing_before_delete.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
!before_results.is_empty(),
|
||||
"Relationship should exist before deletion"
|
||||
);
|
||||
|
||||
// Delete relationship by ID
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationship by ID");
|
||||
|
||||
// Query to verify the relationship was deleted
|
||||
let query = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship.id);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
// Query to verify relationship was deleted
|
||||
let mut result = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
||||
user_id, source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify the relationship no longer exists
|
||||
// Verify relationship no longer exists
|
||||
assert!(results.is_empty(), "Relationship should be deleted");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id() {
|
||||
async fn test_delete_relationship_by_id_unauthorized() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
|
||||
let owner_user_id = "owner-user".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
owner_user_id.clone(),
|
||||
source_id,
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
let mut before_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
||||
owner_user_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
!before_results.is_empty(),
|
||||
"Relationship should exist before unauthorized delete attempt"
|
||||
);
|
||||
|
||||
let result = KnowledgeRelationship::delete_relationship_by_id(
|
||||
&relationship.id,
|
||||
"different-user",
|
||||
&db,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected authorization error when deleting someone else's relationship"),
|
||||
}
|
||||
|
||||
let mut after_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
||||
owner_user_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
!results.is_empty(),
|
||||
"Relationship should still exist after unauthorized delete attempt"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_exists() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::stored_object;
|
||||
@@ -56,7 +57,7 @@ impl fmt::Display for Message {
|
||||
pub fn format_history(history: &[Message]) -> String {
|
||||
history
|
||||
.iter()
|
||||
.map(|msg| format!("{}", msg))
|
||||
.map(|msg| format!("{msg}"))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#![allow(clippy::unsafe_derive_deserialize)]
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub mod analytics;
|
||||
pub mod conversation;
|
||||
@@ -5,11 +6,14 @@ pub mod file_info;
|
||||
pub mod ingestion_payload;
|
||||
pub mod ingestion_task;
|
||||
pub mod knowledge_entity;
|
||||
pub mod knowledge_entity_embedding;
|
||||
pub mod knowledge_relationship;
|
||||
pub mod message;
|
||||
pub mod scratchpad;
|
||||
pub mod system_prompts;
|
||||
pub mod system_settings;
|
||||
pub mod text_chunk;
|
||||
pub mod text_chunk_embedding;
|
||||
pub mod text_content;
|
||||
pub mod user;
|
||||
|
||||
@@ -20,7 +24,7 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! stored_object {
|
||||
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use surrealdb::sql::Thing;
|
||||
use $crate::storage::types::StoredObject;
|
||||
@@ -83,7 +87,36 @@ macro_rules! stored_object {
|
||||
Ok(DateTime::<Utc>::from(dt))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::ref_option)]
|
||||
fn serialize_option_datetime<S>(
|
||||
date: &Option<DateTime<Utc>>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match date {
|
||||
Some(dt) => serializer
|
||||
.serialize_some(&Into::<surrealdb::sql::Datetime>::into(*dt)),
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::ref_option)]
|
||||
fn deserialize_option_datetime<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<Option<DateTime<Utc>>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let value = Option::<surrealdb::sql::Datetime>::deserialize(deserializer)?;
|
||||
Ok(value.map(DateTime::<Utc>::from))
|
||||
}
|
||||
|
||||
|
||||
$(#[$struct_attr])*
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct $name {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
@@ -92,7 +125,7 @@ macro_rules! stored_object {
|
||||
pub created_at: DateTime<Utc>,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
|
||||
pub updated_at: DateTime<Utc>,
|
||||
$(pub $field: $ty),*
|
||||
$( $(#[$field_attr])* pub $field: $ty),*
|
||||
}
|
||||
|
||||
impl StoredObject for $name {
|
||||
|
||||
@@ -0,0 +1,502 @@
|
||||
use chrono::Utc as ChronoUtc;
|
||||
use surrealdb::opt::PatchOp;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
stored_object!(Scratchpad, "scratchpad", {
|
||||
user_id: String,
|
||||
title: String,
|
||||
content: String,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with="deserialize_datetime")]
|
||||
last_saved_at: DateTime<Utc>,
|
||||
is_dirty: bool,
|
||||
#[serde(default)]
|
||||
is_archived: bool,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
archived_at: Option<DateTime<Utc>>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
ingested_at: Option<DateTime<Utc>>
|
||||
});
|
||||
|
||||
impl Scratchpad {
|
||||
pub fn new(user_id: String, title: String) -> Self {
|
||||
let now = ChronoUtc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id,
|
||||
title,
|
||||
content: String::new(),
|
||||
last_saved_at: now,
|
||||
is_dirty: false,
|
||||
is_archived: false,
|
||||
archived_at: None,
|
||||
ingested_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_by_user(user_id: &str, db: &SurrealDbClient) -> Result<Vec<Self>, AppError> {
|
||||
let scratchpads: Vec<Scratchpad> = db.client
|
||||
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id AND (is_archived = false OR is_archived IS NONE) ORDER BY updated_at DESC")
|
||||
.bind(("table_name", Self::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(scratchpads)
|
||||
}
|
||||
|
||||
pub async fn get_archived_by_user(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<Self>, AppError> {
|
||||
let scratchpads: Vec<Scratchpad> = db.client
|
||||
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id AND is_archived = true ORDER BY archived_at DESC, updated_at DESC")
|
||||
.bind(("table_name", Self::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(scratchpads)
|
||||
}
|
||||
|
||||
pub async fn get_by_id(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Self, AppError> {
|
||||
let scratchpad: Option<Scratchpad> = db.get_item(id).await?;
|
||||
|
||||
let scratchpad =
|
||||
scratchpad.ok_or_else(|| AppError::NotFound("Scratchpad not found".to_string()))?;
|
||||
|
||||
if scratchpad.user_id != user_id {
|
||||
return Err(AppError::Auth(
|
||||
"You don't have access to this scratchpad".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(scratchpad)
|
||||
}
|
||||
|
||||
pub async fn update_content(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
new_content: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Self, AppError> {
|
||||
// First verify ownership
|
||||
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
if scratchpad.is_archived {
|
||||
return Ok(scratchpad);
|
||||
}
|
||||
|
||||
let now = ChronoUtc::now();
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/content", new_content.to_string()))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(now),
|
||||
))
|
||||
.patch(PatchOp::replace(
|
||||
"/last_saved_at",
|
||||
surrealdb::Datetime::from(now),
|
||||
))
|
||||
.patch(PatchOp::replace("/is_dirty", false))
|
||||
.await?;
|
||||
|
||||
// Return the updated scratchpad
|
||||
Self::get_by_id(id, user_id, db).await
|
||||
}
|
||||
|
||||
pub async fn update_title(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
new_title: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
// First verify ownership
|
||||
let _scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(ChronoUtc::now()),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete(id: &str, user_id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
// First verify ownership
|
||||
let _scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
let _: Option<Self> = db.client.delete((Self::table_name(), id)).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn archive(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
mark_ingested: bool,
|
||||
) -> Result<Self, AppError> {
|
||||
// Verify ownership
|
||||
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
if scratchpad.is_archived {
|
||||
if mark_ingested && scratchpad.ingested_at.is_none() {
|
||||
// Ensure ingested_at is set if required
|
||||
let surreal_now = surrealdb::Datetime::from(ChronoUtc::now());
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/ingested_at", surreal_now))
|
||||
.await?;
|
||||
return Self::get_by_id(id, user_id, db).await;
|
||||
}
|
||||
return Ok(scratchpad);
|
||||
}
|
||||
|
||||
let now = ChronoUtc::now();
|
||||
let surreal_now = surrealdb::Datetime::from(now);
|
||||
let mut update = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/is_archived", true))
|
||||
.patch(PatchOp::replace("/archived_at", surreal_now.clone()))
|
||||
.patch(PatchOp::replace("/updated_at", surreal_now.clone()));
|
||||
|
||||
update = if mark_ingested {
|
||||
update.patch(PatchOp::replace("/ingested_at", surreal_now))
|
||||
} else {
|
||||
update.patch(PatchOp::remove("/ingested_at"))
|
||||
};
|
||||
|
||||
let _updated: Option<Self> = update.await?;
|
||||
|
||||
Self::get_by_id(id, user_id, db).await
|
||||
}
|
||||
|
||||
pub async fn restore(id: &str, user_id: &str, db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
// Verify ownership
|
||||
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
if !scratchpad.is_archived {
|
||||
return Ok(scratchpad);
|
||||
}
|
||||
|
||||
let now = ChronoUtc::now();
|
||||
let surreal_now = surrealdb::Datetime::from(now);
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/is_archived", false))
|
||||
.patch(PatchOp::remove("/archived_at"))
|
||||
.patch(PatchOp::remove("/ingested_at"))
|
||||
.patch(PatchOp::replace("/updated_at", surreal_now))
|
||||
.await?;
|
||||
|
||||
Self::get_by_id(id, user_id, db).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_scratchpad() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
// Create a new scratchpad
|
||||
let user_id = "test_user";
|
||||
let title = "Test Scratchpad";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), title.to_string());
|
||||
|
||||
// Verify scratchpad properties
|
||||
assert_eq!(scratchpad.user_id, user_id);
|
||||
assert_eq!(scratchpad.title, title);
|
||||
assert_eq!(scratchpad.content, "");
|
||||
assert!(!scratchpad.is_dirty);
|
||||
assert!(!scratchpad.is_archived);
|
||||
assert!(scratchpad.archived_at.is_none());
|
||||
assert!(scratchpad.ingested_at.is_none());
|
||||
assert!(!scratchpad.id.is_empty());
|
||||
|
||||
// Store the scratchpad
|
||||
let result = db.store_item(scratchpad.clone()).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<Scratchpad> = db
|
||||
.get_item(&scratchpad.id)
|
||||
.await
|
||||
.expect("Failed to retrieve scratchpad");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, scratchpad.id);
|
||||
assert_eq!(retrieved.user_id, user_id);
|
||||
assert_eq!(retrieved.title, title);
|
||||
assert!(!retrieved.is_archived);
|
||||
assert!(retrieved.archived_at.is_none());
|
||||
assert!(retrieved.ingested_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_user() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create multiple scratchpads
|
||||
let scratchpad1 = Scratchpad::new(user_id.to_string(), "First".to_string());
|
||||
let scratchpad2 = Scratchpad::new(user_id.to_string(), "Second".to_string());
|
||||
let scratchpad3 = Scratchpad::new("other_user".to_string(), "Other".to_string());
|
||||
|
||||
// Store them
|
||||
let scratchpad1_id = scratchpad1.id.clone();
|
||||
let scratchpad2_id = scratchpad2.id.clone();
|
||||
db.store_item(scratchpad1).await.unwrap();
|
||||
db.store_item(scratchpad2).await.unwrap();
|
||||
db.store_item(scratchpad3).await.unwrap();
|
||||
|
||||
// Archive one of the user's scratchpads
|
||||
Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get scratchpads for user_id
|
||||
let user_scratchpads = Scratchpad::get_by_user(user_id, &db).await.unwrap();
|
||||
assert_eq!(user_scratchpads.len(), 1);
|
||||
assert_eq!(user_scratchpads[0].id, scratchpad1_id);
|
||||
|
||||
// Verify they belong to the user
|
||||
for scratchpad in &user_scratchpads {
|
||||
assert_eq!(scratchpad.user_id, user_id);
|
||||
}
|
||||
|
||||
let archived = Scratchpad::get_archived_by_user(user_id, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(archived.len(), 1);
|
||||
assert_eq!(archived[0].id, scratchpad2_id);
|
||||
assert!(archived[0].is_archived);
|
||||
assert!(archived[0].ingested_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_archive_and_restore() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
|
||||
.await
|
||||
.expect("Failed to archive");
|
||||
assert!(archived.is_archived);
|
||||
assert!(archived.archived_at.is_some());
|
||||
assert!(archived.ingested_at.is_some());
|
||||
|
||||
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
|
||||
.await
|
||||
.expect("Failed to restore");
|
||||
assert!(!restored.is_archived);
|
||||
assert!(restored.archived_at.is_none());
|
||||
assert!(restored.ingested_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let new_content = "Updated content";
|
||||
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(updated.content, new_content);
|
||||
assert!(!updated.is_dirty);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content_unauthorized() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_scratchpad() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
// Delete should succeed
|
||||
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it's gone
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
||||
assert!(retrieved.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_unauthorized() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
}
|
||||
|
||||
// Verify it still exists
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_aware_scratchpad_conversion() {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to create test database");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user_123";
|
||||
let scratchpad =
|
||||
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Test that datetime fields are preserved and can be used for timezone formatting
|
||||
assert!(retrieved.created_at.timestamp() > 0);
|
||||
assert!(retrieved.updated_at.timestamp() > 0);
|
||||
assert!(retrieved.last_saved_at.timestamp() > 0);
|
||||
|
||||
// Test that optional datetime fields work correctly
|
||||
assert!(retrieved.archived_at.is_none());
|
||||
assert!(retrieved.ingested_at.is_none());
|
||||
|
||||
// Archive the scratchpad to test optional datetime handling
|
||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(archived.archived_at.is_some());
|
||||
assert!(archived.archived_at.unwrap().timestamp() > 0);
|
||||
assert!(archived.ingested_at.is_none());
|
||||
}
|
||||
}
|
||||
@@ -54,3 +54,10 @@ Guidelines:
|
||||
7. Only create relationships between existing KnowledgeEntities.
|
||||
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
|
||||
9. A new relationship MUST include a newly created KnowledgeEntity."#;
|
||||
|
||||
pub static DEFAULT_IMAGE_PROCESSING_PROMPT: &str = r#"Analyze this image and respond based on its primary content:
|
||||
- If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.
|
||||
- If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.
|
||||
- For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a "Text:" heading.
|
||||
|
||||
Respond directly with the analysis."#;
|
||||
|
||||
@@ -11,8 +11,16 @@ pub struct SystemSettings {
|
||||
pub require_email_verification: bool,
|
||||
pub query_model: String,
|
||||
pub processing_model: String,
|
||||
pub embedding_model: String,
|
||||
pub embedding_dimensions: u32,
|
||||
/// Active embedding backend ("openai", "fastembed", "hashed"). Read-only, synced from config.
|
||||
#[serde(default)]
|
||||
pub embedding_backend: Option<String>,
|
||||
pub query_system_prompt: String,
|
||||
pub ingestion_system_prompt: String,
|
||||
pub image_processing_model: String,
|
||||
pub image_processing_prompt: String,
|
||||
pub voice_processing_model: String,
|
||||
}
|
||||
|
||||
impl StoredObject for SystemSettings {
|
||||
@@ -26,18 +34,6 @@ impl StoredObject for SystemSettings {
|
||||
}
|
||||
|
||||
impl SystemSettings {
|
||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let settings: Option<Self> = db.get_item("current").await?;
|
||||
|
||||
if settings.is_none() {
|
||||
let created_settings = Self::new();
|
||||
let stored: Option<Self> = db.store_item(created_settings).await?;
|
||||
return stored.ok_or(AppError::Validation("Failed to initialize settings".into()));
|
||||
}
|
||||
|
||||
settings.ok_or(AppError::Validation("Failed to initialize settings".into()))
|
||||
}
|
||||
|
||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let settings: Option<Self> = db.get_item("current").await?;
|
||||
settings.ok_or(AppError::NotFound("System settings not found".into()))
|
||||
@@ -57,27 +53,112 @@ impl SystemSettings {
|
||||
))
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: "current".to_string(),
|
||||
query_system_prompt: crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
ingestion_system_prompt:
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
.to_string(),
|
||||
query_model: "gpt-4o-mini".to_string(),
|
||||
processing_model: "gpt-4o-mini".to_string(),
|
||||
registrations_enabled: true,
|
||||
require_email_verification: false,
|
||||
/// Syncs SystemSettings with the active embedding provider's properties.
|
||||
/// Updates embedding_backend, embedding_model, and embedding_dimensions if they differ.
|
||||
/// Returns true if any settings were changed.
|
||||
pub async fn sync_from_embedding_provider(
|
||||
db: &SurrealDbClient,
|
||||
provider: &crate::utils::embedding::EmbeddingProvider,
|
||||
) -> Result<(Self, bool), AppError> {
|
||||
let mut settings = Self::get_current(db).await?;
|
||||
let mut needs_update = false;
|
||||
|
||||
let backend_label = provider.backend_label().to_string();
|
||||
let provider_dimensions = provider.dimension() as u32;
|
||||
let provider_model = provider.model_code();
|
||||
|
||||
// Sync backend label
|
||||
if settings.embedding_backend.as_deref() != Some(&backend_label) {
|
||||
settings.embedding_backend = Some(backend_label);
|
||||
needs_update = true;
|
||||
}
|
||||
|
||||
// Sync dimensions
|
||||
if settings.embedding_dimensions != provider_dimensions {
|
||||
tracing::info!(
|
||||
old_dimensions = settings.embedding_dimensions,
|
||||
new_dimensions = provider_dimensions,
|
||||
"Embedding dimensions changed, updating SystemSettings"
|
||||
);
|
||||
settings.embedding_dimensions = provider_dimensions;
|
||||
needs_update = true;
|
||||
}
|
||||
|
||||
// Sync model if provider has one
|
||||
if let Some(model) = provider_model {
|
||||
if settings.embedding_model != model {
|
||||
tracing::info!(
|
||||
old_model = %settings.embedding_model,
|
||||
new_model = %model,
|
||||
"Embedding model changed, updating SystemSettings"
|
||||
);
|
||||
settings.embedding_model = model;
|
||||
needs_update = true;
|
||||
}
|
||||
}
|
||||
|
||||
if needs_update {
|
||||
settings = Self::update(db, settings).await?;
|
||||
}
|
||||
|
||||
Ok((settings, needs_update))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::storage::indexes::ensure_runtime_indexes;
|
||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||
use async_openai::Client;
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn get_hnsw_index_dimension(
|
||||
db: &SurrealDbClient,
|
||||
table_name: &str,
|
||||
index_name: &str,
|
||||
) -> u32 {
|
||||
let query = format!("INFO FOR TABLE {table_name};");
|
||||
let mut response = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Failed to fetch table info");
|
||||
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
.expect("Failed to extract table info response");
|
||||
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("Failed to convert info to json");
|
||||
|
||||
let indexes = info_json["Object"]["indexes"]["Object"]
|
||||
.as_object()
|
||||
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}"));
|
||||
|
||||
let definition = indexes
|
||||
.get(index_name)
|
||||
.and_then(|definition| definition.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}"));
|
||||
|
||||
let dimension_part = definition
|
||||
.split("DIMENSION")
|
||||
.nth(1)
|
||||
.expect("Index definition missing DIMENSION clause");
|
||||
|
||||
let dimension_token = dimension_part
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.expect("Dimension value missing in definition")
|
||||
.trim_end_matches(';');
|
||||
|
||||
dimension_token
|
||||
.parse::<u32>()
|
||||
.expect("Dimension value is not a valid number")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() {
|
||||
// Setup in-memory database for testing
|
||||
@@ -88,9 +169,12 @@ mod tests {
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Test initialization of system settings
|
||||
let settings = SystemSettings::ensure_initialized(&db)
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
.expect("Failed to apply migrations");
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get system settings");
|
||||
|
||||
// Verify initial state after initialization
|
||||
assert_eq!(settings.id, "current");
|
||||
@@ -98,17 +182,22 @@ mod tests {
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(
|
||||
settings.query_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
);
|
||||
assert_eq!(
|
||||
settings.ingestion_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
);
|
||||
assert_eq!(settings.image_processing_model, "gpt-4o-mini");
|
||||
// Dont test these for now, having a hard time getting the formatting exactly the same
|
||||
// assert_eq!(
|
||||
// settings.query_system_prompt,
|
||||
// crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
// );
|
||||
// assert_eq!(
|
||||
// settings.ingestion_system_prompt,
|
||||
// crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
// );
|
||||
|
||||
// Test idempotency - ensure calling it again doesn't change anything
|
||||
let settings_again = SystemSettings::ensure_initialized(&db)
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
let settings_again = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to get settings after initialization");
|
||||
|
||||
@@ -133,9 +222,9 @@ mod tests {
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize settings
|
||||
SystemSettings::ensure_initialized(&db)
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
// Test get_current method
|
||||
let settings = SystemSettings::get_current(&db)
|
||||
@@ -157,12 +246,12 @@ mod tests {
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Initialize settings
|
||||
SystemSettings::ensure_initialized(&db)
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to initialize system settings");
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
// Create updated settings
|
||||
let mut updated_settings = SystemSettings::new();
|
||||
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
|
||||
updated_settings.id = "current".to_string();
|
||||
updated_settings.registrations_enabled = false;
|
||||
updated_settings.require_email_verification = true;
|
||||
@@ -211,21 +300,155 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_method() {
|
||||
let settings = SystemSettings::new();
|
||||
async fn test_migration_after_changing_embedding_length() {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
|
||||
let initial_chunk = TextChunk::new(
|
||||
"source1".into(),
|
||||
"This chunk has the original dimension".into(),
|
||||
"user1".into(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
||||
.await
|
||||
.expect("Failed to store initial chunk with embedding");
|
||||
|
||||
async fn simulate_reembedding(
|
||||
db: &SurrealDbClient,
|
||||
target_dimension: usize,
|
||||
initial_chunk: TextChunk,
|
||||
) {
|
||||
db.query(
|
||||
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let define_index_query = format!(
|
||||
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
target_dimension
|
||||
);
|
||||
db.query(define_index_query)
|
||||
.await
|
||||
.expect("Re-defining index should succeed");
|
||||
|
||||
let new_embedding = vec![0.5; target_dimension];
|
||||
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
||||
|
||||
let update_result = db
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("id", initial_chunk.id.clone()))
|
||||
.bind(("user_id", initial_chunk.user_id.clone()))
|
||||
.bind(("embedding", new_embedding))
|
||||
.await;
|
||||
|
||||
assert!(update_result.is_ok());
|
||||
}
|
||||
|
||||
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
|
||||
let target_dimension = 1536usize;
|
||||
simulate_reembedding(&db, target_dimension, initial_chunk).await;
|
||||
|
||||
let migration_result = db.apply_migrations().await;
|
||||
|
||||
assert!(
|
||||
migration_result.is_ok(),
|
||||
"Migrations should not fail: {:?}",
|
||||
migration_result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length() {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
|
||||
let mut current_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to load current settings");
|
||||
|
||||
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
|
||||
ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize)
|
||||
.await
|
||||
.expect("failed to build runtime indexes");
|
||||
|
||||
let initial_chunk_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"text_chunk_embedding",
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(settings.id.len() > 0);
|
||||
assert_eq!(settings.registrations_enabled, true);
|
||||
assert_eq!(settings.require_email_verification, false);
|
||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||
assert_eq!(
|
||||
settings.query_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
||||
initial_chunk_dimension, current_settings.embedding_dimensions,
|
||||
"embedding size should match initial system settings"
|
||||
);
|
||||
|
||||
let new_dimension = 768;
|
||||
let new_model = "new-test-embedding-model".to_string();
|
||||
|
||||
current_settings.embedding_dimensions = new_dimension;
|
||||
current_settings.embedding_model = new_model.clone();
|
||||
|
||||
let updated_settings = SystemSettings::update(&db, current_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
|
||||
assert_eq!(
|
||||
updated_settings.embedding_dimensions, new_dimension,
|
||||
"Settings should reflect the new embedding dimension"
|
||||
);
|
||||
|
||||
let openai_client = Client::new();
|
||||
|
||||
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("TextChunk re-embedding should succeed on fresh DB");
|
||||
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
|
||||
|
||||
let text_chunk_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"text_chunk_embedding",
|
||||
"idx_embedding_text_chunk_embedding",
|
||||
)
|
||||
.await;
|
||||
let knowledge_dimension = get_hnsw_index_dimension(
|
||||
&db,
|
||||
"knowledge_entity_embedding",
|
||||
"idx_embedding_knowledge_entity_embedding",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
text_chunk_dimension, new_dimension,
|
||||
"text_chunk index dimension should update"
|
||||
);
|
||||
assert_eq!(
|
||||
settings.ingestion_system_prompt,
|
||||
crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
||||
knowledge_dimension, new_dimension,
|
||||
"knowledge_entity index dimension should update"
|
||||
);
|
||||
|
||||
let persisted_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to reload updated settings");
|
||||
assert_eq!(
|
||||
persisted_settings.embedding_dimensions, new_dimension,
|
||||
"Settings should persist new embedding dimension"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,34 @@
|
||||
#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)]
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
Retry,
|
||||
};
|
||||
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
stored_object!(TextChunk, "text_chunk", {
|
||||
source_id: String,
|
||||
chunk: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String
|
||||
});
|
||||
|
||||
/// Search result including hydrated chunk.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||
pub struct TextChunkSearchResult {
|
||||
pub chunk: TextChunk,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
impl TextChunk {
|
||||
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
@@ -17,7 +36,6 @@ impl TextChunk {
|
||||
updated_at: now,
|
||||
source_id,
|
||||
chunk,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
@@ -35,176 +53,833 @@ impl TextChunk {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Atomically store a text chunk and its embedding.
|
||||
/// Writes the chunk to `text_chunk` and the embedding to `text_chunk_embedding`.
|
||||
pub async fn store_with_embedding(
|
||||
chunk: TextChunk,
|
||||
embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let chunk_id = chunk.id.clone();
|
||||
let source_id = chunk.source_id.clone();
|
||||
let user_id = chunk.user_id.clone();
|
||||
|
||||
let emb = TextChunkEmbedding::new(&chunk_id, source_id.clone(), embedding, user_id.clone());
|
||||
|
||||
// Create both records in a single transaction so we don't orphan embeddings or chunks
|
||||
let response = db
|
||||
.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;",
|
||||
chunk_table = Self::table_name(),
|
||||
))
|
||||
.query(format!(
|
||||
"CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;",
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
))
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.bind(("chunk", chunk))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
response.check().map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
|
||||
pub async fn vector_search(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||
#[allow(clippy::missing_docs_in_private_items)]
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
chunk_id: TextChunk,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
chunk_id,
|
||||
embedding,
|
||||
vector::similarity::cosine(embedding, $embedding) AS score
|
||||
FROM {emb_table}
|
||||
WHERE user_id = $user_id
|
||||
AND embedding <|{take},100|> $embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT {take}
|
||||
FETCH chunk_id;
|
||||
"#,
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
take = take
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).unwrap_or_default();
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| TextChunkSearchResult {
|
||||
chunk: r.chunk_id,
|
||||
score: r.score,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Full-text search over text chunks using the BM25 FTS index.
|
||||
pub async fn fts_search(
|
||||
take: usize,
|
||||
terms: &str,
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Row {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
#[serde(deserialize_with = "deserialize_datetime")]
|
||||
created_at: DateTime<Utc>,
|
||||
#[serde(deserialize_with = "deserialize_datetime")]
|
||||
updated_at: DateTime<Utc>,
|
||||
source_id: String,
|
||||
chunk: String,
|
||||
user_id: String,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
let limit = i64::try_from(take).unwrap_or(i64::MAX);
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
source_id,
|
||||
chunk,
|
||||
user_id,
|
||||
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END AS score
|
||||
FROM {chunk_table}
|
||||
WHERE chunk @0@ $terms
|
||||
AND user_id = $user_id
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit;
|
||||
"#,
|
||||
chunk_table = Self::table_name(),
|
||||
);
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("terms", terms.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", limit))
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
|
||||
|
||||
response = response.check().map_err(AppError::Database)?;
|
||||
|
||||
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| {
|
||||
let chunk = TextChunk {
|
||||
id: r.id,
|
||||
created_at: r.created_at,
|
||||
updated_at: r.updated_at,
|
||||
source_id: r.source_id,
|
||||
chunk: r.chunk,
|
||||
user_id: r.user_id,
|
||||
};
|
||||
|
||||
TextChunkSearchResult {
|
||||
chunk,
|
||||
score: r.score,
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
|
||||
///
|
||||
/// This is a costly operation that should be run in the background. It performs these steps:
|
||||
/// 1. **Fetches All Chunks**: Loads all existing text_chunk records into memory.
|
||||
/// 2. **Generates All Embeddings**: Creates new embeddings for every chunk. If any fails or
|
||||
/// has the wrong dimension, the entire operation is aborted before any DB changes are made.
|
||||
/// 3. **Executes Atomic Transaction**: All data updates and the index recreation are
|
||||
/// performed in a single, all-or-nothing database transaction.
|
||||
pub async fn update_all_embeddings(
|
||||
db: &SurrealDbClient,
|
||||
openai_client: &Client<OpenAIConfig>,
|
||||
new_model: &str,
|
||||
new_dimensions: u32,
|
||||
) -> Result<(), AppError> {
|
||||
info!(
|
||||
"Starting re-embedding process for all text chunks. New dimensions: {}",
|
||||
new_dimensions
|
||||
);
|
||||
|
||||
// Fetch all chunks first
|
||||
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
|
||||
let total_chunks = all_chunks.len();
|
||||
if total_chunks == 0 {
|
||||
info!("No text chunks to update. Just updating the idx");
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} chunks to process.", total_chunks);
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all chunks...");
|
||||
for chunk in &all_chunks {
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
|
||||
let embedding = Retry::spawn(retry_strategy, || {
|
||||
crate::utils::embedding::generate_embedding_with_params(
|
||||
openai_client,
|
||||
&chunk.chunk,
|
||||
new_model,
|
||||
new_dimensions,
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions as usize {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
chunk.id.clone(),
|
||||
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Perform DB updates in a single transaction against the embedding table
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
// Use the chunk id as the embedding record id to keep a 1:1 mapping
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
|
||||
chunk_id = type::thing('text_chunk', '{id}'), \
|
||||
source_id = '{source_id}', \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id,
|
||||
source_id = source_id
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
info!("Re-embedding process for text chunks completed successfully.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-creates embeddings for all text chunks using an `EmbeddingProvider`.
|
||||
///
|
||||
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
|
||||
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
|
||||
pub async fn update_all_embeddings_with_provider(
|
||||
db: &SurrealDbClient,
|
||||
provider: &crate::utils::embedding::EmbeddingProvider,
|
||||
) -> Result<(), AppError> {
|
||||
let new_dimensions = provider.dimension();
|
||||
info!(
|
||||
dimensions = new_dimensions,
|
||||
backend = provider.backend_label(),
|
||||
"Starting re-embedding process for all text chunks"
|
||||
);
|
||||
|
||||
// Fetch all chunks first
|
||||
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
|
||||
let total_chunks = all_chunks.len();
|
||||
if total_chunks == 0 {
|
||||
info!("No text chunks to update. Just updating the index.");
|
||||
TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions).await?;
|
||||
return Ok(());
|
||||
}
|
||||
info!(chunks = total_chunks, "Found chunks to process");
|
||||
|
||||
// Generate all new embeddings in memory
|
||||
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
|
||||
info!("Generating new embeddings for all chunks...");
|
||||
|
||||
for (i, chunk) in all_chunks.iter().enumerate() {
|
||||
if i > 0 && i % 100 == 0 {
|
||||
info!(progress = i, total = total_chunks, "Re-embedding progress");
|
||||
}
|
||||
|
||||
let embedding = provider
|
||||
.embed(&chunk.chunk)
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
let err_msg = format!(
|
||||
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
|
||||
chunk.id, embedding.len(), new_dimensions
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
return Err(AppError::InternalError(err_msg));
|
||||
}
|
||||
new_embeddings.insert(
|
||||
chunk.id.clone(),
|
||||
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
|
||||
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
|
||||
info!("Removing old index and clearing embeddings...");
|
||||
|
||||
// Explicitly remove the index first. This prevents background HNSW maintenance from crashing
|
||||
// when we delete/replace data, dealing with a known SurrealDB panic.
|
||||
db.client
|
||||
.query(format!(
|
||||
"REMOVE INDEX idx_embedding_text_chunk_embedding ON TABLE {};",
|
||||
TextChunkEmbedding::table_name()
|
||||
))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
db.client
|
||||
.query(format!("DELETE FROM {};", TextChunkEmbedding::table_name()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
// Perform DB updates in a single transaction against the embedding table
|
||||
info!("Applying embedding updates in a transaction...");
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"CREATE type::thing('text_chunk_embedding', '{id}') SET \
|
||||
chunk_id = type::thing('text_chunk', '{id}'), \
|
||||
source_id = '{source_id}', \
|
||||
embedding = {embedding}, \
|
||||
user_id = '{user_id}', \
|
||||
created_at = time::now(), \
|
||||
updated_at = time::now();",
|
||||
id = id,
|
||||
embedding = embedding_str,
|
||||
user_id = user_id,
|
||||
source_id = source_id
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
}
|
||||
|
||||
write!(
|
||||
&mut transaction_query,
|
||||
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.client
|
||||
.query(transaction_query)
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
info!("Re-embedding process for text chunks completed successfully.");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes};
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use surrealdb::RecordId;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn ensure_chunk_fts_index(db: &SurrealDbClient) {
|
||||
let snowball_sql = r#"
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
if let Err(err) = db.client.query(snowball_sql).await {
|
||||
// Fall back to ascii-only analyzer when snowball is unavailable in the build.
|
||||
let fallback_sql = r#"
|
||||
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
"#;
|
||||
|
||||
db.client
|
||||
.query(fallback_sql)
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_text_chunk_creation() {
|
||||
// Test basic object creation
|
||||
let source_id = "source123".to_string();
|
||||
let chunk = "This is a text chunk for testing embeddings".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_chunk = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk.clone(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
let text_chunk = TextChunk::new(source_id.clone(), chunk.clone(), user_id.clone());
|
||||
|
||||
// Check that the fields are set correctly
|
||||
assert_eq!(text_chunk.source_id, source_id);
|
||||
assert_eq!(text_chunk.chunk, chunk);
|
||||
assert_eq!(text_chunk.embedding, embedding);
|
||||
assert_eq!(text_chunk.user_id, user_id);
|
||||
assert!(!text_chunk.id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
// Create test data
|
||||
let source_id = "source123".to_string();
|
||||
let chunk1 = "First chunk from the same source".to_string();
|
||||
let chunk2 = "Second chunk from the same source".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
// Create two chunks with the same source_id
|
||||
let text_chunk1 = TextChunk::new(
|
||||
let chunk1 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk1,
|
||||
embedding.clone(),
|
||||
"First chunk from the same source".to_string(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let text_chunk2 = TextChunk::new(
|
||||
let chunk2 = TextChunk::new(
|
||||
source_id.clone(),
|
||||
chunk2,
|
||||
embedding.clone(),
|
||||
"Second chunk from the same source".to_string(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create a chunk with a different source_id
|
||||
let different_source_id = "different_source".to_string();
|
||||
let different_chunk = TextChunk::new(
|
||||
different_source_id.clone(),
|
||||
"different_source".to_string(),
|
||||
"Different source chunk".to_string(),
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store the chunks
|
||||
db.store_item(text_chunk1)
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("Failed to store text chunk 1");
|
||||
db.store_item(text_chunk2)
|
||||
.expect("store chunk1");
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("Failed to store text chunk 2");
|
||||
db.store_item(different_chunk.clone())
|
||||
.await
|
||||
.expect("Failed to store different chunk");
|
||||
.expect("store chunk2");
|
||||
TextChunk::store_with_embedding(
|
||||
different_chunk.clone(),
|
||||
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("store different chunk");
|
||||
|
||||
// Delete by source_id
|
||||
TextChunk::delete_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete chunks by source_id");
|
||||
|
||||
// Verify all chunks with the original source_id are deleted
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
source_id
|
||||
);
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(query)
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
0,
|
||||
"All chunks with the source_id should be deleted"
|
||||
);
|
||||
assert_eq!(remaining.len(), 0);
|
||||
|
||||
// Verify the different source_id chunk still exists
|
||||
let different_query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
different_source_id
|
||||
);
|
||||
let different_remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(different_query)
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
"different_source"
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(
|
||||
different_remaining.len(),
|
||||
1,
|
||||
"Chunk with different source_id should still exist"
|
||||
);
|
||||
assert_eq!(different_remaining.len(), 1);
|
||||
assert_eq!(different_remaining[0].id, different_chunk.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_nonexistent_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
// Create a chunk with a real source_id
|
||||
let real_source_id = "real_source".to_string();
|
||||
let chunk = "Test chunk".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let text_chunk = TextChunk::new(real_source_id.clone(), chunk, embedding, user_id);
|
||||
|
||||
// Store the chunk
|
||||
db.store_item(text_chunk)
|
||||
.await
|
||||
.expect("Failed to store text chunk");
|
||||
|
||||
// Delete using nonexistent source_id
|
||||
let nonexistent_source_id = "nonexistent_source";
|
||||
TextChunk::delete_by_source_id(nonexistent_source_id, &db)
|
||||
.await
|
||||
.expect("Delete operation with nonexistent source_id should not fail");
|
||||
|
||||
// Verify the real chunk still exists
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
real_source_id
|
||||
let chunk = TextChunk::new(
|
||||
real_source_id.clone(),
|
||||
"Test chunk".to_string(),
|
||||
"user123".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk");
|
||||
|
||||
TextChunk::delete_by_source_id("nonexistent_source", &db)
|
||||
.await
|
||||
.expect("Delete should succeed");
|
||||
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query(query)
|
||||
.query(format!(
|
||||
"SELECT * FROM {} WHERE source_id = '{}'",
|
||||
TextChunk::table_name(),
|
||||
real_source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed")
|
||||
.take(0)
|
||||
.expect("Failed to get query results");
|
||||
assert_eq!(remaining.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_both_records() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
let source_id = "store-src".to_string();
|
||||
let user_id = "user_store".to_string();
|
||||
let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone());
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store with embedding");
|
||||
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
|
||||
assert!(stored_chunk.is_some());
|
||||
let stored_chunk = stored_chunk.unwrap();
|
||||
assert_eq!(stored_chunk.source_id, source_id);
|
||||
assert_eq!(stored_chunk.user_id, user_id);
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
.await
|
||||
.expect("get embedding");
|
||||
assert!(embedding.is_some());
|
||||
let embedding = embedding.unwrap();
|
||||
assert_eq!(embedding.chunk_id, rid);
|
||||
assert_eq!(embedding.user_id, user_id);
|
||||
assert_eq!(embedding.source_id, source_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_with_runtime_indexes() {
|
||||
let namespace = "test_ns_runtime";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
// Ensure runtime indexes are built with the expected dimension.
|
||||
let embedding_dimension = 3usize;
|
||||
ensure_runtime_indexes(&db, embedding_dimension)
|
||||
.await
|
||||
.expect("ensure runtime indexes");
|
||||
|
||||
let chunk = TextChunk::new(
|
||||
"runtime_src".to_string(),
|
||||
"runtime chunk body".to_string(),
|
||||
"runtime_user".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store with embedding");
|
||||
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
|
||||
assert!(stored_chunk.is_some(), "chunk should be stored");
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
.await
|
||||
.expect("get embedding");
|
||||
assert!(embedding.is_some(), "embedding should exist");
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
1,
|
||||
"Chunk with real source_id should still exist"
|
||||
embedding.unwrap().embedding.len(),
|
||||
embedding_dimension,
|
||||
"embedding dimension should match runtime index"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_single_result() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
let source_id = "src".to_string();
|
||||
let user_id = "user".to_string();
|
||||
let chunk = TextChunk::new(
|
||||
source_id.clone(),
|
||||
"hello world".to_string(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store");
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let res = &results[0];
|
||||
assert_eq!(res.chunk.id, chunk.id);
|
||||
assert_eq!(res.chunk.source_id, source_id);
|
||||
assert_eq!(res.chunk.chunk, "hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_orders_by_similarity() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
let user_id = "user".to_string();
|
||||
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
|
||||
let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone());
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].chunk.id, chunk2.id);
|
||||
assert_eq!(results[1].chunk.id, chunk1.id);
|
||||
assert!(results[0].score >= results[1].score);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_returns_empty_when_no_chunks() {
|
||||
let namespace = "fts_chunk_ns_empty";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
|
||||
let results = TextChunk::fts_search(5, "hello", &db, "user")
|
||||
.await
|
||||
.expect("fts search");
|
||||
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_single_result() {
|
||||
let namespace = "fts_chunk_ns_single";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
|
||||
let user_id = "fts_user";
|
||||
let chunk = TextChunk::new(
|
||||
"fts_src".to_string(),
|
||||
"rustaceans love rust".to_string(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(chunk.clone()).await.expect("store chunk");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
|
||||
let results = TextChunk::fts_search(3, "rust", &db, user_id)
|
||||
.await
|
||||
.expect("fts search");
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].chunk.id, chunk.id);
|
||||
assert!(results[0].score.is_finite(), "expected a finite FTS score");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fts_search_orders_by_score_and_filters_user() {
|
||||
let namespace = "fts_chunk_ns_order";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
|
||||
let user_id = "fts_user_order";
|
||||
let high_score_chunk = TextChunk::new(
|
||||
"src1".to_string(),
|
||||
"apple apple apple pie recipe".to_string(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
let low_score_chunk = TextChunk::new(
|
||||
"src2".to_string(),
|
||||
"apple tart".to_string(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
let other_user_chunk = TextChunk::new(
|
||||
"src3".to_string(),
|
||||
"apple orchard guide".to_string(),
|
||||
"other_user".to_string(),
|
||||
);
|
||||
|
||||
db.store_item(high_score_chunk.clone())
|
||||
.await
|
||||
.expect("store high score chunk");
|
||||
db.store_item(low_score_chunk.clone())
|
||||
.await
|
||||
.expect("store low score chunk");
|
||||
db.store_item(other_user_chunk)
|
||||
.await
|
||||
.expect("store other user chunk");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
|
||||
let results = TextChunk::fts_search(3, "apple", &db, user_id)
|
||||
.await
|
||||
.expect("fts search");
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect();
|
||||
assert!(
|
||||
ids.contains(&high_score_chunk.id.as_str())
|
||||
&& ids.contains(&low_score_chunk.id.as_str()),
|
||||
"expected only the two chunks for the same user"
|
||||
);
|
||||
assert!(
|
||||
results[0].score >= results[1].score,
|
||||
"expected results ordered by descending score"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,436 @@
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||
/// Record link to the owning text_chunk
|
||||
chunk_id: RecordId,
|
||||
/// Denormalized source id for bulk deletes
|
||||
source_id: String,
|
||||
/// Embedding vector
|
||||
embedding: Vec<f32>,
|
||||
/// Denormalized user id (for scoping + permissions)
|
||||
user_id: String
|
||||
});
|
||||
|
||||
impl TextChunkEmbedding {
|
||||
/// Recreate the HNSW index with a new embedding dimension.
|
||||
///
|
||||
/// This is useful when the embedding length changes; Surreal requires the
|
||||
/// index definition to be recreated with the updated dimension.
|
||||
pub async fn redefine_hnsw_index(
|
||||
db: &SurrealDbClient,
|
||||
dimension: usize,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE {table};
|
||||
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
|
||||
COMMIT TRANSACTION;",
|
||||
table = Self::table_name(),
|
||||
);
|
||||
|
||||
let res = db.client.query(query).await.map_err(AppError::Database)?;
|
||||
res.check().map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new text chunk embedding
|
||||
///
|
||||
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID),
|
||||
/// not "text_chunk:uuid".
|
||||
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
|
||||
Self {
|
||||
// NOTE: `stored_object!` macro defines `id` as `String`
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
// Create a record<text_chunk> link: text_chunk:<chunk_id>
|
||||
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
||||
source_id,
|
||||
embedding,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a single embedding by its chunk RecordId
|
||||
pub async fn get_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(embeddings.into_iter().next())
|
||||
}
|
||||
|
||||
/// Delete embeddings for a given chunk RecordId
|
||||
pub async fn delete_by_chunk_id(
|
||||
chunk_id: &RecordId,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE FROM {} WHERE chunk_id = $chunk_id",
|
||||
Self::table_name()
|
||||
);
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings that belong to chunks with a given `source_id`
|
||||
///
|
||||
/// This uses a subquery to the `text_chunk` table:
|
||||
///
|
||||
/// DELETE FROM text_chunk_embedding
|
||||
/// WHERE chunk_id IN (SELECT id FROM text_chunk WHERE source_id = $source_id)
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
#[allow(clippy::missing_docs_in_private_items)]
|
||||
#[derive(Deserialize)]
|
||||
struct IdRow {
|
||||
id: RecordId,
|
||||
}
|
||||
let ids_query = format!(
|
||||
"SELECT id FROM {} WHERE source_id = $source_id",
|
||||
TextChunk::table_name()
|
||||
);
|
||||
let mut res = db
|
||||
.client
|
||||
.query(ids_query)
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
|
||||
|
||||
if ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let delete_query = format!(
|
||||
"DELETE FROM {} WHERE chunk_id IN $chunk_ids",
|
||||
Self::table_name()
|
||||
);
|
||||
db.client
|
||||
.query(delete_query)
|
||||
.bind((
|
||||
"chunk_ids",
|
||||
ids.into_iter().map(|row| row.id).collect::<Vec<_>>(),
|
||||
))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::db::SurrealDbClient;
|
||||
use surrealdb::Value as SurrealValue;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Helper to create an in-memory DB and apply migrations
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
/// Helper: create a text_chunk with a known key, return its RecordId
|
||||
async fn create_text_chunk_with_id(
|
||||
db: &SurrealDbClient,
|
||||
key: &str,
|
||||
source_id: &str,
|
||||
user_id: &str,
|
||||
) -> RecordId {
|
||||
let chunk = TextChunk {
|
||||
id: key.to_owned(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
source_id: source_id.to_owned(),
|
||||
chunk: "Some test chunk text".to_owned(),
|
||||
user_id: user_id.to_owned(),
|
||||
};
|
||||
|
||||
db.store_item(chunk)
|
||||
.await
|
||||
.expect("Failed to create text_chunk");
|
||||
|
||||
RecordId::from_table_key(TextChunk::table_name(), key)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_get_by_chunk_id() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let user_id = "user_a";
|
||||
let chunk_key = "chunk-123";
|
||||
let source_id = "source-1";
|
||||
|
||||
// 1) Create a text_chunk with a known key
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
||||
|
||||
// 2) Create and store an embedding for that chunk
|
||||
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
|
||||
let emb = TextChunkEmbedding::new(
|
||||
chunk_key,
|
||||
source_id.to_string(),
|
||||
embedding_vec.clone(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
|
||||
// 3) Fetch it via get_by_chunk_id
|
||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding by chunk_id");
|
||||
|
||||
assert!(fetched.is_some(), "Expected an embedding to be found");
|
||||
let fetched = fetched.unwrap();
|
||||
|
||||
assert_eq!(fetched.user_id, user_id);
|
||||
assert_eq!(fetched.chunk_id, chunk_rid);
|
||||
assert_eq!(fetched.embedding, embedding_vec);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_chunk_id() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let user_id = "user_b";
|
||||
let chunk_key = "chunk-delete";
|
||||
let source_id = "source-del";
|
||||
|
||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
||||
|
||||
let emb = TextChunkEmbedding::new(
|
||||
chunk_key,
|
||||
source_id.to_string(),
|
||||
vec![0.4_f32, 0.5, 0.6],
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
|
||||
// Ensure it exists
|
||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding before delete");
|
||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||
|
||||
// Delete by chunk_id
|
||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to delete by chunk_id");
|
||||
|
||||
// Ensure it no longer exists
|
||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||
.await
|
||||
.expect("Failed to get embedding after delete");
|
||||
assert!(after.is_none(), "Embedding should have been deleted");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let user_id = "user_c";
|
||||
let source_id = "shared-source";
|
||||
let other_source = "other-source";
|
||||
|
||||
// Two chunks with the same source_id
|
||||
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await;
|
||||
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await;
|
||||
|
||||
// One chunk with a different source_id
|
||||
let chunk_other_rid =
|
||||
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await;
|
||||
|
||||
// Create embeddings for all three
|
||||
let emb1 = TextChunkEmbedding::new(
|
||||
"chunk-s1",
|
||||
source_id.to_string(),
|
||||
vec![0.1],
|
||||
user_id.to_string(),
|
||||
);
|
||||
let emb2 = TextChunkEmbedding::new(
|
||||
"chunk-s2",
|
||||
source_id.to_string(),
|
||||
vec![0.2],
|
||||
user_id.to_string(),
|
||||
);
|
||||
let emb3 = TextChunkEmbedding::new(
|
||||
"chunk-other",
|
||||
other_source.to_string(),
|
||||
vec![0.3],
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
// Update length on index
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
for emb in [emb1, emb2, emb3] {
|
||||
let _: Option<TextChunkEmbedding> = db
|
||||
.client
|
||||
.create(TextChunkEmbedding::table_name())
|
||||
.content(emb)
|
||||
.await
|
||||
.expect("Failed to store embedding")
|
||||
.take()
|
||||
.expect("Failed to deserialize stored embedding");
|
||||
}
|
||||
|
||||
// Sanity check: they all exist
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
|
||||
// Delete embeddings by source_id (shared-source)
|
||||
TextChunkEmbedding::delete_by_source_id(source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete by source_id");
|
||||
|
||||
// Chunks from shared-source should have no embeddings
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
// The other chunk should still have its embedding
|
||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Change the index dimension from default (1536) to a smaller test value.
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
|
||||
.await
|
||||
.expect("failed to redefine index");
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 8"),
|
||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redefine_hnsw_index_is_idempotent() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||
.await
|
||||
.expect("first redefine failed");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||
.await
|
||||
.expect("second redefine failed");
|
||||
|
||||
let mut info_res = db
|
||||
.client
|
||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||
.await
|
||||
.expect("info query failed");
|
||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
||||
let info_json: serde_json::Value =
|
||||
serde_json::to_value(info).expect("failed to convert info to json");
|
||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
idx_sql.contains("DIMENSION 4"),
|
||||
"expected index definition to retain dimension 4, got: {idx_sql}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -5,10 +5,57 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use super::file_info::FileInfo;
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct TextContentSearchResult {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
#[serde(
|
||||
serialize_with = "serialize_datetime",
|
||||
deserialize_with = "deserialize_datetime",
|
||||
default
|
||||
)]
|
||||
pub created_at: DateTime<Utc>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_datetime",
|
||||
deserialize_with = "deserialize_datetime",
|
||||
default
|
||||
)]
|
||||
pub updated_at: DateTime<Utc>,
|
||||
|
||||
pub text: String,
|
||||
#[serde(default)]
|
||||
pub file_info: Option<FileInfo>,
|
||||
#[serde(default)]
|
||||
pub url_info: Option<UrlInfo>,
|
||||
#[serde(default)]
|
||||
pub context: Option<String>,
|
||||
pub category: String,
|
||||
pub user_id: String,
|
||||
|
||||
pub score: f32,
|
||||
// Highlighted fields from the query aliases
|
||||
#[serde(default)]
|
||||
pub highlighted_text: Option<String>,
|
||||
#[serde(default)]
|
||||
pub highlighted_category: Option<String>,
|
||||
#[serde(default)]
|
||||
pub highlighted_context: Option<String>,
|
||||
#[serde(default)]
|
||||
pub highlighted_file_name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub highlighted_url: Option<String>,
|
||||
#[serde(default)]
|
||||
pub highlighted_url_title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct UrlInfo {
|
||||
#[serde(default)]
|
||||
pub url: String,
|
||||
#[serde(default)]
|
||||
pub title: String,
|
||||
#[serde(default)]
|
||||
pub image_id: String,
|
||||
}
|
||||
|
||||
@@ -58,11 +105,82 @@ impl TextContent {
|
||||
.patch(PatchOp::replace("/context", context))
|
||||
.patch(PatchOp::replace("/category", category))
|
||||
.patch(PatchOp::replace("/text", text))
|
||||
.patch(PatchOp::replace("/updated_at", now))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(now),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn has_other_with_file(
|
||||
file_id: &str,
|
||||
exclude_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<bool, AppError> {
|
||||
let mut response = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT VALUE id FROM type::table($table_name) WHERE file_info.id = $file_id AND id != type::thing($table_name, $exclude_id) LIMIT 1",
|
||||
)
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.bind(("file_id", file_id.to_owned()))
|
||||
.bind(("exclude_id", exclude_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let existing: Option<surrealdb::sql::Thing> = response.take(0)?;
|
||||
|
||||
Ok(existing.is_some())
|
||||
}
|
||||
|
||||
pub async fn search(
|
||||
db: &SurrealDbClient,
|
||||
search_terms: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<TextContentSearchResult>, AppError> {
|
||||
let sql = r#"
|
||||
SELECT
|
||||
*,
|
||||
search::highlight('<b>', '</b>', 0) AS highlighted_text,
|
||||
search::highlight('<b>', '</b>', 1) AS highlighted_category,
|
||||
search::highlight('<b>', '</b>', 2) AS highlighted_context,
|
||||
search::highlight('<b>', '</b>', 3) AS highlighted_file_name,
|
||||
search::highlight('<b>', '</b>', 4) AS highlighted_url,
|
||||
search::highlight('<b>', '</b>', 5) AS highlighted_url_title,
|
||||
(
|
||||
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END +
|
||||
IF search::score(1) != NONE THEN search::score(1) ELSE 0 END +
|
||||
IF search::score(2) != NONE THEN search::score(2) ELSE 0 END +
|
||||
IF search::score(3) != NONE THEN search::score(3) ELSE 0 END +
|
||||
IF search::score(4) != NONE THEN search::score(4) ELSE 0 END +
|
||||
IF search::score(5) != NONE THEN search::score(5) ELSE 0 END
|
||||
) AS score
|
||||
FROM text_content
|
||||
WHERE
|
||||
(
|
||||
text @0@ $terms OR
|
||||
category @1@ $terms OR
|
||||
context @2@ $terms OR
|
||||
file_info.file_name @3@ $terms OR
|
||||
url_info.url @4@ $terms OR
|
||||
url_info.title @5@ $terms
|
||||
)
|
||||
AND user_id = $user_id
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit;
|
||||
"#;
|
||||
|
||||
Ok(db
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("terms", search_terms.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", limit))
|
||||
.await?
|
||||
.take(0)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -182,4 +300,64 @@ mod tests {
|
||||
assert_eq!(updated_content.text, new_text);
|
||||
assert!(updated_content.updated_at > text_content.updated_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_has_other_with_file_detects_shared_usage() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
let file_info = FileInfo {
|
||||
id: "file-1".to_string(),
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
sha256: "sha-test".to_string(),
|
||||
path: "user123/file-1/test.txt".to_string(),
|
||||
file_name: "test.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
user_id: user_id.clone(),
|
||||
};
|
||||
|
||||
let content_a = TextContent::new(
|
||||
"First".to_string(),
|
||||
Some("ctx-a".to_string()),
|
||||
"category".to_string(),
|
||||
Some(file_info.clone()),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
let content_b = TextContent::new(
|
||||
"Second".to_string(),
|
||||
Some("ctx-b".to_string()),
|
||||
"category".to_string(),
|
||||
Some(file_info.clone()),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
db.store_item(content_a.clone())
|
||||
.await
|
||||
.expect("Failed to store first content");
|
||||
db.store_item(content_b.clone())
|
||||
.await
|
||||
.expect("Failed to store second content");
|
||||
|
||||
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||
.await
|
||||
.expect("Failed to check for shared file usage");
|
||||
assert!(has_other);
|
||||
|
||||
let _removed: Option<TextContent> = db
|
||||
.delete_item(&content_b.id)
|
||||
.await
|
||||
.expect("Failed to delete second content");
|
||||
|
||||
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||
.await
|
||||
.expect("Failed to check shared usage after delete");
|
||||
assert!(!has_other_after);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,33 @@
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
use anyhow::anyhow;
|
||||
use async_trait::async_trait;
|
||||
use axum_session_auth::Authentication;
|
||||
use chrono_tz::Tz;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::text_chunk::TextChunk;
|
||||
use super::{
|
||||
conversation::Conversation, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
|
||||
conversation::Conversation,
|
||||
ingestion_task::{IngestionTask, TaskState},
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
system_settings::SystemSettings,
|
||||
text_content::TextContent,
|
||||
};
|
||||
use chrono::Duration;
|
||||
use futures::try_join;
|
||||
|
||||
/// Result row for returning user category.
|
||||
#[derive(Deserialize)]
|
||||
pub struct CategoryResponse {
|
||||
/// Category name tied to the user.
|
||||
category: String,
|
||||
}
|
||||
|
||||
stored_object!(User, "user", {
|
||||
stored_object!(
|
||||
#[allow(clippy::unsafe_derive_deserialize)]
|
||||
User, "user", {
|
||||
email: String,
|
||||
password: String,
|
||||
anonymous: bool,
|
||||
@@ -28,11 +40,11 @@ stored_object!(User, "user", {
|
||||
#[async_trait]
|
||||
impl Authentication<User, String, Surreal<Any>> for User {
|
||||
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
|
||||
let db = db.unwrap();
|
||||
let db = db.ok_or_else(|| anyhow!("Database handle missing"))?;
|
||||
Ok(db
|
||||
.select((Self::table_name(), userid.as_str()))
|
||||
.await?
|
||||
.unwrap())
|
||||
.ok_or_else(|| anyhow!("User {userid} not found"))?)
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
@@ -48,20 +60,109 @@ impl Authentication<User, String, Surreal<Any>> for User {
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures a timezone string parses, defaulting to UTC when invalid.
|
||||
fn validate_timezone(input: &str) -> String {
|
||||
use chrono_tz::Tz;
|
||||
|
||||
// Check if it's a valid IANA timezone identifier
|
||||
match input.parse::<Tz>() {
|
||||
Ok(_) => input.to_owned(),
|
||||
Err(_) => {
|
||||
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
|
||||
"UTC".to_owned()
|
||||
}
|
||||
if input.parse::<Tz>().is_ok() {
|
||||
return input.to_owned();
|
||||
}
|
||||
|
||||
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
|
||||
"UTC".to_owned()
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct DashboardStats {
|
||||
pub total_documents: i64,
|
||||
pub new_documents_week: i64,
|
||||
pub total_entities: i64,
|
||||
pub new_entities_week: i64,
|
||||
pub total_conversations: i64,
|
||||
pub new_conversations_week: i64,
|
||||
pub total_text_chunks: i64,
|
||||
pub new_text_chunks_week: i64,
|
||||
}
|
||||
|
||||
/// Helper for aggregating `SurrealDB` count responses.
|
||||
#[derive(Deserialize)]
|
||||
struct CountResult {
|
||||
/// Row count returned by the query.
|
||||
count: i64,
|
||||
}
|
||||
|
||||
impl User {
|
||||
/// Counts all objects of a given type belonging to the user.
|
||||
async fn count_total<T: crate::storage::types::StoredObject>(
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<i64, AppError> {
|
||||
let result: Option<CountResult> = db
|
||||
.client
|
||||
.query("SELECT count() as count FROM type::table($table) WHERE user_id = $user_id GROUP ALL")
|
||||
.bind(("table", T::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
Ok(result.map_or(0, |r| r.count))
|
||||
}
|
||||
|
||||
/// Counts objects of a given type created after a specific timestamp.
|
||||
async fn count_since<T: crate::storage::types::StoredObject>(
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
since: chrono::DateTime<chrono::Utc>,
|
||||
) -> Result<i64, AppError> {
|
||||
let result: Option<CountResult> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT count() as count FROM type::table($table) WHERE user_id = $user_id AND created_at >= $since GROUP ALL",
|
||||
)
|
||||
.bind(("table", T::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.bind(("since", surrealdb::Datetime::from(since)))
|
||||
.await?
|
||||
.take(0)?;
|
||||
Ok(result.map_or(0, |r| r.count))
|
||||
}
|
||||
|
||||
pub async fn get_dashboard_stats(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<DashboardStats, AppError> {
|
||||
let since = chrono::Utc::now()
|
||||
.checked_sub_signed(Duration::days(7))
|
||||
.unwrap_or_else(chrono::Utc::now);
|
||||
|
||||
let (
|
||||
total_documents,
|
||||
new_documents_week,
|
||||
total_entities,
|
||||
new_entities_week,
|
||||
total_conversations,
|
||||
new_conversations_week,
|
||||
total_text_chunks,
|
||||
new_text_chunks_week,
|
||||
) = try_join!(
|
||||
Self::count_total::<TextContent>(db, user_id),
|
||||
Self::count_since::<TextContent>(db, user_id, since),
|
||||
Self::count_total::<KnowledgeEntity>(db, user_id),
|
||||
Self::count_since::<KnowledgeEntity>(db, user_id, since),
|
||||
Self::count_total::<Conversation>(db, user_id),
|
||||
Self::count_since::<Conversation>(db, user_id, since),
|
||||
Self::count_total::<TextChunk>(db, user_id),
|
||||
Self::count_since::<TextChunk>(db, user_id, since)
|
||||
)?;
|
||||
|
||||
Ok(DashboardStats {
|
||||
total_documents,
|
||||
new_documents_week,
|
||||
total_entities,
|
||||
new_entities_week,
|
||||
total_conversations,
|
||||
new_conversations_week,
|
||||
total_text_chunks,
|
||||
new_text_chunks_week,
|
||||
})
|
||||
}
|
||||
pub async fn create_new(
|
||||
email: String,
|
||||
password: String,
|
||||
@@ -78,7 +179,7 @@ impl User {
|
||||
let now = Utc::now();
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"LET $count = (SELECT count() FROM type::table($table))[0].count;
|
||||
@@ -95,8 +196,8 @@ impl User {
|
||||
.bind(("id", id))
|
||||
.bind(("email", email))
|
||||
.bind(("password", password))
|
||||
.bind(("created_at", now))
|
||||
.bind(("updated_at", now))
|
||||
.bind(("created_at", surrealdb::Datetime::from(now)))
|
||||
.bind(("updated_at", surrealdb::Datetime::from(now)))
|
||||
.bind(("timezone", validated_tz))
|
||||
.await?
|
||||
.take(1)?;
|
||||
@@ -127,7 +228,7 @@ impl User {
|
||||
password: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Self, AppError> {
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM user
|
||||
@@ -145,7 +246,7 @@ impl User {
|
||||
email: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM user WHERE email = $email LIMIT 1")
|
||||
.bind(("email", email.to_string()))
|
||||
@@ -159,7 +260,7 @@ impl User {
|
||||
api_key: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
|
||||
.bind(("api_key", api_key.to_string()))
|
||||
@@ -171,10 +272,10 @@ impl User {
|
||||
|
||||
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
|
||||
// Generate a secure random API key
|
||||
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
|
||||
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
// Update the user record with the new API key
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('user', $id)
|
||||
@@ -195,7 +296,7 @@ impl User {
|
||||
}
|
||||
|
||||
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('user', $id)
|
||||
@@ -251,6 +352,7 @@ impl User {
|
||||
) -> Result<Vec<String>, AppError> {
|
||||
#[derive(Deserialize)]
|
||||
struct EntityTypeResponse {
|
||||
/// Raw entity type value from the database.
|
||||
entity_type: String,
|
||||
}
|
||||
|
||||
@@ -266,7 +368,10 @@ impl User {
|
||||
// Extract the entity types from the response
|
||||
let entity_types: Vec<String> = response
|
||||
.into_iter()
|
||||
.map(|item| format!("{:?}", item.entity_type))
|
||||
.map(|item| {
|
||||
let normalized = KnowledgeEntityType::from(item.entity_type);
|
||||
format!("{normalized:?}")
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(entity_types)
|
||||
@@ -356,7 +461,7 @@ impl User {
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
db.query("UPDATE type::thing('user', $user_id) SET timezone = $timezone")
|
||||
.bind(("table_name", User::table_name()))
|
||||
.bind(("table_name", Self::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.bind(("timezone", timezone.to_string()))
|
||||
.await?;
|
||||
@@ -442,19 +547,43 @@ impl User {
|
||||
let jobs: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE user_id = $user_id
|
||||
AND (
|
||||
status = 'Created'
|
||||
OR (
|
||||
status.InProgress != NONE
|
||||
AND status.InProgress.attempts < $max_attempts
|
||||
)
|
||||
)
|
||||
ORDER BY created_at DESC",
|
||||
WHERE user_id = $user_id
|
||||
AND (
|
||||
state IN $active_states
|
||||
OR (state = $failed_state AND attempts < max_attempts)
|
||||
)
|
||||
ORDER BY scheduled_at ASC, created_at DESC",
|
||||
)
|
||||
.bind(("table", IngestionTask::table_name()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind((
|
||||
"active_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind(("failed_state", TaskState::Failed.as_str()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(jobs)
|
||||
}
|
||||
|
||||
/// Gets all ingestion tasks for the specified user ordered by newest first
|
||||
pub async fn get_all_ingestion_tasks(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<IngestionTask>, AppError> {
|
||||
let jobs: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE user_id = $user_id
|
||||
ORDER BY created_at DESC",
|
||||
)
|
||||
.bind(("table", IngestionTask::table_name()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("max_attempts", 3))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
@@ -511,6 +640,9 @@ impl User {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||
use std::collections::HashSet;
|
||||
|
||||
// Helper function to set up a test database with SystemSettings
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
@@ -596,6 +728,122 @@ mod tests {
|
||||
assert!(nonexistent.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_ingestion_tasks_filters_correctly() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "unfinished_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
let payload = IngestionPayload::Text {
|
||||
text: "Test".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
};
|
||||
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(created_task.clone())
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
|
||||
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
processing_task.state = TaskState::Processing;
|
||||
processing_task.attempts = 1;
|
||||
db.store_item(processing_task.clone())
|
||||
.await
|
||||
.expect("Failed to store processing task");
|
||||
|
||||
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
failed_retry_task.state = TaskState::Failed;
|
||||
failed_retry_task.attempts = 1;
|
||||
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
|
||||
db.store_item(failed_retry_task.clone())
|
||||
.await
|
||||
.expect("Failed to store retryable failed task");
|
||||
|
||||
let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
failed_blocked_task.state = TaskState::Failed;
|
||||
failed_blocked_task.attempts = MAX_ATTEMPTS;
|
||||
failed_blocked_task.error_message = Some("Too many failures".into());
|
||||
db.store_item(failed_blocked_task.clone())
|
||||
.await
|
||||
.expect("Failed to store blocked task");
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
completed_task.state = TaskState::Succeeded;
|
||||
db.store_item(completed_task.clone())
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: other_user_id.to_string(),
|
||||
};
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||
db.store_item(other_task)
|
||||
.await
|
||||
.expect("Failed to store other user task");
|
||||
|
||||
let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to fetch unfinished tasks");
|
||||
|
||||
let unfinished_ids: HashSet<String> =
|
||||
unfinished.iter().map(|task| task.id.clone()).collect();
|
||||
|
||||
assert!(unfinished_ids.contains(&created_task.id));
|
||||
assert!(unfinished_ids.contains(&processing_task.id));
|
||||
assert!(unfinished_ids.contains(&failed_retry_task.id));
|
||||
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
|
||||
assert!(!unfinished_ids.contains(&completed_task.id));
|
||||
assert_eq!(unfinished_ids.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_all_ingestion_tasks_returns_sorted() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "archive_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
let payload = IngestionPayload::Text {
|
||||
text: "One".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
};
|
||||
|
||||
// Oldest task
|
||||
let mut first = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
first.created_at = first.created_at - chrono::Duration::minutes(1);
|
||||
first.updated_at = first.created_at;
|
||||
first.state = TaskState::Succeeded;
|
||||
db.store_item(first.clone()).await.expect("store first");
|
||||
|
||||
// Latest task
|
||||
let mut second = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
second.state = TaskState::Processing;
|
||||
db.store_item(second.clone()).await.expect("store second");
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: other_user_id.to_string(),
|
||||
};
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||
db.store_item(other_task).await.expect("store other");
|
||||
|
||||
let tasks = User::get_all_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("fetch all tasks");
|
||||
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0].id, second.id); // newest first
|
||||
assert_eq!(tasks[1].id, first.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_by_email() {
|
||||
// Setup test database
|
||||
@@ -816,4 +1064,56 @@ mod tests {
|
||||
let most_recent = conversations.iter().max_by_key(|c| c.created_at).unwrap();
|
||||
assert_eq!(retrieved[0].id, most_recent.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_latest_text_contents_returns_last_five() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "latest_text_user";
|
||||
|
||||
let mut inserted_ids = Vec::new();
|
||||
let base_time = chrono::Utc::now() - chrono::Duration::minutes(60);
|
||||
|
||||
for i in 0..12 {
|
||||
let mut item = TextContent::new(
|
||||
format!("Text {}", i),
|
||||
Some(format!("Context {}", i)),
|
||||
"Category".to_string(),
|
||||
None,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
let timestamp = base_time + chrono::Duration::minutes(i);
|
||||
item.created_at = timestamp;
|
||||
item.updated_at = timestamp;
|
||||
|
||||
db.store_item(item.clone())
|
||||
.await
|
||||
.expect("Failed to store text content");
|
||||
|
||||
inserted_ids.push(item.id.clone());
|
||||
}
|
||||
|
||||
let latest = User::get_latest_text_contents(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to fetch latest text contents");
|
||||
|
||||
assert_eq!(latest.len(), 5, "Expected exactly five items");
|
||||
|
||||
let mut expected_ids = inserted_ids[inserted_ids.len() - 5..].to_vec();
|
||||
expected_ids.reverse();
|
||||
|
||||
let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect();
|
||||
assert_eq!(
|
||||
returned_ids, expected_ids,
|
||||
"Latest items did not match expectation"
|
||||
);
|
||||
|
||||
for window in latest.windows(2) {
|
||||
assert!(
|
||||
window[0].created_at >= window[1].created_at,
|
||||
"Results are not ordered by created_at descending"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+136
-2
@@ -1,6 +1,49 @@
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
|
||||
/// Selects the embedding backend for vector generation.
|
||||
#[derive(Clone, Deserialize, Debug, Default, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EmbeddingBackend {
|
||||
/// Use OpenAI-compatible API for embeddings.
|
||||
OpenAI,
|
||||
/// Use FastEmbed local embeddings (default).
|
||||
#[default]
|
||||
FastEmbed,
|
||||
/// Use deterministic hashed embeddings (for testing).
|
||||
Hashed,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Debug, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum StorageKind {
|
||||
Local,
|
||||
Memory,
|
||||
}
|
||||
|
||||
/// Default storage backend when none is configured.
|
||||
fn default_storage_kind() -> StorageKind {
|
||||
StorageKind::Local
|
||||
}
|
||||
|
||||
/// Selects the strategy used for PDF ingestion.
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum PdfIngestMode {
|
||||
/// Only rely on classic text extraction (no LLM fallbacks).
|
||||
Classic,
|
||||
/// Prefer fast text extraction, but fall back to the LLM rendering path when needed.
|
||||
LlmFirst,
|
||||
}
|
||||
|
||||
/// Default PDF ingestion mode when unset.
|
||||
fn default_pdf_ingest_mode() -> PdfIngestMode {
|
||||
PdfIngestMode::LlmFirst
|
||||
}
|
||||
|
||||
/// Application configuration loaded from files and environment variables.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
pub struct AppConfig {
|
||||
pub openai_api_key: String,
|
||||
@@ -9,19 +52,110 @@ pub struct AppConfig {
|
||||
pub surrealdb_password: String,
|
||||
pub surrealdb_namespace: String,
|
||||
pub surrealdb_database: String,
|
||||
// #[serde(default = "default_data_dir")]
|
||||
#[serde(default = "default_data_dir")]
|
||||
pub data_dir: String,
|
||||
pub http_port: u16,
|
||||
#[serde(default = "default_base_url")]
|
||||
pub openai_base_url: String,
|
||||
#[serde(default = "default_storage_kind")]
|
||||
pub storage: StorageKind,
|
||||
#[serde(default = "default_pdf_ingest_mode")]
|
||||
pub pdf_ingest_mode: PdfIngestMode,
|
||||
#[serde(default = "default_reranking_enabled")]
|
||||
pub reranking_enabled: bool,
|
||||
#[serde(default)]
|
||||
pub reranking_pool_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub fastembed_cache_dir: Option<String>,
|
||||
#[serde(default)]
|
||||
pub fastembed_show_download_progress: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub fastembed_max_length: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub retrieval_strategy: Option<String>,
|
||||
#[serde(default)]
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
}
|
||||
|
||||
/// Default data directory for persisted assets.
|
||||
fn default_data_dir() -> String {
|
||||
"./data".to_string()
|
||||
}
|
||||
|
||||
/// Default base URL used for OpenAI-compatible APIs.
|
||||
fn default_base_url() -> String {
|
||||
"https://api.openai.com/v1".to_string()
|
||||
}
|
||||
|
||||
/// Whether reranking is enabled by default.
|
||||
fn default_reranking_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn ensure_ort_path() {
|
||||
if env::var_os("ORT_DYLIB_PATH").is_some() {
|
||||
return;
|
||||
}
|
||||
if let Ok(mut exe) = env::current_exe() {
|
||||
exe.pop();
|
||||
|
||||
if cfg!(target_os = "windows") {
|
||||
for p in [
|
||||
exe.join("onnxruntime.dll"),
|
||||
exe.join("lib").join("onnxruntime.dll"),
|
||||
] {
|
||||
if p.exists() {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
let name = if cfg!(target_os = "macos") {
|
||||
"libonnxruntime.dylib"
|
||||
} else {
|
||||
"libonnxruntime.so"
|
||||
};
|
||||
let p = exe.join("lib").join(name);
|
||||
if p.exists() {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
openai_api_key: String::new(),
|
||||
surrealdb_address: String::new(),
|
||||
surrealdb_username: String::new(),
|
||||
surrealdb_password: String::new(),
|
||||
surrealdb_namespace: String::new(),
|
||||
surrealdb_database: String::new(),
|
||||
data_dir: default_data_dir(),
|
||||
http_port: 0,
|
||||
openai_base_url: default_base_url(),
|
||||
storage: default_storage_kind(),
|
||||
pdf_ingest_mode: default_pdf_ingest_mode(),
|
||||
reranking_enabled: default_reranking_enabled(),
|
||||
reranking_pool_size: None,
|
||||
fastembed_cache_dir: None,
|
||||
fastembed_show_download_progress: None,
|
||||
fastembed_max_length: None,
|
||||
retrieval_strategy: None,
|
||||
embedding_backend: EmbeddingBackend::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads the application configuration from the environment and optional config file.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||
ensure_ort_path();
|
||||
|
||||
let config = Config::builder()
|
||||
.add_source(File::with_name("config").required(false))
|
||||
.add_source(Environment::default())
|
||||
.build()?;
|
||||
|
||||
Ok(config.try_deserialize()?)
|
||||
config.try_deserialize()
|
||||
}
|
||||
|
||||
@@ -1,15 +1,328 @@
|
||||
use async_openai::types::CreateEmbeddingRequestArgs;
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::error::AppError;
|
||||
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
|
||||
/// Supported embedding backends.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum EmbeddingBackend {
|
||||
#[default]
|
||||
OpenAI,
|
||||
FastEmbed,
|
||||
Hashed,
|
||||
}
|
||||
|
||||
impl std::str::FromStr for EmbeddingBackend {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"openai" => Ok(Self::OpenAI),
|
||||
"hashed" => Ok(Self::Hashed),
|
||||
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||
other => Err(anyhow!(
|
||||
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around the chosen embedding backend.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingProvider {
|
||||
/// Concrete backend implementation.
|
||||
inner: EmbeddingInner,
|
||||
}
|
||||
|
||||
/// Concrete embedding implementations.
|
||||
#[derive(Clone)]
|
||||
enum EmbeddingInner {
|
||||
/// Uses an `OpenAI`-compatible API.
|
||||
OpenAI {
|
||||
/// Client used to issue embedding requests.
|
||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
/// Model identifier for the API.
|
||||
model: String,
|
||||
/// Expected output dimensions.
|
||||
dimensions: u32,
|
||||
},
|
||||
/// Generates deterministic hashed embeddings without external calls.
|
||||
Hashed {
|
||||
/// Output vector length.
|
||||
dimension: usize,
|
||||
},
|
||||
/// Uses `FastEmbed` running locally.
|
||||
FastEmbed {
|
||||
/// Shared `FastEmbed` model.
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
/// Model metadata used for info logging.
|
||||
model_name: EmbeddingModel,
|
||||
/// Output vector length.
|
||||
dimension: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
pub fn backend_label(&self) -> &'static str {
|
||||
match self.inner {
|
||||
EmbeddingInner::Hashed { .. } => "hashed",
|
||||
EmbeddingInner::FastEmbed { .. } => "fastembed",
|
||||
EmbeddingInner::OpenAI { .. } => "openai",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimension(&self) -> usize {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => {
|
||||
*dimension
|
||||
}
|
||||
EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn model_code(&self) -> Option<String> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
|
||||
EmbeddingInner::OpenAI { model, .. } => Some(model.clone()),
|
||||
EmbeddingInner::Hashed { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
let mut guard = model.lock().await;
|
||||
let embeddings = guard
|
||||
.embed(vec![text.to_owned()], None)
|
||||
.context("generating fastembed vector")?;
|
||||
embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
model,
|
||||
dimensions,
|
||||
} => {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model.clone())
|
||||
.input([text])
|
||||
.dimensions(*dimensions)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||
.into_iter()
|
||||
.map(|text| hashed_embedding(&text, *dimension))
|
||||
.collect()),
|
||||
EmbeddingInner::FastEmbed { model, .. } => {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut guard = model.lock().await;
|
||||
guard
|
||||
.embed(texts, None)
|
||||
.context("generating fastembed batch embeddings")
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
model,
|
||||
dimensions,
|
||||
} => {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model.clone())
|
||||
.input(texts)
|
||||
.dimensions(*dimensions)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
let embeddings: Vec<Vec<f32>> = response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|item| item.embedding)
|
||||
.collect();
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_openai(
|
||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||
model: String,
|
||||
dimensions: u32,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
inner: EmbeddingInner::OpenAI {
|
||||
client,
|
||||
model,
|
||||
dimensions,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
|
||||
let model_name = if let Some(code) = model_override {
|
||||
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
|
||||
} else {
|
||||
EmbeddingModel::default()
|
||||
};
|
||||
|
||||
let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
|
||||
let model =
|
||||
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task)
|
||||
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
|
||||
Ok((model, info.dim))
|
||||
})
|
||||
.await
|
||||
.context("joining FastEmbed initialisation task")??;
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
model: Arc::new(Mutex::new(model)),
|
||||
model_name,
|
||||
dimension,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_hashed(dimension: usize) -> Result<Self> {
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::Hashed {
|
||||
dimension: dimension.max(1),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an embedding provider based on application configuration.
|
||||
///
|
||||
/// Dispatches to the appropriate constructor based on `config.embedding_backend`:
|
||||
/// - `OpenAI`: Requires a valid OpenAI client
|
||||
/// - `FastEmbed`: Uses local embedding model
|
||||
/// - `Hashed`: Uses deterministic hashed embeddings (for testing)
|
||||
pub async fn from_config(
|
||||
config: &crate::utils::config::AppConfig,
|
||||
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||
) -> Result<Self> {
|
||||
use crate::utils::config::EmbeddingBackend;
|
||||
|
||||
match config.embedding_backend {
|
||||
EmbeddingBackend::OpenAI => {
|
||||
let client = openai_client
|
||||
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
|
||||
// Use defaults that match SystemSettings initial values
|
||||
Self::new_openai(client, "text-embedding-3-small".to_string(), 1536)
|
||||
}
|
||||
EmbeddingBackend::FastEmbed => {
|
||||
// Use nomic-embed-text-v1.5 as the default FastEmbed model
|
||||
Self::new_fastembed(Some("nomic-ai/nomic-embed-text-v1.5".to_string())).await
|
||||
}
|
||||
EmbeddingBackend::Hashed => Self::new_hashed(384),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for hashed embeddings
|
||||
/// Generates a hashed embedding vector without external dependencies.
|
||||
fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
|
||||
let dim = dimension.max(1);
|
||||
let mut vector = vec![0.0f32; dim];
|
||||
if text.is_empty() {
|
||||
return vector;
|
||||
}
|
||||
|
||||
for token in tokens(text) {
|
||||
let idx = bucket(&token, dim);
|
||||
if let Some(slot) = vector.get_mut(idx) {
|
||||
*slot += 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for value in &mut vector {
|
||||
*value /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
vector
|
||||
}
|
||||
|
||||
/// Tokenizes the text into alphanumeric lowercase tokens.
|
||||
fn tokens(text: &str) -> impl Iterator<Item = String> + '_ {
|
||||
text.split(|c: char| !c.is_ascii_alphanumeric())
|
||||
.filter(|token| !token.is_empty())
|
||||
.map(str::to_ascii_lowercase)
|
||||
}
|
||||
|
||||
/// Buckets a token into the hashed embedding vector.
|
||||
#[allow(clippy::arithmetic_side_effects)]
|
||||
fn bucket(token: &str, dimension: usize) -> usize {
|
||||
let safe_dimension = dimension.max(1);
|
||||
let mut hasher = DefaultHasher::new();
|
||||
token.hash(&mut hasher);
|
||||
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
|
||||
}
|
||||
|
||||
// Backward compatibility function
|
||||
pub async fn generate_embedding_with_provider(
|
||||
provider: &EmbeddingProvider,
|
||||
input: &str,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
provider.embed(input).await.map_err(AppError::from)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
||||
///
|
||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
||||
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity
|
||||
/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity
|
||||
/// comparisons, vector search, and other natural language processing tasks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `client`: The OpenAI client instance used to make API requests.
|
||||
/// * `client`: The `OpenAI` client instance used to make API requests.
|
||||
/// * `input`: The text string to generate embeddings for.
|
||||
///
|
||||
/// # Returns
|
||||
@@ -21,15 +334,20 @@ use crate::error::AppError;
|
||||
/// # Errors
|
||||
///
|
||||
/// This function can return a `AppError` in the following cases:
|
||||
/// * If the OpenAI API request fails
|
||||
/// * If the `OpenAI` API request fails
|
||||
/// * If the request building fails
|
||||
/// * If no embedding data is received in the response
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub async fn generate_embedding(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let model = SystemSettings::get_current(db).await?;
|
||||
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model("text-embedding-3-small")
|
||||
.model(model.embedding_model)
|
||||
.dimensions(model.embedding_dimensions)
|
||||
.input([input])
|
||||
.build()?;
|
||||
|
||||
@@ -46,3 +364,36 @@ pub async fn generate_embedding(
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Generates an embedding vector using a specific model and dimension.
|
||||
///
|
||||
/// This is used for the re-embedding process where the model and dimensions
|
||||
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
|
||||
pub async fn generate_embedding_with_params(
|
||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
input: &str,
|
||||
model: &str,
|
||||
dimensions: u32,
|
||||
) -> Result<Vec<f32>, AppError> {
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model)
|
||||
.input([input])
|
||||
.dimensions(dimensions)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
let embedding = response
|
||||
.data
|
||||
.first()
|
||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
|
||||
.embedding
|
||||
.clone();
|
||||
|
||||
debug!(
|
||||
"Embedding was created with {:?} dimensions",
|
||||
embedding.len()
|
||||
);
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ pub use minijinja_contrib;
|
||||
pub use minijinja_embed;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub trait ProvidesTemplateEngine {
|
||||
fn template_engine(&self) -> &Arc<TemplateEngine>;
|
||||
}
|
||||
@@ -59,13 +60,13 @@ impl TemplateEngine {
|
||||
match self {
|
||||
// Only compile this arm for debug builds
|
||||
#[cfg(debug_assertions)]
|
||||
TemplateEngine::AutoReload(reloader) => {
|
||||
Self::AutoReload(reloader) => {
|
||||
let env = reloader.acquire_env()?;
|
||||
env.get_template(name)?.render(ctx)
|
||||
}
|
||||
// Only compile this arm for release builds
|
||||
#[cfg(not(debug_assertions))]
|
||||
TemplateEngine::Embedded(env) => env.get_template(name)?.render(ctx),
|
||||
Self::Embedded(env) => env.get_template(name)?.render(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,19 +79,17 @@ impl TemplateEngine {
|
||||
match self {
|
||||
// Only compile this arm for debug builds
|
||||
#[cfg(debug_assertions)]
|
||||
TemplateEngine::AutoReload(reloader) => {
|
||||
let env = reloader.acquire_env()?;
|
||||
let template = env.get_template(template_name)?;
|
||||
let mut state = template.eval_to_state(context)?;
|
||||
state.render_block(block_name)
|
||||
}
|
||||
Self::AutoReload(reloader) => reloader
|
||||
.acquire_env()?
|
||||
.get_template(template_name)?
|
||||
.eval_to_state(context)?
|
||||
.render_block(block_name),
|
||||
// Only compile this arm for release builds
|
||||
#[cfg(not(debug_assertions))]
|
||||
TemplateEngine::Embedded(env) => {
|
||||
let template = env.get_template(template_name)?;
|
||||
let mut state = template.eval_to_state(context)?;
|
||||
state.render_block(block_name)
|
||||
}
|
||||
Self::Embedded(env) => env
|
||||
.get_template(template_name)?
|
||||
.eval_to_state(context)?
|
||||
.render_block(block_name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,340 +0,0 @@
|
||||
use surrealdb::Error;
|
||||
use tracing::debug;
|
||||
|
||||
use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
|
||||
|
||||
/// Retrieves database entries that match a specific source identifier.
|
||||
///
|
||||
/// This function queries the database for all records in a specified table that have
|
||||
/// a matching `source_id` field. It's commonly used to find related entities or
|
||||
/// track the origin of database entries.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `source_id` - The identifier to search for in the database
|
||||
/// * `table_name` - The name of the table to search in
|
||||
/// * `db_client` - The SurrealDB client instance for database operations
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a `Result` containing either:
|
||||
/// * `Ok(Vec<T>)` - A vector of matching records deserialized into type `T`
|
||||
/// * `Err(Error)` - An error if the database query fails
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This function will return a `Error` if:
|
||||
/// * The database query fails to execute
|
||||
/// * The results cannot be deserialized into type `T`
|
||||
pub async fn find_entities_by_source_ids<T>(
|
||||
source_id: Vec<String>,
|
||||
table_name: String,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
|
||||
|
||||
db.query(query)
|
||||
.bind(("table", table_name))
|
||||
.bind(("source_ids", source_id))
|
||||
.await?
|
||||
.take(0)
|
||||
}
|
||||
|
||||
/// Find entities by their relationship to the id
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: String,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let query = format!(
|
||||
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
|
||||
entity_id
|
||||
);
|
||||
|
||||
debug!("{}", query);
|
||||
|
||||
db.query(query).await?.take(0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::knowledge_relationship::KnowledgeRelationship;
|
||||
use common::storage::types::StoredObject;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_source_ids() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create some test entities with different source_ids
|
||||
let source_id1 = "source123".to_string();
|
||||
let source_id2 = "source456".to_string();
|
||||
let source_id3 = "source789".to_string();
|
||||
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Entity with source_id1
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
source_id1.clone(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Entity with source_id2
|
||||
let entity2 = KnowledgeEntity::new(
|
||||
source_id2.clone(),
|
||||
"Entity 2".to_string(),
|
||||
"Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Another entity with source_id1
|
||||
let entity3 = KnowledgeEntity::new(
|
||||
source_id1.clone(),
|
||||
"Entity 3".to_string(),
|
||||
"Description 3".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Entity with source_id3
|
||||
let entity4 = KnowledgeEntity::new(
|
||||
source_id3.clone(),
|
||||
"Entity 4".to_string(),
|
||||
"Description 4".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store all entities
|
||||
db.store_item(entity1.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 1");
|
||||
db.store_item(entity2.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 2");
|
||||
db.store_item(entity3.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 3");
|
||||
db.store_item(entity4.clone())
|
||||
.await
|
||||
.expect("Failed to store entity 4");
|
||||
|
||||
// Test finding entities by multiple source_ids
|
||||
let source_ids = vec![source_id1.clone(), source_id2.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name().to_string(), &db)
|
||||
.await
|
||||
.expect("Failed to find entities by source_ids");
|
||||
|
||||
// Should find 3 entities (2 with source_id1, 1 with source_id2)
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
3,
|
||||
"Should find 3 entities with the specified source_ids"
|
||||
);
|
||||
|
||||
// Check that entities with source_id1 and source_id2 are found
|
||||
let found_source_ids: Vec<String> =
|
||||
found_entities.iter().map(|e| e.source_id.clone()).collect();
|
||||
assert!(
|
||||
found_source_ids.contains(&source_id1),
|
||||
"Should find entities with source_id1"
|
||||
);
|
||||
assert!(
|
||||
found_source_ids.contains(&source_id2),
|
||||
"Should find entities with source_id2"
|
||||
);
|
||||
assert!(
|
||||
!found_source_ids.contains(&source_id3),
|
||||
"Should not find entities with source_id3"
|
||||
);
|
||||
|
||||
// Test finding entities by a single source_id
|
||||
let single_source_id = vec![source_id1.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
single_source_id,
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to find entities by single source_id");
|
||||
|
||||
// Should find 2 entities with source_id1
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
2,
|
||||
"Should find 2 entities with source_id1"
|
||||
);
|
||||
|
||||
// Check that all found entities have source_id1
|
||||
for entity in found_entities {
|
||||
assert_eq!(
|
||||
entity.source_id, source_id1,
|
||||
"All found entities should have source_id1"
|
||||
);
|
||||
}
|
||||
|
||||
// Test finding entities with non-existent source_id
|
||||
let non_existent_source_id = vec!["non_existent_source".to_string()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
non_existent_source_id,
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to find entities by non-existent source_id");
|
||||
|
||||
// Should find 0 entities
|
||||
assert_eq!(
|
||||
found_entities.len(),
|
||||
0,
|
||||
"Should find 0 entities with non-existent source_id"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_entities_by_relationship_by_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
// Create some test entities
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Create the central entity we'll query relationships for
|
||||
let central_entity = KnowledgeEntity::new(
|
||||
"central_source".to_string(),
|
||||
"Central Entity".to_string(),
|
||||
"Central Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create related entities
|
||||
let related_entity1 = KnowledgeEntity::new(
|
||||
"related_source1".to_string(),
|
||||
"Related Entity 1".to_string(),
|
||||
"Related Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let related_entity2 = KnowledgeEntity::new(
|
||||
"related_source2".to_string(),
|
||||
"Related Entity 2".to_string(),
|
||||
"Related Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Create an unrelated entity
|
||||
let unrelated_entity = KnowledgeEntity::new(
|
||||
"unrelated_source".to_string(),
|
||||
"Unrelated Entity".to_string(),
|
||||
"Unrelated Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
// Store all entities
|
||||
let central_entity = db
|
||||
.store_item(central_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store central entity")
|
||||
.unwrap();
|
||||
let related_entity1 = db
|
||||
.store_item(related_entity1.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 1")
|
||||
.unwrap();
|
||||
let related_entity2 = db
|
||||
.store_item(related_entity2.clone())
|
||||
.await
|
||||
.expect("Failed to store related entity 2")
|
||||
.unwrap();
|
||||
let _unrelated_entity = db
|
||||
.store_item(unrelated_entity.clone())
|
||||
.await
|
||||
.expect("Failed to store unrelated entity")
|
||||
.unwrap();
|
||||
|
||||
// Create relationships
|
||||
let source_id = "relationship_source".to_string();
|
||||
|
||||
// Create relationship 1: central -> related1
|
||||
let relationship1 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity1.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
// Create relationship 2: central -> related2
|
||||
let relationship2 = KnowledgeRelationship::new(
|
||||
central_entity.id.clone(),
|
||||
related_entity2.id.clone(),
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
// Store relationships
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 1");
|
||||
relationship2
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship 2");
|
||||
|
||||
// Test finding entities related to the central entity
|
||||
let related_entities = find_entities_by_relationship_by_id(&db, central_entity.id.clone())
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
|
||||
// Check that we found relationships
|
||||
assert!(related_entities.len() > 0, "Should find related entities");
|
||||
}
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
pub mod answer_retrieval;
|
||||
pub mod answer_retrieval_helper;
|
||||
pub mod graph;
|
||||
pub mod vector;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
},
|
||||
};
|
||||
use futures::future::{try_join, try_join_all};
|
||||
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
|
||||
use std::collections::HashMap;
|
||||
use vector::find_items_by_vector_similarity;
|
||||
|
||||
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
|
||||
/// to find the most relevant entities for a given query.
|
||||
///
|
||||
/// # Strategy
|
||||
/// The function employs a three-pronged approach to knowledge retrieval:
|
||||
/// 1. Direct vector similarity search on knowledge entities
|
||||
/// 2. Text chunk similarity search with source entity lookup
|
||||
/// 3. Graph relationship traversal from related entities
|
||||
///
|
||||
/// This combined approach ensures both semantic similarity matches and structurally
|
||||
/// related content are included in the results.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - SurrealDB client for database operations
|
||||
/// * `openai_client` - OpenAI client for vector embeddings generation
|
||||
/// * `query` - The search query string to find relevant knowledge entities
|
||||
/// * 'user_id' - The user id of the current user
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
|
||||
/// knowledge entities, or an error if the retrieval process fails
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
||||
find_items_by_vector_similarity(
|
||||
10,
|
||||
query,
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
openai_client,
|
||||
user_id,
|
||||
),
|
||||
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let source_ids = closest_chunks
|
||||
.iter()
|
||||
.map(|chunk: &TextChunk| chunk.source_id.clone())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
|
||||
|
||||
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
|
||||
.collect();
|
||||
|
||||
let items_from_relationships = try_join_all(items_from_relationships_futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<KnowledgeEntity>>();
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
|
||||
.into_iter()
|
||||
.chain(items_from_text_chunk_similarity.into_iter())
|
||||
.chain(items_from_relationships.into_iter())
|
||||
.fold(HashMap::new(), |mut map, entity| {
|
||||
map.insert(entity.id.clone(), entity);
|
||||
map
|
||||
})
|
||||
.into_values()
|
||||
.collect();
|
||||
|
||||
Ok(entities)
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
|
||||
use common::{error::AppError, utils::embedding::generate_embedding};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
|
||||
/// and then deserializes the results into the specified type `T`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `take` - The number of items to retrieve from the database.
|
||||
/// * `input_text` - The text to generate embeddings for.
|
||||
/// * `db_client` - The SurrealDB client to use for querying the database.
|
||||
/// * `table` - The table to query in the database.
|
||||
/// * `openai_client` - The OpenAI client to use for generating embeddings.
|
||||
/// * 'user_id`- The user id of the current user.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: u8,
|
||||
input_text: &str,
|
||||
db_client: &Surreal<Any>,
|
||||
table: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<T>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE embedding <|{},40|> {:?} AND user_id = '{}' ORDER BY distance", table, take, input_embedding, user_id);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
Ok(closest_entities)
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
[graph]
|
||||
targets = []
|
||||
all-features = false
|
||||
no-default-features = false
|
||||
|
||||
[output]
|
||||
feature-depth = 1
|
||||
|
||||
[advisories]
|
||||
ignore = []
|
||||
|
||||
[licenses]
|
||||
allow = [
|
||||
"GPL-3.0",
|
||||
"AGPL-3.0",
|
||||
"LGPL-2.1",
|
||||
"MIT",
|
||||
"Apache-2.0",
|
||||
"MPL-2.0",
|
||||
"BSD-2-Clause",
|
||||
"BSD-3-Clause",
|
||||
"ISC",
|
||||
"CC0-1.0",
|
||||
"Unlicense",
|
||||
"Zlib",
|
||||
"Unicode-3.0",
|
||||
"CDLA-Permissive-2.0",
|
||||
]
|
||||
confidence-threshold = 0.8
|
||||
exceptions = [
|
||||
{ allow = ["BUSL-1.1"], crate = "surrealdb"},
|
||||
{ allow = ["BUSL-1.1"], crate = "surrealdb-core"}
|
||||
]
|
||||
|
||||
[licenses.private]
|
||||
ignore = true
|
||||
registries = []
|
||||
|
||||
[bans]
|
||||
multiple-versions = "warn"
|
||||
wildcards = "allow"
|
||||
highlight = "all"
|
||||
workspace-default-features = "allow"
|
||||
external-default-features = "allow"
|
||||
allow = []
|
||||
deny = []
|
||||
skip = []
|
||||
skip-tree = []
|
||||
|
||||
[sources]
|
||||
unknown-registry = "warn"
|
||||
unknown-git = "warn"
|
||||
allow-registry = ["https://github.com/rust-lang/crates.io-index"]
|
||||
allow-git = []
|
||||
|
||||
[sources.allow-org]
|
||||
github = []
|
||||
gitlab = []
|
||||
bitbucket = []
|
||||
+79
-9
@@ -3,10 +3,10 @@
|
||||
"devenv": {
|
||||
"locked": {
|
||||
"dir": "src/modules",
|
||||
"lastModified": 1746681099,
|
||||
"lastModified": 1761839147,
|
||||
"owner": "cachix",
|
||||
"repo": "devenv",
|
||||
"rev": "a7f2ea275621391209fd702f5ddced32dd56a4e2",
|
||||
"rev": "bb7849648b68035f6b910120252c22b28195cf54",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -16,13 +16,31 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": "nixpkgs",
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761893049,
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1733328505,
|
||||
"lastModified": 1761588595,
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
||||
"rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -40,10 +58,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1746537231,
|
||||
"lastModified": 1760663237,
|
||||
"owner": "cachix",
|
||||
"repo": "git-hooks.nix",
|
||||
"rev": "fa466640195d38ec97cf0493d6d6882bc4d14969",
|
||||
"rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -74,10 +92,25 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1746576598,
|
||||
"lastModified": 1761672384,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "b3582c75c7f21ce0b429898980eddbbf05c68e55",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nixos",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1761880412,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "a7fc11be66bdfb5cdde611ee5ce381c183da8386",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -90,11 +123,48 @@
|
||||
"root": {
|
||||
"inputs": {
|
||||
"devenv": "devenv",
|
||||
"fenix": "fenix",
|
||||
"git-hooks": "git-hooks",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs": "nixpkgs_2",
|
||||
"pre-commit-hooks": [
|
||||
"git-hooks"
|
||||
],
|
||||
"rust-overlay": "rust-overlay"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761849405,
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "rust-lang",
|
||||
"ref": "nightly",
|
||||
"repo": "rust-analyzer",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"rust-overlay": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761878277,
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "6604534e44090c917db714faa58d47861657690c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
+10
@@ -11,14 +11,24 @@
|
||||
pkgs.openssl
|
||||
pkgs.nodejs
|
||||
pkgs.vscode-langservers-extracted
|
||||
pkgs.cargo-dist
|
||||
pkgs.cargo-xwin
|
||||
pkgs.clang
|
||||
pkgs.onnxruntime
|
||||
];
|
||||
|
||||
languages.rust = {
|
||||
enable = true;
|
||||
components = ["rustc" "clippy" "rustfmt" "cargo" "rust-analyzer"];
|
||||
channel = "nightly";
|
||||
targets = ["x86_64-unknown-linux-gnu" "x86_64-pc-windows-msvc"];
|
||||
mold.enable = true;
|
||||
};
|
||||
|
||||
env = {
|
||||
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
|
||||
};
|
||||
|
||||
processes = {
|
||||
surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --net=host --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password";
|
||||
};
|
||||
|
||||
+7
-11
@@ -1,15 +1,11 @@
|
||||
# yaml-language-server: $schema=https://devenv.sh/devenv.schema.json
|
||||
inputs:
|
||||
fenix:
|
||||
url: github:nix-community/fenix
|
||||
nixpkgs:
|
||||
url: github:nixos/nixpkgs/nixpkgs-unstable
|
||||
|
||||
# If you're using non-OSS software, you can set allowUnfree to true.
|
||||
rust-overlay:
|
||||
url: github:oxalica/rust-overlay
|
||||
inputs:
|
||||
nixpkgs:
|
||||
follows: nixpkgs
|
||||
allowUnfree: true
|
||||
|
||||
# If you're willing to use a package that's vulnerable
|
||||
# permittedInsecurePackages:
|
||||
# - "openssl-1.1.1w"
|
||||
|
||||
# If you have more than one devenv you can merge them
|
||||
#imports:
|
||||
# - ./backend
|
||||
|
||||
+4
-1
@@ -4,9 +4,11 @@ members = ["cargo:."]
|
||||
# Config for 'dist'
|
||||
[dist]
|
||||
# The preferred dist version to use in CI (Cargo.toml SemVer syntax)
|
||||
cargo-dist-version = "0.28.0"
|
||||
cargo-dist-version = "0.30.0"
|
||||
# CI backends to support
|
||||
ci = "github"
|
||||
# Extra static files to include in each App (path relative to this Cargo.toml's dir)
|
||||
include = ["lib"]
|
||||
# The installers to generate for each app
|
||||
installers = []
|
||||
# Target platforms to build apps for (Rust target-triple syntax)
|
||||
@@ -17,3 +19,4 @@ allow-dirty = ["ci"]
|
||||
[dist.github-custom-runners]
|
||||
x86_64-unknown-linux-gnu = "ubuntu-22.04"
|
||||
x86_64-unknown-linux-musl = "ubuntu-22.04"
|
||||
x86_64-pc-windows-msvc = "windows-latest"
|
||||
|
||||
+6
-5
@@ -1,5 +1,3 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
minne:
|
||||
build: .
|
||||
@@ -12,8 +10,11 @@ services:
|
||||
SURREALDB_PASSWORD: "root_password"
|
||||
SURREALDB_DATABASE: "test"
|
||||
SURREALDB_NAMESPACE: "test"
|
||||
OPENAI_API_KEY: "sk-key"
|
||||
# RUST_LOG: "info"
|
||||
OPENAI_API_KEY: "sk-add-your-key"
|
||||
DATA_DIR: "./data"
|
||||
HTTP_PORT: 3000
|
||||
RUST_LOG: "info"
|
||||
RERANKING_ENABLED: false ## Change to true to enable reranking
|
||||
depends_on:
|
||||
- surrealdb
|
||||
networks:
|
||||
@@ -29,7 +30,7 @@ services:
|
||||
- ./database:/database # Mounts a 'database' folder from your project directory
|
||||
command: >
|
||||
start
|
||||
--log debug
|
||||
--log info
|
||||
--user root_user
|
||||
--pass root_password
|
||||
rocksdb:./database/database.db
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
# Architecture
|
||||
|
||||
## Tech Stack
|
||||
|
||||
| Layer | Technology |
|
||||
|-------|------------|
|
||||
| Backend | Rust with Axum (SSR) |
|
||||
| Frontend | HTML + HTMX + minimal JS |
|
||||
| Database | SurrealDB (graph, document, vector) |
|
||||
| AI | OpenAI-compatible API |
|
||||
| Web Processing | Headless Chromium |
|
||||
|
||||
## Crate Structure
|
||||
|
||||
```
|
||||
minne/
|
||||
├── main/ # Combined server + worker binary
|
||||
├── api-router/ # REST API routes
|
||||
├── html-router/ # SSR web interface
|
||||
├── ingestion-pipeline/ # Content processing pipeline
|
||||
├── retrieval-pipeline/ # Search and retrieval logic
|
||||
├── common/ # Shared types, storage, utilities
|
||||
├── evaluations/ # Benchmarking framework
|
||||
└── json-stream-parser/ # Streaming JSON utilities
|
||||
```
|
||||
|
||||
## Process Modes
|
||||
|
||||
| Binary | Purpose |
|
||||
|--------|---------|
|
||||
| `main` | All-in-one: serves UI and processes content |
|
||||
| `server` | UI and API only (no background processing) |
|
||||
| `worker` | Background processing only (no UI) |
|
||||
|
||||
Split deployment is useful for scaling or resource isolation.
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
Content In → Ingestion Pipeline → SurrealDB
|
||||
↓
|
||||
Entity Extraction
|
||||
↓
|
||||
Embedding Generation
|
||||
↓
|
||||
Graph Relationships
|
||||
|
||||
Query → Retrieval Pipeline → Results
|
||||
↓
|
||||
Vector Search + FTS
|
||||
↓
|
||||
RRF Fusion → (Optional Rerank) → Response
|
||||
```
|
||||
|
||||
## Database Schema
|
||||
|
||||
SurrealDB stores:
|
||||
|
||||
- **TextContent** — Raw ingested content
|
||||
- **TextChunk** — Chunked content with embeddings
|
||||
- **KnowledgeEntity** — Extracted entities (people, concepts, etc.)
|
||||
- **KnowledgeRelationship** — Connections between entities
|
||||
- **User** — Authentication and preferences
|
||||
- **SystemSettings** — Model configuration
|
||||
|
||||
Embeddings are stored in dedicated tables with HNSW indexes for fast vector search.
|
||||
|
||||
## Retrieval Strategy
|
||||
|
||||
1. **Collect candidates** — Vector similarity + full-text search
|
||||
2. **Merge ranks** — Reciprocal Rank Fusion (RRF)
|
||||
3. **Attach context** — Link chunks to parent entities
|
||||
4. **Rerank** (optional) — Cross-encoder reranking
|
||||
5. **Return** — Top-k results with metadata
|
||||
@@ -0,0 +1,89 @@
|
||||
# Configuration
|
||||
|
||||
Minne can be configured via environment variables or a `config.yaml` file. Environment variables take precedence.
|
||||
|
||||
## Required Settings
|
||||
|
||||
| Variable | Description | Example |
|
||||
|----------|-------------|---------|
|
||||
| `OPENAI_API_KEY` | API key for OpenAI-compatible endpoint | `sk-...` |
|
||||
| `SURREALDB_ADDRESS` | WebSocket address of SurrealDB | `ws://127.0.0.1:8000` |
|
||||
| `SURREALDB_USERNAME` | SurrealDB username | `root_user` |
|
||||
| `SURREALDB_PASSWORD` | SurrealDB password | `root_password` |
|
||||
| `SURREALDB_DATABASE` | Database name | `minne_db` |
|
||||
| `SURREALDB_NAMESPACE` | Namespace | `minne_ns` |
|
||||
|
||||
|
||||
## Optional Settings
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `HTTP_PORT` | Server port | `3000` |
|
||||
| `DATA_DIR` | Local data directory | `./data` |
|
||||
| `OPENAI_BASE_URL` | Custom AI provider URL | OpenAI default |
|
||||
| `RUST_LOG` | Logging level | `info` |
|
||||
| `STORAGE` | Storage backend (`local`, `memory`) | `local` |
|
||||
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
|
||||
| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - |
|
||||
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
|
||||
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
|
||||
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
|
||||
| `FASTEMBED_MAX_LENGTH` | Max sequence length for FastEmbed models | - |
|
||||
|
||||
### Reranking (Optional)
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `RERANKING_ENABLED` | Enable FastEmbed reranking | `false` |
|
||||
| `RERANKING_POOL_SIZE` | Concurrent reranker workers | - |
|
||||
|
||||
> [!NOTE]
|
||||
> Enabling reranking downloads ~1.1 GB of model data on first startup.
|
||||
|
||||
## Example config.yaml
|
||||
|
||||
```yaml
|
||||
surrealdb_address: "ws://127.0.0.1:8000"
|
||||
surrealdb_username: "root_user"
|
||||
surrealdb_password: "root_password"
|
||||
surrealdb_database: "minne_db"
|
||||
surrealdb_namespace: "minne_ns"
|
||||
openai_api_key: "sk-your-key-here"
|
||||
data_dir: "./minne_data"
|
||||
http_port: 3000
|
||||
|
||||
# New settings
|
||||
storage: "local"
|
||||
pdf_ingest_mode: "llm-first"
|
||||
embedding_backend: "fastembed"
|
||||
|
||||
# Optional reranking
|
||||
reranking_enabled: true
|
||||
reranking_pool_size: 2
|
||||
```
|
||||
|
||||
## AI Provider Setup
|
||||
|
||||
Minne works with any OpenAI-compatible API that supports structured outputs.
|
||||
|
||||
### OpenAI (Default)
|
||||
|
||||
Set `OPENAI_API_KEY` only. The default base URL points to OpenAI.
|
||||
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
OPENAI_API_KEY="ollama"
|
||||
OPENAI_BASE_URL="http://localhost:11434/v1"
|
||||
```
|
||||
|
||||
### Other Providers
|
||||
|
||||
Any provider exposing an OpenAI-compatible endpoint works. Set `OPENAI_BASE_URL` accordingly.
|
||||
|
||||
## Model Selection
|
||||
|
||||
1. Access `/admin` in your Minne instance
|
||||
2. Select models for content processing and chat
|
||||
3. **Content Processing**: Must support structured outputs
|
||||
4. **Embedding Dimensions**: Update when changing embedding models (e.g., 1536 for `text-embedding-3-small`)
|
||||
@@ -0,0 +1,64 @@
|
||||
# Features
|
||||
|
||||
## Search vs Chat
|
||||
|
||||
**Search** — Use when you know what you're looking for. Full-text search matches query terms across your content.
|
||||
|
||||
**Chat** — Use when exploring concepts or reasoning about your knowledge. The AI analyzes your query and retrieves relevant context from your entire knowledge base.
|
||||
|
||||
## Content Processing
|
||||
|
||||
Minne automatically processes saved content:
|
||||
|
||||
1. **Web scraping** extracts readable text from URLs (via headless Chrome)
|
||||
2. **Text analysis** identifies key concepts and relationships
|
||||
3. **Graph creation** builds connections between related content
|
||||
4. **Embedding generation** enables semantic search
|
||||
|
||||
## Knowledge Graph
|
||||
|
||||
Explore your knowledge as an interactive network:
|
||||
|
||||
- **Manual curation** — Create entities and relationships yourself
|
||||
- **AI automation** — Let AI extract entities and discover relationships
|
||||
- **Hybrid approach** — AI suggests connections for your approval
|
||||
|
||||
The D3-based graph visualization shows entities as nodes and relationships as edges.
|
||||
|
||||
## Hybrid Retrieval
|
||||
|
||||
Minne combines multiple retrieval strategies:
|
||||
|
||||
- **Vector similarity** — Semantic matching via embeddings
|
||||
- **Full-text search** — Keyword matching with BM25
|
||||
- **Graph traversal** — Following relationships between entities
|
||||
|
||||
Results are merged using Reciprocal Rank Fusion (RRF) for optimal relevance.
|
||||
|
||||
## Reranking (Optional)
|
||||
|
||||
When enabled, retrieval results are rescored with a cross-encoder model for improved relevance. Powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs).
|
||||
|
||||
**Trade-offs:**
|
||||
- Downloads ~1.1 GB of model data
|
||||
- Adds latency per query
|
||||
- Potentially improves answer quality, see [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/)
|
||||
|
||||
Enable via `RERANKING_ENABLED=true`. See [Configuration](./configuration.md).
|
||||
|
||||
## Multi-Format Ingestion
|
||||
|
||||
Supported content types:
|
||||
- Plain text and notes
|
||||
- URLs (web pages)
|
||||
- PDF documents
|
||||
- Audio files
|
||||
- Images
|
||||
|
||||
## Scratchpad
|
||||
|
||||
Quickly capture content without committing to permanent storage. Convert to full content when ready.
|
||||
|
||||
## iOS Shortcut
|
||||
|
||||
Use the [Minne iOS Shortcut](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) for quick content capture from your phone.
|
||||
@@ -0,0 +1,67 @@
|
||||
# Installation
|
||||
|
||||
Minne can be installed through several methods. Choose the one that best fits your setup.
|
||||
|
||||
## Docker Compose (Recommended)
|
||||
|
||||
The fastest way to get Minne running with all dependencies:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/perstarkse/minne.git
|
||||
cd minne
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
The included `docker-compose.yml` handles SurrealDB and Chromium automatically.
|
||||
|
||||
**Required:** Set your `OPENAI_API_KEY` in `docker-compose.yml` before starting.
|
||||
|
||||
## Nix
|
||||
|
||||
Run Minne directly with Nix (includes Chromium):
|
||||
|
||||
```bash
|
||||
nix run 'github:perstarkse/minne#main'
|
||||
```
|
||||
|
||||
Configure via environment variables or a `config.yaml` file. See [Configuration](./configuration.md).
|
||||
|
||||
## Pre-built Binaries
|
||||
|
||||
Download binaries for Windows, macOS, and Linux from [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||
|
||||
**Requirements:**
|
||||
- SurrealDB instance (local or remote)
|
||||
- Chromium (for web scraping)
|
||||
|
||||
## Build from Source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/perstarkse/minne.git
|
||||
cd minne
|
||||
cargo build --release --bin main
|
||||
```
|
||||
|
||||
The binary will be at `target/release/main`.
|
||||
|
||||
**Requirements:**
|
||||
- Rust toolchain
|
||||
- SurrealDB accessible at configured address
|
||||
- Chromium in PATH
|
||||
|
||||
## Process Modes
|
||||
|
||||
Minne offers flexible deployment:
|
||||
|
||||
| Binary | Description |
|
||||
|--------|-------------|
|
||||
| `main` | Combined server + worker (recommended) |
|
||||
| `server` | Web interface and API only |
|
||||
| `worker` | Background processing only |
|
||||
|
||||
For most users, `main` is the right choice. Split deployments are useful for resource optimization or scaling.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [Configuration](./configuration.md) — Environment variables and config.yaml
|
||||
- [Features](./features.md) — What Minne can do
|
||||
@@ -0,0 +1,48 @@
|
||||
# Vision
|
||||
|
||||
## The "Why" Behind Minne
|
||||
|
||||
Personal knowledge management has always fascinated me. I wanted something that made it incredibly easy to capture content—snippets of text, URLs, media—while automatically discovering connections between ideas. But I also wanted control over my knowledge structure.
|
||||
|
||||
Traditional tools like Logseq and Obsidian are excellent, but manual linking often becomes a hindrance. Fully automated systems sometimes miss important context or create relationships I wouldn't have chosen.
|
||||
|
||||
Minne offers the best of both worlds: effortless capture with AI-assisted relationship discovery, but with flexibility to manually curate, edit, or override connections. Let AI handle the heavy lifting, take full control yourself, or use a hybrid approach where AI suggests and you approve.
|
||||
|
||||
## Design Principles
|
||||
|
||||
- **Capture should be instant** — No friction between thought and storage
|
||||
- **Connections should emerge** — AI finds relationships you might miss
|
||||
- **Control should be optional** — Automate by default, curate when it matters
|
||||
- **Privacy should be default** — Self-hosted, your data stays yours
|
||||
|
||||
## Roadmap
|
||||
|
||||
### Near-term
|
||||
|
||||
- [ ] TUI frontend with system editor integration
|
||||
- [ ] Enhanced retrieval recall via improved reranking
|
||||
- [ ] Additional content type support (e-books, research papers)
|
||||
|
||||
### Medium-term
|
||||
|
||||
- [ ] Embedded SurrealDB option (zero-config `nix run` with just `OPENAI_API_KEY`)
|
||||
- [ ] Browser extension for seamless capture
|
||||
- [ ] Mobile-native apps
|
||||
|
||||
### Long-term
|
||||
|
||||
- [ ] Federated knowledge sharing (opt-in)
|
||||
- [ ] Local LLM integration (fully offline operation)
|
||||
- [ ] Plugin system for custom entity extractors
|
||||
|
||||
## Related Projects
|
||||
|
||||
If Minne isn't quite right for you, check out:
|
||||
|
||||
- [Karakeep](https://github.com/karakeep-app/karakeep) (formerly Hoarder) — Excellent bookmark/read-later with AI tagging
|
||||
- [Logseq](https://logseq.com/) — Outliner-based PKM with manual linking
|
||||
- [Obsidian](https://obsidian.md/) — Markdown-based PKM with plugin ecosystem
|
||||
|
||||
## Contributing
|
||||
|
||||
Feature requests and contributions are welcome. Minne was built for personal use first, but the self-hosted community benefits when we share.
|
||||
@@ -0,0 +1,35 @@
|
||||
[package]
|
||||
name = "evaluations"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
common = { path = "../common" }
|
||||
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||
ingestion-pipeline = { path = "../ingestion-pipeline" }
|
||||
futures = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
text-splitter = { workspace = true }
|
||||
unicode-normalization = { workspace = true }
|
||||
rand = "0.8"
|
||||
sha2 = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
surrealdb = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
once_cell = "1.19"
|
||||
serde_yaml = "0.9"
|
||||
criterion = "0.5"
|
||||
state-machines = { workspace = true }
|
||||
clap = { version = "4.4", features = ["derive", "env"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
@@ -0,0 +1,212 @@
|
||||
# Evaluations
|
||||
|
||||
The `evaluations` crate provides a retrieval evaluation framework for benchmarking Minne's information retrieval pipeline against standard datasets.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Run SQuAD v2.0 evaluation (vector-only, recommended)
|
||||
cargo run --package evaluations -- --ingest-chunks-only
|
||||
|
||||
# Run a specific dataset
|
||||
cargo run --package evaluations -- --dataset fiqa --ingest-chunks-only
|
||||
|
||||
# Convert dataset only (no evaluation)
|
||||
cargo run --package evaluations -- --convert-only
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. SurrealDB
|
||||
|
||||
Start a SurrealDB instance before running evaluations:
|
||||
|
||||
```bash
|
||||
docker-compose up -d surrealdb
|
||||
```
|
||||
|
||||
Or using the default endpoint configuration:
|
||||
|
||||
```bash
|
||||
surreal start --user root_user --pass root_password
|
||||
```
|
||||
|
||||
### 2. Download Raw Datasets
|
||||
|
||||
Raw datasets must be downloaded manually and placed in `evaluations/data/raw/`. See [Dataset Sources](#dataset-sources) below for links and formats.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
evaluations/
|
||||
├── data/
|
||||
│ ├── raw/ # Downloaded raw datasets (manual)
|
||||
│ │ ├── squad/ # SQuAD v2.0
|
||||
│ │ ├── nq-dev/ # Natural Questions
|
||||
│ │ ├── fiqa/ # BEIR: FiQA-2018
|
||||
│ │ ├── fever/ # BEIR: FEVER
|
||||
│ │ ├── hotpotqa/ # BEIR: HotpotQA
|
||||
│ │ └── ... # Other BEIR subsets
|
||||
│ └── converted/ # Auto-generated (Minne JSON format)
|
||||
├── cache/ # Ingestion and embedding caches
|
||||
├── reports/ # Evaluation output (JSON + Markdown)
|
||||
├── manifest.yaml # Dataset and slice definitions
|
||||
└── src/ # Evaluation source code
|
||||
```
|
||||
|
||||
## Dataset Sources
|
||||
|
||||
### SQuAD v2.0
|
||||
|
||||
Download and place at `data/raw/squad/dev-v2.0.json`:
|
||||
|
||||
```bash
|
||||
mkdir -p evaluations/data/raw/squad
|
||||
curl -L https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json \
|
||||
-o evaluations/data/raw/squad/dev-v2.0.json
|
||||
```
|
||||
|
||||
### Natural Questions (NQ)
|
||||
|
||||
Download and place at `data/raw/nq-dev/dev-all.jsonl`:
|
||||
|
||||
```bash
|
||||
mkdir -p evaluations/data/raw/nq-dev
|
||||
# Download from Google's Natural Questions page or HuggingFace
|
||||
# File: dev-all.jsonl (simplified JSONL format)
|
||||
```
|
||||
|
||||
Source: [Google Natural Questions](https://ai.google.com/research/NaturalQuestions)
|
||||
|
||||
### BEIR Datasets
|
||||
|
||||
All BEIR datasets follow the same format structure:
|
||||
|
||||
```
|
||||
data/raw/<dataset>/
|
||||
├── corpus.jsonl # Document corpus
|
||||
├── queries.jsonl # Query set
|
||||
└── qrels/
|
||||
└── test.tsv # Relevance judgments (or dev.tsv)
|
||||
```
|
||||
|
||||
Download datasets from the [BEIR Benchmark repository](https://github.com/beir-cellar/beir). Each dataset zip extracts to the required directory structure.
|
||||
|
||||
| Dataset | Directory |
|
||||
|------------|---------------|
|
||||
| FEVER | `fever/` |
|
||||
| FiQA-2018 | `fiqa/` |
|
||||
| HotpotQA | `hotpotqa/` |
|
||||
| NFCorpus | `nfcorpus/` |
|
||||
| Quora | `quora/` |
|
||||
| TREC-COVID | `trec-covid/` |
|
||||
| SciFact | `scifact/` |
|
||||
| NQ (BEIR) | `nq/` |
|
||||
|
||||
Example download:
|
||||
|
||||
```bash
|
||||
cd evaluations/data/raw
|
||||
curl -L https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip -o fiqa.zip
|
||||
unzip fiqa.zip && rm fiqa.zip
|
||||
```
|
||||
|
||||
## Dataset Conversion
|
||||
|
||||
Raw datasets are automatically converted to Minne's internal JSON format on first run. To force reconversion:
|
||||
|
||||
```bash
|
||||
cargo run --package evaluations -- --force-convert
|
||||
```
|
||||
|
||||
Converted files are saved to `data/converted/` and cached for subsequent runs.
|
||||
|
||||
## CLI Reference
|
||||
|
||||
### Common Options
|
||||
|
||||
| Flag | Description | Default |
|
||||
|------|-------------|---------|
|
||||
| `--dataset <NAME>` | Dataset to evaluate | `squad-v2` |
|
||||
| `--limit <N>` | Max questions to evaluate (0 = all) | `200` |
|
||||
| `--k <N>` | Precision@k cutoff | `5` |
|
||||
| `--slice <ID>` | Use a predefined slice from manifest | — |
|
||||
| `--rerank` | Enable FastEmbed reranking stage | disabled |
|
||||
| `--embedding-backend <BE>` | `fastembed` or `hashed` | `fastembed` |
|
||||
| `--ingest-chunks-only` | Skip entity extraction, ingest only text chunks | disabled |
|
||||
|
||||
> [!TIP]
|
||||
> Use `--ingest-chunks-only` when evaluating vector-only retrieval strategies. This skips the LLM-based entity extraction and graph generation, significantly speeding up ingestion while focusing on pure chunk-based vector search.
|
||||
|
||||
### Available Datasets
|
||||
|
||||
```
|
||||
squad-v2, natural-questions, beir, fever, fiqa, hotpotqa,
|
||||
nfcorpus, quora, trec-covid, scifact, nq-beir
|
||||
```
|
||||
|
||||
### Database Configuration
|
||||
|
||||
| Flag | Environment | Default |
|
||||
|------|-------------|---------|
|
||||
| `--db-endpoint` | `EVAL_DB_ENDPOINT` | `ws://127.0.0.1:8000` |
|
||||
| `--db-username` | `EVAL_DB_USERNAME` | `root_user` |
|
||||
| `--db-password` | `EVAL_DB_PASSWORD` | `root_password` |
|
||||
| `--db-namespace` | `EVAL_DB_NAMESPACE` | auto-generated |
|
||||
| `--db-database` | `EVAL_DB_DATABASE` | auto-generated |
|
||||
|
||||
### Example Runs
|
||||
|
||||
```bash
|
||||
# Vector-only evaluation (recommended for benchmarking)
|
||||
cargo run --package evaluations -- \
|
||||
--dataset fiqa \
|
||||
--ingest-chunks-only \
|
||||
--limit 200
|
||||
|
||||
# Full FiQA evaluation with reranking
|
||||
cargo run --package evaluations -- \
|
||||
--dataset fiqa \
|
||||
--ingest-chunks-only \
|
||||
--limit 500 \
|
||||
--rerank \
|
||||
--k 10
|
||||
|
||||
# Use a predefined slice for reproducibility
|
||||
cargo run --package evaluations -- --slice fiqa-test-200 --ingest-chunks-only
|
||||
|
||||
# Run the mixed BEIR benchmark
|
||||
cargo run --package evaluations -- --dataset beir --slice beir-mix-600 --ingest-chunks-only
|
||||
```
|
||||
|
||||
## Slices
|
||||
|
||||
Slices are predefined, reproducible subsets defined in `manifest.yaml`. Each slice specifies:
|
||||
|
||||
- **limit**: Number of questions
|
||||
- **corpus_limit**: Maximum corpus size
|
||||
- **seed**: Fixed RNG seed for reproducibility
|
||||
|
||||
View available slices in [manifest.yaml](./manifest.yaml).
|
||||
|
||||
## Reports
|
||||
|
||||
Evaluations generate reports in `reports/`:
|
||||
|
||||
- **JSON**: Full structured results (`*-report.json`)
|
||||
- **Markdown**: Human-readable summary with sample mismatches (`*-report.md`)
|
||||
- **History**: Timestamped run history (`history/`)
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
```bash
|
||||
# Log per-stage performance timings
|
||||
cargo run --package evaluations -- --perf-log-console
|
||||
|
||||
# Save telemetry to file
|
||||
cargo run --package evaluations -- --perf-log-json ./perf.json
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
See [../LICENSE](../LICENSE).
|
||||
@@ -0,0 +1,168 @@
|
||||
default_dataset: squad-v2
|
||||
datasets:
|
||||
- id: squad-v2
|
||||
label: "SQuAD v2.0"
|
||||
category: "SQuAD v2.0"
|
||||
entity_suffix: "SQuAD"
|
||||
source_prefix: "squad"
|
||||
raw: "data/raw/squad/dev-v2.0.json"
|
||||
converted: "data/converted/squad-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: squad-dev-200
|
||||
label: "SQuAD dev (200)"
|
||||
description: "Deterministic 200-case slice for local eval"
|
||||
limit: 200
|
||||
corpus_limit: 2000
|
||||
seed: 0x5eed2025
|
||||
- id: natural-questions-dev
|
||||
label: "Natural Questions (dev)"
|
||||
category: "Natural Questions"
|
||||
entity_suffix: "Natural Questions"
|
||||
source_prefix: "nq"
|
||||
raw: "data/raw/nq-dev/dev-all.jsonl"
|
||||
converted: "data/converted/nq-dev-minne.json"
|
||||
include_unanswerable: true
|
||||
slices:
|
||||
- id: nq-dev-200
|
||||
label: "NQ dev (200)"
|
||||
description: "200-case slice of the dev set"
|
||||
limit: 200
|
||||
corpus_limit: 2000
|
||||
include_unanswerable: false
|
||||
seed: 0x5eed2025
|
||||
- id: beir
|
||||
label: "BEIR mix"
|
||||
category: "BEIR"
|
||||
entity_suffix: "BEIR"
|
||||
source_prefix: "beir"
|
||||
raw: "data/raw/beir"
|
||||
converted: "data/converted/beir-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: beir-mix-600
|
||||
label: "BEIR mix (600)"
|
||||
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR"
|
||||
limit: 600
|
||||
corpus_limit: 6000
|
||||
seed: 0x5eed2025
|
||||
- id: fever
|
||||
label: "FEVER (BEIR)"
|
||||
category: "FEVER"
|
||||
entity_suffix: "FEVER"
|
||||
source_prefix: "fever"
|
||||
raw: "data/raw/fever"
|
||||
converted: "data/converted/fever-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: fever-test-200
|
||||
label: "FEVER test (200)"
|
||||
description: "200-case slice from BEIR test qrels"
|
||||
limit: 200
|
||||
corpus_limit: 5000
|
||||
seed: 0x5eed2025
|
||||
- id: fiqa
|
||||
label: "FiQA-2018 (BEIR)"
|
||||
category: "FiQA-2018"
|
||||
entity_suffix: "FiQA"
|
||||
source_prefix: "fiqa"
|
||||
raw: "data/raw/fiqa"
|
||||
converted: "data/converted/fiqa-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: fiqa-test-200
|
||||
label: "FiQA test (200)"
|
||||
description: "200-case slice from BEIR test qrels"
|
||||
limit: 200
|
||||
corpus_limit: 5000
|
||||
seed: 0x5eed2025
|
||||
- id: hotpotqa
|
||||
label: "HotpotQA (BEIR)"
|
||||
category: "HotpotQA"
|
||||
entity_suffix: "HotpotQA"
|
||||
source_prefix: "hotpotqa"
|
||||
raw: "data/raw/hotpotqa"
|
||||
converted: "data/converted/hotpotqa-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: hotpotqa-test-200
|
||||
label: "HotpotQA test (200)"
|
||||
description: "200-case slice from BEIR test qrels"
|
||||
limit: 200
|
||||
corpus_limit: 5000
|
||||
seed: 0x5eed2025
|
||||
- id: nfcorpus
|
||||
label: "NFCorpus (BEIR)"
|
||||
category: "NFCorpus"
|
||||
entity_suffix: "NFCorpus"
|
||||
source_prefix: "nfcorpus"
|
||||
raw: "data/raw/nfcorpus"
|
||||
converted: "data/converted/nfcorpus-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: nfcorpus-test-200
|
||||
label: "NFCorpus test (200)"
|
||||
description: "200-case slice from BEIR test qrels"
|
||||
limit: 200
|
||||
corpus_limit: 5000
|
||||
seed: 0x5eed2025
|
||||
- id: quora
|
||||
label: "Quora (IR)"
|
||||
category: "Quora"
|
||||
entity_suffix: "Quora"
|
||||
source_prefix: "quora"
|
||||
raw: "data/raw/quora"
|
||||
converted: "data/converted/quora-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: quora-test-200
|
||||
label: "Quora test (200)"
|
||||
description: "200-case slice from BEIR test qrels"
|
||||
limit: 200
|
||||
corpus_limit: 5000
|
||||
seed: 0x5eed2025
|
||||
- id: trec-covid
|
||||
label: "TREC-COVID (BEIR)"
|
||||
category: "TREC-COVID"
|
||||
entity_suffix: "TREC-COVID"
|
||||
source_prefix: "trec-covid"
|
||||
raw: "data/raw/trec-covid"
|
||||
converted: "data/converted/trec-covid-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: trec-covid-test-200
|
||||
label: "TREC-COVID test (200)"
|
||||
description: "200-case slice from BEIR test qrels"
|
||||
limit: 200
|
||||
corpus_limit: 5000
|
||||
seed: 0x5eed2025
|
||||
- id: scifact
|
||||
label: "SciFact (BEIR)"
|
||||
category: "SciFact"
|
||||
entity_suffix: "SciFact"
|
||||
source_prefix: "scifact"
|
||||
raw: "data/raw/scifact"
|
||||
converted: "data/converted/scifact-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: scifact-test-200
|
||||
label: "SciFact test (200)"
|
||||
description: "200-case slice from BEIR test qrels"
|
||||
limit: 200
|
||||
corpus_limit: 3000
|
||||
seed: 0x5eed2025
|
||||
- id: nq-beir
|
||||
label: "Natural Questions (BEIR)"
|
||||
category: "Natural Questions"
|
||||
entity_suffix: "Natural Questions"
|
||||
source_prefix: "nq-beir"
|
||||
raw: "data/raw/nq"
|
||||
converted: "data/converted/nq-beir-minne.json"
|
||||
include_unanswerable: false
|
||||
slices:
|
||||
- id: nq-beir-test-200
|
||||
label: "NQ (BEIR) test (200)"
|
||||
description: "200-case slice from BEIR test qrels"
|
||||
limit: 200
|
||||
corpus_limit: 5000
|
||||
seed: 0x5eed2025
|
||||
@@ -0,0 +1,506 @@
|
||||
use std::{
|
||||
env,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use clap::{Args, Parser, ValueEnum};
|
||||
use retrieval_pipeline::RetrievalStrategy;
|
||||
|
||||
use crate::datasets::DatasetKind;
|
||||
|
||||
fn workspace_root() -> PathBuf {
|
||||
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
manifest_dir.parent().unwrap_or(&manifest_dir).to_path_buf()
|
||||
}
|
||||
|
||||
fn default_report_dir() -> PathBuf {
|
||||
workspace_root().join("evaluations/reports")
|
||||
}
|
||||
|
||||
fn default_cache_dir() -> PathBuf {
|
||||
workspace_root().join("evaluations/cache")
|
||||
}
|
||||
|
||||
fn default_ingestion_cache_dir() -> PathBuf {
|
||||
workspace_root().join("evaluations/cache/ingested")
|
||||
}
|
||||
|
||||
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
#[value(rename_all = "lowercase")]
|
||||
pub enum EmbeddingBackend {
|
||||
Hashed,
|
||||
#[default]
|
||||
FastEmbed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EmbeddingBackend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Hashed => write!(f, "hashed"),
|
||||
Self::FastEmbed => write!(f, "fastembed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Args)]
|
||||
pub struct RetrievalSettings {
|
||||
/// Override chunk vector candidate cap
|
||||
#[arg(long)]
|
||||
pub chunk_vector_take: Option<usize>,
|
||||
|
||||
/// Override chunk FTS candidate cap
|
||||
#[arg(long)]
|
||||
pub chunk_fts_take: Option<usize>,
|
||||
|
||||
/// Override average characters per token used for budgeting
|
||||
#[arg(long)]
|
||||
pub chunk_avg_chars_per_token: Option<usize>,
|
||||
|
||||
/// Override maximum chunks attached per entity
|
||||
#[arg(long)]
|
||||
pub max_chunks_per_entity: Option<usize>,
|
||||
|
||||
/// Enable the FastEmbed reranking stage
|
||||
#[arg(long = "rerank", action = clap::ArgAction::SetTrue, default_value_t = false)]
|
||||
pub rerank: bool,
|
||||
|
||||
/// Reranking engine pool size / parallelism
|
||||
#[arg(long, default_value_t = 4)]
|
||||
pub rerank_pool_size: usize,
|
||||
|
||||
/// Keep top-N entities after reranking
|
||||
#[arg(long, default_value_t = 10)]
|
||||
pub rerank_keep_top: usize,
|
||||
|
||||
/// Cap the number of chunks returned by retrieval (revised strategy)
|
||||
#[arg(long, default_value_t = 5)]
|
||||
pub chunk_result_cap: usize,
|
||||
|
||||
/// Reciprocal rank fusion k value for revised chunk merging
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_k: Option<f32>,
|
||||
|
||||
/// Weight for vector ranks in revised RRF
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_vector_weight: Option<f32>,
|
||||
|
||||
/// Weight for chunk FTS ranks in revised RRF
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_fts_weight: Option<f32>,
|
||||
|
||||
/// Include vector ranks in revised RRF (default: true)
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_use_vector: Option<bool>,
|
||||
|
||||
/// Include chunk FTS ranks in revised RRF (default: true)
|
||||
#[arg(long)]
|
||||
pub chunk_rrf_use_fts: Option<bool>,
|
||||
|
||||
/// Require verified chunks (disable with --llm-mode)
|
||||
#[arg(skip = true)]
|
||||
pub require_verified_chunks: bool,
|
||||
|
||||
/// Select the retrieval pipeline strategy
|
||||
#[arg(long, default_value_t = RetrievalStrategy::Default)]
|
||||
pub strategy: RetrievalStrategy,
|
||||
}
|
||||
|
||||
impl Default for RetrievalSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
chunk_vector_take: None,
|
||||
chunk_fts_take: None,
|
||||
chunk_avg_chars_per_token: None,
|
||||
max_chunks_per_entity: None,
|
||||
rerank: false,
|
||||
rerank_pool_size: 4,
|
||||
rerank_keep_top: 10,
|
||||
chunk_result_cap: 5,
|
||||
chunk_rrf_k: None,
|
||||
chunk_rrf_vector_weight: None,
|
||||
chunk_rrf_fts_weight: None,
|
||||
chunk_rrf_use_vector: None,
|
||||
chunk_rrf_use_fts: None,
|
||||
require_verified_chunks: true,
|
||||
strategy: RetrievalStrategy::Default,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Args)]
|
||||
pub struct IngestConfig {
|
||||
/// Directory for ingestion corpora caches
|
||||
#[arg(long, default_value_os_t = default_ingestion_cache_dir())]
|
||||
pub ingestion_cache_dir: PathBuf,
|
||||
|
||||
/// Minimum tokens per chunk for ingestion
|
||||
#[arg(long, default_value_t = 256)]
|
||||
pub ingest_chunk_min_tokens: usize,
|
||||
|
||||
/// Maximum tokens per chunk for ingestion
|
||||
#[arg(long, default_value_t = 512)]
|
||||
pub ingest_chunk_max_tokens: usize,
|
||||
|
||||
/// Overlap between chunks during ingestion (tokens)
|
||||
#[arg(long, default_value_t = 50)]
|
||||
pub ingest_chunk_overlap_tokens: usize,
|
||||
|
||||
/// Run ingestion in chunk-only mode (skip analyzer/graph generation)
|
||||
#[arg(long)]
|
||||
pub ingest_chunks_only: bool,
|
||||
|
||||
/// Number of paragraphs to ingest concurrently
|
||||
#[arg(long, default_value_t = 10)]
|
||||
pub ingestion_batch_size: usize,
|
||||
|
||||
/// Maximum retries for ingestion failures per paragraph
|
||||
#[arg(long, default_value_t = 3)]
|
||||
pub ingestion_max_retries: usize,
|
||||
|
||||
/// Recompute embeddings for cached corpora without re-running ingestion
|
||||
#[arg(long, alias = "refresh-embeddings")]
|
||||
pub refresh_embeddings_only: bool,
|
||||
|
||||
/// Delete cached paragraph shards before rebuilding the ingestion corpus
|
||||
#[arg(long)]
|
||||
pub slice_reset_ingestion: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Args)]
|
||||
pub struct DatabaseArgs {
|
||||
/// SurrealDB server endpoint
|
||||
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
|
||||
pub db_endpoint: String,
|
||||
|
||||
/// SurrealDB root username
|
||||
#[arg(long, default_value = "root_user", env = "EVAL_DB_USERNAME")]
|
||||
pub db_username: String,
|
||||
|
||||
/// SurrealDB root password
|
||||
#[arg(long, default_value = "root_password", env = "EVAL_DB_PASSWORD")]
|
||||
pub db_password: String,
|
||||
|
||||
/// Override the namespace used on the SurrealDB server
|
||||
#[arg(long, env = "EVAL_DB_NAMESPACE")]
|
||||
pub db_namespace: Option<String>,
|
||||
|
||||
/// Override the database used on the SurrealDB server
|
||||
#[arg(long, env = "EVAL_DB_DATABASE")]
|
||||
pub db_database: Option<String>,
|
||||
|
||||
/// Path to inspect DB state
|
||||
#[arg(long)]
|
||||
pub inspect_db_state: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Config {
|
||||
/// Convert the selected dataset and exit
|
||||
#[arg(long)]
|
||||
pub convert_only: bool,
|
||||
|
||||
/// Regenerate the converted dataset even if it already exists
|
||||
#[arg(long, alias = "refresh")]
|
||||
pub force_convert: bool,
|
||||
|
||||
/// Dataset to evaluate
|
||||
#[arg(long, default_value_t = DatasetKind::default())]
|
||||
pub dataset: DatasetKind,
|
||||
|
||||
/// Enable LLM-assisted evaluation features (includes unanswerable cases)
|
||||
#[arg(long)]
|
||||
pub llm_mode: bool,
|
||||
|
||||
/// Cap the slice corpus size (positives + negatives)
|
||||
#[arg(long)]
|
||||
pub corpus_limit: Option<usize>,
|
||||
|
||||
/// Path to the raw dataset (defaults per dataset)
|
||||
#[arg(long)]
|
||||
pub raw: Option<PathBuf>,
|
||||
|
||||
/// Path to write/read the converted dataset (defaults per dataset)
|
||||
#[arg(long)]
|
||||
pub converted: Option<PathBuf>,
|
||||
|
||||
/// Directory to write evaluation reports
|
||||
#[arg(long, default_value_os_t = default_report_dir())]
|
||||
pub report_dir: PathBuf,
|
||||
|
||||
/// Precision@k cutoff
|
||||
#[arg(long, default_value_t = 5)]
|
||||
pub k: usize,
|
||||
|
||||
/// Limit the number of questions evaluated (0 = all)
|
||||
#[arg(long = "limit", default_value_t = 200)]
|
||||
pub limit_arg: usize,
|
||||
|
||||
/// Number of mismatches to surface in the Markdown summary
|
||||
#[arg(long, default_value_t = 5)]
|
||||
pub sample: usize,
|
||||
|
||||
/// Disable context cropping when converting datasets (ingest entire documents)
|
||||
#[arg(long)]
|
||||
pub full_context: bool,
|
||||
|
||||
#[command(flatten)]
|
||||
pub retrieval: RetrievalSettings,
|
||||
|
||||
/// Concurrency level
|
||||
#[arg(long, default_value_t = 1)]
|
||||
pub concurrency: usize,
|
||||
|
||||
/// Embedding backend
|
||||
#[arg(long, default_value_t = EmbeddingBackend::FastEmbed)]
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
|
||||
/// FastEmbed model code
|
||||
#[arg(long)]
|
||||
pub embedding_model: Option<String>,
|
||||
|
||||
/// Directory for embedding caches
|
||||
#[arg(long, default_value_os_t = default_cache_dir())]
|
||||
pub cache_dir: PathBuf,
|
||||
|
||||
#[command(flatten)]
|
||||
pub ingest: IngestConfig,
|
||||
|
||||
/// Include entity descriptions and categories in JSON reports
|
||||
#[arg(long)]
|
||||
pub detailed_report: bool,
|
||||
|
||||
/// Use a cached dataset slice by id or path
|
||||
#[arg(long)]
|
||||
pub slice: Option<String>,
|
||||
|
||||
/// Ignore cached corpus state and rebuild the slice's SurrealDB corpus
|
||||
#[arg(long)]
|
||||
pub reseed_slice: bool,
|
||||
|
||||
/// Slice seed
|
||||
#[arg(skip = DEFAULT_SLICE_SEED)]
|
||||
pub slice_seed: u64,
|
||||
|
||||
/// Grow the slice ledger to contain at least this many answerable cases, then exit
|
||||
#[arg(long)]
|
||||
pub slice_grow: Option<usize>,
|
||||
|
||||
/// Evaluate questions starting at this offset within the slice
|
||||
#[arg(long, default_value_t = 0)]
|
||||
pub slice_offset: usize,
|
||||
|
||||
/// Target negative-to-positive paragraph ratio for slice growth
|
||||
#[arg(long, default_value_t = crate::slice::DEFAULT_NEGATIVE_MULTIPLIER)]
|
||||
pub negative_multiplier: f32,
|
||||
|
||||
/// Annotate the run; label is stored in JSON/Markdown reports
|
||||
#[arg(long)]
|
||||
pub label: Option<String>,
|
||||
|
||||
/// Write per-query chunk diagnostics JSONL to the provided path
|
||||
#[arg(long, alias = "chunk-diagnostics")]
|
||||
pub chunk_diagnostics_path: Option<PathBuf>,
|
||||
|
||||
/// Inspect an ingestion cache question and exit
|
||||
#[arg(long)]
|
||||
pub inspect_question: Option<String>,
|
||||
|
||||
/// Path to an ingestion cache manifest JSON for inspection mode
|
||||
#[arg(long)]
|
||||
pub inspect_manifest: Option<PathBuf>,
|
||||
|
||||
/// Override the SurrealDB system settings query model
|
||||
#[arg(long)]
|
||||
pub query_model: Option<String>,
|
||||
|
||||
/// Write structured performance telemetry JSON to the provided path
|
||||
#[arg(long)]
|
||||
pub perf_log_json: Option<PathBuf>,
|
||||
|
||||
/// Directory that receives timestamped perf JSON copies
|
||||
#[arg(long)]
|
||||
pub perf_log_dir: Option<PathBuf>,
|
||||
|
||||
/// Print per-stage performance timings to stdout after the run
|
||||
#[arg(long, alias = "perf-log")]
|
||||
pub perf_log_console: bool,
|
||||
|
||||
#[command(flatten)]
|
||||
pub database: DatabaseArgs,
|
||||
|
||||
// Computed fields (not arguments)
|
||||
#[arg(skip)]
|
||||
pub raw_dataset_path: PathBuf,
|
||||
#[arg(skip)]
|
||||
pub converted_dataset_path: PathBuf,
|
||||
#[arg(skip)]
|
||||
pub limit: Option<usize>,
|
||||
#[arg(skip)]
|
||||
pub summary_sample: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn context_token_limit(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
pub fn finalize(&mut self) -> Result<()> {
|
||||
// Handle dataset paths
|
||||
if let Some(raw) = &self.raw {
|
||||
self.raw_dataset_path = raw.clone();
|
||||
} else {
|
||||
self.raw_dataset_path = self.dataset.default_raw_path();
|
||||
}
|
||||
|
||||
if let Some(converted) = &self.converted {
|
||||
self.converted_dataset_path = converted.clone();
|
||||
} else {
|
||||
self.converted_dataset_path = self.dataset.default_converted_path();
|
||||
}
|
||||
|
||||
// Handle limit
|
||||
if self.limit_arg == 0 {
|
||||
self.limit = None;
|
||||
} else {
|
||||
self.limit = Some(self.limit_arg);
|
||||
}
|
||||
|
||||
// Handle sample
|
||||
self.summary_sample = self.sample.max(1);
|
||||
|
||||
// Handle retrieval settings
|
||||
self.retrieval.require_verified_chunks = !self.llm_mode;
|
||||
|
||||
if self.dataset == DatasetKind::Beir {
|
||||
self.negative_multiplier = 9.0;
|
||||
}
|
||||
|
||||
// Validations
|
||||
if self.ingest.ingest_chunk_min_tokens == 0
|
||||
|| self.ingest.ingest_chunk_min_tokens >= self.ingest.ingest_chunk_max_tokens
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"--ingest-chunk-min-tokens must be greater than zero and less than --ingest-chunk-max-tokens (got {} >= {})",
|
||||
self.ingest.ingest_chunk_min_tokens,
|
||||
self.ingest.ingest_chunk_max_tokens
|
||||
));
|
||||
}
|
||||
|
||||
if self.ingest.ingest_chunk_overlap_tokens >= self.ingest.ingest_chunk_min_tokens {
|
||||
return Err(anyhow!(
|
||||
"--ingest-chunk-overlap-tokens ({}) must be less than --ingest-chunk-min-tokens ({})",
|
||||
self.ingest.ingest_chunk_overlap_tokens,
|
||||
self.ingest.ingest_chunk_min_tokens
|
||||
));
|
||||
}
|
||||
|
||||
if self.retrieval.rerank && self.retrieval.rerank_pool_size == 0 {
|
||||
return Err(anyhow!(
|
||||
"--rerank-pool must be greater than zero when reranking is enabled"
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(k) = self.retrieval.chunk_rrf_k {
|
||||
if k <= 0.0 || !k.is_finite() {
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-k must be a positive, finite number (got {k})"
|
||||
));
|
||||
}
|
||||
}
|
||||
if let Some(weight) = self.retrieval.chunk_rrf_vector_weight {
|
||||
if weight < 0.0 || !weight.is_finite() {
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-vector-weight must be a non-negative, finite number (got {weight})"
|
||||
));
|
||||
}
|
||||
}
|
||||
if let Some(weight) = self.retrieval.chunk_rrf_fts_weight {
|
||||
if weight < 0.0 || !weight.is_finite() {
|
||||
return Err(anyhow!(
|
||||
"--chunk-rrf-fts-weight must be a non-negative, finite number (got {weight})"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if self.concurrency == 0 {
|
||||
return Err(anyhow!("--concurrency must be greater than zero"));
|
||||
}
|
||||
|
||||
if self.embedding_backend == EmbeddingBackend::Hashed && self.embedding_model.is_some() {
|
||||
return Err(anyhow!(
|
||||
"--embedding-model cannot be used with the 'hashed' embedding backend"
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(query_model) = &self.query_model {
|
||||
if query_model.trim().is_empty() {
|
||||
return Err(anyhow!("--query-model requires a non-empty model name"));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(grow) = self.slice_grow {
|
||||
if grow == 0 {
|
||||
return Err(anyhow!("--slice-grow must be greater than zero"));
|
||||
}
|
||||
}
|
||||
|
||||
if self.negative_multiplier <= 0.0 || !self.negative_multiplier.is_finite() {
|
||||
return Err(anyhow!(
|
||||
"--negative-multiplier must be a positive finite number"
|
||||
));
|
||||
}
|
||||
|
||||
// Handle corpus limit logic
|
||||
if let Some(limit) = self.limit {
|
||||
if let Some(corpus_limit) = self.corpus_limit {
|
||||
if corpus_limit < limit {
|
||||
self.corpus_limit = Some(limit);
|
||||
}
|
||||
} else {
|
||||
let default_multiplier = 10usize;
|
||||
let mut computed = limit.saturating_mul(default_multiplier);
|
||||
if computed < limit {
|
||||
computed = limit;
|
||||
}
|
||||
let max_cap = 1_000usize;
|
||||
if computed > max_cap {
|
||||
computed = max_cap;
|
||||
}
|
||||
self.corpus_limit = Some(computed);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle perf log dir env var fallback
|
||||
if self.perf_log_dir.is_none() {
|
||||
if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") {
|
||||
if !dir.trim().is_empty() {
|
||||
self.perf_log_dir = Some(PathBuf::from(dir));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ParsedArgs {
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
pub fn parse() -> Result<ParsedArgs> {
|
||||
let mut config = Config::parse();
|
||||
config.finalize()?;
|
||||
Ok(ParsedArgs { config })
|
||||
}
|
||||
|
||||
pub fn ensure_parent(path: &Path) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating parent directory for {}", path.display()))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::{Path, PathBuf},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct EmbeddingCacheData {
|
||||
entities: HashMap<String, Vec<f32>>,
|
||||
chunks: HashMap<String, Vec<f32>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingCache {
|
||||
path: Arc<PathBuf>,
|
||||
data: Arc<Mutex<EmbeddingCacheData>>,
|
||||
dirty: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl EmbeddingCache {
|
||||
pub async fn load(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref().to_path_buf();
|
||||
let data = if path.exists() {
|
||||
let raw = tokio::fs::read(&path)
|
||||
.await
|
||||
.with_context(|| format!("reading embedding cache {}", path.display()))?;
|
||||
serde_json::from_slice(&raw)
|
||||
.with_context(|| format!("parsing embedding cache {}", path.display()))?
|
||||
} else {
|
||||
EmbeddingCacheData::default()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
path: Arc::new(path),
|
||||
data: Arc::new(Mutex::new(data)),
|
||||
dirty: Arc::new(AtomicBool::new(false)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_entity(&self, id: &str) -> Option<Vec<f32>> {
|
||||
let guard = self.data.lock().await;
|
||||
guard.entities.get(id).cloned()
|
||||
}
|
||||
|
||||
pub async fn insert_entity(&self, id: String, embedding: Vec<f32>) {
|
||||
let mut guard = self.data.lock().await;
|
||||
guard.entities.insert(id, embedding);
|
||||
self.dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn get_chunk(&self, id: &str) -> Option<Vec<f32>> {
|
||||
let guard = self.data.lock().await;
|
||||
guard.chunks.get(id).cloned()
|
||||
}
|
||||
|
||||
pub async fn insert_chunk(&self, id: String, embedding: Vec<f32>) {
|
||||
let mut guard = self.data.lock().await;
|
||||
guard.chunks.insert(id, embedding);
|
||||
self.dirty.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> Result<()> {
|
||||
if !self.dirty.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let guard = self.data.lock().await;
|
||||
let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?;
|
||||
if let Some(parent) = self.path.parent() {
|
||||
tokio::fs::create_dir_all(parent)
|
||||
.await
|
||||
.with_context(|| format!("creating cache directory {}", parent.display()))?;
|
||||
}
|
||||
tokio::fs::write(&*self.path, body)
|
||||
.await
|
||||
.with_context(|| format!("writing embedding cache {}", self.path.display()))?;
|
||||
self.dirty.store(false, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
//! Case generation from corpus manifests.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::corpus;
|
||||
|
||||
/// A test case for retrieval evaluation derived from a manifest question.
|
||||
pub(crate) struct SeededCase {
|
||||
pub question_id: String,
|
||||
pub question: String,
|
||||
pub expected_source: String,
|
||||
pub answers: Vec<String>,
|
||||
pub paragraph_id: String,
|
||||
pub paragraph_title: String,
|
||||
pub expected_chunk_ids: Vec<String>,
|
||||
pub is_impossible: bool,
|
||||
pub has_verified_chunks: bool,
|
||||
}
|
||||
|
||||
/// Convert a corpus manifest into seeded evaluation cases.
|
||||
pub(crate) fn cases_from_manifest(manifest: &corpus::CorpusManifest) -> Vec<SeededCase> {
|
||||
let mut title_map = HashMap::new();
|
||||
for paragraph in &manifest.paragraphs {
|
||||
title_map.insert(paragraph.paragraph_id.as_str(), paragraph.title.clone());
|
||||
}
|
||||
|
||||
let include_impossible = manifest.metadata.include_unanswerable;
|
||||
let require_verified_chunks = manifest.metadata.require_verified_chunks;
|
||||
|
||||
manifest
|
||||
.questions
|
||||
.iter()
|
||||
.filter(|question| {
|
||||
should_include_question(question, include_impossible, require_verified_chunks)
|
||||
})
|
||||
.map(|question| {
|
||||
let title = title_map
|
||||
.get(question.paragraph_id.as_str())
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "Untitled".to_string());
|
||||
SeededCase {
|
||||
question_id: question.question_id.clone(),
|
||||
question: question.question_text.clone(),
|
||||
expected_source: question.text_content_id.clone(),
|
||||
answers: question.answers.clone(),
|
||||
paragraph_id: question.paragraph_id.clone(),
|
||||
paragraph_title: title,
|
||||
expected_chunk_ids: question.matching_chunk_ids.clone(),
|
||||
is_impossible: question.is_impossible,
|
||||
has_verified_chunks: !question.matching_chunk_ids.is_empty(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn should_include_question(
|
||||
question: &corpus::CorpusQuestion,
|
||||
include_impossible: bool,
|
||||
require_verified_chunks: bool,
|
||||
) -> bool {
|
||||
if !include_impossible && question.is_impossible {
|
||||
return false;
|
||||
}
|
||||
if require_verified_chunks && question.matching_chunk_ids.is_empty() {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::corpus::store::{CorpusParagraph, EmbeddedKnowledgeEntity, EmbeddedTextChunk};
|
||||
use crate::corpus::{CorpusManifest, CorpusMetadata, CorpusQuestion, MANIFEST_VERSION};
|
||||
use chrono::Utc;
|
||||
use common::storage::types::text_content::TextContent;
|
||||
|
||||
fn sample_manifest() -> CorpusManifest {
|
||||
let paragraphs = vec![
|
||||
CorpusParagraph {
|
||||
paragraph_id: "p1".to_string(),
|
||||
title: "Alpha".to_string(),
|
||||
text_content: TextContent::new(
|
||||
"alpha context".to_string(),
|
||||
None,
|
||||
"test".to_string(),
|
||||
None,
|
||||
None,
|
||||
"user".to_string(),
|
||||
),
|
||||
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
|
||||
relationships: Vec::new(),
|
||||
chunks: Vec::<EmbeddedTextChunk>::new(),
|
||||
},
|
||||
CorpusParagraph {
|
||||
paragraph_id: "p2".to_string(),
|
||||
title: "Beta".to_string(),
|
||||
text_content: TextContent::new(
|
||||
"beta context".to_string(),
|
||||
None,
|
||||
"test".to_string(),
|
||||
None,
|
||||
None,
|
||||
"user".to_string(),
|
||||
),
|
||||
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
|
||||
relationships: Vec::new(),
|
||||
chunks: Vec::<EmbeddedTextChunk>::new(),
|
||||
},
|
||||
];
|
||||
let questions = vec![
|
||||
CorpusQuestion {
|
||||
question_id: "q1".to_string(),
|
||||
paragraph_id: "p1".to_string(),
|
||||
text_content_id: "tc-alpha".to_string(),
|
||||
question_text: "What is Alpha?".to_string(),
|
||||
answers: vec!["Alpha".to_string()],
|
||||
is_impossible: false,
|
||||
matching_chunk_ids: vec!["chunk-alpha".to_string()],
|
||||
},
|
||||
CorpusQuestion {
|
||||
question_id: "q2".to_string(),
|
||||
paragraph_id: "p1".to_string(),
|
||||
text_content_id: "tc-alpha".to_string(),
|
||||
question_text: "Unanswerable?".to_string(),
|
||||
answers: Vec::new(),
|
||||
is_impossible: true,
|
||||
matching_chunk_ids: Vec::new(),
|
||||
},
|
||||
CorpusQuestion {
|
||||
question_id: "q3".to_string(),
|
||||
paragraph_id: "p2".to_string(),
|
||||
text_content_id: "tc-beta".to_string(),
|
||||
question_text: "Where is Beta?".to_string(),
|
||||
answers: vec!["Beta".to_string()],
|
||||
is_impossible: false,
|
||||
matching_chunk_ids: Vec::new(),
|
||||
},
|
||||
];
|
||||
CorpusManifest {
|
||||
version: MANIFEST_VERSION,
|
||||
metadata: CorpusMetadata {
|
||||
dataset_id: "ds".to_string(),
|
||||
dataset_label: "Dataset".to_string(),
|
||||
slice_id: "slice".to_string(),
|
||||
include_unanswerable: true,
|
||||
require_verified_chunks: true,
|
||||
ingestion_fingerprint: "fp".to_string(),
|
||||
embedding_backend: "test".to_string(),
|
||||
embedding_model: None,
|
||||
embedding_dimension: 3,
|
||||
converted_checksum: "chk".to_string(),
|
||||
generated_at: Utc::now(),
|
||||
paragraph_count: paragraphs.len(),
|
||||
question_count: questions.len(),
|
||||
chunk_min_tokens: 1,
|
||||
chunk_max_tokens: 10,
|
||||
chunk_only: false,
|
||||
},
|
||||
paragraphs,
|
||||
questions,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cases_respect_mode_filters() {
|
||||
let mut manifest = sample_manifest();
|
||||
manifest.metadata.include_unanswerable = false;
|
||||
manifest.metadata.require_verified_chunks = true;
|
||||
|
||||
let strict_cases = cases_from_manifest(&manifest);
|
||||
assert_eq!(strict_cases.len(), 1);
|
||||
assert_eq!(strict_cases[0].question_id, "q1");
|
||||
assert_eq!(strict_cases[0].paragraph_title, "Alpha");
|
||||
|
||||
let mut llm_manifest = manifest.clone();
|
||||
llm_manifest.metadata.include_unanswerable = true;
|
||||
llm_manifest.metadata.require_verified_chunks = false;
|
||||
|
||||
let llm_cases = cases_from_manifest(&llm_manifest);
|
||||
let ids: Vec<_> = llm_cases
|
||||
.iter()
|
||||
.map(|case| case.question_id.as_str())
|
||||
.collect();
|
||||
assert_eq!(ids, vec!["q1", "q2", "q3"]);
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user