mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-12 17:24:26 +02:00
Compare commits
163 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 28e8ede478 | |||
| 00453fdcbe | |||
| c53ec8c0a1 | |||
| 60cf63292a | |||
| ac0d34bfbd | |||
| 4e20da538d | |||
| 15c9f18f6e | |||
| 7b850769c9 | |||
| 2a28243213 | |||
| b22c351785 | |||
| 3897345ab3 | |||
| 5c2d2e24d3 | |||
| c70141de35 | |||
| 2aa92b6ad7 | |||
| d3443d4153 | |||
| e3bb2935d0 | |||
| 93d11b66eb | |||
| 125b856c49 | |||
| bc41a619ce | |||
| ba8c36da1e | |||
| 5724f11dc1 | |||
| 189adb1a5f | |||
| 97beb91710 | |||
| 85336d77a3 | |||
| 9d5e7cd794 | |||
| 30bb59f243 | |||
| 224a7db451 | |||
| 4579725130 | |||
| 0b08801c90 | |||
| 45d13230a6 | |||
| 0acdba4f54 | |||
| 9609880cff | |||
| 31d585b59f | |||
| 890a4b381d | |||
| 2d630e2af9 | |||
| 9ec11e1f79 | |||
| c60db0fb56 | |||
| f5f0454904 | |||
| 18aadab8ee | |||
| 414d2f5b34 | |||
| 293440b0ee | |||
| 041d9bd81f | |||
| b4383bb227 | |||
| 6c7b586fc5 | |||
| 1927149ce9 | |||
| a52dc802de | |||
| 000852c94c | |||
| 6a5d631287 | |||
| b965c5a2e6 | |||
| 79e46e9c09 | |||
| f22a1e5ba4 | |||
| 4d237ff6d9 | |||
| eb928cdb0e | |||
| 1490852a09 | |||
| b0b01182d7 | |||
| 679308aa1d | |||
| f93c06b347 | |||
| a3f207beb1 | |||
| e07199adfc | |||
| f22cac891c | |||
| b89171d934 | |||
| 0133eead63 | |||
| e5d2b6605f | |||
| bbad91d55b | |||
| 96846ad664 | |||
| 269bcec659 | |||
| 7c738c4b30 | |||
| cb88127fcb | |||
| 49e1fbd985 | |||
| f2fa5bbbcc | |||
| a3bc6fba98 | |||
| ece744d5a0 | |||
| a9fda67209 | |||
| fa7f407306 | |||
| b25cfb4633 | |||
| 0df2b9810c | |||
| 354dc727c1 | |||
| 037057d108 | |||
| 9f17c6c2b0 | |||
| 17f252e630 | |||
| db43be1606 | |||
| 8e8370b080 | |||
| 84695fa0cc | |||
| 654add98bc | |||
| 244ec0ea25 | |||
| d8416ac711 | |||
| f9f48d1046 | |||
| 30b8a65377 | |||
| 04faa38ee6 | |||
| cdc62dda30 | |||
| ab8ff8b07a | |||
| 79ea007b0a | |||
| a5bc72aedf | |||
| 2e2ea0c4ff | |||
| a090a8c76e | |||
| a8d10f265c | |||
| 0cb1abc6db | |||
| d1a6d9abdf | |||
| d3fa3be3e5 | |||
| a2c9bb848d | |||
| dd881efbf9 | |||
| 2939e4c2a4 | |||
| 1039ec32a4 | |||
| cb906c5b53 | |||
| 08b1612fcb | |||
| 67004c9646 | |||
| 030f0fc17d | |||
| 226b2db43a | |||
| 6f88d87e74 | |||
| bd519ab269 | |||
| f535df7e61 | |||
| 6b7befbd04 | |||
| 0eda65b07e | |||
| 04ee225732 | |||
| 13b7ad6f3a | |||
| 112a6965a4 | |||
| 911e830be5 | |||
| 3196e65172 | |||
| 380c900c86 | |||
| a99e5ada8b | |||
| b0deabaf3f | |||
| a8f0d9fa88 | |||
| 56a1dfddb8 | |||
| 863b921fb4 | |||
| f13791cfcf | |||
| 75c200b2ba | |||
| 1b7c24747a | |||
| 241ad9a089 | |||
| 72578296db | |||
| a0e9387c76 | |||
| 798b1468b6 | |||
| 3b805778b4 | |||
| 07b3e1a0e8 | |||
| 83d39afad4 | |||
| 21e4ab1f42 | |||
| 3c97d8ead5 | |||
| ab68bccb80 | |||
| 99b88c3063 | |||
| 44e5d8a2fc | |||
| 7332347f1a | |||
| 199186e5a3 | |||
| 64728468cd | |||
| c3a7e8dc59 | |||
| 35ff4e1464 | |||
| 2964f1a5a5 | |||
| cb7f625b81 | |||
| dc40cf7663 | |||
| aa0b1462a1 | |||
| 41fc7bb99c | |||
| 61d8d7abe7 | |||
| b7344644dc | |||
| 3742598a6d | |||
| c6a6080e1c | |||
| 1159712724 | |||
| e5e1414f54 | |||
| fcc49b1954 | |||
| 022f4d8575 | |||
| 945a2b7f37 | |||
| ff4ea55cd5 | |||
| c4c76efe92 | |||
| c0fcad5952 | |||
| b0ed69330d | |||
| 5cb15dab45 |
@@ -0,0 +1,2 @@
|
|||||||
|
[alias]
|
||||||
|
eval = "run -p evaluations --"
|
||||||
+124
-117
@@ -1,44 +1,8 @@
|
|||||||
# This file was autogenerated by dist: https://opensource.axo.dev/cargo-dist/
|
|
||||||
#
|
|
||||||
# Copyright 2022-2024, axodotdev
|
|
||||||
# SPDX-License-Identifier: MIT or Apache-2.0
|
|
||||||
#
|
|
||||||
# CI that:
|
|
||||||
#
|
|
||||||
# * checks for a Git Tag that looks like a release
|
|
||||||
# * builds artifacts with dist (archives, installers, hashes)
|
|
||||||
# * uploads those artifacts to temporary workflow zip
|
|
||||||
# * on success, uploads the artifacts to a GitHub Release
|
|
||||||
#
|
|
||||||
# Note that the GitHub Release will be created with a generated
|
|
||||||
# title/body based on your changelogs.
|
|
||||||
|
|
||||||
name: Release
|
name: Release
|
||||||
permissions:
|
permissions:
|
||||||
"contents": "write"
|
contents: write
|
||||||
"packages": "write"
|
packages: write
|
||||||
|
|
||||||
# This task will run whenever you push a git tag that looks like a version
|
|
||||||
# like "1.0.0", "v0.1.0-prerelease.1", "my-app/0.1.0", "releases/v1.0.0", etc.
|
|
||||||
# Various formats will be parsed into a VERSION and an optional PACKAGE_NAME, where
|
|
||||||
# PACKAGE_NAME must be the name of a Cargo package in your workspace, and VERSION
|
|
||||||
# must be a Cargo-style SemVer Version (must have at least major.minor.patch).
|
|
||||||
#
|
|
||||||
# If PACKAGE_NAME is specified, then the announcement will be for that
|
|
||||||
# package (erroring out if it doesn't have the given version or isn't dist-able).
|
|
||||||
#
|
|
||||||
# If PACKAGE_NAME isn't specified, then the announcement will be for all
|
|
||||||
# (dist-able) packages in the workspace with that version (this mode is
|
|
||||||
# intended for workspaces with only one dist-able package, or with all dist-able
|
|
||||||
# packages versioned/released in lockstep).
|
|
||||||
#
|
|
||||||
# If you push multiple tags at once, separate instances of this workflow will
|
|
||||||
# spin up, creating an independent announcement for each one. However, GitHub
|
|
||||||
# will hard limit this to 3 tags per commit, as it will assume more tags is a
|
|
||||||
# mistake.
|
|
||||||
#
|
|
||||||
# If there's a prerelease-style suffix to the version, then the release(s)
|
|
||||||
# will be marked as a prerelease.
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
push:
|
push:
|
||||||
@@ -46,9 +10,8 @@ on:
|
|||||||
- '**[0-9]+.[0-9]+.[0-9]+*'
|
- '**[0-9]+.[0-9]+.[0-9]+*'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# Run 'dist plan' (or host) to determine what tasks we need to do
|
|
||||||
plan:
|
plan:
|
||||||
runs-on: "ubuntu-22.04"
|
runs-on: ubuntu-22.04
|
||||||
outputs:
|
outputs:
|
||||||
val: ${{ steps.plan.outputs.manifest }}
|
val: ${{ steps.plan.outputs.manifest }}
|
||||||
tag: ${{ !github.event.pull_request && github.ref_name || '' }}
|
tag: ${{ !github.event.pull_request && github.ref_name || '' }}
|
||||||
@@ -60,52 +23,45 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install Nix
|
||||||
|
uses: cachix/install-nix-action@v27
|
||||||
|
with:
|
||||||
|
extra_nix_config: |
|
||||||
|
experimental-features = nix-command flakes
|
||||||
|
|
||||||
|
- name: Verify ort-version matches nixpkgs onnxruntime
|
||||||
|
run: nix flake check --system x86_64-linux -L
|
||||||
|
|
||||||
- name: Install dist
|
- name: Install dist
|
||||||
# we specify bash to get pipefail; it guards against the `curl` command
|
|
||||||
# failing. otherwise `sh` won't catch that `curl` returned non-0
|
|
||||||
shell: bash
|
shell: bash
|
||||||
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.28.0/cargo-dist-installer.sh | sh"
|
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.30.3/cargo-dist-installer.sh | sh"
|
||||||
|
|
||||||
- name: Cache dist
|
- name: Cache dist
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: cargo-dist-cache
|
name: cargo-dist-cache
|
||||||
path: ~/.cargo/bin/dist
|
path: ~/.cargo/bin/dist
|
||||||
# sure would be cool if github gave us proper conditionals...
|
|
||||||
# so here's a doubly-nested ternary-via-truthiness to try to provide the best possible
|
|
||||||
# functionality based on whether this is a pull_request, and whether it's from a fork.
|
|
||||||
# (PRs run on the *source* but secrets are usually on the *target* -- that's *good*
|
|
||||||
# but also really annoying to build CI around when it needs secrets to work right.)
|
|
||||||
- id: plan
|
- id: plan
|
||||||
run: |
|
run: |
|
||||||
dist ${{ (!github.event.pull_request && format('host --steps=create --tag={0}', github.ref_name)) || 'plan' }} --output-format=json > plan-dist-manifest.json
|
dist ${{ (!github.event.pull_request && format('host --steps=create --tag={0}', github.ref_name)) || 'plan' }} --output-format=json > plan-dist-manifest.json
|
||||||
echo "dist ran successfully"
|
echo "dist ran successfully"
|
||||||
cat plan-dist-manifest.json
|
cat plan-dist-manifest.json
|
||||||
echo "manifest=$(jq -c "." plan-dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
echo "manifest=$(jq -c . plan-dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||||
- name: "Upload dist-manifest.json"
|
|
||||||
|
- name: Upload dist-manifest.json
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: artifacts-plan-dist-manifest
|
name: artifacts-plan-dist-manifest
|
||||||
path: plan-dist-manifest.json
|
path: plan-dist-manifest.json
|
||||||
|
|
||||||
# Build and packages all the platform-specific things
|
|
||||||
build-local-artifacts:
|
build-local-artifacts:
|
||||||
name: build-local-artifacts (${{ join(matrix.targets, ', ') }})
|
name: build-local-artifacts (${{ join(matrix.targets, ', ') }})
|
||||||
# Let the initial task tell us to not run (currently very blunt)
|
needs: [plan]
|
||||||
needs:
|
|
||||||
- plan
|
|
||||||
if: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix.include != null && (needs.plan.outputs.publishing == 'true' || fromJson(needs.plan.outputs.val).ci.github.pr_run_mode == 'upload') }}
|
if: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix.include != null && (needs.plan.outputs.publishing == 'true' || fromJson(needs.plan.outputs.val).ci.github.pr_run_mode == 'upload') }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
# Target platforms/runners are computed by dist in create-release.
|
|
||||||
# Each member of the matrix has the following arguments:
|
|
||||||
#
|
|
||||||
# - runner: the github runner
|
|
||||||
# - dist-args: cli flags to pass to dist
|
|
||||||
# - install-dist: expression to run to install dist on the runner
|
|
||||||
#
|
|
||||||
# Typically there will be:
|
|
||||||
# - 1 "global" task that builds universal installers
|
|
||||||
# - N "local" tasks that build each platform's binaries and platform-specific installers
|
|
||||||
matrix: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix }}
|
matrix: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix }}
|
||||||
runs-on: ${{ matrix.runner }}
|
runs-on: ${{ matrix.runner }}
|
||||||
container: ${{ matrix.container && matrix.container.image || null }}
|
container: ${{ matrix.container && matrix.container.image || null }}
|
||||||
@@ -114,11 +70,16 @@ jobs:
|
|||||||
BUILD_MANIFEST_NAME: target/distrib/${{ join(matrix.targets, '-') }}-dist-manifest.json
|
BUILD_MANIFEST_NAME: target/distrib/${{ join(matrix.targets, '-') }}-dist-manifest.json
|
||||||
steps:
|
steps:
|
||||||
- name: enable windows longpaths
|
- name: enable windows longpaths
|
||||||
run: |
|
run: git config --global core.longpaths true
|
||||||
git config --global core.longpaths true
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Load ONNX Runtime version
|
||||||
|
shell: bash
|
||||||
|
run: echo "ORT_VER=$(tr -d '[:space:]' < ort-version)" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
- name: Install Rust non-interactively if not already installed
|
- name: Install Rust non-interactively if not already installed
|
||||||
if: ${{ matrix.container }}
|
if: ${{ matrix.container }}
|
||||||
run: |
|
run: |
|
||||||
@@ -126,37 +87,97 @@ jobs:
|
|||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Install dist
|
- name: Install dist
|
||||||
run: ${{ matrix.install_dist.run }}
|
run: ${{ matrix.install_dist.run }}
|
||||||
# Get the dist-manifest
|
|
||||||
- name: Fetch local artifacts
|
- name: Fetch local artifacts
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
pattern: artifacts-*
|
pattern: artifacts-*
|
||||||
path: target/distrib/
|
path: target/distrib/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
|
|
||||||
|
# ===== BEGIN: Injected ORT staging for cargo-dist bundling =====
|
||||||
|
- run: echo "=== BUILD-SETUP START ==="
|
||||||
|
|
||||||
|
# Unix shells
|
||||||
|
- name: Prepare lib dir (Unix)
|
||||||
|
if: runner.os != 'Windows'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
mkdir -p lib
|
||||||
|
rm -f lib/*
|
||||||
|
|
||||||
|
# Windows PowerShell
|
||||||
|
- name: Prepare lib dir (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
New-Item -ItemType Directory -Force -Path lib | Out-Null
|
||||||
|
# remove contents if any
|
||||||
|
Get-ChildItem -Path lib -Force | Remove-Item -Force -Recurse -ErrorAction SilentlyContinue
|
||||||
|
|
||||||
|
- name: Fetch ONNX Runtime (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
ARCH="$(uname -m)"
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-x64-${ORT_VER}.tgz" ;;
|
||||||
|
aarch64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-aarch64-${ORT_VER}.tgz" ;;
|
||||||
|
*) echo "Unsupported arch $ARCH"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
curl -fsSL -o ort.tgz "$URL"
|
||||||
|
tar -xzf ort.tgz
|
||||||
|
cp -v onnxruntime-*/lib/libonnxruntime.so* lib/
|
||||||
|
# normalize to stable name if needed
|
||||||
|
[ -f lib/libonnxruntime.so ] || cp -v lib/libonnxruntime.so.* lib/libonnxruntime.so
|
||||||
|
|
||||||
|
- name: Fetch ONNX Runtime (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
||||||
|
tar -xzf ort.tgz
|
||||||
|
cp -v onnxruntime-*/lib/libonnxruntime*.dylib lib/
|
||||||
|
[ -f lib/libonnxruntime.dylib ] || cp -v lib/libonnxruntime*.dylib lib/libonnxruntime.dylib
|
||||||
|
|
||||||
|
- name: Fetch ONNX Runtime (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
||||||
|
Invoke-WebRequest $url -OutFile ort.zip
|
||||||
|
Expand-Archive ort.zip -DestinationPath ort
|
||||||
|
$dll = Get-ChildItem -Recurse -Path ort -Filter onnxruntime.dll | Select-Object -First 1
|
||||||
|
Copy-Item $dll.FullName lib\onnxruntime.dll
|
||||||
|
|
||||||
|
- run: |
|
||||||
|
echo "=== BUILD-SETUP END ==="
|
||||||
|
echo "lib/ contents:"
|
||||||
|
ls -l lib || dir lib
|
||||||
|
# ===== END: Injected ORT staging =====
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
${{ matrix.packages_install }}
|
${{ matrix.packages_install }}
|
||||||
|
|
||||||
- name: Build artifacts
|
- name: Build artifacts
|
||||||
run: |
|
run: |
|
||||||
# Actually do builds and make zips and whatnot
|
|
||||||
dist build ${{ needs.plan.outputs.tag-flag }} --print=linkage --output-format=json ${{ matrix.dist_args }} > dist-manifest.json
|
dist build ${{ needs.plan.outputs.tag-flag }} --print=linkage --output-format=json ${{ matrix.dist_args }} > dist-manifest.json
|
||||||
echo "dist ran successfully"
|
echo "dist ran successfully"
|
||||||
|
|
||||||
- id: cargo-dist
|
- id: cargo-dist
|
||||||
name: Post-build
|
name: Post-build
|
||||||
# We force bash here just because github makes it really hard to get values up
|
|
||||||
# to "real" actions without writing to env-vars, and writing to env-vars has
|
|
||||||
# inconsistent syntax between shell and powershell.
|
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
# Parse out what we just built and upload it to scratch storage
|
|
||||||
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
||||||
dist print-upload-files-from-manifest --manifest dist-manifest.json >> "$GITHUB_OUTPUT"
|
dist print-upload-files-from-manifest --manifest dist-manifest.json >> "$GITHUB_OUTPUT"
|
||||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
||||||
- name: "Upload artifacts"
|
|
||||||
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: artifacts-build-local-${{ join(matrix.targets, '_') }}
|
name: artifacts-build-local-${{ join(matrix.targets, '_') }}
|
||||||
@@ -170,13 +191,13 @@ jobs:
|
|||||||
needs: [plan]
|
needs: [plan]
|
||||||
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
||||||
permissions:
|
permissions:
|
||||||
contents: read # Permission to checkout the repository
|
contents: read
|
||||||
packages: write # Permission to push Docker image to GHCR
|
packages: write
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive # Matches your other checkout steps
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
@@ -185,7 +206,7 @@ jobs:
|
|||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }} # User triggering the workflow
|
username: ${{ github.actor }}
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Extract Docker metadata
|
- name: Extract Docker metadata
|
||||||
@@ -193,8 +214,6 @@ jobs:
|
|||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: ghcr.io/${{ github.repository }}
|
images: ghcr.io/${{ github.repository }}
|
||||||
# This action automatically uses the Git tag as the Docker image tag.
|
|
||||||
# For example, a Git tag 'v1.2.3' will result in Docker tag 'ghcr.io/owner/repo:v1.2.3'.
|
|
||||||
|
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
@@ -203,15 +222,12 @@ jobs:
|
|||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
cache-from: type=gha # Enable Docker layer caching from GitHub Actions cache
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max # Enable Docker layer caching to GitHub Actions cache
|
cache-to: type=gha,mode=max
|
||||||
|
|
||||||
# Build and package all the platform-agnostic(ish) things
|
|
||||||
build-global-artifacts:
|
build-global-artifacts:
|
||||||
needs:
|
needs: [plan, build-local-artifacts]
|
||||||
- plan
|
runs-on: ubuntu-22.04
|
||||||
- build-local-artifacts
|
|
||||||
runs-on: "ubuntu-22.04"
|
|
||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
BUILD_MANIFEST_NAME: target/distrib/global-dist-manifest.json
|
BUILD_MANIFEST_NAME: target/distrib/global-dist-manifest.json
|
||||||
@@ -219,92 +235,90 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Install cached dist
|
- name: Install cached dist
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: cargo-dist-cache
|
name: cargo-dist-cache
|
||||||
path: ~/.cargo/bin/
|
path: ~/.cargo/bin/
|
||||||
- run: chmod +x ~/.cargo/bin/dist
|
- run: chmod +x ~/.cargo/bin/dist
|
||||||
# Get all the local artifacts for the global tasks to use (for e.g. checksums)
|
|
||||||
- name: Fetch local artifacts
|
- name: Fetch local artifacts
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
pattern: artifacts-*
|
pattern: artifacts-*
|
||||||
path: target/distrib/
|
path: target/distrib/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
|
|
||||||
- id: cargo-dist
|
- id: cargo-dist
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
dist build ${{ needs.plan.outputs.tag-flag }} --output-format=json "--artifacts=global" > dist-manifest.json
|
dist build ${{ needs.plan.outputs.tag-flag }} --output-format=json "--artifacts=global" > dist-manifest.json
|
||||||
echo "dist ran successfully"
|
echo "dist ran successfully"
|
||||||
|
|
||||||
# Parse out what we just built and upload it to scratch storage
|
|
||||||
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
||||||
jq --raw-output ".upload_files[]" dist-manifest.json >> "$GITHUB_OUTPUT"
|
jq --raw-output ".upload_files[]" dist-manifest.json >> "$GITHUB_OUTPUT"
|
||||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
||||||
- name: "Upload artifacts"
|
|
||||||
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: artifacts-build-global
|
name: artifacts-build-global
|
||||||
path: |
|
path: |
|
||||||
${{ steps.cargo-dist.outputs.paths }}
|
${{ steps.cargo-dist.outputs.paths }}
|
||||||
${{ env.BUILD_MANIFEST_NAME }}
|
${{ env.BUILD_MANIFEST_NAME }}
|
||||||
# Determines if we should publish/announce
|
|
||||||
host:
|
host:
|
||||||
needs:
|
needs: [plan, build-local-artifacts, build-global-artifacts]
|
||||||
- plan
|
|
||||||
- build-local-artifacts
|
|
||||||
- build-global-artifacts
|
|
||||||
# Only run if we're "publishing", and only if local and global didn't fail (skipped is fine)
|
|
||||||
if: ${{ always() && needs.plan.outputs.publishing == 'true' && (needs.build-global-artifacts.result == 'skipped' || needs.build-global-artifacts.result == 'success') && (needs.build-local-artifacts.result == 'skipped' || needs.build-local-artifacts.result == 'success') }}
|
if: ${{ always() && needs.plan.outputs.publishing == 'true' && (needs.build-global-artifacts.result == 'skipped' || needs.build-global-artifacts.result == 'success') && (needs.build-local-artifacts.result == 'skipped' || needs.build-local-artifacts.result == 'success') }}
|
||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
runs-on: "ubuntu-22.04"
|
runs-on: ubuntu-22.04
|
||||||
outputs:
|
outputs:
|
||||||
val: ${{ steps.host.outputs.manifest }}
|
val: ${{ steps.host.outputs.manifest }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Install cached dist
|
- name: Install cached dist
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: cargo-dist-cache
|
name: cargo-dist-cache
|
||||||
path: ~/.cargo/bin/
|
path: ~/.cargo/bin/
|
||||||
- run: chmod +x ~/.cargo/bin/dist
|
- run: chmod +x ~/.cargo/bin/dist
|
||||||
# Fetch artifacts from scratch-storage
|
|
||||||
- name: Fetch artifacts
|
- name: Fetch artifacts
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
pattern: artifacts-*
|
pattern: artifacts-*
|
||||||
path: target/distrib/
|
path: target/distrib/
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
|
|
||||||
- id: host
|
- id: host
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
dist host ${{ needs.plan.outputs.tag-flag }} --steps=upload --steps=release --output-format=json > dist-manifest.json
|
dist host ${{ needs.plan.outputs.tag-flag }} --steps=upload --steps=release --output-format=json > dist-manifest.json
|
||||||
echo "artifacts uploaded and released successfully"
|
echo "artifacts uploaded and released successfully"
|
||||||
cat dist-manifest.json
|
cat dist-manifest.json
|
||||||
echo "manifest=$(jq -c "." dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
echo "manifest=$(jq -c . dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||||
- name: "Upload dist-manifest.json"
|
|
||||||
|
- name: Upload dist-manifest.json
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
# Overwrite the previous copy
|
|
||||||
name: artifacts-dist-manifest
|
name: artifacts-dist-manifest
|
||||||
path: dist-manifest.json
|
path: dist-manifest.json
|
||||||
# Create a GitHub Release while uploading all files to it
|
|
||||||
- name: "Download GitHub Artifacts"
|
- name: Download GitHub Artifacts
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
pattern: artifacts-*
|
pattern: artifacts-*
|
||||||
path: artifacts
|
path: artifacts
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
|
|
||||||
- name: Cleanup
|
- name: Cleanup
|
||||||
run: |
|
run: rm -f artifacts/*-dist-manifest.json
|
||||||
# Remove the granular manifests
|
|
||||||
rm -f artifacts/*-dist-manifest.json
|
|
||||||
- name: Create GitHub Release
|
- name: Create GitHub Release
|
||||||
env:
|
env:
|
||||||
PRERELEASE_FLAG: "${{ fromJson(steps.host.outputs.manifest).announcement_is_prerelease && '--prerelease' || '' }}"
|
PRERELEASE_FLAG: "${{ fromJson(steps.host.outputs.manifest).announcement_is_prerelease && '--prerelease' || '' }}"
|
||||||
@@ -312,20 +326,13 @@ jobs:
|
|||||||
ANNOUNCEMENT_BODY: "${{ fromJson(steps.host.outputs.manifest).announcement_github_body }}"
|
ANNOUNCEMENT_BODY: "${{ fromJson(steps.host.outputs.manifest).announcement_github_body }}"
|
||||||
RELEASE_COMMIT: "${{ github.sha }}"
|
RELEASE_COMMIT: "${{ github.sha }}"
|
||||||
run: |
|
run: |
|
||||||
# Write and read notes from a file to avoid quoting breaking things
|
|
||||||
echo "$ANNOUNCEMENT_BODY" > $RUNNER_TEMP/notes.txt
|
echo "$ANNOUNCEMENT_BODY" > $RUNNER_TEMP/notes.txt
|
||||||
|
|
||||||
gh release create "${{ needs.plan.outputs.tag }}" --target "$RELEASE_COMMIT" $PRERELEASE_FLAG --title "$ANNOUNCEMENT_TITLE" --notes-file "$RUNNER_TEMP/notes.txt" artifacts/*
|
gh release create "${{ needs.plan.outputs.tag }}" --target "$RELEASE_COMMIT" $PRERELEASE_FLAG --title "$ANNOUNCEMENT_TITLE" --notes-file "$RUNNER_TEMP/notes.txt" artifacts/*
|
||||||
|
|
||||||
announce:
|
announce:
|
||||||
needs:
|
needs: [plan, host]
|
||||||
- plan
|
|
||||||
- host
|
|
||||||
# use "always() && ..." to allow us to wait for all publish jobs while
|
|
||||||
# still allowing individual publish jobs to skip themselves (for prereleases).
|
|
||||||
# "host" however must run to completion, no skipping allowed!
|
|
||||||
if: ${{ always() && needs.host.result == 'success' }}
|
if: ${{ always() && needs.host.result == 'success' }}
|
||||||
runs-on: "ubuntu-22.04"
|
runs-on: ubuntu-22.04
|
||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ result
|
|||||||
data
|
data
|
||||||
database
|
database
|
||||||
|
|
||||||
|
evaluations/cache/
|
||||||
|
evaluations/reports/
|
||||||
|
|
||||||
# Devenv
|
# Devenv
|
||||||
.devenv*
|
.devenv*
|
||||||
devenv.local.nix
|
devenv.local.nix
|
||||||
@@ -21,3 +24,8 @@ devenv.local.nix
|
|||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
# html-router/assets/style.css
|
# html-router/assets/style.css
|
||||||
html-router/node_modules
|
html-router/node_modules
|
||||||
|
.fastembed_cache/
|
||||||
|
|
||||||
|
# insta: pending (unreviewed) snapshots; accepted *.snap files are committed
|
||||||
|
*.snap.new
|
||||||
|
.insta.bak
|
||||||
|
|||||||
@@ -0,0 +1,93 @@
|
|||||||
|
# Changelog
|
||||||
|
## Unreleased
|
||||||
|
|
||||||
|
## 1.0.3 (2026-06-12)
|
||||||
|
- Search: filter results by type — knowledge entities, ingested content, or both
|
||||||
|
- Admin: choose the local FastEmbed model from the admin UI; changes save immediately and apply after restart (re-embeds when the vector dimension changes)
|
||||||
|
- Performance: pooled FastEmbed workers and batched embedding generation for faster ingestion and search
|
||||||
|
- Performance: lower search and chat latency from backend allocation and retrieval optimizations
|
||||||
|
- Fix: modal dialogs (scratchpad editor, admin prompts, entity creation) open and close more reliably
|
||||||
|
- Fix: improved knowledge-entity relationship suggestions when creating entities manually
|
||||||
|
- Fix: API key revocation now correctly clears the stored key
|
||||||
|
|
||||||
|
## 1.0.2 (2026-02-15)
|
||||||
|
- Fix: edge case where navigation back to a chat page could trigger a new response generation
|
||||||
|
- Fix: chat references now validate and render more reliably
|
||||||
|
- Fix: improved admin access checks for restricted routes
|
||||||
|
- Performance: faster chat sidebar loads from cached conversation archive data
|
||||||
|
- API: harmonized ingest endpoint naming and added configurable ingest safety limits
|
||||||
|
- Security: hardened query handling and ingestion logging to reduce injection and data exposure risk
|
||||||
|
|
||||||
|
## 1.0.1 (2026-02-11)
|
||||||
|
- Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments.
|
||||||
|
- Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling.
|
||||||
|
- Fixed edge cases, including content deletion behavior and compatibility for older user records.
|
||||||
|
|
||||||
|
## 1.0.0 (2026-01-02)
|
||||||
|
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
|
||||||
|
- Added a benchmarks create for evaluating the retrieval process
|
||||||
|
- Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms.
|
||||||
|
- Embeddings stored on own table.
|
||||||
|
- Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details.
|
||||||
|
|
||||||
|
## Version 0.2.7 (2025-12-04)
|
||||||
|
- Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features.
|
||||||
|
- Fix: timezone aware info in scratchpad
|
||||||
|
|
||||||
|
## Version 0.2.6 (2025-10-29)
|
||||||
|
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||||
|
- Fix: default name for relationships harmonized across application
|
||||||
|
|
||||||
|
## Version 0.2.5 (2025-10-24)
|
||||||
|
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
||||||
|
- Scratchpad feature, with the feature to convert scratchpads to content.
|
||||||
|
- Added knowledge entity search results to the global search
|
||||||
|
- Backend fixes for improved performance when ingesting and retrieval
|
||||||
|
|
||||||
|
## Version 0.2.4 (2025-10-15)
|
||||||
|
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
|
||||||
|
- Ingestion task archive
|
||||||
|
|
||||||
|
## Version 0.2.3 (2025-10-12)
|
||||||
|
- Fix changing vector dimensions on a fresh database (#3)
|
||||||
|
|
||||||
|
## Version 0.2.2 (2025-10-07)
|
||||||
|
- Support for ingestion of PDF files
|
||||||
|
- Improved ingestion speed
|
||||||
|
- Fix deletion of items work as expected
|
||||||
|
- Fix enabling GPT-5 use via OpenAI API
|
||||||
|
|
||||||
|
## Version 0.2.1 (2025-09-24)
|
||||||
|
- Fixed API JSON responses so iOS Shortcuts integrations keep working.
|
||||||
|
|
||||||
|
## Version 0.2.0 (2025-09-23)
|
||||||
|
- Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph.
|
||||||
|
- Added pagination for entities and content plus new observability metrics on the dashboard.
|
||||||
|
- Enabled audio ingestion and merged the new storage backend.
|
||||||
|
- Improved performance, request filtering, and journalctl/systemd compatibility.
|
||||||
|
|
||||||
|
## Version 0.1.4 (2025-07-01)
|
||||||
|
- Added image ingestion with configurable system settings and updated Docker Compose docs.
|
||||||
|
- Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses.
|
||||||
|
|
||||||
|
## Version 0.1.3 (2025-06-08)
|
||||||
|
- Added support for AI providers beyond OpenAI.
|
||||||
|
- Made the HTTP port configurable for deployments.
|
||||||
|
- Smoothed graph mapper failures, long content tiles, and refreshed project documentation.
|
||||||
|
|
||||||
|
## Version 0.1.2 (2025-05-26)
|
||||||
|
- Introduced full-text search across indexed knowledge.
|
||||||
|
- Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling.
|
||||||
|
- Fixed search result links and SurrealDB vector formatting glitches.
|
||||||
|
|
||||||
|
## Version 0.1.1 (2025-05-13)
|
||||||
|
- Added streaming feedback to ingestion tasks for clearer progress updates.
|
||||||
|
- Made the data storage path configurable.
|
||||||
|
- Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes.
|
||||||
|
|
||||||
|
## Version 0.1.0 (2025-05-06)
|
||||||
|
- Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage.
|
||||||
|
- Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts.
|
||||||
|
- Introduced an admin console with analytics, registration and timezone controls, and job monitoring.
|
||||||
|
- Shipped a Tailwind/daisyUI web UI with responsive layouts, modals, content viewers, and editing flows.
|
||||||
|
- Provided readability-based content ingestion, API/HTML ingress routes, and Docker/Docker Compose tooling.
|
||||||
Generated
+2608
-1065
File diff suppressed because it is too large
Load Diff
+66
-6
@@ -5,14 +5,15 @@ members = [
|
|||||||
"api-router",
|
"api-router",
|
||||||
"html-router",
|
"html-router",
|
||||||
"ingestion-pipeline",
|
"ingestion-pipeline",
|
||||||
"composite-retrieval",
|
"retrieval-pipeline",
|
||||||
"json-stream-parser"
|
"json-stream-parser",
|
||||||
|
"evaluations"
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
anyhow = "1.0.94"
|
anyhow = "1.0.94"
|
||||||
async-openai = "0.24.1"
|
async-openai = "0.29.3"
|
||||||
async-stream = "0.3.6"
|
async-stream = "0.3.6"
|
||||||
async-trait = "0.1.88"
|
async-trait = "0.1.88"
|
||||||
axum-htmx = "0.7.0"
|
axum-htmx = "0.7.0"
|
||||||
@@ -39,9 +40,11 @@ serde_json = "1.0.128"
|
|||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
sha2 = "0.10.8"
|
sha2 = "0.10.8"
|
||||||
surrealdb-migrations = "2.2.2"
|
surrealdb-migrations = "2.2.2"
|
||||||
surrealdb = { version = "2", features = ["kv-mem"] }
|
surrealdb = { version = "2" }
|
||||||
tempfile = "3.12.0"
|
tempfile = "3.12.0"
|
||||||
text-splitter = "0.18.1"
|
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
|
||||||
|
tokenizers = { version = "0.20.4", features = ["http"] }
|
||||||
|
unicode-normalization = "0.1.24"
|
||||||
thiserror = "1.0.63"
|
thiserror = "1.0.63"
|
||||||
tokio-util = { version = "0.7.15", features = ["io"] }
|
tokio-util = { version = "0.7.15", features = ["io"] }
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
@@ -53,9 +56,66 @@ url = { version = "2.5.2", features = ["serde"] }
|
|||||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||||
tokio-retry = "0.3.0"
|
tokio-retry = "0.3.0"
|
||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
object_store = { version = "0.11.2" }
|
object_store = { version = "0.11.2", features = ["aws"] }
|
||||||
bytes = "1.7.1"
|
bytes = "1.7.1"
|
||||||
|
state-machines = "0.9"
|
||||||
|
pdf-extract = "0.9"
|
||||||
|
lopdf = "0.32"
|
||||||
|
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
|
||||||
|
|
||||||
[profile.dist]
|
[profile.dist]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
lto = "thin"
|
lto = "thin"
|
||||||
|
|
||||||
|
[workspace.lints.rust]
|
||||||
|
unexpected_cfgs = { level = "warn", check-cfg = ["cfg(feature, values(\"inspect\"))"] }
|
||||||
|
|
||||||
|
[workspace.lints.clippy]
|
||||||
|
# Performance-focused lints
|
||||||
|
perf = { level = "warn", priority = -1 }
|
||||||
|
vec_init_then_push = "warn"
|
||||||
|
large_stack_frames = "warn"
|
||||||
|
redundant_allocation = "warn"
|
||||||
|
single_char_pattern = "warn"
|
||||||
|
string_extend_chars = "warn"
|
||||||
|
format_in_format_args = "warn"
|
||||||
|
slow_vector_initialization = "warn"
|
||||||
|
inefficient_to_string = "warn"
|
||||||
|
implicit_clone = "warn"
|
||||||
|
redundant_clone = "warn"
|
||||||
|
|
||||||
|
# Security-focused lints
|
||||||
|
arithmetic_side_effects = "warn"
|
||||||
|
indexing_slicing = "warn"
|
||||||
|
unwrap_used = "warn"
|
||||||
|
expect_used = "warn"
|
||||||
|
panic = "warn"
|
||||||
|
unimplemented = "warn"
|
||||||
|
todo = "warn"
|
||||||
|
|
||||||
|
# Async/Network lints
|
||||||
|
async_yields_async = "warn"
|
||||||
|
await_holding_invalid_type = "warn"
|
||||||
|
rc_buffer = "warn"
|
||||||
|
|
||||||
|
# Maintainability-focused lints
|
||||||
|
cargo = { level = "warn", priority = -1 }
|
||||||
|
pedantic = { level = "warn", priority = -1 }
|
||||||
|
clone_on_ref_ptr = "warn"
|
||||||
|
float_cmp = "warn"
|
||||||
|
manual_string_new = "warn"
|
||||||
|
uninlined_format_args = "warn"
|
||||||
|
unused_self = "warn"
|
||||||
|
must_use_candidate = "allow"
|
||||||
|
missing_errors_doc = "allow"
|
||||||
|
missing_panics_doc = "warn"
|
||||||
|
module_name_repetitions = "warn"
|
||||||
|
wildcard_dependencies = "warn"
|
||||||
|
missing_docs_in_private_items = "allow"
|
||||||
|
|
||||||
|
# Allow noisy lints that don't add value for this project
|
||||||
|
needless_raw_string_hashes = "allow"
|
||||||
|
multiple_bound_locations = "allow"
|
||||||
|
cargo_common_metadata = "allow"
|
||||||
|
multiple-crate-versions = "allow"
|
||||||
|
|
||||||
|
|||||||
+34
-34
@@ -1,53 +1,53 @@
|
|||||||
# === Builder Stage ===
|
# === Builder ===
|
||||||
FROM clux/muslrust:1.86.0-stable as builder
|
FROM rust:1.91.1-bookworm AS builder
|
||||||
|
|
||||||
WORKDIR /usr/src/minne
|
WORKDIR /usr/src/minne
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Cache deps
|
||||||
COPY Cargo.toml Cargo.lock ./
|
COPY Cargo.toml Cargo.lock ./
|
||||||
RUN mkdir -p api-router common composite-retrieval html-router ingestion-pipeline json-stream-parser main worker
|
RUN mkdir -p api-router common retrieval-pipeline html-router ingestion-pipeline json-stream-parser main worker
|
||||||
COPY api-router/Cargo.toml ./api-router/
|
COPY api-router/Cargo.toml ./api-router/
|
||||||
COPY common/Cargo.toml ./common/
|
COPY common/Cargo.toml ./common/
|
||||||
COPY composite-retrieval/Cargo.toml ./composite-retrieval/
|
COPY retrieval-pipeline/Cargo.toml ./retrieval-pipeline/
|
||||||
COPY html-router/Cargo.toml ./html-router/
|
COPY html-router/Cargo.toml ./html-router/
|
||||||
COPY ingestion-pipeline/Cargo.toml ./ingestion-pipeline/
|
COPY ingestion-pipeline/Cargo.toml ./ingestion-pipeline/
|
||||||
COPY json-stream-parser/Cargo.toml ./json-stream-parser/
|
COPY json-stream-parser/Cargo.toml ./json-stream-parser/
|
||||||
COPY main/Cargo.toml ./main/
|
COPY main/Cargo.toml ./main/
|
||||||
|
RUN cargo build --release --bin main --features ingestion-pipeline/docker || true
|
||||||
|
|
||||||
# Build with the MUSL target
|
# Build
|
||||||
RUN cargo build --release --target x86_64-unknown-linux-musl --bin main --features ingestion-pipeline/docker || true
|
|
||||||
|
|
||||||
# Copy the rest of the source code
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
RUN cargo build --release --bin main --features ingestion-pipeline/docker
|
||||||
|
|
||||||
# Build the final application binary with the MUSL target
|
# === Runtime ===
|
||||||
RUN cargo build --release --target x86_64-unknown-linux-musl --bin main --features ingestion-pipeline/docker
|
FROM debian:bookworm-slim
|
||||||
|
|
||||||
# === Runtime Stage ===
|
# Chromium + runtime deps + OpenMP for ORT
|
||||||
FROM alpine:latest
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
chromium libnss3 libasound2 libgbm1 libxshmfence1 \
|
||||||
|
ca-certificates fonts-dejavu fonts-noto-color-emoji \
|
||||||
|
libgomp1 libstdc++6 curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN apk update && apk add --no-cache \
|
# ONNX Runtime (CPU). Version is read from ort-version (override with --build-arg ORT_VERSION=...).
|
||||||
chromium \
|
COPY ort-version /tmp/ort-version
|
||||||
nss \
|
ARG ORT_VERSION
|
||||||
freetype \
|
RUN ORT_VERSION="${ORT_VERSION:-$(tr -d '[:space:]' < /tmp/ort-version)}" && \
|
||||||
harfbuzz \
|
mkdir -p /opt/onnxruntime && \
|
||||||
ca-certificates \
|
curl -fsSL -o /tmp/ort.tgz \
|
||||||
ttf-freefont \
|
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
||||||
font-noto-emoji \
|
tar -xzf /tmp/ort.tgz -C /opt/onnxruntime --strip-components=1 && rm /tmp/ort.tgz
|
||||||
&& \
|
|
||||||
rm -rf /var/cache/apk/*
|
|
||||||
|
|
||||||
ENV CHROME_BIN=/usr/bin/chromium-browser \
|
ENV CHROME_BIN=/usr/bin/chromium \
|
||||||
CHROME_PATH=/usr/lib/chromium/ \
|
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \
|
||||||
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt
|
ORT_DYLIB_PATH=/opt/onnxruntime/lib/libonnxruntime.so
|
||||||
|
|
||||||
# Create a non-root user to run the application
|
# Non-root
|
||||||
RUN adduser -D -h /home/appuser appuser
|
RUN useradd -m appuser
|
||||||
WORKDIR /home/appuser
|
|
||||||
USER appuser
|
USER appuser
|
||||||
|
WORKDIR /home/appuser
|
||||||
|
|
||||||
# Copy the compiled binary from the builder stage (note the target path)
|
COPY --from=builder /usr/src/minne/target/release/main /usr/local/bin/main
|
||||||
COPY --from=builder /usr/src/minne/target/x86_64-unknown-linux-musl/release/main /usr/local/bin/main
|
|
||||||
|
|
||||||
EXPOSE 3000
|
EXPOSE 3000
|
||||||
# EXPOSE 8000-9000
|
|
||||||
|
|
||||||
CMD ["main"]
|
CMD ["main"]
|
||||||
|
|||||||
@@ -6,200 +6,148 @@
|
|||||||
[](https://www.gnu.org/licenses/agpl-3.0)
|
[](https://www.gnu.org/licenses/agpl-3.0)
|
||||||
[](https://github.com/perstarkse/minne/releases/latest)
|
[](https://github.com/perstarkse/minne/releases/latest)
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## Demo deployment
|
## Demo deployment
|
||||||
|
|
||||||
To test _Minne_ out, enter [this](https://minne-demo.stark.pub) read-only demo deployment to view and test functionality out.
|
To test _Minne_ out, enter [this](https://minne.stark.pub) and sign in to a read-only demo deployment to view and test functionality out.
|
||||||
|
|
||||||
|
## Noteworthy Features
|
||||||
|
|
||||||
|
- **Search & Chat Interface** - Find content or knowledge instantly with full-text search, or use the chat mode and conversational AI to find and reason about content
|
||||||
|
- **Manual and AI-assisted connections** - Build entities and relationships manually with full control, let AI create entities and relationships automatically, or blend both approaches with AI suggestions for manual approval
|
||||||
|
- **Hybrid Retrieval System** - Search combining vector similarity & full-text search
|
||||||
|
- **Scratchpad Feature** - Quickly capture thoughts and convert them to permanent content when ready
|
||||||
|
- **Visual Graph Explorer** - Interactive D3-based navigation of your knowledge entities and connections
|
||||||
|
- **Multi-Format Support** - Ingest text, URLs, PDFs, audio files, and images into your knowledge base
|
||||||
|
- **Performance Focus** - Built with Rust and server-side rendering for speed and efficiency
|
||||||
|
- **Self-Hosted & Privacy-Focused** - Full control over your data, and compatible with any OpenAI-compatible API that supports structured outputs
|
||||||
|
|
||||||
## The "Why" Behind Minne
|
## The "Why" Behind Minne
|
||||||
|
|
||||||
For a while I've been fascinated by Zettelkasten-style PKM systems. While tools like Logseq and Obsidian are excellent, I found the manual linking process to be a hindrance for me. I also wanted a centralized storage and easy access across devices.
|
For a while I've been fascinated by personal knowledge management systems. I wanted something that made it incredibly easy to capture content - snippets of text, URLs, and other media - while automatically discovering connections between ideas. But I also wanted to maintain control over my knowledge structure.
|
||||||
|
|
||||||
While developing Minne, I discovered [KaraKeep](https://karakeep.com/) (formerly Hoarder), which is an excellent application in a similar space – you probably want to check it out! However, if you're interested in a PKM that builds an automatic network between related concepts using AI, offers search and the **possibility to chat with your knowledge resource**, and provides a blend of manual and AI-driven organization, then Minne might be worth testing.
|
Traditional tools like Logseq and Obsidian are excellent, but the manual linking process often became a hindrance. Meanwhile, fully automated systems sometimes miss important context or create relationships I wouldn't have chosen myself.
|
||||||
|
|
||||||
## Core Philosophy & Features
|
So I built Minne to offer the best of both worlds: effortless content capture with AI-assisted relationship discovery, but with the flexibility to manually curate, edit, or override any connections. You can let AI handle the heavy lifting of extracting entities and finding relationships, take full control yourself, or use a hybrid approach where AI suggests connections that you can approve or modify.
|
||||||
|
|
||||||
Minne is designed to make it incredibly easy to save snippets of text, URLs, and other content (limited, pending demand). Simply send content along with a category tag. Minne then ingests this, leveraging AI to create relevant nodes and relationships within its graph database, alongside your manual categorization. This graph backend allows for discoverable connections between your pieces of knowledge.
|
While developing Minne, I discovered [KaraKeep](https://github.com/karakeep-app/karakeep) (formerly Hoarder), which is an excellent application in a similar space – you probably want to check it out! However, if you're interested in a PKM that offers both intelligent automation and manual curation, with the ability to chat with your knowledge base, then Minne might be worth testing.
|
||||||
|
|
||||||
You can converse with your knowledge base through an LLM-powered chat interface (via OpenAI compatible API, like Ollama or others). For those who like to see the bigger picture, Minne also includes an feature to visually explore your knowledge graph.
|
## Table of Contents
|
||||||
|
|
||||||
You may switch and choose between models used, and have the possiblity to change the prompts to your liking. There is the option to change embeddings length, making it easy to test another embedding model.
|
- [Quick Start](#quick-start)
|
||||||
|
- [Features in Detail](#features-in-detail)
|
||||||
|
- [Configuration](#configuration)
|
||||||
|
- [Tech Stack](#tech-stack)
|
||||||
|
- [Application Architecture](#application-architecture)
|
||||||
|
- [AI Configuration](#ai-configuration--model-selection)
|
||||||
|
- [Roadmap](#roadmap)
|
||||||
|
- [Development](#development)
|
||||||
|
- [Contributing](#contributing)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
The application is built for speed and efficiency using Rust with a Server-Side Rendered (SSR) frontend (HTMX and minimal JavaScript). It's fully responsive, offering a complete mobile interface for reading, editing, and managing your content, including the graph database itself. **PWA (Progressive Web App) support** means you can "install" Minne to your device for a native-like experience. For quick capture on the go on iOS, a [**Shortcut**](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) makes sending content to your Minne instance a breeze.
|
## Quick Start
|
||||||
|
|
||||||
Minne is open source (AGPL), self-hostable, and can be deployed flexibly: via Nix, Docker Compose, pre-built binaries, or by building from source. It can run as a single `main` binary or as separate `server` and `worker` processes for optimized resource allocation.
|
The fastest way to get Minne running is with Docker Compose:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone https://github.com/perstarkse/minne.git
|
||||||
|
cd minne
|
||||||
|
|
||||||
|
# Start Minne and its database
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Access at http://localhost:3000
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required Setup:**
|
||||||
|
- Replace `your_openai_api_key_here` in `docker-compose.yml` with your actual API key
|
||||||
|
- Configure `OPENAI_BASE_URL` if using a custom AI provider (like Ollama)
|
||||||
|
|
||||||
|
For detailed installation options, see [Configuration](#configuration).
|
||||||
|
|
||||||
|
## Features in Detail
|
||||||
|
|
||||||
|
### Search vs. Chat mode
|
||||||
|
|
||||||
|
**Search** - Use when you know roughly what you're looking for. Full-text search finds items quickly by matching your query terms.
|
||||||
|
|
||||||
|
**Chat Mode** - Use when you want to explore concepts, find connections, or reason about your knowledge. The AI analyzes your query and finds relevant context across your entire knowledge base.
|
||||||
|
|
||||||
|
### Content Processing
|
||||||
|
|
||||||
|
Minne automatically processes content you save:
|
||||||
|
1. **Web scraping** extracts readable text from URLs
|
||||||
|
2. **Text analysis** identifies key concepts and relationships
|
||||||
|
3. **Graph creation** builds connections between related content
|
||||||
|
4. **Embedding generation** enables semantic search capabilities
|
||||||
|
|
||||||
|
### Visual Knowledge Graph
|
||||||
|
|
||||||
|
Explore your knowledge as an interactive network with flexible curation options:
|
||||||
|
|
||||||
|
**Manual Curation** - Create knowledge entities and relationships yourself with full control over your graph structure
|
||||||
|
|
||||||
|
**AI Automation** - Let AI automatically extract entities and discover relationships from your content
|
||||||
|
|
||||||
|
**Hybrid Approach** - Get AI-suggested relationships and entities that you can manually review, edit, or approve
|
||||||
|
|
||||||
|
The graph visualization shows:
|
||||||
|
- Knowledge entities as nodes (manually created or AI-extracted)
|
||||||
|
- Relationships as connections (manually defined, AI-discovered, or suggested)
|
||||||
|
- Interactive navigation for discovery and editing
|
||||||
|
|
||||||
|
### Optional FastEmbed Reranking
|
||||||
|
|
||||||
|
Minne ships with an opt-in reranking stage powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs). When enabled, the hybrid retrieval results are rescored with a lightweight cross-encoder before being returned to chat or ingestion flows. In practice this often means more relevant results, boosting answer quality and downstream enrichment.
|
||||||
|
|
||||||
|
⚠️ **Resource notes**
|
||||||
|
- Enabling reranking downloads and caches ~1.1 GB of model data on first startup (cached under `<data_dir>/fastembed/reranker` by default).
|
||||||
|
- Initialization takes longer while warming the cache, and each query consumes extra CPU. The default pool size (2) is tuned for a singe user setup, but could work with a pool size on 1 as well.
|
||||||
|
- The feature is disabled by default. Set `reranking_enabled: true` (or `RERANKING_ENABLED=true`) if you’re comfortable with the additional footprint.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
reranking_enabled: true
|
||||||
|
reranking_pool_size: 2
|
||||||
|
fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults to .fastembed_cache
|
||||||
|
```
|
||||||
|
|
||||||
## Tech Stack
|
## Tech Stack
|
||||||
|
|
||||||
- **Backend:** Rust. Server-Side Rendering (SSR). Axum. Minijinja for templating.
|
- **Backend:** Rust with Axum framework and Server-Side Rendering (SSR)
|
||||||
- **Frontend:** HTML. HTMX and plain JavaScript for interactivity.
|
- **Frontend:** HTML with HTMX and minimal JavaScript for interactivity
|
||||||
- **Database:** SurrealDB
|
- **Database:** SurrealDB (graph, document, and vector search)
|
||||||
- **AI Integration:** OpenAI API compatible endpoint (for chat and content processing), with support for structured outputs.
|
- **AI Integration:** OpenAI-compatible API with structured outputs
|
||||||
- **Web Content Processing:** Relies on a Chromium instance for robust webpage fetching/rendering.
|
- **Web Processing:** Headless Chrome for robust webpage content extraction
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
- **For Docker/Nix:** Docker or Nix installed. These methods handle SurrealDB and Chromium dependencies.
|
|
||||||
- **For Binaries/Source:**
|
|
||||||
- A running SurrealDB instance.
|
|
||||||
- Chromium (or a compatible Chrome browser) installed and accessible in your `PATH`.
|
|
||||||
- Git (if cloning and building from source).
|
|
||||||
- Rust toolchain (if building from source).
|
|
||||||
|
|
||||||
## Getting Started
|
|
||||||
|
|
||||||
You have several options to get Minne up and running:
|
|
||||||
|
|
||||||
### 1. Nix (Recommended for ease of dependency management)
|
|
||||||
|
|
||||||
If you have Nix installed, you can run Minne directly:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
nix run 'github:perstarkse/minne#main'
|
|
||||||
```
|
|
||||||
|
|
||||||
This command will fetch Minne and its dependencies (including Chromium) and run the `main` (combined server/worker) application.
|
|
||||||
|
|
||||||
### 2. Docker Compose (Recommended for containerized environments)
|
|
||||||
|
|
||||||
This is a great way to manage Minne and its SurrealDB dependency together.
|
|
||||||
|
|
||||||
1. Clone the repository (or just save the `docker-compose.yml` below).
|
|
||||||
|
|
||||||
1. Create a `docker-compose.yml` file:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
version: "3.8"
|
|
||||||
services:
|
|
||||||
minne:
|
|
||||||
image: ghcr.io/perstarkse/minne:latest # Pulls the latest pre-built image
|
|
||||||
# Or, to build from local source:
|
|
||||||
# build: .
|
|
||||||
container_name: minne_app
|
|
||||||
ports:
|
|
||||||
- "3000:3000" # Exposes Minne on port 3000
|
|
||||||
environment:
|
|
||||||
# These are examples, ensure they match your SurrealDB setup below
|
|
||||||
# and your actual OpenAI key.
|
|
||||||
SURREALDB_ADDRESS: "ws://surrealdb:8000"
|
|
||||||
SURREALDB_USERNAME: "root_user" # Default from SurrealDB service below
|
|
||||||
SURREALDB_PASSWORD: "root_password" # Default from SurrealDB service below
|
|
||||||
SURREALDB_DATABASE: "minne_db"
|
|
||||||
SURREALDB_NAMESPACE: "minne_ns"
|
|
||||||
OPENAI_API_KEY: "your_openai_api_key_here" # IMPORTANT: Replace with your actual key
|
|
||||||
#OPENAI_BASE_URL: "your_ollama_address" # Uncomment this and change it to override the default openai base url
|
|
||||||
HTTP_PORT: 3000
|
|
||||||
DATA_DIR: "/data" # Data directory inside the container
|
|
||||||
RUST_LOG: "minne=info,tower_http=info" # Example logging level
|
|
||||||
volumes:
|
|
||||||
- ./minne_data:/data # Persists Minne's data (e.g., scraped content) on the host
|
|
||||||
depends_on:
|
|
||||||
- surrealdb
|
|
||||||
networks:
|
|
||||||
- minne-net
|
|
||||||
# Waits for SurrealDB to be ready before starting Minne
|
|
||||||
command: >
|
|
||||||
sh -c "
|
|
||||||
echo 'Waiting for SurrealDB to start...' &&
|
|
||||||
# Adjust sleep time if SurrealDB takes longer to initialize in your environment
|
|
||||||
until nc -z surrealdb 8000; do echo 'Waiting for SurrealDB...'; sleep 2; done &&
|
|
||||||
echo 'SurrealDB is up, starting Minne application...' &&
|
|
||||||
/usr/local/bin/main
|
|
||||||
"
|
|
||||||
# For separate server/worker:
|
|
||||||
# command: /usr/local/bin/server # or /usr/local/bin/worker
|
|
||||||
|
|
||||||
surrealdb:
|
|
||||||
image: surrealdb/surrealdb:latest
|
|
||||||
container_name: minne_surrealdb
|
|
||||||
ports:
|
|
||||||
# Exposes SurrealDB on port 8000 (primarily for direct access/debugging if needed,
|
|
||||||
# not strictly required for Minne if only accessed internally by the minne service)
|
|
||||||
- "127.0.0.1:8000:8000" # Bind to localhost only for SurrealDB by default
|
|
||||||
volumes:
|
|
||||||
# Persists SurrealDB data on the host in a 'surreal_database' folder
|
|
||||||
- ./surreal_database:/database
|
|
||||||
command: >
|
|
||||||
start
|
|
||||||
--log info # Consider 'debug' for troubleshooting
|
|
||||||
--user root_user
|
|
||||||
--pass root_password
|
|
||||||
file:/database/minne_v1.db # Using file-based storage for simplicity
|
|
||||||
networks:
|
|
||||||
- minne-net
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
minne_data: {} # Defines a named volume for Minne data (can be managed by Docker)
|
|
||||||
surreal_database: {} # Defines a named volume for SurrealDB data
|
|
||||||
|
|
||||||
networks:
|
|
||||||
minne-net:
|
|
||||||
driver: bridge
|
|
||||||
```
|
|
||||||
|
|
||||||
1. Run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
Minne will be accessible at `http://localhost:3000`.
|
|
||||||
|
|
||||||
### 3. Pre-built Binaries (GitHub Releases)
|
|
||||||
|
|
||||||
Binaries for Windows, macOS, and Linux (combined `main` version) are available on the [GitHub Releases page](https://github.com/perstarkse/minne/releases/latest).
|
|
||||||
|
|
||||||
1. Download the appropriate binary for your system.
|
|
||||||
1. **You will need to provide and run SurrealDB and have Chromium installed and accessible in your PATH separately.**
|
|
||||||
1. Set the required [Configuration](#configuration) environment variables or use a `config.yaml`.
|
|
||||||
1. Run the executable.
|
|
||||||
|
|
||||||
### 4. Build from Source
|
|
||||||
|
|
||||||
1. Clone the repository:
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/perstarkse/minne.git
|
|
||||||
cd minne
|
|
||||||
```
|
|
||||||
1. **You will need to provide and run SurrealDB and have Chromium installed and accessible in your PATH separately.**
|
|
||||||
1. Set the required [Configuration](#configuration) environment variables or use a `config.yaml`.
|
|
||||||
1. Build and run:
|
|
||||||
- For the combined `main` binary:
|
|
||||||
```bash
|
|
||||||
cargo run --release --bin main
|
|
||||||
```
|
|
||||||
- For the `server` binary:
|
|
||||||
```bash
|
|
||||||
cargo run --release --bin server
|
|
||||||
```
|
|
||||||
- For the `worker` binary (if you want to run it separately):
|
|
||||||
```bash
|
|
||||||
cargo run --release --bin worker
|
|
||||||
```
|
|
||||||
The compiled binaries will be in `target/release/`.
|
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
Minne can be configured using environment variables or a `config.yaml` file placed in the working directory where you run the application. Environment variables take precedence over `config.yaml`.
|
Minne can be configured using environment variables or a `config.yaml` file. Environment variables take precedence over `config.yaml`.
|
||||||
|
|
||||||
**Required Configuration:**
|
### Required Configuration
|
||||||
|
|
||||||
- `SURREALDB_ADDRESS`: WebSocket address of your SurrealDB instance (e.g., `ws://127.0.0.1:8000` or `ws://surrealdb:8000` for Docker).
|
- `SURREALDB_ADDRESS`: WebSocket address of your SurrealDB instance (e.g., `ws://127.0.0.1:8000`)
|
||||||
- `SURREALDB_USERNAME`: Username for SurrealDB (e.g., `root_user`).
|
- `SURREALDB_USERNAME`: Username for SurrealDB (e.g., `root_user`)
|
||||||
- `SURREALDB_PASSWORD`: Password for SurrealDB (e.g., `root_password`).
|
- `SURREALDB_PASSWORD`: Password for SurrealDB (e.g., `root_password`)
|
||||||
- `SURREALDB_DATABASE`: Database name in SurrealDB (e.g., `minne_db`).
|
- `SURREALDB_DATABASE`: Database name in SurrealDB (e.g., `minne_db`)
|
||||||
- `SURREALDB_NAMESPACE`: Namespace in SurrealDB (e.g., `minne_ns`).
|
- `SURREALDB_NAMESPACE`: Namespace in SurrealDB (e.g., `minne_ns`)
|
||||||
- `OPENAI_API_KEY`: Your API key for OpenAI compatible endpoint (e.g., `sk-YourActualOpenAIKeyGoesHere`).
|
- `OPENAI_API_KEY`: Your API key for OpenAI compatible endpoint
|
||||||
- `HTTP_PORT`: Port for the Minne server to listen on (Default: `3000`).
|
- `HTTP_PORT`: Port for the Minne server (Default: `3000`)
|
||||||
|
|
||||||
**Optional Configuration:**
|
### Optional Configuration
|
||||||
|
|
||||||
- `RUST_LOG`: Controls logging level (e.g., `minne=info,tower_http=debug`).
|
- `RUST_LOG`: Controls logging level (e.g., `minne=info,tower_http=debug`)
|
||||||
- `DATA_DIR`: Directory to store local data like fetched webpage content (e.g., `./data`).
|
- `DATA_DIR`: Directory to store local data (e.g., `./data`)
|
||||||
- `OPENAI_BASE_URL`: Base URL to a OpenAI API provider, such as Ollama.
|
- `OPENAI_BASE_URL`: Base URL for custom AI providers (like Ollama)
|
||||||
|
- `RERANKING_ENABLED` / `reranking_enabled`: Set to `true` to enable the FastEmbed reranking stage (default `false`)
|
||||||
|
- `RERANKING_POOL_SIZE` / `reranking_pool_size`: Maximum concurrent reranker workers (defaults to `2`)
|
||||||
|
- `FASTEMBED_CACHE_DIR` / `fastembed_cache_dir`: Directory for cached FastEmbed models (defaults to `<data_dir>/fastembed/reranker`)
|
||||||
|
- `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` / `fastembed_show_download_progress`: Show model download progress when warming the cache (default `true`)
|
||||||
|
|
||||||
**Example `config.yaml`:**
|
### Example config.yaml
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
surrealdb_address: "ws://127.0.0.1:8000"
|
surrealdb_address: "ws://127.0.0.1:8000"
|
||||||
@@ -213,66 +161,105 @@ http_port: 3000
|
|||||||
# rust_log: "info"
|
# rust_log: "info"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Application Architecture (Binaries)
|
## Installation Options
|
||||||
|
|
||||||
Minne offers flexibility in deployment:
|
### 1. Docker Compose (Recommended)
|
||||||
|
|
||||||
- **`main`**: A combined binary running both server (API, web UI) and worker (background tasks) in one process. Ideal for simpler setups.
|
```bash
|
||||||
- **`server`**: Runs only the server component.
|
# Clone and run
|
||||||
- **`worker`**: Runs only the worker component, suitable for deployment on a machine with more resources for intensive tasks.
|
git clone https://github.com/perstarkse/minne.git
|
||||||
|
cd minne
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
This modularity allows scaling and resource optimization. The `main` binary or the Docker Compose setup (using `main`) is sufficient for most users.
|
The included `docker-compose.yml` handles SurrealDB and Chromium dependencies automatically.
|
||||||
|
|
||||||
|
### 2. Nix
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nix run 'github:perstarkse/minne#main'
|
||||||
|
```
|
||||||
|
|
||||||
|
This fetches Minne and all dependencies, including Chromium.
|
||||||
|
|
||||||
|
### 3. Pre-built Binaries
|
||||||
|
|
||||||
|
Download binaries for Windows, macOS, and Linux from the [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||||
|
|
||||||
|
**Requirements:** You'll need to provide SurrealDB and Chromium separately.
|
||||||
|
|
||||||
|
### 4. Build from Source
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/perstarkse/minne.git
|
||||||
|
cd minne
|
||||||
|
cargo run --release --bin main
|
||||||
|
```
|
||||||
|
|
||||||
|
**Requirements:** SurrealDB and Chromium must be installed and accessible in your PATH.
|
||||||
|
|
||||||
|
## Application Architecture
|
||||||
|
|
||||||
|
Minne offers flexible deployment options:
|
||||||
|
|
||||||
|
- **`main`**: Combined server and worker in one process (recommended for most users)
|
||||||
|
- **`server`**: Web interface and API only
|
||||||
|
- **`worker`**: Background processing only (for resource optimization)
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Once Minne is running:
|
Once Minne is running at `http://localhost:3000`:
|
||||||
|
|
||||||
1. Access the web interface at `http://localhost:3000` (or your configured port).
|
1. **Web Interface**: Full-featured experience for desktop and mobile
|
||||||
1. On iOS, consider setting up the [Minne iOS Shortcut](https://www.icloud.com/shortcuts/9aa960600ec14329837ba4169f57a166) for effortless content sending. **Add the shortcut, replace the [insert_url] and the [insert_api_key] snippets**.
|
2. **iOS Shortcut**: Use the [Minne iOS Shortcut](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) for quick content capture
|
||||||
1. Add notes, URLs, **audio files**, and explore your growing knowledge graph.
|
3. **Content Types**: Save notes, URLs, audio files, and more
|
||||||
1. Engage with the chat interface to query your saved content.
|
4. **Knowledge Graph**: Explore automatic connections between your content
|
||||||
1. Try the experimental visual graph explorer to see connections.
|
5. **Chat Interface**: Query your knowledge base conversationally
|
||||||
|
|
||||||
## AI Configuration & Model Selection
|
## AI Configuration & Model Selection
|
||||||
|
|
||||||
Minne relies on an OpenAI-compatible API for processing content, generating graph relationships, and powering the chat feature.
|
### Setting Up AI Providers
|
||||||
|
|
||||||
**Environment Variables / `config.yaml` keys:**
|
Minne uses OpenAI-compatible APIs. Configure via environment variables or `config.yaml`:
|
||||||
|
|
||||||
- `OPENAI_API_KEY` (required): Your API key for the chosen AI provider.
|
- `OPENAI_API_KEY` (required): Your API key
|
||||||
- `OPENAI_BASE_URL` (optional): Use this to override the default OpenAI API URL (`https://api.openai.com/v1`). This is essential for using local models via services like Ollama, or other API providers.
|
- `OPENAI_BASE_URL` (optional): Custom provider URL (e.g., Ollama: `http://localhost:11434/v1`)
|
||||||
- **Example for Ollama:** `http://<your-ollama-ip>:11434/v1`
|
|
||||||
|
|
||||||
### Changing Models
|
### Model Selection
|
||||||
|
|
||||||
Once you have configured the `OPENAI_BASE_URL` to point to your desired provider, you can select the specific models Minne should use.
|
1. Access the `/admin` page in your Minne instance
|
||||||
|
2. Select models for content processing and chat from your configured provider
|
||||||
1. Navigate to the `/admin` page in your Minne instance.
|
3. **Content Processing Requirements**: The model must support structured outputs
|
||||||
1. The page will list the models available from your configured endpoint. You can select different models for processing content and for chat.
|
4. **Embedding Dimensions**: Update this setting when changing embedding models (e.g., 1536 for `text-embedding-3-small`, 768 for `nomic-embed-text`)
|
||||||
1. **Important:** For content processing, Minne relies on structured outputs (function calling). The model and provider you select for this task **must** support this feature.
|
|
||||||
1. **Embedding Dimensions:** If you change the embedding model, you **must** update the "Embedding Dimensions" setting in the admin panel to match the output dimensions of your new model (e.g., `text-embedding-3-small` uses 1536, `nomic-embed-text` uses 768). Mismatched dimensions will cause errors. Some newer models will accept a dimension argument, and for these setting the dimensions to whatever should work.
|
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
I've developed Minne primarily for my own use, but having been in the selfhosted space for a long time, and using the efforts by others, I thought I'd share with the community. Feature requests are welcome.
|
Current development focus:
|
||||||
The roadmap as of now is:
|
|
||||||
|
|
||||||
~~- Handle uploaded images wisely.~~
|
- TUI frontend with system editor integration
|
||||||
~~- An updated explorer of the graph database.~~
|
- Enhanced reranking for improved retrieval recall
|
||||||
- A TUI frontend which opens your system default editor for improved writing and document management.
|
- Additional content type support
|
||||||
|
|
||||||
## Contributing
|
Feature requests and contributions are welcome!
|
||||||
|
|
||||||
Contributions are welcome! Whether it's bug reports, feature suggestions, documentation improvements, or code contributions, please feel free to open an issue or submit a pull request.
|
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
Run test with
|
```bash
|
||||||
```rust
|
# Run tests
|
||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
|
# Development build
|
||||||
|
cargo build
|
||||||
|
|
||||||
|
# Comprehensive linting
|
||||||
|
cargo clippy --workspace --all-targets --all-features
|
||||||
```
|
```
|
||||||
There is currently a variety of unit tests for commonly used functions. Additional tests, especially integration tests would be very welcome.
|
|
||||||
|
The codebase includes extensive unit tests. Integration tests and additional contributions are welcome.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
I've developed Minne primarily for my own use, but having been in the selfhosted space for a long time, and using the efforts by others, I thought I'd share with the community. Feature requests are welcome.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Minne is licensed under the **GNU Affero General Public License v3.0 (AGPL-3.0)**. See the [LICENSE](LICENSE) file for details. This means if you run a modified version of Minne as a network service, you must also offer the source code of that modified version to its users.
|
Minne is licensed under the **GNU Affero General Public License v3.0 (AGPL-3.0)**. See the [LICENSE](LICENSE) file for details.
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ version = "0.1.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "AGPL-3.0-or-later"
|
license = "AGPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
@@ -17,3 +20,8 @@ futures = { workspace = true }
|
|||||||
axum_typed_multipart = { workspace = true}
|
axum_typed_multipart = { workspace = true}
|
||||||
|
|
||||||
common = { path = "../common" }
|
common = { path = "../common" }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
common = { path = "../common", features = ["test-utils"] }
|
||||||
|
tower = "0.5"
|
||||||
|
uuid = { workspace = true }
|
||||||
|
|||||||
@@ -1,33 +1,13 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use common::{storage::db::SurrealDbClient, utils::config::AppConfig};
|
use common::{
|
||||||
|
storage::{db::SurrealDbClient, store::StorageManager},
|
||||||
|
utils::config::AppConfig,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ApiState {
|
pub struct ApiState {
|
||||||
pub db: Arc<SurrealDbClient>,
|
pub db: Arc<SurrealDbClient>,
|
||||||
pub config: AppConfig,
|
pub config: AppConfig,
|
||||||
}
|
pub storage: StorageManager,
|
||||||
|
|
||||||
impl ApiState {
|
|
||||||
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
|
||||||
let surreal_db_client = Arc::new(
|
|
||||||
SurrealDbClient::new(
|
|
||||||
&config.surrealdb_address,
|
|
||||||
&config.surrealdb_username,
|
|
||||||
&config.surrealdb_password,
|
|
||||||
&config.surrealdb_namespace,
|
|
||||||
&config.surrealdb_database,
|
|
||||||
)
|
|
||||||
.await?,
|
|
||||||
);
|
|
||||||
|
|
||||||
surreal_db_client.apply_migrations().await?;
|
|
||||||
|
|
||||||
let app_state = ApiState {
|
|
||||||
db: surreal_db_client.clone(),
|
|
||||||
config: config.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(app_state)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
+65
-40
@@ -7,66 +7,75 @@ use common::error::AppError;
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Error, Debug, Serialize, Clone)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ApiError {
|
pub enum ApiErr {
|
||||||
#[error("Internal server error")]
|
#[error("internal server error")]
|
||||||
InternalError(String),
|
InternalError(String),
|
||||||
|
|
||||||
#[error("Validation error: {0}")]
|
#[error("validation error: {0}")]
|
||||||
ValidationError(String),
|
ValidationError(String),
|
||||||
|
|
||||||
#[error("Not found: {0}")]
|
#[error("not found: {0}")]
|
||||||
NotFound(String),
|
NotFound(String),
|
||||||
|
|
||||||
#[error("Unauthorized: {0}")]
|
#[error("unauthorized: {0}")]
|
||||||
Unauthorized(String),
|
Unauthorized(String),
|
||||||
|
|
||||||
|
#[error("payload too large: {0}")]
|
||||||
|
PayloadTooLarge(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<AppError> for ApiError {
|
impl From<AppError> for ApiErr {
|
||||||
fn from(err: AppError) -> Self {
|
fn from(err: AppError) -> Self {
|
||||||
match err {
|
match err {
|
||||||
AppError::Database(_) | AppError::OpenAI(_) => {
|
AppError::NotFound(msg) => Self::NotFound(msg),
|
||||||
tracing::error!("Internal error: {:?}", err);
|
AppError::Validation(msg) => Self::ValidationError(msg),
|
||||||
ApiError::InternalError("Internal server error".to_string())
|
AppError::Auth(msg) => Self::Unauthorized(msg),
|
||||||
|
other => {
|
||||||
|
tracing::error!("internal API error: {other:?}");
|
||||||
|
Self::InternalError("Internal server error".to_string())
|
||||||
}
|
}
|
||||||
AppError::NotFound(msg) => ApiError::NotFound(msg),
|
|
||||||
AppError::Validation(msg) => ApiError::ValidationError(msg),
|
|
||||||
AppError::Auth(msg) => ApiError::Unauthorized(msg),
|
|
||||||
_ => ApiError::InternalError("Internal server error".to_string()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl IntoResponse for ApiError {
|
impl IntoResponse for ApiErr {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
let (status, error_response) = match self {
|
let (status, error_response) = match self {
|
||||||
ApiError::InternalError(message) => (
|
Self::InternalError(message) => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
ErrorResponse {
|
ErrorResponse {
|
||||||
error: message,
|
error: message,
|
||||||
status: "error".to_string(),
|
status: "error".to_string(),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
ApiError::ValidationError(message) => (
|
Self::ValidationError(message) => (
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
ErrorResponse {
|
ErrorResponse {
|
||||||
error: message,
|
error: message,
|
||||||
status: "error".to_string(),
|
status: "error".to_string(),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
ApiError::NotFound(message) => (
|
Self::NotFound(message) => (
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
ErrorResponse {
|
ErrorResponse {
|
||||||
error: message,
|
error: message,
|
||||||
status: "error".to_string(),
|
status: "error".to_string(),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
ApiError::Unauthorized(message) => (
|
Self::Unauthorized(message) => (
|
||||||
StatusCode::UNAUTHORIZED,
|
StatusCode::UNAUTHORIZED,
|
||||||
ErrorResponse {
|
ErrorResponse {
|
||||||
error: message,
|
error: message,
|
||||||
status: "error".to_string(),
|
status: "error".to_string(),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
Self::PayloadTooLarge(message) => (
|
||||||
|
StatusCode::PAYLOAD_TOO_LARGE,
|
||||||
|
ErrorResponse {
|
||||||
|
error: message,
|
||||||
|
status: "error".to_string(),
|
||||||
|
},
|
||||||
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
(status, Json(error_response)).into_response()
|
(status, Json(error_response)).into_response()
|
||||||
@@ -84,6 +93,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use common::error::AppError;
|
use common::error::AppError;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
use std::io;
|
||||||
|
|
||||||
// Helper to check status code
|
// Helper to check status code
|
||||||
fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) {
|
fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) {
|
||||||
@@ -95,43 +105,58 @@ mod tests {
|
|||||||
fn test_app_error_to_api_error_conversion() {
|
fn test_app_error_to_api_error_conversion() {
|
||||||
// Test NotFound error conversion
|
// Test NotFound error conversion
|
||||||
let not_found = AppError::NotFound("resource not found".to_string());
|
let not_found = AppError::NotFound("resource not found".to_string());
|
||||||
let api_error = ApiError::from(not_found);
|
let api_error = ApiErr::from(not_found);
|
||||||
assert!(matches!(api_error, ApiError::NotFound(msg) if msg == "resource not found"));
|
assert!(matches!(api_error, ApiErr::NotFound(msg) if msg == "resource not found"));
|
||||||
|
|
||||||
// Test Validation error conversion
|
// Test Validation error conversion
|
||||||
let validation = AppError::Validation("invalid input".to_string());
|
let validation = AppError::Validation("invalid input".to_string());
|
||||||
let api_error = ApiError::from(validation);
|
let api_error = ApiErr::from(validation);
|
||||||
assert!(matches!(api_error, ApiError::ValidationError(msg) if msg == "invalid input"));
|
assert!(matches!(api_error, ApiErr::ValidationError(msg) if msg == "invalid input"));
|
||||||
|
|
||||||
// Test Auth error conversion
|
// Test Auth error conversion
|
||||||
let auth = AppError::Auth("unauthorized".to_string());
|
let auth = AppError::Auth("unauthorized".to_string());
|
||||||
let api_error = ApiError::from(auth);
|
let api_error = ApiErr::from(auth);
|
||||||
assert!(matches!(api_error, ApiError::Unauthorized(msg) if msg == "unauthorized"));
|
assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
|
||||||
|
|
||||||
// Test for internal errors - create a mock error that doesn't require surrealdb
|
// Test for internal errors - create a mock error that doesn't require surrealdb
|
||||||
let internal_error =
|
let internal_error = AppError::Io(io::Error::other("io error"));
|
||||||
AppError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io error"));
|
let api_error = ApiErr::from(internal_error);
|
||||||
let api_error = ApiError::from(internal_error);
|
assert!(matches!(
|
||||||
assert!(matches!(api_error, ApiError::InternalError(_)));
|
api_error,
|
||||||
|
ApiErr::InternalError(msg) if msg == "Internal server error"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_app_error_internal_error_is_sanitized() {
|
||||||
|
let api_error = ApiErr::from(AppError::internal("db password incorrect"));
|
||||||
|
assert!(matches!(
|
||||||
|
api_error,
|
||||||
|
ApiErr::InternalError(msg) if msg == "Internal server error"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_api_error_response_status_codes() {
|
fn test_api_error_response_status_codes() {
|
||||||
// Test internal error status
|
// Test internal error status
|
||||||
let error = ApiError::InternalError("server error".to_string());
|
let error = ApiErr::InternalError("server error".to_string());
|
||||||
assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR);
|
assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
// Test not found status
|
// Test not found status
|
||||||
let error = ApiError::NotFound("not found".to_string());
|
let error = ApiErr::NotFound("not found".to_string());
|
||||||
assert_status_code(error, StatusCode::NOT_FOUND);
|
assert_status_code(error, StatusCode::NOT_FOUND);
|
||||||
|
|
||||||
// Test validation error status
|
// Test validation error status
|
||||||
let error = ApiError::ValidationError("invalid input".to_string());
|
let error = ApiErr::ValidationError("invalid input".to_string());
|
||||||
assert_status_code(error, StatusCode::BAD_REQUEST);
|
assert_status_code(error, StatusCode::BAD_REQUEST);
|
||||||
|
|
||||||
// Test unauthorized status
|
// Test unauthorized status
|
||||||
let error = ApiError::Unauthorized("not allowed".to_string());
|
let error = ApiErr::Unauthorized("not allowed".to_string());
|
||||||
assert_status_code(error, StatusCode::UNAUTHORIZED);
|
assert_status_code(error, StatusCode::UNAUTHORIZED);
|
||||||
|
|
||||||
|
// Test payload too large status
|
||||||
|
let error = ApiErr::PayloadTooLarge("too big".to_string());
|
||||||
|
assert_status_code(error, StatusCode::PAYLOAD_TOO_LARGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Alternative approach that doesn't try to parse the response body
|
// Alternative approach that doesn't try to parse the response body
|
||||||
@@ -139,15 +164,15 @@ mod tests {
|
|||||||
fn test_error_messages() {
|
fn test_error_messages() {
|
||||||
// For validation errors
|
// For validation errors
|
||||||
let message = "invalid data format";
|
let message = "invalid data format";
|
||||||
let error = ApiError::ValidationError(message.to_string());
|
let error = ApiErr::ValidationError(message.to_string());
|
||||||
|
|
||||||
// Check that the error itself contains the message
|
// Check that the error itself contains the message
|
||||||
assert_eq!(error.to_string(), format!("Validation error: {}", message));
|
assert_eq!(error.to_string(), format!("validation error: {message}"));
|
||||||
|
|
||||||
// For not found errors
|
// For not found errors
|
||||||
let message = "user not found";
|
let message = "user not found";
|
||||||
let error = ApiError::NotFound(message.to_string());
|
let error = ApiErr::NotFound(message.to_string());
|
||||||
assert_eq!(error.to_string(), format!("Not found: {}", message));
|
assert_eq!(error.to_string(), format!("not found: {message}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Alternative approach for internal error test
|
// Alternative approach for internal error test
|
||||||
@@ -156,11 +181,11 @@ mod tests {
|
|||||||
// Create a sensitive error message
|
// Create a sensitive error message
|
||||||
let sensitive_info = "db password incorrect";
|
let sensitive_info = "db password incorrect";
|
||||||
|
|
||||||
// Create ApiError with sensitive info
|
// Create ApiErr with sensitive info
|
||||||
let api_error = ApiError::InternalError(sensitive_info.to_string());
|
let api_error = ApiErr::InternalError(sensitive_info.to_string());
|
||||||
|
|
||||||
// Check the error message is correctly set
|
// Check the error message is correctly set
|
||||||
assert_eq!(api_error.to_string(), "Internal server error");
|
assert_eq!(api_error.to_string(), "internal server error");
|
||||||
|
|
||||||
// Also verify correct status code
|
// Also verify correct status code
|
||||||
assert_status_code(api_error, StatusCode::INTERNAL_SERVER_ERROR);
|
assert_status_code(api_error, StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use axum::{
|
|||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use middleware_api_auth::api_auth;
|
use middleware_api_auth::api_auth;
|
||||||
use routes::{categories::get_categories, ingress::ingest_data, liveness::live, readiness::ready};
|
use routes::{categories::list, ingest::handle, liveness::live, readiness::ready};
|
||||||
|
|
||||||
pub mod api_state;
|
pub mod api_state;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
@@ -26,9 +26,13 @@ where
|
|||||||
|
|
||||||
// Protected API endpoints (require auth)
|
// Protected API endpoints (require auth)
|
||||||
let protected = Router::new()
|
let protected = Router::new()
|
||||||
.route("/ingress", post(ingest_data))
|
.route(
|
||||||
.route("/categories", get(get_categories))
|
"/ingest",
|
||||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
post(handle).layer(DefaultBodyLimit::max(
|
||||||
|
app_state.config.ingest_max_body_bytes,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.route("/categories", get(list))
|
||||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth));
|
.route_layer(from_fn_with_state(app_state.clone(), api_auth));
|
||||||
|
|
||||||
public.merge(protected)
|
public.merge(protected)
|
||||||
|
|||||||
@@ -6,28 +6,26 @@ use axum::{
|
|||||||
|
|
||||||
use common::storage::types::user::User;
|
use common::storage::types::user::User;
|
||||||
|
|
||||||
use crate::{api_state::ApiState, error::ApiError};
|
use crate::{api_state::ApiState, error::ApiErr};
|
||||||
|
|
||||||
pub async fn api_auth(
|
pub async fn api_auth(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
mut request: Request,
|
mut request: Request,
|
||||||
next: Next,
|
next: Next,
|
||||||
) -> Result<Response, ApiError> {
|
) -> Result<Response, ApiErr> {
|
||||||
let api_key = extract_api_key(&request).ok_or(ApiError::Unauthorized(
|
let api_key = extract_api_key(&request)
|
||||||
"You have to be authenticated".to_string(),
|
.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
|
||||||
))?;
|
|
||||||
|
|
||||||
let user = User::find_by_api_key(&api_key, &state.db).await?;
|
let user = User::find_by_api_key(api_key, &state.db).await?;
|
||||||
let user = user.ok_or(ApiError::Unauthorized(
|
let user =
|
||||||
"You have to be authenticated".to_string(),
|
user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
|
||||||
))?;
|
|
||||||
|
|
||||||
request.extensions_mut().insert(user);
|
request.extensions_mut().insert(user);
|
||||||
|
|
||||||
Ok(next.run(request).await)
|
Ok(next.run(request).await)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_api_key(request: &Request) -> Option<String> {
|
fn extract_api_key(request: &Request) -> Option<&str> {
|
||||||
request
|
request
|
||||||
.headers()
|
.headers()
|
||||||
.get("X-API-Key")
|
.get("X-API-Key")
|
||||||
@@ -37,7 +35,67 @@ fn extract_api_key(request: &Request) -> Option<String> {
|
|||||||
.headers()
|
.headers()
|
||||||
.get("Authorization")
|
.get("Authorization")
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(|s| s.trim()))
|
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||||
|
.map(str::trim)
|
||||||
})
|
})
|
||||||
.map(String::from)
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
|
mod tests {
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::http::{HeaderValue, Request};
|
||||||
|
|
||||||
|
use super::extract_api_key;
|
||||||
|
|
||||||
|
fn request_with_headers(headers: &[(&str, &str)]) -> Request<Body> {
|
||||||
|
let mut builder = Request::builder().method("GET").uri("/");
|
||||||
|
for (name, value) in headers {
|
||||||
|
builder = builder.header(*name, *value);
|
||||||
|
}
|
||||||
|
builder.body(Body::empty()).expect("test request")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_from_x_api_key_header() {
|
||||||
|
let request = request_with_headers(&[("X-API-Key", "sk_test_key")]);
|
||||||
|
assert_eq!(extract_api_key(&request), Some("sk_test_key"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_from_bearer_authorization() {
|
||||||
|
let request = request_with_headers(&[("Authorization", "Bearer sk_bearer_key")]);
|
||||||
|
assert_eq!(extract_api_key(&request), Some("sk_bearer_key"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_prefers_x_api_key_over_authorization() {
|
||||||
|
let request = request_with_headers(&[
|
||||||
|
("X-API-Key", "sk_header"),
|
||||||
|
("Authorization", "Bearer sk_bearer"),
|
||||||
|
]);
|
||||||
|
assert_eq!(extract_api_key(&request), Some("sk_header"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_returns_none_when_missing() {
|
||||||
|
let request = request_with_headers(&[]);
|
||||||
|
assert_eq!(extract_api_key(&request), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_rejects_non_bearer_authorization() {
|
||||||
|
let request = request_with_headers(&[("Authorization", "Basic abc")]);
|
||||||
|
assert_eq!(extract_api_key(&request), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_rejects_invalid_header_values() {
|
||||||
|
let mut request = request_with_headers(&[]);
|
||||||
|
request.headers_mut().insert(
|
||||||
|
"X-API-Key",
|
||||||
|
HeaderValue::from_bytes(&[0xFF]).expect("invalid header"),
|
||||||
|
);
|
||||||
|
assert_eq!(extract_api_key(&request), None);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
use axum::{extract::State, response::IntoResponse, Extension, Json};
|
use axum::{extract::State, response::IntoResponse, Extension, Json};
|
||||||
use common::storage::types::user::User;
|
use common::storage::types::user::User;
|
||||||
|
|
||||||
use crate::{api_state::ApiState, error::ApiError};
|
use crate::{api_state::ApiState, error::ApiErr};
|
||||||
|
|
||||||
pub async fn get_categories(
|
pub async fn list(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
Extension(user): Extension<User>,
|
Extension(user): Extension<User>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiErr> {
|
||||||
let categories = User::get_user_categories(&user.id, &state.db).await?;
|
let categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||||
|
|
||||||
Ok(Json(categories))
|
Ok(Json(categories))
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
|
||||||
|
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||||
|
use common::{
|
||||||
|
error::AppError,
|
||||||
|
storage::types::{
|
||||||
|
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
|
||||||
|
user::User,
|
||||||
|
},
|
||||||
|
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
|
||||||
|
};
|
||||||
|
use futures::{future::try_join_all, TryFutureExt};
|
||||||
|
use serde_json::json;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::{api_state::ApiState, error::ApiErr};
|
||||||
|
|
||||||
|
#[derive(Debug, TryFromMultipart)]
|
||||||
|
pub struct Params {
|
||||||
|
pub content: Option<String>,
|
||||||
|
pub context: String,
|
||||||
|
pub category: String,
|
||||||
|
#[form_data(limit = "20000000")]
|
||||||
|
#[form_data(default)]
|
||||||
|
pub files: Vec<FieldData<NamedTempFile>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle(
|
||||||
|
State(state): State<ApiState>,
|
||||||
|
Extension(user): Extension<User>,
|
||||||
|
TypedMultipart(input): TypedMultipart<Params>,
|
||||||
|
) -> Result<impl IntoResponse, ApiErr> {
|
||||||
|
let user_id = user.id;
|
||||||
|
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
|
||||||
|
|
||||||
|
match validate_ingest_input(
|
||||||
|
&state.config,
|
||||||
|
input.content.as_deref(),
|
||||||
|
&input.context,
|
||||||
|
&input.category,
|
||||||
|
input.files.len(),
|
||||||
|
) {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(IngestValidationError::PayloadTooLarge(message)) => {
|
||||||
|
return Err(ApiErr::PayloadTooLarge(message));
|
||||||
|
}
|
||||||
|
Err(IngestValidationError::BadRequest(message)) => {
|
||||||
|
return Err(ApiErr::ValidationError(message));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
user_id = %user_id,
|
||||||
|
has_content,
|
||||||
|
content_len = input.content.as_ref().map_or(0, String::len),
|
||||||
|
context_len = input.context.len(),
|
||||||
|
category_len = input.category.len(),
|
||||||
|
file_count = input.files.len(),
|
||||||
|
"Received ingest request"
|
||||||
|
);
|
||||||
|
|
||||||
|
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
||||||
|
FileInfo::new_with_storage(file, &state.db, &user_id, &state.storage)
|
||||||
|
.map_err(AppError::from)
|
||||||
|
}))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let payloads = IngestionPayload::create_ingestion_payload(
|
||||||
|
input.content,
|
||||||
|
input.context,
|
||||||
|
input.category,
|
||||||
|
file_infos,
|
||||||
|
user_id.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
IngestionTask::create_all_and_add_to_db(payloads, &user_id, &state.db).await?;
|
||||||
|
|
||||||
|
Ok((StatusCode::OK, Json(json!({ "status": "success" }))))
|
||||||
|
}
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
|
|
||||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
|
||||||
use common::{
|
|
||||||
error::AppError,
|
|
||||||
storage::types::{
|
|
||||||
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
|
|
||||||
user::User,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use futures::{future::try_join_all, TryFutureExt};
|
|
||||||
use serde_json::json;
|
|
||||||
use tempfile::NamedTempFile;
|
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
use crate::{api_state::ApiState, error::ApiError};
|
|
||||||
|
|
||||||
#[derive(Debug, TryFromMultipart)]
|
|
||||||
pub struct IngestParams {
|
|
||||||
pub content: Option<String>,
|
|
||||||
pub context: String,
|
|
||||||
pub category: String,
|
|
||||||
#[form_data(limit = "10000000")] // Adjust limit as needed
|
|
||||||
#[form_data(default)]
|
|
||||||
pub files: Vec<FieldData<NamedTempFile>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ingest_data(
|
|
||||||
State(state): State<ApiState>,
|
|
||||||
Extension(user): Extension<User>,
|
|
||||||
TypedMultipart(input): TypedMultipart<IngestParams>,
|
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
|
||||||
info!("Received input: {:?}", input);
|
|
||||||
|
|
||||||
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
|
||||||
FileInfo::new(file, &state.db, &user.id, &state.config).map_err(AppError::from)
|
|
||||||
}))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let payloads = IngestionPayload::create_ingestion_payload(
|
|
||||||
input.content,
|
|
||||||
input.context,
|
|
||||||
input.category,
|
|
||||||
file_infos,
|
|
||||||
user.id.as_str(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let futures: Vec<_> = payloads
|
|
||||||
.into_iter()
|
|
||||||
.map(|object| {
|
|
||||||
IngestionTask::create_and_add_to_db(object.clone(), user.id.clone(), &state.db)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
try_join_all(futures).await?;
|
|
||||||
|
|
||||||
Ok((StatusCode::OK, Json(json!({ "status": "success" }))))
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
pub mod categories;
|
pub mod categories;
|
||||||
pub mod ingress;
|
pub mod ingest;
|
||||||
pub mod liveness;
|
pub mod liveness;
|
||||||
pub mod readiness;
|
pub mod readiness;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use tracing::error;
|
||||||
|
|
||||||
use crate::api_state::ApiState;
|
use crate::api_state::ApiState;
|
||||||
|
|
||||||
@@ -13,13 +14,15 @@ pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
|
|||||||
"checks": { "db": "ok" }
|
"checks": { "db": "ok" }
|
||||||
})),
|
})),
|
||||||
),
|
),
|
||||||
Err(e) => (
|
Err(e) => {
|
||||||
|
error!("readiness check failed: {e:?}");
|
||||||
|
(
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
Json(json!({
|
Json(json!({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"checks": { "db": "fail" },
|
"checks": { "db": "fail" }
|
||||||
"reason": e.to_string()
|
|
||||||
})),
|
})),
|
||||||
),
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
#![allow(clippy::expect_used)]
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use api_router::{api_routes_v1, api_state::ApiState};
|
||||||
|
use axum::{
|
||||||
|
body::{to_bytes, Body},
|
||||||
|
http::{Request, StatusCode},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use common::{
|
||||||
|
storage::{db::SurrealDbClient, store::StorageManager, types::user::User},
|
||||||
|
utils::config::{AppConfig, StorageKind},
|
||||||
|
};
|
||||||
|
use tower::ServiceExt;
|
||||||
|
|
||||||
|
async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
|
||||||
|
let namespace = "api_router_test";
|
||||||
|
let database = uuid::Uuid::new_v4().to_string();
|
||||||
|
let db = Arc::new(
|
||||||
|
SurrealDbClient::memory(namespace, &database)
|
||||||
|
.await
|
||||||
|
.expect("in-memory db"),
|
||||||
|
);
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.expect("migrations should apply");
|
||||||
|
|
||||||
|
let config = AppConfig {
|
||||||
|
storage: StorageKind::Memory,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let storage = StorageManager::new(&config).await.expect("storage manager");
|
||||||
|
|
||||||
|
let state = ApiState {
|
||||||
|
db: Arc::clone(&db),
|
||||||
|
config,
|
||||||
|
storage,
|
||||||
|
};
|
||||||
|
|
||||||
|
let router = api_routes_v1(&state).with_state(state);
|
||||||
|
|
||||||
|
(router, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn response_body(response: axum::response::Response) -> String {
|
||||||
|
let body = to_bytes(response.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.expect("response body");
|
||||||
|
String::from_utf8(body.to_vec()).expect("utf-8 body")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn live_probe_is_public() {
|
||||||
|
let (app, _db) = build_test_app().await;
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/live")
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("live request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("live response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
assert!(response_body(response).await.contains("\"status\":\"ok\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn ready_probe_is_public_and_reports_db_ok() {
|
||||||
|
let (app, _db) = build_test_app().await;
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/ready")
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("ready request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("ready response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
let body = response_body(response).await;
|
||||||
|
assert!(body.contains("\"checks\":{\"db\":\"ok\"}") || body.contains("\"db\":\"ok\""));
|
||||||
|
assert!(!body.contains("reason"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn protected_route_requires_api_key() {
|
||||||
|
let (app, _db) = build_test_app().await;
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/categories")
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("categories request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("categories response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn protected_route_rejects_invalid_api_key() {
|
||||||
|
let (app, _db) = build_test_app().await;
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/categories")
|
||||||
|
.header("X-API-Key", "sk_invalid")
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("categories request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("categories response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn authenticated_user_can_list_categories() {
|
||||||
|
let (app, db) = build_test_app().await;
|
||||||
|
|
||||||
|
let user = User::create_new(
|
||||||
|
"api_router_test@example.com".to_string(),
|
||||||
|
"test_password".to_string(),
|
||||||
|
&db,
|
||||||
|
"UTC".to_string(),
|
||||||
|
"system".to_string(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("test user");
|
||||||
|
|
||||||
|
let api_key = User::set_api_key(&user.id, &db).await.expect("api key");
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/categories")
|
||||||
|
.header("X-API-Key", api_key)
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("categories request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("categories response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
}
|
||||||
+10
-2
@@ -4,6 +4,9 @@ version = "0.1.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "AGPL-3.0-or-later"
|
license = "AGPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Workspace dependencies
|
# Workspace dependencies
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
@@ -13,7 +16,7 @@ tracing = { workspace = true }
|
|||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
surrealdb = { workspace = true, features = ["kv-mem"] }
|
surrealdb = { workspace = true }
|
||||||
async-openai = { workspace = true }
|
async-openai = { workspace = true }
|
||||||
futures = { workspace = true }
|
futures = { workspace = true }
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
@@ -41,7 +44,12 @@ surrealdb-migrations = { workspace = true }
|
|||||||
tokio-retry = { workspace = true }
|
tokio-retry = { workspace = true }
|
||||||
object_store = { workspace = true }
|
object_store = { workspace = true }
|
||||||
bytes = { workspace = true }
|
bytes = { workspace = true }
|
||||||
|
state-machines = { workspace = true }
|
||||||
|
fastembed = { workspace = true }
|
||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
test-utils = []
|
test-utils = ["surrealdb/kv-mem"]
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
surrealdb = { workspace = true, features = ["kv-mem"] }
|
||||||
|
|||||||
+4
-1
@@ -14,8 +14,11 @@ CREATE system_settings:current CONTENT {
|
|||||||
query_model: "gpt-4o-mini",
|
query_model: "gpt-4o-mini",
|
||||||
processing_model: "gpt-4o-mini",
|
processing_model: "gpt-4o-mini",
|
||||||
embedding_model: "text-embedding-3-small",
|
embedding_model: "text-embedding-3-small",
|
||||||
|
voice_processing_model: "whisper-1",
|
||||||
|
image_processing_model: "gpt-4o-mini",
|
||||||
|
image_processing_prompt: "Analyze this image and respond based on its primary content:\n - If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.\n - If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.\n - For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a Text: heading.\n\n Respond directly with the analysis.",
|
||||||
embedding_dimensions: 1536,
|
embedding_dimensions: 1536,
|
||||||
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
|
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
|
||||||
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity\"\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."
|
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity.\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."
|
||||||
};
|
};
|
||||||
END;
|
END;
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- Runtime-managed: text_content FTS indexes now created at startup via the shared Surreal helper.
|
||||||
|
-- This migration is intentionally left as a no-op to avoid heavy index builds during migration.
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
-- No-op: legacy `job` table was superseded by `ingestion_task`; kept for migration order compatibility.
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
-- Runtime-managed: FTS indexes now built at startup; migration retained as a no-op.
|
||||||
@@ -0,0 +1,173 @@
|
|||||||
|
-- State machine migration for ingestion_task records
|
||||||
|
|
||||||
|
DEFINE FIELD IF NOT EXISTS state ON TABLE ingestion_task TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS attempts ON TABLE ingestion_task TYPE option<number>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS max_attempts ON TABLE ingestion_task TYPE option<number>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS scheduled_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS locked_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS lease_duration_secs ON TABLE ingestion_task TYPE option<number>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS worker_id ON TABLE ingestion_task TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS error_code ON TABLE ingestion_task TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS error_message ON TABLE ingestion_task TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS last_error_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS priority ON TABLE ingestion_task TYPE option<number>;
|
||||||
|
|
||||||
|
REMOVE FIELD status ON TABLE ingestion_task;
|
||||||
|
DEFINE FIELD status ON TABLE ingestion_task TYPE option<object>;
|
||||||
|
|
||||||
|
DEFINE INDEX IF NOT EXISTS idx_ingestion_task_state_sched ON TABLE ingestion_task FIELDS state, scheduled_at;
|
||||||
|
|
||||||
|
LET $needs_migration = (SELECT count() AS count FROM type::table('ingestion_task') WHERE state = NONE)[0].count;
|
||||||
|
|
||||||
|
IF $needs_migration > 0 THEN {
|
||||||
|
-- Created -> Pending
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET
|
||||||
|
state = "Pending",
|
||||||
|
attempts = 0,
|
||||||
|
max_attempts = 3,
|
||||||
|
scheduled_at = IF created_at != NONE THEN created_at ELSE time::now() END,
|
||||||
|
locked_at = NONE,
|
||||||
|
lease_duration_secs = 300,
|
||||||
|
worker_id = NONE,
|
||||||
|
error_code = NONE,
|
||||||
|
error_message = NONE,
|
||||||
|
last_error_at = NONE,
|
||||||
|
priority = 0
|
||||||
|
WHERE state = NONE
|
||||||
|
AND status != NONE
|
||||||
|
AND status.name = "Created";
|
||||||
|
|
||||||
|
-- InProgress -> Processing
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET
|
||||||
|
state = "Processing",
|
||||||
|
attempts = IF status.attempts != NONE THEN status.attempts ELSE 1 END,
|
||||||
|
max_attempts = 3,
|
||||||
|
scheduled_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
|
||||||
|
locked_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
|
||||||
|
lease_duration_secs = 300,
|
||||||
|
worker_id = NONE,
|
||||||
|
error_code = NONE,
|
||||||
|
error_message = NONE,
|
||||||
|
last_error_at = NONE,
|
||||||
|
priority = 0
|
||||||
|
WHERE state = NONE
|
||||||
|
AND status != NONE
|
||||||
|
AND status.name = "InProgress";
|
||||||
|
|
||||||
|
-- Completed -> Succeeded
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET
|
||||||
|
state = "Succeeded",
|
||||||
|
attempts = 1,
|
||||||
|
max_attempts = 3,
|
||||||
|
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||||
|
locked_at = NONE,
|
||||||
|
lease_duration_secs = 300,
|
||||||
|
worker_id = NONE,
|
||||||
|
error_code = NONE,
|
||||||
|
error_message = NONE,
|
||||||
|
last_error_at = NONE,
|
||||||
|
priority = 0
|
||||||
|
WHERE state = NONE
|
||||||
|
AND status != NONE
|
||||||
|
AND status.name = "Completed";
|
||||||
|
|
||||||
|
-- Error -> DeadLetter (terminal failure)
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET
|
||||||
|
state = "DeadLetter",
|
||||||
|
attempts = 3,
|
||||||
|
max_attempts = 3,
|
||||||
|
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||||
|
locked_at = NONE,
|
||||||
|
lease_duration_secs = 300,
|
||||||
|
worker_id = NONE,
|
||||||
|
error_code = NONE,
|
||||||
|
error_message = status.message,
|
||||||
|
last_error_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||||
|
priority = 0
|
||||||
|
WHERE state = NONE
|
||||||
|
AND status != NONE
|
||||||
|
AND status.name = "Error";
|
||||||
|
|
||||||
|
-- Cancelled -> Cancelled
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET
|
||||||
|
state = "Cancelled",
|
||||||
|
attempts = 0,
|
||||||
|
max_attempts = 3,
|
||||||
|
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||||
|
locked_at = NONE,
|
||||||
|
lease_duration_secs = 300,
|
||||||
|
worker_id = NONE,
|
||||||
|
error_code = NONE,
|
||||||
|
error_message = NONE,
|
||||||
|
last_error_at = NONE,
|
||||||
|
priority = 0
|
||||||
|
WHERE state = NONE
|
||||||
|
AND status != NONE
|
||||||
|
AND status.name = "Cancelled";
|
||||||
|
|
||||||
|
-- Fallback for any remaining records missing state
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET
|
||||||
|
state = "Pending",
|
||||||
|
attempts = 0,
|
||||||
|
max_attempts = 3,
|
||||||
|
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||||
|
locked_at = NONE,
|
||||||
|
lease_duration_secs = 300,
|
||||||
|
worker_id = NONE,
|
||||||
|
error_code = NONE,
|
||||||
|
error_message = NONE,
|
||||||
|
last_error_at = NONE,
|
||||||
|
priority = 0
|
||||||
|
WHERE state = NONE;
|
||||||
|
} END;
|
||||||
|
|
||||||
|
-- Ensure defaults for newly added fields
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET max_attempts = 3
|
||||||
|
WHERE max_attempts = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET lease_duration_secs = 300
|
||||||
|
WHERE lease_duration_secs = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET attempts = 0
|
||||||
|
WHERE attempts = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET priority = 0
|
||||||
|
WHERE priority = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END
|
||||||
|
WHERE scheduled_at = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET locked_at = NONE
|
||||||
|
WHERE locked_at = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET worker_id = NONE
|
||||||
|
WHERE worker_id != NONE AND worker_id = "";
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET error_code = NONE
|
||||||
|
WHERE error_code = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET error_message = NONE
|
||||||
|
WHERE error_message = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET last_error_at = NONE
|
||||||
|
WHERE last_error_at = NONE;
|
||||||
|
|
||||||
|
UPDATE type::table('ingestion_task')
|
||||||
|
SET status = NONE
|
||||||
|
WHERE status != NONE;
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
-- Add scratchpad table and schema
|
||||||
|
|
||||||
|
-- Define scratchpad table and schema
|
||||||
|
DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;
|
||||||
|
|
||||||
|
-- Standard fields from stored_object! macro
|
||||||
|
DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;
|
||||||
|
|
||||||
|
-- Custom fields from the Scratchpad struct
|
||||||
|
DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;
|
||||||
|
DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;
|
||||||
|
DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;
|
||||||
|
|
||||||
|
-- Indexes based on query patterns
|
||||||
|
DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;
|
||||||
|
DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;
|
||||||
|
DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
-- Remove HNSW indexes from base tables (now created at runtime on *_embedding tables)
|
||||||
|
REMOVE INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
|
||||||
|
REMOVE INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
|
||||||
|
|
||||||
|
-- Remove FTS indexes (now created at runtime via indexes.rs)
|
||||||
|
REMOVE INDEX IF EXISTS text_content_fts_text_idx ON text_content;
|
||||||
|
REMOVE INDEX IF EXISTS text_content_fts_category_idx ON text_content;
|
||||||
|
REMOVE INDEX IF EXISTS text_content_fts_context_idx ON text_content;
|
||||||
|
REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON text_content;
|
||||||
|
REMOVE INDEX IF EXISTS text_content_fts_url_idx ON text_content;
|
||||||
|
REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON text_content;
|
||||||
|
REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
|
||||||
|
REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
|
||||||
|
REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
|
||||||
|
|
||||||
|
-- Remove legacy analyzers (recreated at runtime with updated configuration)
|
||||||
|
REMOVE ANALYZER IF EXISTS app_default_fts_analyzer;
|
||||||
|
REMOVE ANALYZER IF EXISTS app_en_fts_analyzer;
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
-- Move chunk/entity embeddings to dedicated tables for index efficiency.
|
||||||
|
|
||||||
|
-- Text chunk embeddings table
|
||||||
|
DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;
|
||||||
|
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
|
||||||
|
|
||||||
|
-- Knowledge entity embeddings table
|
||||||
|
DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
|
||||||
|
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
-- Copy embeddings from base tables to dedicated tables
|
||||||
|
-- This runs BEFORE the field removal migration
|
||||||
|
|
||||||
|
FOR $chunk IN (SELECT * FROM text_chunk WHERE embedding != NONE AND array::len(embedding) > 0) {
|
||||||
|
CREATE text_chunk_embedding CONTENT {
|
||||||
|
chunk_id: $chunk.id,
|
||||||
|
embedding: $chunk.embedding,
|
||||||
|
user_id: $chunk.user_id,
|
||||||
|
source_id: $chunk.source_id,
|
||||||
|
created_at: $chunk.created_at,
|
||||||
|
updated_at: $chunk.updated_at
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
FOR $entity IN (SELECT * FROM knowledge_entity WHERE embedding != NONE AND array::len(embedding) > 0) {
|
||||||
|
CREATE knowledge_entity_embedding CONTENT {
|
||||||
|
entity_id: $entity.id,
|
||||||
|
embedding: $entity.embedding,
|
||||||
|
user_id: $entity.user_id,
|
||||||
|
created_at: $entity.created_at,
|
||||||
|
updated_at: $entity.updated_at
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
-- Drop legacy embedding fields from base tables; embeddings now live in *_embedding tables.
|
||||||
|
REMOVE FIELD IF EXISTS embedding ON TABLE text_chunk;
|
||||||
|
REMOVE FIELD IF EXISTS embedding ON TABLE knowledge_entity;
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
-- Add embedding_backend field to system_settings for visibility of active backend
|
||||||
|
|
||||||
|
DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;
|
||||||
|
|
||||||
|
-- Set default to 'openai' for existing installs to preserve backward compatibility
|
||||||
|
UPDATE system_settings:current SET
|
||||||
|
embedding_backend = 'openai'
|
||||||
|
WHERE embedding_backend == NONE;
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
-- Enforce SCHEMAFULL on all tables and define missing fields
|
||||||
|
|
||||||
|
-- 1. Define missing fields for ingestion_task (formerly job, but now ingestion_task)
|
||||||
|
DEFINE TABLE OVERWRITE ingestion_task SCHEMAFULL;
|
||||||
|
|
||||||
|
-- Core Fields
|
||||||
|
DEFINE FIELD IF NOT EXISTS id ON ingestion_task TYPE record<ingestion_task>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime DEFAULT time::now();
|
||||||
|
DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime DEFAULT time::now();
|
||||||
|
DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;
|
||||||
|
|
||||||
|
-- State Machine Fields
|
||||||
|
DEFINE FIELD IF NOT EXISTS state ON ingestion_task TYPE string ASSERT $value IN ['Pending', 'Reserved', 'Processing', 'Succeeded', 'Failed', 'Cancelled', 'DeadLetter'];
|
||||||
|
DEFINE FIELD IF NOT EXISTS attempts ON ingestion_task TYPE int DEFAULT 0;
|
||||||
|
DEFINE FIELD IF NOT EXISTS max_attempts ON ingestion_task TYPE int DEFAULT 3;
|
||||||
|
DEFINE FIELD IF NOT EXISTS scheduled_at ON ingestion_task TYPE datetime DEFAULT time::now();
|
||||||
|
DEFINE FIELD IF NOT EXISTS locked_at ON ingestion_task TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS lease_duration_secs ON ingestion_task TYPE int DEFAULT 300;
|
||||||
|
DEFINE FIELD IF NOT EXISTS worker_id ON ingestion_task TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS error_code ON ingestion_task TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS error_message ON ingestion_task TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS last_error_at ON ingestion_task TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS priority ON ingestion_task TYPE int DEFAULT 0;
|
||||||
|
|
||||||
|
-- Content Payload (IngestionPayload Enum)
|
||||||
|
DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Url ON ingestion_task TYPE option<object>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Text ON ingestion_task TYPE option<object>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File ON ingestion_task TYPE option<object>;
|
||||||
|
|
||||||
|
-- Content: Url Variant
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Url.url ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Url.context ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Url.category ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Url.user_id ON ingestion_task TYPE string;
|
||||||
|
|
||||||
|
-- Content: Text Variant
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Text.text ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Text.context ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Text.category ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.Text.user_id ON ingestion_task TYPE string;
|
||||||
|
|
||||||
|
-- Content: File Variant
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.context ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.category ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.user_id ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info ON ingestion_task TYPE object;
|
||||||
|
|
||||||
|
-- Content: File.file_info (FileInfo Struct)
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info.id ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info.created_at ON ingestion_task TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info.updated_at ON ingestion_task TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info.sha256 ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info.path ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info.file_name ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info.mime_type ON ingestion_task TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content.File.file_info.user_id ON ingestion_task TYPE string;
|
||||||
|
|
||||||
|
-- 2. Enforce SCHEMAFULL on all other tables
|
||||||
|
DEFINE TABLE OVERWRITE analytics SCHEMAFULL;
|
||||||
|
DEFINE TABLE OVERWRITE conversation SCHEMAFULL;
|
||||||
|
DEFINE TABLE OVERWRITE file SCHEMAFULL;
|
||||||
|
DEFINE TABLE OVERWRITE knowledge_entity SCHEMAFULL;
|
||||||
|
DEFINE TABLE OVERWRITE message SCHEMAFULL;
|
||||||
|
DEFINE TABLE OVERWRITE relates_to SCHEMAFULL TYPE RELATION;
|
||||||
|
DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
|
||||||
|
DEFINE TABLE OVERWRITE scratchpad SCHEMAFULL;
|
||||||
|
DEFINE TABLE OVERWRITE system_settings SCHEMAFULL;
|
||||||
|
DEFINE TABLE OVERWRITE text_chunk SCHEMAFULL;
|
||||||
|
-- text_content must have fields defined before enforcing SCHEMAFULL
|
||||||
|
DEFINE TABLE OVERWRITE text_content SCHEMAFULL;
|
||||||
|
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;
|
||||||
|
|
||||||
|
DEFINE TABLE OVERWRITE user SCHEMAFULL;
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
DEFINE FIELD IF NOT EXISTS theme ON user TYPE string DEFAULT "system";
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
-- Per-user deduplication: same SHA256 may exist for different users.
|
||||||
|
REMOVE INDEX IF EXISTS file_sha256_idx ON file;
|
||||||
|
DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
-- Harden knowledge entity embeddings and graph storage invariants.
|
||||||
|
|
||||||
|
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
|
||||||
|
|
||||||
|
-- Backfill denormalized source_id from the linked entity.
|
||||||
|
FOR $emb IN (SELECT * FROM knowledge_entity_embedding WHERE source_id = NONE OR source_id = '') {
|
||||||
|
LET $entity = (SELECT source_id FROM $emb.entity_id)[0];
|
||||||
|
IF $entity != NONE {
|
||||||
|
UPDATE $emb.id SET source_id = $entity.source_id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
-- Re-key embeddings so record id matches entity id (stable 1:1 identity).
|
||||||
|
FOR $emb IN (SELECT * FROM knowledge_entity_embedding) {
|
||||||
|
LET $entity_key = record::id($emb.entity_id);
|
||||||
|
LET $canonical = type::thing('knowledge_entity_embedding', $entity_key);
|
||||||
|
IF $emb.id != $canonical {
|
||||||
|
UPSERT $canonical CONTENT {
|
||||||
|
entity_id: $emb.entity_id,
|
||||||
|
embedding: $emb.embedding,
|
||||||
|
user_id: $emb.user_id,
|
||||||
|
source_id: $emb.source_id,
|
||||||
|
created_at: $emb.created_at,
|
||||||
|
updated_at: $emb.updated_at
|
||||||
|
};
|
||||||
|
DELETE $emb.id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REMOVE INDEX IF EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
-- Harden text chunk embeddings storage invariants.
|
||||||
|
|
||||||
|
-- Re-key embeddings so record id matches chunk id (stable 1:1 identity).
|
||||||
|
FOR $emb IN (SELECT * FROM text_chunk_embedding) {
|
||||||
|
LET $chunk_key = record::id($emb.chunk_id);
|
||||||
|
LET $canonical = type::thing('text_chunk_embedding', $chunk_key);
|
||||||
|
IF $emb.id != $canonical {
|
||||||
|
UPSERT $canonical CONTENT {
|
||||||
|
chunk_id: $emb.chunk_id,
|
||||||
|
embedding: $emb.embedding,
|
||||||
|
user_id: $emb.user_id,
|
||||||
|
source_id: $emb.source_id,
|
||||||
|
created_at: $emb.created_at,
|
||||||
|
updated_at: $emb.updated_at
|
||||||
|
};
|
||||||
|
DELETE $emb.id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REMOVE INDEX IF EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE;
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
-- Align persisted embedding settings when FastEmbed is the recorded backend but the model
|
||||||
|
-- name is still the OpenAI migration default (invalid for FastEmbed `from_str`).
|
||||||
|
|
||||||
|
UPDATE system_settings:current SET
|
||||||
|
embedding_model = 'Xenova/bge-small-en-v1.5',
|
||||||
|
embedding_dimensions = 384
|
||||||
|
WHERE embedding_backend = 'fastembed'
|
||||||
|
AND embedding_model = 'text-embedding-3-small';
|
||||||
+1
File diff suppressed because one or more lines are too long
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -242,7 +242,7 @@\n\n # Defines the schema for the 'text_content' table.\n\n-DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS text_content SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n@@ -254,10 +254,24 @@\n DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;\n # UrlInfo is a struct, store as object\n DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;\n+DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;\n+\n DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;\n DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;\n\n+# FileInfo fields\n+DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;\n+\n # Indexes based on query patterns\n DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;\n","events":null}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -28,6 +28,7 @@\n # Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)\n DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY\n+DEFINE INDEX IF NOT EXISTS conversation_user_updated_at_idx ON conversation FIELDS user_id, updated_at; # For sidebar conversation projection ORDER BY\n\n # Defines the schema for the 'file' table (used by FileInfo).\n\n","events":null}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -45,9 +45,8 @@\n DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;\n\n-# Indexes based on usage (get_by_sha, potentially user lookups)\n-# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates\n-DEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;\n+# Indexes based on usage (get_by_sha scoped by user_id, user lookups)\n+DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;\n DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;\n\n # Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n","events":null}
|
||||||
+1
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -68,7 +68,7 @@\n\n # Defines the schema for the 'knowledge_entity' table.\n\n-DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n@@ -90,6 +90,7 @@\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n@@ -102,6 +103,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n+DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;\n\n -- Custom fields\n DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;\n@@ -109,8 +111,9 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;\n\n # Defines the schema for the 'message' table.\n\n@@ -135,19 +138,17 @@\n # Defines the 'relates_to' edge table for KnowledgeRelationships.\n # Edges connect nodes, in this case knowledge_entity records.\n\n-# Define the edge table itself, enforcing connections between knowledge_entity records\n-# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary\n-DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+\n+DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;\n+DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;\n\n-# Define the metadata field within the edge\n # RelationshipMetadata is a struct, store as object\n DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;\n+DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n\n-# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)\n-# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n-\n # Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n","events":null}
|
||||||
+1
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -237,7 +237,7 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;\n+DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;\n\n","events":null}
|
||||||
File diff suppressed because one or more lines are too long
@@ -13,3 +13,4 @@ DEFINE FIELD IF NOT EXISTS title ON conversation TYPE string;
|
|||||||
# Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)
|
# Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)
|
||||||
DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;
|
||||||
DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY
|
DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY
|
||||||
|
DEFINE INDEX IF NOT EXISTS conversation_user_updated_at_idx ON conversation FIELDS user_id, updated_at; # For sidebar conversation projection ORDER BY
|
||||||
@@ -13,7 +13,6 @@ DEFINE FIELD IF NOT EXISTS file_name ON file TYPE string;
|
|||||||
DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;
|
DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;
|
||||||
DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;
|
DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;
|
||||||
|
|
||||||
# Indexes based on usage (get_by_sha, potentially user lookups)
|
# Indexes based on usage (get_by_sha scoped by user_id, user lookups)
|
||||||
# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates
|
DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;
|
||||||
DEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;
|
|
||||||
DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# Defines the schema for the 'knowledge_entity' table.
|
# Defines the schema for the 'knowledge_entity' table.
|
||||||
|
|
||||||
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;
|
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;
|
||||||
|
|
||||||
# Standard fields
|
# Standard fields
|
||||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
|
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
|
||||||
@@ -15,16 +15,13 @@ DEFINE FIELD IF NOT EXISTS entity_type ON knowledge_entity TYPE string;
|
|||||||
# metadata is Option<serde_json::Value>, store as object
|
# metadata is Option<serde_json::Value>, store as object
|
||||||
DEFINE FIELD IF NOT EXISTS metadata ON knowledge_entity TYPE option<object>;
|
DEFINE FIELD IF NOT EXISTS metadata ON knowledge_entity TYPE option<object>;
|
||||||
|
|
||||||
# Define embedding as a standard array of floats for schema definition
|
|
||||||
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity TYPE array<float>;
|
|
||||||
# The specific vector nature is handled by the index definition below
|
|
||||||
|
|
||||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
|
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
|
||||||
|
|
||||||
# Indexes based on build_indexes and query patterns
|
-- Indexes based on build_indexes and query patterns
|
||||||
# The INDEX definition correctly specifies the vector properties
|
-- HNSW index now defined on knowledge_entity_embedding table for better memory usage
|
||||||
DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
-- Defines the schema for the 'knowledge_entity_embedding' table.
|
||||||
|
-- Separate table to optimize HNSW index creation memory usage
|
||||||
|
|
||||||
|
DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
|
||||||
|
|
||||||
|
-- Standard fields
|
||||||
|
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
|
||||||
|
|
||||||
|
-- Custom fields
|
||||||
|
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;
|
||||||
|
|
||||||
|
-- Indexes
|
||||||
|
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Defines the 'relates_to' edge table for KnowledgeRelationships.
|
||||||
|
# Edges connect nodes, in this case knowledge_entity records.
|
||||||
|
|
||||||
|
DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;
|
||||||
|
|
||||||
|
DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
|
||||||
|
|
||||||
|
# RelationshipMetadata is a struct, store as object
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
|
||||||
|
|
||||||
|
# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)
|
||||||
|
DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
# Defines the schema for the 'scratchpad' table.
|
||||||
|
|
||||||
|
DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;
|
||||||
|
|
||||||
|
# Standard fields from stored_object! macro
|
||||||
|
DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;
|
||||||
|
|
||||||
|
# Custom fields from the Scratchpad struct
|
||||||
|
DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;
|
||||||
|
DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;
|
||||||
|
DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;
|
||||||
|
|
||||||
|
# Indexes based on query patterns
|
||||||
|
DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;
|
||||||
|
DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;
|
||||||
|
DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;
|
||||||
@@ -10,14 +10,8 @@ DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;
|
|||||||
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;
|
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;
|
||||||
DEFINE FIELD IF NOT EXISTS chunk ON text_chunk TYPE string;
|
DEFINE FIELD IF NOT EXISTS chunk ON text_chunk TYPE string;
|
||||||
|
|
||||||
# Define embedding as a standard array of floats for schema definition
|
|
||||||
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk TYPE array<float>;
|
|
||||||
# The specific vector nature is handled by the index definition below
|
|
||||||
|
|
||||||
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk TYPE string;
|
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk TYPE string;
|
||||||
|
|
||||||
# Indexes based on build_indexes and query patterns (delete_by_source_id)
|
# Indexes based on build_indexes and query patterns (delete_by_source_id)
|
||||||
# The INDEX definition correctly specifies the vector properties
|
|
||||||
DEFINE INDEX IF NOT EXISTS idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536;
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;
|
DEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;
|
||||||
DEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
-- Defines the schema for the 'text_chunk_embedding' table.
|
||||||
|
-- Separate table to optimize HNSW index creation memory usage
|
||||||
|
|
||||||
|
DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;
|
||||||
|
|
||||||
|
# Standard fields
|
||||||
|
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;
|
||||||
|
|
||||||
|
# Custom fields
|
||||||
|
DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
|
||||||
|
|
||||||
|
-- Indexes
|
||||||
|
-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# Defines the schema for the 'text_content' table.
|
# Defines the schema for the 'text_content' table.
|
||||||
|
|
||||||
DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;
|
DEFINE TABLE IF NOT EXISTS text_content SCHEMAFULL;
|
||||||
|
|
||||||
# Standard fields
|
# Standard fields
|
||||||
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
|
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
|
||||||
@@ -12,10 +12,24 @@ DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
|
|||||||
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
|
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
|
||||||
# UrlInfo is a struct, store as object
|
# UrlInfo is a struct, store as object
|
||||||
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
|
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;
|
||||||
|
|
||||||
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
|
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
|
||||||
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
|
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
|
||||||
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
|
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
|
||||||
|
|
||||||
|
# FileInfo fields
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;
|
||||||
|
|
||||||
# Indexes based on query patterns
|
# Indexes based on query patterns
|
||||||
DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;
|
||||||
DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;
|
DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
DEFINE ANALYZER IF NOT EXISTS app_default_fts_analyzer
|
|
||||||
TOKENIZERS class
|
|
||||||
FILTERS lowercase, ascii;
|
|
||||||
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_content_fts_text_idx ON TABLE text_content
|
|
||||||
FIELDS text
|
|
||||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
|
||||||
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_content_fts_category_idx ON TABLE text_content
|
|
||||||
FIELDS category
|
|
||||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
|
||||||
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_content_fts_context_idx ON TABLE text_content
|
|
||||||
FIELDS context
|
|
||||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
|
||||||
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_content_fts_file_name_idx ON TABLE text_content
|
|
||||||
FIELDS file_info.file_name
|
|
||||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
|
||||||
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_content_fts_url_idx ON TABLE text_content
|
|
||||||
FIELDS url_info.url
|
|
||||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
|
||||||
|
|
||||||
DEFINE INDEX IF NOT EXISTS text_content_fts_url_title_idx ON TABLE text_content
|
|
||||||
FIELDS url_info.title
|
|
||||||
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
REMOVE TABLE job;
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"schemas":"--- original\n+++ modified\n@@ -98,7 +98,7 @@\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n # Defines the schema for the 'message' table.\n\n@@ -157,6 +157,8 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n\n","events":null}
|
|
||||||
-1
@@ -1 +0,0 @@
|
|||||||
{"schemas":"--- original\n+++ modified\n@@ -51,23 +51,23 @@\n\n # Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n\n-DEFINE TABLE IF NOT EXISTS job SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON job TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON job TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n\n # Custom fields from the IngestionTask struct\n # IngestionPayload is complex, store as object\n-DEFINE FIELD IF NOT EXISTS content ON job TYPE object;\n+DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n # IngestionTaskStatus can hold data (InProgress), store as object\n-DEFINE FIELD IF NOT EXISTS status ON job TYPE object;\n-DEFINE FIELD IF NOT EXISTS user_id ON job TYPE string;\n+DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n+DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n # Indexes explicitly defined in build_indexes and useful for get_unfinished_tasks\n-DEFINE INDEX IF NOT EXISTS idx_job_status ON job FIELDS status;\n-DEFINE INDEX IF NOT EXISTS idx_job_user ON job FIELDS user_id;\n-DEFINE INDEX IF NOT EXISTS idx_job_created ON job FIELDS created_at;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_status ON ingestion_task FIELDS status;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_user ON ingestion_task FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_created ON ingestion_task FIELDS created_at;\n\n # Defines the schema for the 'knowledge_entity' table.\n\n","events":null}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"schemas":"--- original\n+++ modified\n@@ -57,10 +57,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n\n-# Custom fields from the IngestionTask struct\n-# IngestionPayload is complex, store as object\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n-# IngestionTaskStatus can hold data (InProgress), store as object\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n@@ -157,10 +154,12 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS image_processing_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}
|
|
||||||
-1
@@ -1 +0,0 @@
|
|||||||
{"schemas":"--- original\n+++ modified\n@@ -160,6 +160,7 @@\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{"schemas":"--- original\n+++ modified\n@@ -18,8 +18,8 @@\n DEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE datetime;\n\n # Custom fields from the Conversation struct\n DEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;\n@@ -34,8 +34,8 @@\n DEFINE TABLE IF NOT EXISTS file SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON file TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON file TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE datetime;\n\n # Custom fields from the FileInfo struct\n DEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;\n@@ -54,8 +54,8 @@\n DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime;\n\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n@@ -71,8 +71,8 @@\n DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE datetime;\n\n # Custom fields from the KnowledgeEntity struct\n DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;\n@@ -102,8 +102,8 @@\n DEFINE TABLE IF NOT EXISTS message SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON message TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON message TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE datetime;\n\n # Custom fields from the Message struct\n DEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;\n@@ -167,8 +167,8 @@\n DEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;\n\n # Custom fields from the TextChunk struct\n DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;\n@@ -191,8 +191,8 @@\n DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;\n\n # Custom fields from the TextContent struct\n DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;\n@@ -215,8 +215,8 @@\n DEFINE TABLE IF NOT EXISTS user SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON user TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON user TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE datetime;\n\n # Custom fields from the User struct\n DEFINE FIELD IF NOT EXISTS email ON user TYPE string;\n","events":null}
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -1,19 +0,0 @@
|
|||||||
# Defines the 'relates_to' edge table for KnowledgeRelationships.
|
|
||||||
# Edges connect nodes, in this case knowledge_entity records.
|
|
||||||
|
|
||||||
# Define the edge table itself, enforcing connections between knowledge_entity records
|
|
||||||
# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary
|
|
||||||
DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;
|
|
||||||
|
|
||||||
# Define the metadata field within the edge
|
|
||||||
# RelationshipMetadata is a struct, store as object
|
|
||||||
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
|
|
||||||
|
|
||||||
# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)
|
|
||||||
# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
|
|
||||||
# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
|
|
||||||
# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
|
|
||||||
|
|
||||||
# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)
|
|
||||||
DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;
|
|
||||||
DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;
|
|
||||||
+111
-20
@@ -4,37 +4,128 @@ use tokio::task::JoinError;
|
|||||||
|
|
||||||
use crate::storage::types::file_info::FileError;
|
use crate::storage::types::file_info::FileError;
|
||||||
|
|
||||||
|
/// Errors from embedding provider operations.
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum EmbeddingError {
|
||||||
|
#[error("openai error: {0}")]
|
||||||
|
OpenAI(Box<OpenAIError>),
|
||||||
|
#[error("fastembed error: {0}")]
|
||||||
|
FastEmbed(String),
|
||||||
|
#[error("task join error: {0}")]
|
||||||
|
Join(#[from] JoinError),
|
||||||
|
#[error("fastembed model mutex poisoned: {0}")]
|
||||||
|
MutexPoisoned(String),
|
||||||
|
#[error("no embedding data received")]
|
||||||
|
NoData,
|
||||||
|
#[error("embedding configuration error: {0}")]
|
||||||
|
Config(String),
|
||||||
|
#[error("unknown fastembed model: {0}")]
|
||||||
|
UnknownModel(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OpenAIError> for EmbeddingError {
|
||||||
|
fn from(err: OpenAIError) -> Self {
|
||||||
|
Self::OpenAI(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingError {
|
||||||
|
pub(crate) fn fastembed(err: impl std::fmt::Display) -> Self {
|
||||||
|
Self::FastEmbed(err.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn mutex_poisoned(err: impl std::fmt::Display) -> Self {
|
||||||
|
Self::MutexPoisoned(err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Core internal errors
|
// Core internal errors
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum AppError {
|
pub enum AppError {
|
||||||
#[error("Database error: {0}")]
|
#[error("database error: {0}")]
|
||||||
Database(#[from] surrealdb::Error),
|
Database(Box<surrealdb::Error>),
|
||||||
#[error("OpenAI error: {0}")]
|
#[error("openai error: {0}")]
|
||||||
OpenAI(#[from] OpenAIError),
|
OpenAI(Box<OpenAIError>),
|
||||||
#[error("File error: {0}")]
|
#[error("embedding error: {0}")]
|
||||||
|
Embedding(#[from] EmbeddingError),
|
||||||
|
#[error("file error: {0}")]
|
||||||
File(#[from] FileError),
|
File(#[from] FileError),
|
||||||
#[error("Not found: {0}")]
|
#[error("not found: {0}")]
|
||||||
NotFound(String),
|
NotFound(String),
|
||||||
#[error("Validation error: {0}")]
|
#[error("validation error: {0}")]
|
||||||
Validation(String),
|
Validation(String),
|
||||||
#[error("Authorization error: {0}")]
|
#[error("authorization error: {0}")]
|
||||||
Auth(String),
|
Auth(String),
|
||||||
#[error("LLM parsing error: {0}")]
|
#[error("llm parsing error: {0}")]
|
||||||
LLMParsing(String),
|
LLMParsing(String),
|
||||||
#[error("Task join error: {0}")]
|
#[error("task join error: {0}")]
|
||||||
Join(#[from] JoinError),
|
Join(#[from] JoinError),
|
||||||
#[error("Graph mapper error: {0}")]
|
#[error("graph mapper error: {0}")]
|
||||||
GraphMapper(String),
|
GraphMapper(String),
|
||||||
#[error("IoError: {0}")]
|
#[error("io error: {0}")]
|
||||||
Io(#[from] std::io::Error),
|
Io(#[from] std::io::Error),
|
||||||
#[error("Reqwest error: {0}")]
|
#[error("reqwest error: {0}")]
|
||||||
Reqwest(#[from] reqwest::Error),
|
Reqwest(Box<reqwest::Error>),
|
||||||
#[error("Anyhow error: {0}")]
|
#[error("storage error: {0}")]
|
||||||
Anyhow(#[from] anyhow::Error),
|
Storage(Box<object_store::Error>),
|
||||||
#[error("Ingestion Processing error: {0}")]
|
#[error("ingestion processing error: {0}")]
|
||||||
Processing(String),
|
Processing(String),
|
||||||
#[error("DOM smoothie error: {0}")]
|
#[error("dom smoothie error: {0}")]
|
||||||
DomSmoothie(#[from] dom_smoothie::ReadabilityError),
|
DomSmoothie(Box<dom_smoothie::ReadabilityError>),
|
||||||
#[error("Internal service error: {0}")]
|
#[error("internal service error: {0}")]
|
||||||
InternalError(String),
|
InternalError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<surrealdb::Error> for AppError {
|
||||||
|
fn from(err: surrealdb::Error) -> Self {
|
||||||
|
Self::Database(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OpenAIError> for AppError {
|
||||||
|
fn from(err: OpenAIError) -> Self {
|
||||||
|
Self::OpenAI(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<reqwest::Error> for AppError {
|
||||||
|
fn from(err: reqwest::Error) -> Self {
|
||||||
|
Self::Reqwest(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<object_store::Error> for AppError {
|
||||||
|
fn from(err: object_store::Error) -> Self {
|
||||||
|
Self::Storage(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<dom_smoothie::ReadabilityError> for AppError {
|
||||||
|
fn from(err: dom_smoothie::ReadabilityError) -> Self {
|
||||||
|
Self::DomSmoothie(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppError {
|
||||||
|
/// Builds an [`AppError::InternalError`] from a displayable message.
|
||||||
|
#[must_use]
|
||||||
|
pub fn internal(msg: impl std::fmt::Display) -> Self {
|
||||||
|
Self::InternalError(msg.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::AppError;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn app_error_is_reasonably_sized() {
|
||||||
|
assert!(
|
||||||
|
std::mem::size_of::<AppError>() <= 64,
|
||||||
|
"AppError is {} bytes",
|
||||||
|
std::mem::size_of::<AppError>()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
#![allow(clippy::doc_markdown)]
|
||||||
|
//! Shared utilities and storage helpers for the workspace crates.
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-utils"))]
|
||||||
|
pub mod test_utils;
|
||||||
|
|||||||
+185
-61
@@ -7,29 +7,39 @@ use include_dir::{include_dir, Dir};
|
|||||||
use std::{ops::Deref, sync::Arc};
|
use std::{ops::Deref, sync::Arc};
|
||||||
use surrealdb::{
|
use surrealdb::{
|
||||||
engine::any::{connect, Any},
|
engine::any::{connect, Any},
|
||||||
opt::auth::Root,
|
opt::auth::{Namespace, Root},
|
||||||
Error, Notification, Surreal,
|
Error, Notification, Surreal,
|
||||||
};
|
};
|
||||||
use surrealdb_migrations::MigrationRunner;
|
use surrealdb_migrations::MigrationRunner;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
|
/// Embedded SurrealDB project root (`migrations/`, `schemas/`, `.surrealdb`).
|
||||||
|
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/db");
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SurrealDbClient {
|
pub struct SurrealDbClient {
|
||||||
pub client: Surreal<Any>,
|
pub client: Surreal<Any>,
|
||||||
}
|
}
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
pub trait ProvidesDb {
|
pub trait ProvidesDb {
|
||||||
fn db(&self) -> &Arc<SurrealDbClient>;
|
fn db(&self) -> &Arc<SurrealDbClient>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SurrealDbClient {
|
impl SurrealDbClient {
|
||||||
/// # Initialize a new datbase client
|
/// Initialize a new database client.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
|
||||||
/// * `SurrealDbClient` initialized
|
/// * `username` — Root username for authentication.
|
||||||
|
/// * `password` — Root password for authentication.
|
||||||
|
/// * `namespace` — SurrealDB namespace to use.
|
||||||
|
/// * `database` — SurrealDB database to use.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
|
||||||
|
/// In-memory (`mem://`) connections skip authentication.
|
||||||
pub async fn new(
|
pub async fn new(
|
||||||
address: &str,
|
address: &str,
|
||||||
username: &str,
|
username: &str,
|
||||||
@@ -39,8 +49,10 @@ impl SurrealDbClient {
|
|||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
let db = connect(address).await?;
|
let db = connect(address).await?;
|
||||||
|
|
||||||
// Sign in to database
|
// Skip sign-in for in-memory engine (no auth support)
|
||||||
|
if !address.starts_with("mem://") {
|
||||||
db.signin(Root { username, password }).await?;
|
db.signin(Root { username, password }).await?;
|
||||||
|
}
|
||||||
|
|
||||||
// Set namespace
|
// Set namespace
|
||||||
db.use_ns(namespace).use_db(database).await?;
|
db.use_ns(namespace).use_db(database).await?;
|
||||||
@@ -48,6 +60,42 @@ impl SurrealDbClient {
|
|||||||
Ok(SurrealDbClient { client: db })
|
Ok(SurrealDbClient { client: db })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initialize a new database client using namespace-level authentication.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `address` — Database connection string.
|
||||||
|
/// * `namespace` — SurrealDB namespace to use (also used for auth).
|
||||||
|
/// * `username` — Namespace username for authentication.
|
||||||
|
/// * `password` — Namespace password for authentication.
|
||||||
|
/// * `database` — SurrealDB database to use.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
|
||||||
|
pub async fn new_with_namespace_user(
|
||||||
|
address: &str,
|
||||||
|
namespace: &str,
|
||||||
|
username: &str,
|
||||||
|
password: &str,
|
||||||
|
database: &str,
|
||||||
|
) -> Result<Self, Error> {
|
||||||
|
let db = connect(address).await?;
|
||||||
|
db.signin(Namespace {
|
||||||
|
namespace,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
db.use_ns(namespace).use_db(database).await?;
|
||||||
|
Ok(SurrealDbClient { client: db })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an Axum session store backed by SurrealDB.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `SessionError` if the session store configuration or table creation fails.
|
||||||
pub async fn create_session_store(
|
pub async fn create_session_store(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||||
@@ -55,7 +103,7 @@ impl SurrealDbClient {
|
|||||||
SessionStore::new(
|
SessionStore::new(
|
||||||
Some(self.client.clone().into()),
|
Some(self.client.clone().into()),
|
||||||
SessionConfig::default()
|
SessionConfig::default()
|
||||||
.with_table_name("test_session_table")
|
.with_table_name("session")
|
||||||
.with_secure(true),
|
.with_secure(true),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -66,53 +114,63 @@ impl SurrealDbClient {
|
|||||||
/// This function should be called during application startup, after connecting to
|
/// This function should be called during application startup, after connecting to
|
||||||
/// the database and selecting the appropriate namespace and database, but before
|
/// the database and selecting the appropriate namespace and database, but before
|
||||||
/// the application starts performing operations that rely on the schema.
|
/// the application starts performing operations that rely on the schema.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
|
||||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||||
debug!("Applying migrations");
|
debug!("Applying migrations");
|
||||||
MigrationRunner::new(&self.client)
|
MigrationRunner::new(&self.client)
|
||||||
.load_files(&MIGRATIONS_DIR)
|
.load_files(&MIGRATIONS_DIR)
|
||||||
.up()
|
.up()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
.map_err(AppError::internal)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to rebuild indexes
|
/// Store an object in SurrealDB.
|
||||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
|
||||||
debug!("Rebuilding indexes");
|
|
||||||
self.client
|
|
||||||
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
|
|
||||||
.await?;
|
|
||||||
self.client
|
|
||||||
.query("REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity")
|
|
||||||
.await?;
|
|
||||||
self.client
|
|
||||||
.query("REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content")
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
|
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `item` - The item to be stored
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// * `item` — The item to store. Must implement `StoredObject`.
|
||||||
/// * `Result` - Item or Error
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database create operation fails.
|
||||||
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: StoredObject + Send + Sync + 'static,
|
T: StoredObject + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
self.client
|
self.client
|
||||||
.create((T::table_name(), item.get_id()))
|
.create((T::table_name(), item.id()))
|
||||||
.content(item)
|
.content(item)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject
|
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// Useful for idempotent ingestion flows.
|
||||||
/// * `Result` - Vec<T> or Error
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database upsert operation fails.
|
||||||
|
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||||
|
where
|
||||||
|
T: StoredObject + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let id = item.id().to_string();
|
||||||
|
self.client
|
||||||
|
.upsert((T::table_name(), id))
|
||||||
|
.content(item)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve all objects from a table.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database select operation fails.
|
||||||
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -120,13 +178,16 @@ impl SurrealDbClient {
|
|||||||
self.client.select(T::table_name()).await
|
self.client.select(T::table_name()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to retrieve a single object by its ID, requires the struct to implement StoredObject
|
/// Retrieve a single object by its ID.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `id` - The ID of the item to retrieve
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// * `id` — The ID of the item to retrieve.
|
||||||
/// * `Result<Option<T>, Error>` - The found item or Error
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database select operation fails.
|
||||||
|
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||||
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -134,13 +195,16 @@ impl SurrealDbClient {
|
|||||||
self.client.select((T::table_name(), id)).await
|
self.client.select((T::table_name(), id)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to delete a single object by its ID, requires the struct to implement StoredObject
|
/// Delete a single object by its ID.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `id` - The ID of the item to delete
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// * `id` — The ID of the item to delete.
|
||||||
/// * `Result<Option<T>, Error>` - The deleted item or Error
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database delete operation fails.
|
||||||
|
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||||
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -148,10 +212,11 @@ impl SurrealDbClient {
|
|||||||
self.client.delete((T::table_name(), id)).await
|
self.client.delete((T::table_name(), id)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to listen to a table for updates, requires the struct to implement StoredObject
|
/// Listen to a table for real-time updates via a live query stream.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Errors
|
||||||
/// * `Result<Option<T>, Error>` - The deleted item or Error
|
///
|
||||||
|
/// Returns `Err` if the database live query subscription fails.
|
||||||
pub async fn listen<T>(
|
pub async fn listen<T>(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
||||||
@@ -184,7 +249,9 @@ impl SurrealDbClient {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -194,19 +261,17 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_initialization_and_crud() {
|
async fn test_initialization_and_crud() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Call your initialization
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to initialize schema");
|
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||||
|
|
||||||
// Test basic CRUD
|
|
||||||
let dummy = Dummy {
|
let dummy = Dummy {
|
||||||
id: "abc".to_string(),
|
id: "abc".to_string(),
|
||||||
name: "first".to_string(),
|
name: "first".to_string(),
|
||||||
@@ -214,49 +279,108 @@ mod tests {
|
|||||||
updated_at: Utc::now(),
|
updated_at: Utc::now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Store
|
let stored = db
|
||||||
let stored = db.store_item(dummy.clone()).await.expect("Failed to store");
|
.store_item(dummy.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store".to_string())?;
|
||||||
assert!(stored.is_some());
|
assert!(stored.is_some());
|
||||||
|
|
||||||
// Read
|
|
||||||
let fetched = db
|
let fetched = db
|
||||||
.get_item::<Dummy>(&dummy.id)
|
.get_item::<Dummy>(&dummy.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to fetch");
|
.with_context(|| "Failed to fetch".to_string())?;
|
||||||
assert_eq!(fetched, Some(dummy.clone()));
|
assert_eq!(fetched, Some(dummy.clone()));
|
||||||
|
|
||||||
// Read all
|
|
||||||
let all = db
|
let all = db
|
||||||
.get_all_stored_items::<Dummy>()
|
.get_all_stored_items::<Dummy>()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to fetch all");
|
.with_context(|| "Failed to fetch all".to_string())?;
|
||||||
assert!(all.contains(&dummy));
|
assert!(all.contains(&dummy));
|
||||||
|
|
||||||
// Delete
|
|
||||||
let deleted = db
|
let deleted = db
|
||||||
.delete_item::<Dummy>(&dummy.id)
|
.delete_item::<Dummy>(&dummy.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete");
|
.with_context(|| "Failed to delete".to_string())?;
|
||||||
assert_eq!(deleted, Some(dummy));
|
assert_eq!(deleted, Some(dummy));
|
||||||
|
|
||||||
// After delete, should not be present
|
|
||||||
let fetch_post = db
|
let fetch_post = db
|
||||||
.get_item::<Dummy>("abc")
|
.get_item::<Dummy>("abc")
|
||||||
.await
|
.await
|
||||||
.expect("Failed fetch post delete");
|
.with_context(|| "Failed fetch post delete".to_string())?;
|
||||||
assert!(fetch_post.is_none());
|
assert!(fetch_post.is_none());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_applying_migrations() {
|
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to build indexes");
|
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||||
|
|
||||||
|
let mut dummy = Dummy {
|
||||||
|
id: "abc".to_string(),
|
||||||
|
name: "first".to_string(),
|
||||||
|
created_at: Utc::now(),
|
||||||
|
updated_at: Utc::now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
db.store_item(dummy.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store initial record".to_string())?;
|
||||||
|
|
||||||
|
dummy.name = "updated".to_string();
|
||||||
|
let upserted = db
|
||||||
|
.upsert_item(dummy.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to upsert record".to_string())?;
|
||||||
|
assert!(upserted.is_some());
|
||||||
|
|
||||||
|
let fetched: Option<Dummy> = db
|
||||||
|
.get_item(&dummy.id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "fetch after upsert".to_string())?;
|
||||||
|
let fetched =
|
||||||
|
fetched.ok_or_else(|| anyhow::anyhow!("Expected record to exist after upsert"))?;
|
||||||
|
assert_eq!(fetched.name, "updated");
|
||||||
|
|
||||||
|
let new_record = Dummy {
|
||||||
|
id: "def".to_string(),
|
||||||
|
name: "brand-new".to_string(),
|
||||||
|
created_at: Utc::now(),
|
||||||
|
updated_at: Utc::now(),
|
||||||
|
};
|
||||||
|
db.upsert_item(new_record.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to upsert new record".to_string())?;
|
||||||
|
|
||||||
|
let fetched_new: Option<Dummy> = db
|
||||||
|
.get_item(&new_record.id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "fetch inserted via upsert".to_string())?;
|
||||||
|
assert_eq!(fetched_new, Some(new_record));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_applying_migrations() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to build indexes".to_string())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,978 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use futures::future::try_join_all;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::{Map, Value};
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||||
|
|
||||||
|
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||||
|
const INDEX_BUILD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
||||||
|
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
|
||||||
|
|
||||||
|
/// HNSW index options used by runtime index creation (includes CONCURRENTLY).
|
||||||
|
pub const HNSW_INDEX_OPTIONS: &str = "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY";
|
||||||
|
/// HNSW index options for use inside transactions (CONCURRENTLY not supported).
|
||||||
|
pub const HNSW_INDEX_OPTIONS_SYNC: &str = "DIST COSINE TYPE F32 EFC 100 M 8";
|
||||||
|
|
||||||
|
/// Builds a `DEFINE INDEX OVERWRITE ... HNSW` statement matching runtime index options.
|
||||||
|
#[must_use]
|
||||||
|
pub fn hnsw_index_overwrite_sql(index_name: &str, table: &str, dimension: usize) -> String {
|
||||||
|
format!(
|
||||||
|
"DEFINE INDEX OVERWRITE {index_name} ON TABLE {table} \
|
||||||
|
FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS};"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recreates an HNSW index inside a transaction (for tests and dimension migrations).
|
||||||
|
#[must_use]
|
||||||
|
pub fn hnsw_index_redefine_transaction_sql(
|
||||||
|
index_name: &str,
|
||||||
|
table: &str,
|
||||||
|
dimension: usize,
|
||||||
|
) -> String {
|
||||||
|
format!(
|
||||||
|
"BEGIN TRANSACTION;
|
||||||
|
REMOVE INDEX IF EXISTS {index_name} ON TABLE {table};
|
||||||
|
DEFINE INDEX {index_name} ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS_SYNC};
|
||||||
|
COMMIT TRANSACTION;"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
struct HnswIndexSpec {
|
||||||
|
index_name: &'static str,
|
||||||
|
table: &'static str,
|
||||||
|
options: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
|
||||||
|
[
|
||||||
|
HnswIndexSpec {
|
||||||
|
index_name: "idx_embedding_text_chunk_embedding",
|
||||||
|
table: "text_chunk_embedding",
|
||||||
|
options: HNSW_INDEX_OPTIONS,
|
||||||
|
},
|
||||||
|
HnswIndexSpec {
|
||||||
|
index_name: "idx_embedding_knowledge_entity_embedding",
|
||||||
|
table: "knowledge_entity_embedding",
|
||||||
|
options: HNSW_INDEX_OPTIONS,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
|
||||||
|
[
|
||||||
|
FtsIndexSpec {
|
||||||
|
index_name: "text_content_fts_idx",
|
||||||
|
table: "text_content",
|
||||||
|
field: "text",
|
||||||
|
analyzer: Some(FTS_ANALYZER_NAME),
|
||||||
|
method: "BM25",
|
||||||
|
},
|
||||||
|
FtsIndexSpec {
|
||||||
|
index_name: "text_content_context_fts_idx",
|
||||||
|
table: "text_content",
|
||||||
|
field: "context",
|
||||||
|
analyzer: Some(FTS_ANALYZER_NAME),
|
||||||
|
method: "BM25",
|
||||||
|
},
|
||||||
|
FtsIndexSpec {
|
||||||
|
index_name: "text_content_file_name_fts_idx",
|
||||||
|
table: "text_content",
|
||||||
|
field: "file_info.file_name",
|
||||||
|
analyzer: Some(FTS_ANALYZER_NAME),
|
||||||
|
method: "BM25",
|
||||||
|
},
|
||||||
|
FtsIndexSpec {
|
||||||
|
index_name: "text_content_url_fts_idx",
|
||||||
|
table: "text_content",
|
||||||
|
field: "url_info.url",
|
||||||
|
analyzer: Some(FTS_ANALYZER_NAME),
|
||||||
|
method: "BM25",
|
||||||
|
},
|
||||||
|
FtsIndexSpec {
|
||||||
|
index_name: "text_content_url_title_fts_idx",
|
||||||
|
table: "text_content",
|
||||||
|
field: "url_info.title",
|
||||||
|
analyzer: Some(FTS_ANALYZER_NAME),
|
||||||
|
method: "BM25",
|
||||||
|
},
|
||||||
|
FtsIndexSpec {
|
||||||
|
index_name: "knowledge_entity_fts_name_idx",
|
||||||
|
table: "knowledge_entity",
|
||||||
|
field: "name",
|
||||||
|
analyzer: Some(FTS_ANALYZER_NAME),
|
||||||
|
method: "BM25",
|
||||||
|
},
|
||||||
|
FtsIndexSpec {
|
||||||
|
index_name: "knowledge_entity_fts_description_idx",
|
||||||
|
table: "knowledge_entity",
|
||||||
|
field: "description",
|
||||||
|
analyzer: Some(FTS_ANALYZER_NAME),
|
||||||
|
method: "BM25",
|
||||||
|
},
|
||||||
|
FtsIndexSpec {
|
||||||
|
index_name: "text_chunk_fts_chunk_idx",
|
||||||
|
table: "text_chunk",
|
||||||
|
field: "chunk",
|
||||||
|
analyzer: Some(FTS_ANALYZER_NAME),
|
||||||
|
method: "BM25",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HnswIndexSpec {
|
||||||
|
fn definition_if_not_exists(&self, dimension: usize) -> String {
|
||||||
|
format!(
|
||||||
|
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} \
|
||||||
|
FIELDS embedding HNSW DIMENSION {dimension} {options};",
|
||||||
|
index = self.index_name,
|
||||||
|
table = self.table,
|
||||||
|
dimension = dimension,
|
||||||
|
options = self.options,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn definition_overwrite(&self, dimension: usize) -> String {
|
||||||
|
format!(
|
||||||
|
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} \
|
||||||
|
FIELDS embedding HNSW DIMENSION {dimension} {options};",
|
||||||
|
index = self.index_name,
|
||||||
|
table = self.table,
|
||||||
|
dimension = dimension,
|
||||||
|
options = self.options,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
struct FtsIndexSpec {
|
||||||
|
index_name: &'static str,
|
||||||
|
table: &'static str,
|
||||||
|
field: &'static str,
|
||||||
|
analyzer: Option<&'static str>,
|
||||||
|
method: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FtsIndexSpec {
|
||||||
|
fn definition(&self) -> String {
|
||||||
|
let analyzer_clause = self
|
||||||
|
.analyzer
|
||||||
|
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
format!(
|
||||||
|
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
|
||||||
|
index = self.index_name,
|
||||||
|
table = self.table,
|
||||||
|
field = self.field,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn overwrite_definition(&self) -> String {
|
||||||
|
let analyzer_clause = self
|
||||||
|
.analyzer
|
||||||
|
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
format!(
|
||||||
|
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
|
||||||
|
index = self.index_name,
|
||||||
|
table = self.table,
|
||||||
|
field = self.field,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
|
||||||
|
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::InternalError` if any index definition or polling step fails.
|
||||||
|
pub async fn ensure_runtime(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
embedding_dimension: usize,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
ensure_runtime_inner(db, embedding_dimension)
|
||||||
|
.await
|
||||||
|
.map_err(AppError::internal)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::InternalError` if any index rebuild operation fails.
|
||||||
|
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||||
|
rebuild_inner(db).await.map_err(AppError::internal)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any.
|
||||||
|
///
|
||||||
|
/// Stored embeddings always share this index's dimension because re-embedding rewrites the
|
||||||
|
/// vectors and the index together, so it acts as a persisted marker of the embedding space
|
||||||
|
/// actually present in the database. Returns `Ok(None)` when the index has not been created yet
|
||||||
|
/// (for example on a fresh database with no ingested data).
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::InternalError` if the index metadata cannot be read.
|
||||||
|
pub async fn embedding_index_dimension(db: &SurrealDbClient) -> Result<Option<usize>, AppError> {
|
||||||
|
let spec = HnswIndexSpec {
|
||||||
|
index_name: "idx_embedding_text_chunk_embedding",
|
||||||
|
table: "text_chunk_embedding",
|
||||||
|
options: HNSW_INDEX_OPTIONS,
|
||||||
|
};
|
||||||
|
existing_hnsw_dimension(db, &spec)
|
||||||
|
.await
|
||||||
|
.map_err(AppError::internal)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
|
||||||
|
create_fts_analyzer(db).await?;
|
||||||
|
|
||||||
|
for spec in fts_index_specs() {
|
||||||
|
if index_exists(db, spec.table, spec.index_name).await? {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// We need to create these sequentially otherwise SurrealDB errors with read/write clash
|
||||||
|
create_index_with_polling(
|
||||||
|
db,
|
||||||
|
spec.definition(),
|
||||||
|
spec.index_name,
|
||||||
|
spec.table,
|
||||||
|
Some(spec.table),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||||
|
match hnsw_index_state(db, &spec, embedding_dimension).await? {
|
||||||
|
HnswIndexState::Missing => {
|
||||||
|
create_index_with_polling(
|
||||||
|
db,
|
||||||
|
spec.definition_if_not_exists(embedding_dimension),
|
||||||
|
spec.index_name,
|
||||||
|
spec.table,
|
||||||
|
Some(spec.table),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
HnswIndexState::Matches => {
|
||||||
|
let status = get_index_status(db, spec.index_name, spec.table).await?;
|
||||||
|
if status.eq_ignore_ascii_case("error") {
|
||||||
|
warn!(
|
||||||
|
index = spec.index_name,
|
||||||
|
table = spec.table,
|
||||||
|
"HNSW index found in error state; triggering rebuild"
|
||||||
|
);
|
||||||
|
create_index_with_polling(
|
||||||
|
db,
|
||||||
|
spec.definition_overwrite(embedding_dimension),
|
||||||
|
spec.index_name,
|
||||||
|
spec.table,
|
||||||
|
Some(spec.table),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
HnswIndexState::Different(existing) => {
|
||||||
|
info!(
|
||||||
|
index = spec.index_name,
|
||||||
|
table = spec.table,
|
||||||
|
existing_dimension = existing,
|
||||||
|
target_dimension = embedding_dimension,
|
||||||
|
"Overwriting HNSW index to match new embedding dimension"
|
||||||
|
);
|
||||||
|
create_index_with_polling(
|
||||||
|
db,
|
||||||
|
spec.definition_overwrite(embedding_dimension),
|
||||||
|
spec.index_name,
|
||||||
|
spec.table,
|
||||||
|
Some(spec.table),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
try_join_all(hnsw_tasks).await.map(|_| ())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -> Result<String> {
|
||||||
|
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
|
||||||
|
let mut info_res = db
|
||||||
|
.client
|
||||||
|
.query(info_query)
|
||||||
|
.await
|
||||||
|
.context("checking index status")?;
|
||||||
|
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
|
||||||
|
|
||||||
|
let Some(info) = info else {
|
||||||
|
return Ok("unknown".to_string());
|
||||||
|
};
|
||||||
|
|
||||||
|
let parsed: IndexInfoForIndex =
|
||||||
|
serde_json::from_value(info).context("deserializing INFO FOR INDEX response")?;
|
||||||
|
|
||||||
|
Ok(parsed.building_status())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
|
||||||
|
debug!("Rebuilding indexes with concurrent definitions");
|
||||||
|
create_fts_analyzer(db).await?;
|
||||||
|
|
||||||
|
for spec in fts_index_specs() {
|
||||||
|
if !index_exists(db, spec.table, spec.index_name).await? {
|
||||||
|
debug!(
|
||||||
|
index = spec.index_name,
|
||||||
|
table = spec.table,
|
||||||
|
"Skipping FTS rebuild because index is missing"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
create_index_with_polling(
|
||||||
|
db,
|
||||||
|
spec.overwrite_definition(),
|
||||||
|
spec.index_name,
|
||||||
|
spec.table,
|
||||||
|
Some(spec.table),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||||
|
if !index_exists(db, spec.table, spec.index_name).await? {
|
||||||
|
debug!(
|
||||||
|
index = spec.index_name,
|
||||||
|
table = spec.table,
|
||||||
|
"Skipping HNSW rebuild because index is missing"
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(dimension) = existing_hnsw_dimension(db, &spec).await? else {
|
||||||
|
warn!(
|
||||||
|
index = spec.index_name,
|
||||||
|
table = spec.table,
|
||||||
|
"HNSW index missing dimension; skipping rebuild"
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
create_index_with_polling(
|
||||||
|
db,
|
||||||
|
spec.definition_overwrite(dimension),
|
||||||
|
spec.index_name,
|
||||||
|
spec.table,
|
||||||
|
Some(spec.table),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn existing_hnsw_dimension(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
spec: &HnswIndexSpec,
|
||||||
|
) -> Result<Option<usize>> {
|
||||||
|
let Some(indexes) = table_index_definitions(db, spec.table).await? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(definition) = indexes
|
||||||
|
.get(spec.index_name)
|
||||||
|
.and_then(|details| details.get("Strand"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(extract_dimension(definition).and_then(|d| usize::try_from(d).ok()))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn hnsw_index_state(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
spec: &HnswIndexSpec,
|
||||||
|
expected_dimension: usize,
|
||||||
|
) -> Result<HnswIndexState> {
|
||||||
|
match existing_hnsw_dimension(db, spec).await? {
|
||||||
|
None => Ok(HnswIndexState::Missing),
|
||||||
|
Some(current_dimension) if current_dimension == expected_dimension => {
|
||||||
|
Ok(HnswIndexState::Matches)
|
||||||
|
}
|
||||||
|
Some(current_dimension) => Ok(HnswIndexState::Different(current_dimension as u64)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum HnswIndexState {
|
||||||
|
Missing,
|
||||||
|
Matches,
|
||||||
|
Different(u64),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_dimension(definition: &str) -> Option<u64> {
|
||||||
|
definition
|
||||||
|
.split("DIMENSION")
|
||||||
|
.nth(1)
|
||||||
|
.and_then(|rest| rest.split_whitespace().next())
|
||||||
|
.and_then(|token| token.trim_end_matches(';').parse::<u64>().ok())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
||||||
|
// Prefer snowball stemming when supported; fall back to ascii-only when the filter
|
||||||
|
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
|
||||||
|
// an existing analyzer definition.
|
||||||
|
let snowball_query = format!(
|
||||||
|
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
|
||||||
|
TOKENIZERS class
|
||||||
|
FILTERS lowercase, ascii, snowball(english);"
|
||||||
|
);
|
||||||
|
|
||||||
|
match db.client.query(snowball_query).await {
|
||||||
|
Ok(res) => {
|
||||||
|
if res.check().is_ok() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
warn!(
|
||||||
|
"Snowball analyzer check failed; attempting ascii fallback definition (analyzer: {})",
|
||||||
|
FTS_ANALYZER_NAME
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!(
|
||||||
|
error = %err,
|
||||||
|
"Snowball analyzer creation errored; attempting ascii fallback definition"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let fallback_query = format!(
|
||||||
|
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
|
||||||
|
TOKENIZERS class
|
||||||
|
FILTERS lowercase, ascii;"
|
||||||
|
);
|
||||||
|
|
||||||
|
let res = db
|
||||||
|
.client
|
||||||
|
.query(fallback_query)
|
||||||
|
.await
|
||||||
|
.context("creating fallback FTS analyzer")?;
|
||||||
|
|
||||||
|
if let Err(err) = res.check() {
|
||||||
|
warn!(
|
||||||
|
error = %err,
|
||||||
|
"Fallback analyzer creation failed; FTS will run without snowball/ascii analyzer ({})",
|
||||||
|
FTS_ANALYZER_NAME
|
||||||
|
);
|
||||||
|
return Err(err).context("failed to create fallback FTS analyzer");
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!(
|
||||||
|
"Snowball analyzer unavailable; using fallback analyzer ({}) with lowercase+ascii only",
|
||||||
|
FTS_ANALYZER_NAME
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_index_with_polling(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
definition: String,
|
||||||
|
index_name: &str,
|
||||||
|
table: &str,
|
||||||
|
progress_table: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
const MAX_ATTEMPTS: usize = 3;
|
||||||
|
let expected_total = match progress_table {
|
||||||
|
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
|
||||||
|
format!("counting rows in {table} for index {index_name} progress")
|
||||||
|
})?),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut attempts: usize = 0;
|
||||||
|
loop {
|
||||||
|
attempts = attempts.saturating_add(1);
|
||||||
|
let res = db
|
||||||
|
.client
|
||||||
|
.query(definition.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("creating index {index_name} on table {table}"))?;
|
||||||
|
match res.check() {
|
||||||
|
Ok(_) => break,
|
||||||
|
Err(err) => {
|
||||||
|
let msg = err.to_string();
|
||||||
|
let conflict = msg.contains("read or write conflict");
|
||||||
|
warn!(
|
||||||
|
index = %index_name,
|
||||||
|
table = %table,
|
||||||
|
error = ?err,
|
||||||
|
attempt = attempts,
|
||||||
|
definition = %definition,
|
||||||
|
"Index definition failed"
|
||||||
|
);
|
||||||
|
if conflict && attempts < MAX_ATTEMPTS {
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return Err(err).with_context(|| {
|
||||||
|
format!("index definition failed for {index_name} on {table}")
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
index = %index_name,
|
||||||
|
table = %table,
|
||||||
|
expected_rows = ?expected_total,
|
||||||
|
"Index definition submitted; waiting for build to finish"
|
||||||
|
);
|
||||||
|
|
||||||
|
poll_index_build_status(db, index_name, table, expected_total, INDEX_POLL_INTERVAL).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn poll_index_build_status(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
index_name: &str,
|
||||||
|
table: &str,
|
||||||
|
total_rows: Option<u64>,
|
||||||
|
poll_every: Duration,
|
||||||
|
) -> Result<()> {
|
||||||
|
let started_at = std::time::Instant::now();
|
||||||
|
let mut last_snapshot: Option<IndexBuildSnapshot> = None;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if started_at.elapsed() >= INDEX_BUILD_TIMEOUT {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"index build timed out after {:?} for {index_name} on {table} (last status: {})",
|
||||||
|
INDEX_BUILD_TIMEOUT,
|
||||||
|
last_snapshot
|
||||||
|
.as_ref()
|
||||||
|
.map_or("unknown", |snapshot| snapshot.status.as_str())
|
||||||
|
))
|
||||||
|
.with_context(|| format!("index {index_name} on table {table} did not become ready"));
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(poll_every).await;
|
||||||
|
|
||||||
|
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
|
||||||
|
let mut info_res =
|
||||||
|
db.client.query(info_query).await.with_context(|| {
|
||||||
|
format!("checking index build status for {index_name} on {table}")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let info: Option<Value> = info_res
|
||||||
|
.take(0)
|
||||||
|
.context("failed to deserialize INFO FOR INDEX result")?;
|
||||||
|
|
||||||
|
let Some(snapshot) = parse_index_build_info(info, total_rows) else {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"INFO FOR INDEX returned no data for {index_name} on {table}"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
last_snapshot = Some(snapshot.clone());
|
||||||
|
|
||||||
|
if let Some(pct) = snapshot.progress_pct {
|
||||||
|
debug!(
|
||||||
|
index = %index_name,
|
||||||
|
table = %table,
|
||||||
|
status = snapshot.status,
|
||||||
|
initial = snapshot.initial,
|
||||||
|
pending = snapshot.pending,
|
||||||
|
updated = snapshot.updated,
|
||||||
|
processed = snapshot.processed,
|
||||||
|
total = snapshot.total_rows,
|
||||||
|
progress_pct = format_args!("{pct:.1}"),
|
||||||
|
"Index build status"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
debug!(
|
||||||
|
index = %index_name,
|
||||||
|
table = %table,
|
||||||
|
status = snapshot.status,
|
||||||
|
initial = snapshot.initial,
|
||||||
|
pending = snapshot.pending,
|
||||||
|
updated = snapshot.updated,
|
||||||
|
processed = snapshot.processed,
|
||||||
|
"Index build status"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if snapshot.is_ready() {
|
||||||
|
debug!(
|
||||||
|
index = %index_name,
|
||||||
|
table = %table,
|
||||||
|
elapsed = ?started_at.elapsed(),
|
||||||
|
processed = snapshot.processed,
|
||||||
|
total = snapshot.total_rows,
|
||||||
|
"Index is ready"
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if snapshot.status.eq_ignore_ascii_case("error") {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"index build failed for {index_name} on {table}: status=error, processed={}, total={:?}",
|
||||||
|
snapshot.processed,
|
||||||
|
snapshot.total_rows
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `building` block from SurrealDB `INFO FOR INDEX` (concurrent index builds).
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
|
||||||
|
struct IndexBuildingProgress {
|
||||||
|
#[serde(default)]
|
||||||
|
initial: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
pending: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
updated: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
status: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Top-level `INFO FOR INDEX` payload shape (SurrealDB v2.x).
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Default)]
|
||||||
|
struct IndexInfoForIndex {
|
||||||
|
#[serde(default)]
|
||||||
|
building: Option<IndexBuildingProgress>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IndexInfoForIndex {
|
||||||
|
fn building_status(&self) -> String {
|
||||||
|
match &self.building {
|
||||||
|
None => "ready".to_string(),
|
||||||
|
Some(progress) if progress.status.is_empty() => "ready".to_string(),
|
||||||
|
Some(progress) => progress.status.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_build_snapshot(self, total_rows: Option<u64>) -> IndexBuildSnapshot {
|
||||||
|
let (initial, pending, updated, status) = match self.building {
|
||||||
|
None => (0, 0, 0, "ready".to_string()),
|
||||||
|
Some(progress) => {
|
||||||
|
let status = if progress.status.is_empty() {
|
||||||
|
"ready".to_string()
|
||||||
|
} else {
|
||||||
|
progress.status
|
||||||
|
};
|
||||||
|
(progress.initial, progress.pending, progress.updated, status)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let processed = initial.saturating_add(updated);
|
||||||
|
let progress_pct = total_rows.map(|total| {
|
||||||
|
if total == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX))
|
||||||
|
/ f64::from(u32::try_from(total).unwrap_or(1)))
|
||||||
|
.min(1.0))
|
||||||
|
* 100.0
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
IndexBuildSnapshot {
|
||||||
|
status,
|
||||||
|
initial,
|
||||||
|
pending,
|
||||||
|
updated,
|
||||||
|
processed,
|
||||||
|
total_rows,
|
||||||
|
progress_pct,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Snapshot of an index build progress as reported by SurrealDB's `INFO FOR INDEX`.
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
struct IndexBuildSnapshot {
|
||||||
|
/// Current build status string (e.g., `"indexing"`, `"ready"`, `"error"`).
|
||||||
|
status: String,
|
||||||
|
/// Number of rows present when the build started.
|
||||||
|
initial: u64,
|
||||||
|
/// Number of rows still pending processing.
|
||||||
|
pending: u64,
|
||||||
|
/// Number of rows updated since the build started.
|
||||||
|
updated: u64,
|
||||||
|
/// Total rows processed so far (`initial + updated`).
|
||||||
|
processed: u64,
|
||||||
|
/// Total rows expected (from `SELECT count()` before the build), if available.
|
||||||
|
total_rows: Option<u64>,
|
||||||
|
/// Progress as a percentage of `processed / total_rows`, if `total_rows` is known.
|
||||||
|
progress_pct: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IndexBuildSnapshot {
|
||||||
|
fn is_ready(&self) -> bool {
|
||||||
|
self.status.eq_ignore_ascii_case("ready")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_index_build_info(
|
||||||
|
info: Option<Value>,
|
||||||
|
total_rows: Option<u64>,
|
||||||
|
) -> Option<IndexBuildSnapshot> {
|
||||||
|
let info = info?;
|
||||||
|
let parsed: IndexInfoForIndex = serde_json::from_value(info).ok()?;
|
||||||
|
Some(parsed.into_build_snapshot(total_rows))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct CountRow {
|
||||||
|
count: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
|
||||||
|
let query = format!("SELECT count() AS count FROM {table} GROUP ALL;");
|
||||||
|
let mut response = db
|
||||||
|
.client
|
||||||
|
.query(query)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("counting rows in {table}"))?;
|
||||||
|
let rows: Vec<CountRow> = response
|
||||||
|
.take(0)
|
||||||
|
.context("failed to deserialize count() response")?;
|
||||||
|
Ok(rows.first().map_or(0, |r| r.count))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn table_index_definitions(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
table: &str,
|
||||||
|
) -> Result<Option<Map<String, Value>>> {
|
||||||
|
let info_query = format!("INFO FOR TABLE {table};");
|
||||||
|
let mut response = db
|
||||||
|
.client
|
||||||
|
.query(info_query)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("fetching table info for {table}"))?;
|
||||||
|
|
||||||
|
let info: surrealdb::Value = response
|
||||||
|
.take(0)
|
||||||
|
.context("failed to take table info response")?;
|
||||||
|
|
||||||
|
let info_json: Value =
|
||||||
|
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
|
||||||
|
|
||||||
|
Ok(info_json
|
||||||
|
.get("Object")
|
||||||
|
.and_then(|o| o.get("indexes"))
|
||||||
|
.and_then(|i| i.get("Object"))
|
||||||
|
.and_then(|i| i.as_object())
|
||||||
|
.cloned())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
|
||||||
|
let Some(indexes) = table_index_definitions(db, table).await? else {
|
||||||
|
return Ok(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(indexes.contains_key(index_name))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
use serde_json::json;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_index_build_info_reports_progress() -> anyhow::Result<()> {
|
||||||
|
let info = json!({
|
||||||
|
"building": {
|
||||||
|
"initial": 56894,
|
||||||
|
"pending": 0,
|
||||||
|
"status": "indexing",
|
||||||
|
"updated": 0
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let snapshot = parse_index_build_info(Some(info), Some(61081)).context("snapshot")?;
|
||||||
|
assert_eq!(
|
||||||
|
snapshot,
|
||||||
|
IndexBuildSnapshot {
|
||||||
|
status: "indexing".to_string(),
|
||||||
|
initial: 56894,
|
||||||
|
pending: 0,
|
||||||
|
updated: 0,
|
||||||
|
processed: 56894,
|
||||||
|
total_rows: Some(61081),
|
||||||
|
progress_pct: Some((56894_f64 / 61081_f64) * 100.0),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
assert!(!snapshot.is_ready());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_index_build_info_defaults_to_ready_when_no_building_block() -> anyhow::Result<()> {
|
||||||
|
// Surreal returns `{}` when the index exists but isn't building.
|
||||||
|
let info = json!({});
|
||||||
|
let snapshot = parse_index_build_info(Some(info), Some(10)).context("snapshot")?;
|
||||||
|
assert!(snapshot.is_ready());
|
||||||
|
assert_eq!(snapshot.processed, 0);
|
||||||
|
assert_eq!(snapshot.progress_pct, Some(0.0));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_info_for_index_deserializes_ready_status_shape() -> anyhow::Result<()> {
|
||||||
|
let info = json!({
|
||||||
|
"building": {
|
||||||
|
"status": "ready"
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let parsed: IndexInfoForIndex =
|
||||||
|
serde_json::from_value(info).context("deserialize ready shape")?;
|
||||||
|
assert_eq!(parsed.building_status(), "ready");
|
||||||
|
|
||||||
|
let snapshot = parse_index_build_info(
|
||||||
|
Some(json!({
|
||||||
|
"building": { "status": "ready" }
|
||||||
|
})),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.context("snapshot")?;
|
||||||
|
assert!(snapshot.is_ready());
|
||||||
|
assert_eq!(snapshot.initial, 0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_info_for_index_deserializes_indexing_shape_from_surreal_docs() -> anyhow::Result<()> {
|
||||||
|
let info = json!({
|
||||||
|
"building": {
|
||||||
|
"initial": 8143,
|
||||||
|
"pending": 19,
|
||||||
|
"status": "indexing",
|
||||||
|
"updated": 80
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let parsed: IndexInfoForIndex =
|
||||||
|
serde_json::from_value(info.clone()).context("deserialize indexing shape")?;
|
||||||
|
assert_eq!(parsed.building_status(), "indexing");
|
||||||
|
|
||||||
|
let snapshot = parse_index_build_info(Some(info), None).context("snapshot")?;
|
||||||
|
assert_eq!(snapshot.status, "indexing");
|
||||||
|
assert_eq!(snapshot.initial, 8143);
|
||||||
|
assert_eq!(snapshot.pending, 19);
|
||||||
|
assert_eq!(snapshot.updated, 80);
|
||||||
|
assert_eq!(snapshot.processed, 8223);
|
||||||
|
assert!(!snapshot.is_ready());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_index_build_info_reports_error_status() -> anyhow::Result<()> {
|
||||||
|
let info = json!({
|
||||||
|
"building": {
|
||||||
|
"initial": 100,
|
||||||
|
"pending": 5,
|
||||||
|
"status": "error",
|
||||||
|
"updated": 10
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let snapshot = parse_index_build_info(Some(info), Some(200)).context("snapshot")?;
|
||||||
|
assert_eq!(snapshot.status, "error");
|
||||||
|
assert!(!snapshot.is_ready());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_dimension_parses_value() {
|
||||||
|
let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;";
|
||||||
|
assert_eq!(extract_dimension(definition), Some(1536));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||||
|
let namespace = "indexes_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.context("in-memory db")?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.context("migrations should succeed")?;
|
||||||
|
|
||||||
|
ensure_runtime(&db, 1536)
|
||||||
|
.await
|
||||||
|
.context("first call should succeed")?;
|
||||||
|
ensure_runtime(&db, 1536)
|
||||||
|
.await
|
||||||
|
.context("second index creation")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn embedding_index_dimension_reflects_runtime_state() -> anyhow::Result<()> {
|
||||||
|
let namespace = "indexes_marker";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.context("in-memory db")?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.context("migrations should succeed")?;
|
||||||
|
|
||||||
|
// Before any index exists, there is no stored embedding dimension to detect.
|
||||||
|
assert_eq!(embedding_index_dimension(&db).await?, None);
|
||||||
|
|
||||||
|
ensure_runtime(&db, 1536)
|
||||||
|
.await
|
||||||
|
.context("initial index creation")?;
|
||||||
|
assert_eq!(embedding_index_dimension(&db).await?, Some(1536));
|
||||||
|
|
||||||
|
// After a dimension change the marker tracks the new index dimension.
|
||||||
|
ensure_runtime(&db, 256)
|
||||||
|
.await
|
||||||
|
.context("overwritten index creation")?;
|
||||||
|
assert_eq!(embedding_index_dimension(&db).await?, Some(256));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> {
|
||||||
|
let namespace = "indexes_dim";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.context("in-memory db")?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.context("migrations should succeed")?;
|
||||||
|
|
||||||
|
ensure_runtime(&db, 1536)
|
||||||
|
.await
|
||||||
|
.context("initial index creation")?;
|
||||||
|
ensure_runtime(&db, 128)
|
||||||
|
.await
|
||||||
|
.context("overwritten index creation")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
pub mod db;
|
pub mod db;
|
||||||
|
pub mod indexes;
|
||||||
pub mod store;
|
pub mod store;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
|
|||||||
+1135
-187
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,5 @@
|
|||||||
use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject};
|
use crate::storage::types::{user::User, StoredObject};
|
||||||
|
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||||
@@ -16,61 +17,78 @@ impl StoredObject for Analytics {
|
|||||||
"analytics"
|
"analytics"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_id(&self) -> &str {
|
fn id(&self) -> &str {
|
||||||
&self.id
|
&self.id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Analytics {
|
impl Analytics {
|
||||||
|
const RECORD_ID: &'static str = "current";
|
||||||
|
|
||||||
|
/// Ensures the singleton analytics record exists (idempotent).
|
||||||
|
///
|
||||||
|
/// Production databases are also seeded by `20250503_215025_initial_setup.surql`;
|
||||||
|
/// this uses an atomic `UPSERT` for tests and recovery.
|
||||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let analytics = db.get_item::<Self>("current").await?;
|
let analytics: Option<Self> = db
|
||||||
|
.client
|
||||||
if analytics.is_none() {
|
.query(
|
||||||
let created_analytics = Analytics {
|
"UPSERT type::thing('analytics', $id) SET visitors = visitors ?? 0, page_loads = page_loads ?? 0 RETURN AFTER",
|
||||||
id: "current".to_string(),
|
)
|
||||||
visitors: 0,
|
.bind(("id", Self::RECORD_ID))
|
||||||
page_loads: 0,
|
.await?
|
||||||
};
|
.take(0)?;
|
||||||
|
|
||||||
let stored: Option<Self> = db.store_item(created_analytics).await?;
|
|
||||||
return stored.ok_or(AppError::Validation(
|
|
||||||
"Failed to initialize analytics".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
analytics.ok_or(AppError::Validation(
|
analytics.ok_or(AppError::Validation(
|
||||||
"Failed to initialize analytics".into(),
|
"failed to initialize analytics".into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let analytics: Option<Self> = db.get_item("current").await?;
|
let analytics: Option<Self> = db.get_item("current").await?;
|
||||||
analytics.ok_or(AppError::NotFound("Analytics not found".into()))
|
analytics.ok_or(AppError::NotFound("analytics not found".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.client
|
.client
|
||||||
.query("UPDATE type::thing('analytics', 'current') SET visitors += 1 RETURN AFTER")
|
.query(
|
||||||
|
"UPSERT type::thing('analytics', $id) SET visitors = (visitors ?? 0) + 1, page_loads = page_loads ?? 0 RETURN AFTER",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
|
updated.ok_or(AppError::Validation("failed to update analytics".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
|
Self::record_page_view(db, false).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Records a page view, optionally counting the visitor as new.
|
||||||
|
pub async fn record_page_view(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
is_new_visitor: bool,
|
||||||
|
) -> Result<Self, AppError> {
|
||||||
|
let visitor_delta = i64::from(is_new_visitor);
|
||||||
let updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.client
|
.client
|
||||||
.query("UPDATE type::thing('analytics', 'current') SET page_loads += 1 RETURN AFTER")
|
.query(
|
||||||
|
"UPSERT type::thing('analytics', $id) SET page_loads = (page_loads ?? 0) + 1, visitors = (visitors ?? 0) + $visitor_delta RETURN AFTER",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
|
.bind(("visitor_delta", visitor_delta))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
|
updated.ok_or(AppError::Validation("failed to update analytics".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
|
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
|
||||||
// We need to use a direct query for COUNT aggregation
|
// We need to use a direct query for COUNT aggregation
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct CountResult {
|
struct CountResult {
|
||||||
|
/// Total user count.
|
||||||
count: i64,
|
count: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,14 +99,16 @@ impl Analytics {
|
|||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
Ok(result.map(|r| r.count).unwrap_or(0))
|
Ok(result.map_or(0, |r| r.count))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
|
use anyhow::{self};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
stored_object!(TestUser, "user", {
|
stored_object!(TestUser, "user", {
|
||||||
@@ -98,18 +118,14 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_analytics_initialization() {
|
async fn test_analytics_initialization() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Test initialization of analytics
|
// Test initialization of analytics
|
||||||
let analytics = Analytics::ensure_initialized(&db)
|
let analytics = Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to initialize analytics");
|
|
||||||
|
|
||||||
// Verify initial state after initialization
|
// Verify initial state after initialization
|
||||||
assert_eq!(analytics.id, "current");
|
assert_eq!(analytics.id, "current");
|
||||||
@@ -117,159 +133,198 @@ mod tests {
|
|||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
|
|
||||||
// Test idempotency - ensure calling it again doesn't change anything
|
// Test idempotency - ensure calling it again doesn't change anything
|
||||||
let analytics_again = Analytics::ensure_initialized(&db)
|
let analytics_again = Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to get analytics after initialization");
|
|
||||||
|
|
||||||
assert_eq!(analytics.id, analytics_again.id);
|
assert_eq!(analytics.id, analytics_again.id);
|
||||||
assert_eq!(analytics.page_loads, analytics_again.page_loads);
|
assert_eq!(analytics.page_loads, analytics_again.page_loads);
|
||||||
assert_eq!(analytics.visitors, analytics_again.visitors);
|
assert_eq!(analytics.visitors, analytics_again.visitors);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_analytics() {
|
async fn test_get_current_analytics() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db)
|
Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to initialize analytics");
|
|
||||||
|
|
||||||
// Test get_current method
|
// Test get_current method
|
||||||
let analytics = Analytics::get_current(&db)
|
let analytics = Analytics::get_current(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to get current analytics");
|
|
||||||
|
|
||||||
assert_eq!(analytics.id, "current");
|
assert_eq!(analytics.id, "current");
|
||||||
assert_eq!(analytics.page_loads, 0);
|
assert_eq!(analytics.page_loads, 0);
|
||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_increment_visitors() {
|
async fn test_increment_visitors() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db)
|
Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to initialize analytics");
|
|
||||||
|
|
||||||
// Test increment_visitors method
|
// Test increment_visitors method
|
||||||
let analytics = Analytics::increment_visitors(&db)
|
let analytics = Analytics::increment_visitors(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to increment visitors");
|
|
||||||
|
|
||||||
assert_eq!(analytics.visitors, 1);
|
assert_eq!(analytics.visitors, 1);
|
||||||
assert_eq!(analytics.page_loads, 0);
|
assert_eq!(analytics.page_loads, 0);
|
||||||
|
|
||||||
// Increment again and check
|
// Increment again and check
|
||||||
let analytics = Analytics::increment_visitors(&db)
|
let analytics = Analytics::increment_visitors(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to increment visitors again");
|
|
||||||
|
|
||||||
assert_eq!(analytics.visitors, 2);
|
assert_eq!(analytics.visitors, 2);
|
||||||
assert_eq!(analytics.page_loads, 0);
|
assert_eq!(analytics.page_loads, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_increment_page_loads() {
|
async fn test_increment_page_loads() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db)
|
Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to initialize analytics");
|
|
||||||
|
|
||||||
// Test increment_page_loads method
|
// Test increment_page_loads method
|
||||||
let analytics = Analytics::increment_page_loads(&db)
|
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to increment page loads");
|
|
||||||
|
|
||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
assert_eq!(analytics.page_loads, 1);
|
assert_eq!(analytics.page_loads, 1);
|
||||||
|
|
||||||
// Increment again and check
|
// Increment again and check
|
||||||
let analytics = Analytics::increment_page_loads(&db)
|
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to increment page loads again");
|
|
||||||
|
|
||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
assert_eq!(analytics.page_loads, 2);
|
assert_eq!(analytics.page_loads, 2);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_users_amount() {
|
async fn test_get_users_amount() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Test with no users
|
// Test with no users
|
||||||
let count = Analytics::get_users_amount(&db)
|
let count = Analytics::get_users_amount(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to get users amount");
|
|
||||||
assert_eq!(count, 0);
|
assert_eq!(count, 0);
|
||||||
|
|
||||||
// Create a few test users
|
// Create a few test users
|
||||||
for i in 0..3 {
|
for i in 0..3 {
|
||||||
let user = TestUser {
|
let user = TestUser {
|
||||||
id: format!("user{}", i),
|
id: format!("user{i}"),
|
||||||
email: format!("user{}@example.com", i),
|
email: format!("user{i}@example.com"),
|
||||||
password: "password".to_string(),
|
password: "password".to_string(),
|
||||||
user_id: format!("uid{}", i),
|
user_id: format!("uid{i}"),
|
||||||
created_at: Utc::now(),
|
created_at: Utc::now(),
|
||||||
updated_at: Utc::now(),
|
updated_at: Utc::now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
db.store_item(user)
|
db.store_item(user).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to create test user");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test users amount after adding users
|
// Test users amount after adding users
|
||||||
let count = Analytics::get_users_amount(&db)
|
let count = Analytics::get_users_amount(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to get users amount after adding users");
|
|
||||||
assert_eq!(count, 3);
|
assert_eq!(count, 3);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_nonexistent() {
|
async fn test_increment_visitors_without_prior_init() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let analytics = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(analytics.visitors, 1);
|
||||||
|
assert_eq!(analytics.page_loads, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_increment_page_loads_without_prior_init() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||||
|
assert_eq!(analytics.page_loads, 1);
|
||||||
|
assert_eq!(analytics.visitors, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_visitor_and_page_load_increments_are_independent() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let after_visitors = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(after_visitors.visitors, 1);
|
||||||
|
assert_eq!(after_visitors.page_loads, 0);
|
||||||
|
|
||||||
|
let after_page_load = Analytics::increment_page_loads(&db).await?;
|
||||||
|
assert_eq!(after_page_load.visitors, 1);
|
||||||
|
assert_eq!(after_page_load.page_loads, 1);
|
||||||
|
|
||||||
|
let after_second_visitor = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(after_second_visitor.visitors, 2);
|
||||||
|
assert_eq!(after_second_visitor.page_loads, 1);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_record_page_view() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let first_view = Analytics::record_page_view(&db, true).await?;
|
||||||
|
assert_eq!(first_view.visitors, 1);
|
||||||
|
assert_eq!(first_view.page_loads, 1);
|
||||||
|
|
||||||
|
let returning_view = Analytics::record_page_view(&db, false).await?;
|
||||||
|
assert_eq!(returning_view.visitors, 1);
|
||||||
|
assert_eq!(returning_view.page_loads, 2);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Don't initialize analytics and try to get it
|
// Don't initialize analytics and try to get it
|
||||||
let result = Analytics::get_current(&db).await;
|
let result = Analytics::get_current(&db).await;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
if let Err(err) = result {
|
match result {
|
||||||
match err {
|
Ok(_) => anyhow::bail!("Expected NotFound error, got success"),
|
||||||
AppError::NotFound(_) => {
|
Err(AppError::NotFound(_)) => {}
|
||||||
// Expected error
|
Err(err) => anyhow::bail!("Expected NotFound error, got: {err:?}"),
|
||||||
}
|
|
||||||
_ => panic!("Expected NotFound error, got: {:?}", err),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,57 @@ stored_object!(Conversation, "conversation", {
|
|||||||
title: String
|
title: String
|
||||||
});
|
});
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
|
pub struct SidebarConversation {
|
||||||
|
#[serde(deserialize_with = "deserialize_sidebar_id")]
|
||||||
|
pub id: String,
|
||||||
|
pub title: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SidebarIdVisitor;
|
||||||
|
|
||||||
|
impl<'de> serde::de::Visitor<'de> for SidebarIdVisitor {
|
||||||
|
type Value = String;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
formatter.write_str("a string id or a SurrealDB Thing")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
Ok(value.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: serde::de::MapAccess<'de>,
|
||||||
|
{
|
||||||
|
let thing = <surrealdb::sql::Thing as serde::Deserialize>::deserialize(
|
||||||
|
serde::de::value::MapAccessDeserializer::new(map),
|
||||||
|
)?;
|
||||||
|
Ok(thing.id.to_raw())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize_sidebar_id<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
deserializer.deserialize_any(SidebarIdVisitor)
|
||||||
|
}
|
||||||
|
|
||||||
impl Conversation {
|
impl Conversation {
|
||||||
|
#[must_use]
|
||||||
pub fn new(user_id: String, title: String) -> Self {
|
pub fn new(user_id: String, title: String) -> Self {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
Self {
|
Self {
|
||||||
@@ -30,7 +80,7 @@ impl Conversation {
|
|||||||
let conversation: Conversation = db
|
let conversation: Conversation = db
|
||||||
.get_item(conversation_id)
|
.get_item(conversation_id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
|
.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
|
||||||
|
|
||||||
if conversation.user_id != user_id {
|
if conversation.user_id != user_id {
|
||||||
return Err(AppError::Auth(
|
return Err(AppError::Auth(
|
||||||
@@ -38,10 +88,15 @@ impl Conversation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let messages:Vec<Message> = db.client.
|
let messages: Vec<Message> = db
|
||||||
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
|
.client
|
||||||
bind(("table_name", Message::table_name())).
|
.query(
|
||||||
bind(("conversation_id", conversation_id.to_string()))
|
"SELECT * FROM type::table($message_table) WHERE conversation_id = $conversation_id AND type::thing($conversation_table, $conversation_id).user_id = $user_id ORDER BY updated_at",
|
||||||
|
)
|
||||||
|
.bind(("message_table", Message::table_name()))
|
||||||
|
.bind(("conversation_table", Self::table_name()))
|
||||||
|
.bind(("conversation_id", conversation_id.to_string()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
@@ -56,7 +111,7 @@ impl Conversation {
|
|||||||
// First verify ownership by getting conversation user_id
|
// First verify ownership by getting conversation user_id
|
||||||
let conversation: Option<Conversation> = db.get_item(id).await?;
|
let conversation: Option<Conversation> = db.get_item(id).await?;
|
||||||
let conversation =
|
let conversation =
|
||||||
conversation.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
|
conversation.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
|
||||||
|
|
||||||
if conversation.user_id != user_id {
|
if conversation.user_id != user_id {
|
||||||
return Err(AppError::Auth(
|
return Err(AppError::Auth(
|
||||||
@@ -64,7 +119,7 @@ impl Conversation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let _updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.update((Self::table_name(), id))
|
.update((Self::table_name(), id))
|
||||||
.patch(PatchOp::replace("/title", new_title.to_string()))
|
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||||
.patch(PatchOp::replace(
|
.patch(PatchOp::replace(
|
||||||
@@ -73,82 +128,118 @@ impl Conversation {
|
|||||||
))
|
))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
if updated.is_none() {
|
||||||
|
return Err(AppError::NotFound("conversation not found".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn get_user_sidebar_conversations(
|
||||||
|
user_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<Vec<SidebarConversation>, AppError> {
|
||||||
|
let conversations: Vec<SidebarConversation> = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"SELECT id, title, updated_at FROM type::table($table_name) WHERE user_id = $user_id ORDER BY updated_at DESC",
|
||||||
|
)
|
||||||
|
.bind(("table_name", Self::table_name()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.await?
|
||||||
|
.take(0)?;
|
||||||
|
|
||||||
|
Ok(conversations)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use crate::storage::types::message::MessageRole;
|
use crate::storage::types::message::MessageRole;
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
const MESSAGE_QUERY_FOR_OWNER: &str = "SELECT * FROM type::table($message_table) WHERE conversation_id = $conversation_id AND type::thing($conversation_table, $conversation_id).user_id = $user_id ORDER BY updated_at";
|
||||||
|
|
||||||
|
async fn fetch_messages_for_owner(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
conversation_id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<Vec<Message>, AppError> {
|
||||||
|
db.client
|
||||||
|
.query(MESSAGE_QUERY_FOR_OWNER)
|
||||||
|
.bind(("message_table", Message::table_name()))
|
||||||
|
.bind(("conversation_table", Conversation::table_name()))
|
||||||
|
.bind(("conversation_id", conversation_id.to_string()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.await?
|
||||||
|
.take(0)
|
||||||
|
.map_err(AppError::from)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_conversation() {
|
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create a new conversation
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let title = "Test Conversation";
|
let title = "Test Conversation";
|
||||||
let conversation = Conversation::new(user_id.to_string(), title.to_string());
|
let conversation = Conversation::new(user_id.to_string(), title.to_string());
|
||||||
|
|
||||||
// Verify conversation properties
|
|
||||||
assert_eq!(conversation.user_id, user_id);
|
assert_eq!(conversation.user_id, user_id);
|
||||||
assert_eq!(conversation.title, title);
|
assert_eq!(conversation.title, title);
|
||||||
assert!(!conversation.id.is_empty());
|
assert!(!conversation.id.is_empty());
|
||||||
|
|
||||||
// Store the conversation
|
|
||||||
let result = db.store_item(conversation.clone()).await;
|
let result = db.store_item(conversation.clone()).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Verify it can be retrieved
|
|
||||||
let retrieved: Option<Conversation> = db
|
let retrieved: Option<Conversation> = db
|
||||||
.get_item(&conversation.id)
|
.get_item(&conversation.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve conversation");
|
.with_context(|| "Failed to retrieve conversation".to_string())?;
|
||||||
assert!(retrieved.is_some());
|
|
||||||
|
|
||||||
let retrieved = retrieved.unwrap();
|
let retrieved =
|
||||||
|
retrieved.ok_or_else(|| anyhow::anyhow!("Expected conversation to exist"))?;
|
||||||
assert_eq!(retrieved.id, conversation.id);
|
assert_eq!(retrieved.id, conversation.id);
|
||||||
assert_eq!(retrieved.user_id, user_id);
|
assert_eq!(retrieved.user_id, user_id);
|
||||||
assert_eq!(retrieved.title, title);
|
assert_eq!(retrieved.title, title);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_not_found() {
|
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Try to get a conversation that doesn't exist
|
|
||||||
let result =
|
let result =
|
||||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::NotFound(_)) => { /* expected error */ }
|
Err(AppError::NotFound(_)) => {}
|
||||||
_ => panic!("Expected NotFound error"),
|
_ => anyhow::bail!("Expected NotFound error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_unauthorized() {
|
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create and store a conversation for user_id_1
|
|
||||||
let user_id_1 = "user_1";
|
let user_id_1 = "user_1";
|
||||||
let conversation =
|
let conversation =
|
||||||
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
|
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
|
||||||
@@ -156,27 +247,28 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(conversation)
|
db.store_item(conversation)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
|
|
||||||
// Try to access with a different user
|
|
||||||
let user_id_2 = "user_2";
|
let user_id_2 = "user_2";
|
||||||
let result =
|
let result =
|
||||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected Auth error"),
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_success() {
|
async fn test_patch_title_success() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
let user_id = "user_1";
|
let user_id = "user_1";
|
||||||
let original_title = "Original Title";
|
let original_title = "Original Title";
|
||||||
@@ -185,49 +277,50 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(conversation)
|
db.store_item(conversation)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
|
|
||||||
let new_title = "Updated Title";
|
let new_title = "Updated Title";
|
||||||
|
|
||||||
// Patch title successfully
|
|
||||||
let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await;
|
let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Retrieve from DB to verify
|
|
||||||
let updated_conversation = db
|
let updated_conversation = db
|
||||||
.get_item::<Conversation>(&conversation_id)
|
.get_item::<Conversation>(&conversation_id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get conversation")
|
.with_context(|| "Failed to get conversation".to_string())?
|
||||||
.expect("Conversation missing");
|
.ok_or_else(|| anyhow::anyhow!("Conversation missing"))?;
|
||||||
assert_eq!(updated_conversation.title, new_title);
|
assert_eq!(updated_conversation.title, new_title);
|
||||||
assert_eq!(updated_conversation.user_id, user_id);
|
assert_eq!(updated_conversation.user_id, user_id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_not_found() {
|
async fn test_patch_title_not_found() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Try to patch non-existing conversation
|
|
||||||
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::NotFound(_)) => {}
|
Err(AppError::NotFound(_)) => {}
|
||||||
_ => panic!("Expected NotFound error"),
|
_ => anyhow::bail!("Expected NotFound error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_unauthorized() {
|
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
let owner_id = "owner";
|
let owner_id = "owner";
|
||||||
let other_user_id = "intruder";
|
let other_user_id = "intruder";
|
||||||
@@ -236,38 +329,131 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(conversation)
|
db.store_item(conversation)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
|
|
||||||
// Attempt patch with unauthorized user
|
|
||||||
let result =
|
let result =
|
||||||
Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await;
|
Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::Auth(_)) => {}
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected Auth error"),
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_with_messages() {
|
async fn test_get_user_sidebar_conversations_filters_and_orders_by_updated_at_desc() {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.expect("Failed to start in-memory surrealdb");
|
||||||
|
|
||||||
// Create and store a conversation for user_id_1
|
let user_id = "sidebar_user";
|
||||||
|
let other_user_id = "other_user";
|
||||||
|
let base = Utc::now();
|
||||||
|
|
||||||
|
let mut oldest = Conversation::new(user_id.to_string(), "Oldest".to_string());
|
||||||
|
oldest.updated_at = base - chrono::Duration::minutes(30);
|
||||||
|
|
||||||
|
let mut newest = Conversation::new(user_id.to_string(), "Newest".to_string());
|
||||||
|
newest.updated_at = base - chrono::Duration::minutes(5);
|
||||||
|
|
||||||
|
let mut middle = Conversation::new(user_id.to_string(), "Middle".to_string());
|
||||||
|
middle.updated_at = base - chrono::Duration::minutes(15);
|
||||||
|
|
||||||
|
let mut other_user = Conversation::new(other_user_id.to_string(), "Other".to_string());
|
||||||
|
other_user.updated_at = base;
|
||||||
|
|
||||||
|
db.store_item(oldest.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to store oldest conversation");
|
||||||
|
db.store_item(newest.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to store newest conversation");
|
||||||
|
db.store_item(middle.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to store middle conversation");
|
||||||
|
db.store_item(other_user)
|
||||||
|
.await
|
||||||
|
.expect("Failed to store other-user conversation");
|
||||||
|
|
||||||
|
let sidebar_items = Conversation::get_user_sidebar_conversations(user_id, &db)
|
||||||
|
.await
|
||||||
|
.expect("Failed to get sidebar conversations");
|
||||||
|
|
||||||
|
assert_eq!(sidebar_items.len(), 3);
|
||||||
|
let s0 = sidebar_items.first().expect("expected 3 items");
|
||||||
|
let s1 = sidebar_items.get(1).expect("expected 3 items");
|
||||||
|
let s2 = sidebar_items.get(2).expect("expected 3 items");
|
||||||
|
assert_eq!(s0.id, newest.id);
|
||||||
|
assert_eq!(s0.title, "Newest");
|
||||||
|
assert_eq!(s1.id, middle.id);
|
||||||
|
assert_eq!(s1.title, "Middle");
|
||||||
|
assert_eq!(s2.id, oldest.id);
|
||||||
|
assert_eq!(s2.title, "Oldest");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sidebar_projection_reflects_patch_title_and_updated_at_reorder() {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.expect("Failed to start in-memory surrealdb");
|
||||||
|
|
||||||
|
let user_id = "sidebar_patch_user";
|
||||||
|
let base = Utc::now();
|
||||||
|
|
||||||
|
let mut first = Conversation::new(user_id.to_string(), "First".to_string());
|
||||||
|
first.updated_at = base - chrono::Duration::minutes(20);
|
||||||
|
|
||||||
|
let mut second = Conversation::new(user_id.to_string(), "Second".to_string());
|
||||||
|
second.updated_at = base - chrono::Duration::minutes(10);
|
||||||
|
|
||||||
|
db.store_item(first.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to store first conversation");
|
||||||
|
db.store_item(second.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to store second conversation");
|
||||||
|
|
||||||
|
let before_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
|
||||||
|
.await
|
||||||
|
.expect("Failed to get sidebar conversations before patch");
|
||||||
|
let before = before_patch.first().expect("expected at least 1 item");
|
||||||
|
assert_eq!(before.id, second.id);
|
||||||
|
|
||||||
|
Conversation::patch_title(&first.id, user_id, "First (renamed)", &db)
|
||||||
|
.await
|
||||||
|
.expect("Failed to patch conversation title");
|
||||||
|
|
||||||
|
let after_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
|
||||||
|
.await
|
||||||
|
.expect("Failed to get sidebar conversations after patch");
|
||||||
|
let after = after_patch.first().expect("expected at least 1 item");
|
||||||
|
assert_eq!(after.id, first.id);
|
||||||
|
assert_eq!(after.title, "First (renamed)");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
let user_id_1 = "user_1";
|
let user_id_1 = "user_1";
|
||||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||||
let conversation_id = conversation.id.clone();
|
let conversation_id = conversation.id.clone();
|
||||||
|
|
||||||
db.store_item(conversation)
|
db.store_item(conversation)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
|
|
||||||
// Create messages
|
|
||||||
let message1 = Message::new(
|
let message1 = Message::new(
|
||||||
conversation_id.clone(),
|
conversation_id.clone(),
|
||||||
MessageRole::User,
|
MessageRole::User,
|
||||||
@@ -287,46 +473,200 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Store messages
|
|
||||||
db.store_item(message1)
|
db.store_item(message1)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store message1");
|
.with_context(|| "Failed to store message1".to_string())?;
|
||||||
db.store_item(message2)
|
db.store_item(message2)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store message2");
|
.with_context(|| "Failed to store message2".to_string())?;
|
||||||
db.store_item(message3)
|
db.store_item(message3)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store message3");
|
.with_context(|| "Failed to store message3".to_string())?;
|
||||||
|
|
||||||
// Retrieve the complete conversation
|
|
||||||
let result =
|
let result =
|
||||||
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
|
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
|
||||||
assert!(result.is_ok(), "Failed to retrieve complete conversation");
|
assert!(result.is_ok(), "Failed to retrieve complete conversation");
|
||||||
|
|
||||||
let (retrieved_conversation, messages) = result.unwrap();
|
let (retrieved_conversation, retrieved_messages) =
|
||||||
|
result.with_context(|| "Failed to retrieve complete conversation".to_string())?;
|
||||||
|
|
||||||
// Verify conversation data
|
|
||||||
assert_eq!(retrieved_conversation.id, conversation_id);
|
assert_eq!(retrieved_conversation.id, conversation_id);
|
||||||
assert_eq!(retrieved_conversation.user_id, user_id_1);
|
assert_eq!(retrieved_conversation.user_id, user_id_1);
|
||||||
assert_eq!(retrieved_conversation.title, "Conversation");
|
assert_eq!(retrieved_conversation.title, "Conversation");
|
||||||
|
|
||||||
// Verify messages
|
assert_eq!(retrieved_messages.len(), 3);
|
||||||
assert_eq!(messages.len(), 3);
|
|
||||||
|
|
||||||
// Verify messages are sorted by updated_at
|
let message_contents: Vec<&str> = retrieved_messages
|
||||||
let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect();
|
.iter()
|
||||||
|
.map(|m| m.content.as_str())
|
||||||
|
.collect();
|
||||||
assert!(message_contents.contains(&"Hello, AI!"));
|
assert!(message_contents.contains(&"Hello, AI!"));
|
||||||
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
|
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
|
||||||
assert!(message_contents.contains(&"Tell me about Rust programming."));
|
assert!(message_contents.contains(&"Tell me about Rust programming."));
|
||||||
|
|
||||||
// Make sure we can't access with different user
|
|
||||||
let user_id_2 = "user_2";
|
let user_id_2 = "user_2";
|
||||||
let unauthorized_result =
|
let unauthorized_result =
|
||||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||||
assert!(unauthorized_result.is_err());
|
assert!(unauthorized_result.is_err());
|
||||||
match unauthorized_result {
|
match unauthorized_result {
|
||||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected Auth error"),
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sidebar_conversation_deserializes_plain_string_id() {
|
||||||
|
let item: SidebarConversation =
|
||||||
|
serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#)
|
||||||
|
.expect("valid sidebar conversation json");
|
||||||
|
assert_eq!(item.id, "conv-plain");
|
||||||
|
assert_eq!(item.title, "My chat");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sidebar_conversation_deserializes_id_from_db_record() {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.expect("Failed to start in-memory surrealdb");
|
||||||
|
|
||||||
|
let owner = "sidebar_owner";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "Sidebar title".to_string());
|
||||||
|
let expected_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation)
|
||||||
|
.await
|
||||||
|
.expect("Failed to store conversation");
|
||||||
|
|
||||||
|
let items = Conversation::get_user_sidebar_conversations(owner, &db)
|
||||||
|
.await
|
||||||
|
.expect("Failed to load sidebar");
|
||||||
|
assert_eq!(items.len(), 1);
|
||||||
|
let item = items.first().expect("expected one sidebar item");
|
||||||
|
assert_eq!(item.id, expected_id);
|
||||||
|
assert_eq!(item.title, "Sidebar title");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_message_query_filters_by_owner_user_id_in_sql() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let owner = "owner_user";
|
||||||
|
let intruder = "intruder_user";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "Private".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
db.store_item(Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::User,
|
||||||
|
"secret message".to_string(),
|
||||||
|
None,
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let owner_messages = fetch_messages_for_owner(&db, &conversation_id, owner).await?;
|
||||||
|
assert_eq!(owner_messages.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
owner_messages
|
||||||
|
.first()
|
||||||
|
.expect("expected owner message")
|
||||||
|
.content,
|
||||||
|
"secret message"
|
||||||
|
);
|
||||||
|
|
||||||
|
let intruder_messages = fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
|
||||||
|
assert!(
|
||||||
|
intruder_messages.is_empty(),
|
||||||
|
"SQL owner filter must not return messages for a non-owner user_id"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_complete_conversation_orders_messages_by_updated_at() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let user_id = "order_user";
|
||||||
|
let conversation = Conversation::new(user_id.to_string(), "Ordered".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
|
||||||
|
let base = Utc::now();
|
||||||
|
let mut first = Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::User,
|
||||||
|
"first".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
first.updated_at = base - chrono::Duration::minutes(20);
|
||||||
|
|
||||||
|
let mut second = Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::AI,
|
||||||
|
"second".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
second.updated_at = base - chrono::Duration::minutes(5);
|
||||||
|
|
||||||
|
db.store_item(first).await?;
|
||||||
|
db.store_item(second).await?;
|
||||||
|
|
||||||
|
let (_, messages) =
|
||||||
|
Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?;
|
||||||
|
|
||||||
|
assert_eq!(messages.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
messages.first().expect("expected first message").content,
|
||||||
|
"first"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
messages.get(1).expect("expected second message").content,
|
||||||
|
"second"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_patch_title_not_found_when_conversation_deleted() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let owner = "owner";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "To delete".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
db.delete_item::<Conversation>(&conversation_id).await?;
|
||||||
|
|
||||||
|
let result = Conversation::patch_title(&conversation_id, owner, "New title", &db).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::NotFound(_)) => {}
|
||||||
|
other => anyhow::bail!("expected NotFound, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_conversation_new_initializes_timestamps_and_id() {
|
||||||
|
let before = Utc::now();
|
||||||
|
let conversation = Conversation::new("user".to_string(), "Title".to_string());
|
||||||
|
let after = Utc::now();
|
||||||
|
|
||||||
|
assert!(!conversation.id.is_empty());
|
||||||
|
assert!(conversation.created_at >= before && conversation.created_at <= after);
|
||||||
|
assert_eq!(conversation.created_at, conversation.updated_at);
|
||||||
|
assert_eq!(conversation.user_id, "user");
|
||||||
|
assert_eq!(conversation.title, "Title");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::result_large_err)]
|
||||||
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
@@ -25,77 +26,150 @@ pub enum IngestionPayload {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for IngestionPayload {
|
||||||
|
/// An empty text payload, used as a cheap placeholder when the real content
|
||||||
|
/// has been moved out of a task (see [`crate::storage::types::ingestion_task::IngestionTask::take_content`]).
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Text {
|
||||||
|
text: String::new(),
|
||||||
|
context: String::new(),
|
||||||
|
category: String::new(),
|
||||||
|
user_id: String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared ingest metadata moved or cloned into each payload variant.
|
||||||
|
struct IngestFields {
|
||||||
|
context: String,
|
||||||
|
category: String,
|
||||||
|
user_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of parsing optional ingest content before file payloads are built.
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum ParsedContent {
|
||||||
|
/// No URL or text payload should be appended.
|
||||||
|
Skip,
|
||||||
|
Url(String),
|
||||||
|
Text(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParsedContent {
|
||||||
|
#[must_use]
|
||||||
|
fn follows(&self) -> bool {
|
||||||
|
!matches!(self, Self::Skip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl IngestionPayload {
|
impl IngestionPayload {
|
||||||
/// Creates ingestion payloads from the provided content, context, and files.
|
/// Creates ingestion payloads from the provided content, context, and files.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// Files are emitted first. When both files and content are present, shared
|
||||||
/// * `content` - Optional textual content to be ingressed
|
/// metadata is cloned per file; otherwise the last file-only payload moves
|
||||||
/// * `context` - context for processing the ingress content
|
/// `context`, `category`, and `user_id` without cloning.
|
||||||
/// * `category` - Category to classify the ingressed content
|
|
||||||
/// * `files` - Vector of `FileInfo` objects containing information about uploaded files
|
|
||||||
/// * `user_id` - Identifier of the user performing the ingress operation
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Errors
|
||||||
/// * `Result<Vec<IngestionPayload>, AppError>` - On success, returns a vector of ingress objects
|
///
|
||||||
/// (one per file/content type). On failure, returns an `AppError`.
|
/// Returns [`AppError::NotFound`] when no valid files or content are provided.
|
||||||
|
#[allow(clippy::similar_names)]
|
||||||
pub fn create_ingestion_payload(
|
pub fn create_ingestion_payload(
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
context: String,
|
context: String,
|
||||||
category: String,
|
category: String,
|
||||||
files: Vec<FileInfo>,
|
files: Vec<FileInfo>,
|
||||||
user_id: &str,
|
user_id: String,
|
||||||
) -> Result<Vec<IngestionPayload>, AppError> {
|
) -> Result<Vec<IngestionPayload>, AppError> {
|
||||||
// Initialize list
|
let parsed = Self::parse_content(content);
|
||||||
let mut object_list = Vec::new();
|
let content_follows = parsed.follows();
|
||||||
|
let file_count = files.len();
|
||||||
// Create a IngestionPayload from content if it exists, checking for URL or text
|
#[allow(clippy::arithmetic_side_effects)]
|
||||||
if let Some(input_content) = content {
|
let capacity = file_count + usize::from(content_follows);
|
||||||
match Url::parse(&input_content) {
|
let mut object_list = Vec::with_capacity(capacity);
|
||||||
Ok(url) => {
|
let mut fields = Some(IngestFields {
|
||||||
info!("Detected URL: {}", url);
|
context,
|
||||||
object_list.push(IngestionPayload::Url {
|
category,
|
||||||
url: url.to_string(),
|
user_id,
|
||||||
context: context.clone(),
|
|
||||||
category: category.clone(),
|
|
||||||
user_id: user_id.into(),
|
|
||||||
});
|
});
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
if input_content.len() > 2 {
|
|
||||||
info!("Treating input as plain text");
|
|
||||||
object_list.push(IngestionPayload::Text {
|
|
||||||
text: input_content.to_string(),
|
|
||||||
context: context.clone(),
|
|
||||||
category: category.clone(),
|
|
||||||
user_id: user_id.into(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for file in files {
|
for (index, file) in files.into_iter().enumerate() {
|
||||||
object_list.push(IngestionPayload::File {
|
let is_last_file = index.saturating_add(1) == file_count;
|
||||||
|
if content_follows || !is_last_file {
|
||||||
|
let Some(shared) = fields.as_ref() else {
|
||||||
|
return Err(AppError::internal("shared ingest fields consumed early"));
|
||||||
|
};
|
||||||
|
object_list.push(Self::File {
|
||||||
file_info: file,
|
file_info: file,
|
||||||
context: context.clone(),
|
context: shared.context.clone(),
|
||||||
category: category.clone(),
|
category: shared.category.clone(),
|
||||||
user_id: user_id.into(),
|
user_id: shared.user_id.clone(),
|
||||||
})
|
});
|
||||||
|
} else {
|
||||||
|
let Some(shared) = fields.take() else {
|
||||||
|
return Err(AppError::internal("shared ingest fields missing for file"));
|
||||||
|
};
|
||||||
|
object_list.push(Self::File {
|
||||||
|
file_info: file,
|
||||||
|
context: shared.context,
|
||||||
|
category: shared.category,
|
||||||
|
user_id: shared.user_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let ParsedContent::Url(url) = parsed {
|
||||||
|
info!("Detected URL: {url}");
|
||||||
|
let Some(shared) = fields.take() else {
|
||||||
|
return Err(AppError::internal("shared ingest fields missing for url"));
|
||||||
|
};
|
||||||
|
object_list.push(Self::Url {
|
||||||
|
url,
|
||||||
|
context: shared.context,
|
||||||
|
category: shared.category,
|
||||||
|
user_id: shared.user_id,
|
||||||
|
});
|
||||||
|
} else if let ParsedContent::Text(text) = parsed {
|
||||||
|
info!("Treating input as plain text");
|
||||||
|
let Some(shared) = fields.take() else {
|
||||||
|
return Err(AppError::internal("shared ingest fields missing for text"));
|
||||||
|
};
|
||||||
|
object_list.push(Self::Text {
|
||||||
|
text,
|
||||||
|
context: shared.context,
|
||||||
|
category: shared.category,
|
||||||
|
user_id: shared.user_id,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no objects are constructed, we return Err
|
|
||||||
if object_list.is_empty() {
|
if object_list.is_empty() {
|
||||||
return Err(AppError::NotFound(
|
return Err(AppError::NotFound(
|
||||||
"No valid content or files provided".into(),
|
"no valid content or files provided".into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(object_list)
|
Ok(object_list)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_content(content: Option<String>) -> ParsedContent {
|
||||||
|
let Some(input_content) = content else {
|
||||||
|
return ParsedContent::Skip;
|
||||||
|
};
|
||||||
|
|
||||||
|
if input_content.len() <= 2 {
|
||||||
|
return ParsedContent::Skip;
|
||||||
|
}
|
||||||
|
|
||||||
|
match Url::parse(&input_content) {
|
||||||
|
Ok(url) => ParsedContent::Url(url.to_string()),
|
||||||
|
Err(_) => ParsedContent::Text(input_content),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -124,24 +198,23 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_with_url() {
|
fn test_create_ingestion_payload_with_url() -> anyhow::Result<()> {
|
||||||
let url = "https://example.com";
|
let url = "https://example.com";
|
||||||
let context = "Process this URL";
|
let context = "Process this URL";
|
||||||
let category = "websites";
|
let category = "websites";
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let files = vec![];
|
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
Some(url.to_string()),
|
Some(url.to_string()),
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
vec![],
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
match &result[0] {
|
match result.first().context("expected one result")? {
|
||||||
IngestionPayload::Url {
|
IngestionPayload::Url {
|
||||||
url: payload_url,
|
url: payload_url,
|
||||||
context: payload_context,
|
context: payload_context,
|
||||||
@@ -149,34 +222,34 @@ mod tests {
|
|||||||
user_id: payload_user_id,
|
user_id: payload_user_id,
|
||||||
} => {
|
} => {
|
||||||
// URL parser may normalize the URL by adding a trailing slash
|
// URL parser may normalize the URL by adding a trailing slash
|
||||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||||
assert_eq!(payload_context, &context);
|
assert_eq!(payload_context, &context);
|
||||||
assert_eq!(payload_category, &category);
|
assert_eq!(payload_category, &category);
|
||||||
assert_eq!(payload_user_id, &user_id);
|
assert_eq!(payload_user_id, &user_id);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected Url variant"),
|
_ => anyhow::bail!("Expected Url variant"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_with_text() {
|
fn test_create_ingestion_payload_with_text() -> anyhow::Result<()> {
|
||||||
let text = "This is some text content";
|
let text = "This is some text content";
|
||||||
let context = "Process this text";
|
let context = "Process this text";
|
||||||
let category = "notes";
|
let category = "notes";
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let files = vec![];
|
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
Some(text.to_string()),
|
Some(text.to_string()),
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
vec![],
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
match &result[0] {
|
match result.first().context("expected one result")? {
|
||||||
IngestionPayload::Text {
|
IngestionPayload::Text {
|
||||||
text: payload_text,
|
text: payload_text,
|
||||||
context: payload_context,
|
context: payload_context,
|
||||||
@@ -188,12 +261,13 @@ mod tests {
|
|||||||
assert_eq!(payload_category, category);
|
assert_eq!(payload_category, category);
|
||||||
assert_eq!(payload_user_id, user_id);
|
assert_eq!(payload_user_id, user_id);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected Text variant"),
|
_ => anyhow::bail!("Expected Text variant"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_with_file() {
|
fn test_create_ingestion_payload_with_file() -> anyhow::Result<()> {
|
||||||
let context = "Process this file";
|
let context = "Process this file";
|
||||||
let category = "documents";
|
let category = "documents";
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
@@ -204,36 +278,36 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let file_info: FileInfo = mock_file.into();
|
let file_info: FileInfo = mock_file.into();
|
||||||
let files = vec![file_info.clone()];
|
let file_id = file_info.id.clone();
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
None,
|
None,
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
vec![file_info],
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
)
|
)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
match &result[0] {
|
match result.first().context("expected one result")? {
|
||||||
IngestionPayload::File {
|
IngestionPayload::File {
|
||||||
file_info: payload_file_info,
|
file_info: payload_file_info,
|
||||||
context: payload_context,
|
context: payload_context,
|
||||||
category: payload_category,
|
category: payload_category,
|
||||||
user_id: payload_user_id,
|
user_id: payload_user_id,
|
||||||
} => {
|
} => {
|
||||||
assert_eq!(payload_file_info.id, file_info.id);
|
assert_eq!(payload_file_info.id, file_id);
|
||||||
assert_eq!(payload_context, context);
|
assert_eq!(payload_context, context);
|
||||||
assert_eq!(payload_category, category);
|
assert_eq!(payload_category, category);
|
||||||
assert_eq!(payload_user_id, user_id);
|
assert_eq!(payload_user_id, user_id);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected File variant"),
|
_ => anyhow::bail!("Expected File variant"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_with_url_and_file() {
|
fn test_create_ingestion_payload_with_url_and_file() -> anyhow::Result<()> {
|
||||||
let url = "https://example.com";
|
let url = "https://example.com";
|
||||||
let context = "Process this data";
|
let context = "Process this data";
|
||||||
let category = "mixed";
|
let category = "mixed";
|
||||||
@@ -245,88 +319,207 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let file_info: FileInfo = mock_file.into();
|
let file_info: FileInfo = mock_file.into();
|
||||||
let files = vec![file_info.clone()];
|
let file_id = file_info.id.clone();
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
Some(url.to_string()),
|
Some(url.to_string()),
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
vec![file_info],
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
)
|
)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 2);
|
assert_eq!(result.len(), 2);
|
||||||
|
|
||||||
// Check first item is URL
|
// Check first item is File (files processed first to minimize clones)
|
||||||
match &result[0] {
|
match result.first().context("expected first item")? {
|
||||||
IngestionPayload::Url {
|
|
||||||
url: payload_url, ..
|
|
||||||
} => {
|
|
||||||
// URL parser may normalize the URL by adding a trailing slash
|
|
||||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
|
||||||
}
|
|
||||||
_ => panic!("Expected first item to be Url variant"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check second item is File
|
|
||||||
match &result[1] {
|
|
||||||
IngestionPayload::File {
|
IngestionPayload::File {
|
||||||
file_info: payload_file_info,
|
file_info: payload_file_info,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
assert_eq!(payload_file_info.id, file_info.id);
|
assert_eq!(payload_file_info.id, file_id);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected second item to be File variant"),
|
_ => anyhow::bail!("Expected first item to be File variant"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check second item is URL
|
||||||
|
match result.get(1).context("expected second item")? {
|
||||||
|
IngestionPayload::Url {
|
||||||
|
url: payload_url, ..
|
||||||
|
} => {
|
||||||
|
// URL parser may normalize the URL by adding a trailing slash
|
||||||
|
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("Expected second item to be Url variant"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_empty_input() {
|
fn test_create_ingestion_payload_empty_input() -> anyhow::Result<()> {
|
||||||
let context = "Process something";
|
let context = "Process something";
|
||||||
let category = "empty";
|
let category = "empty";
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let files = vec![];
|
|
||||||
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
|
None,
|
||||||
|
context.to_string(),
|
||||||
|
category.to_string(),
|
||||||
|
vec![],
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::NotFound(msg)) => {
|
||||||
|
assert_eq!(msg, "no valid content or files provided");
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("Expected NotFound error"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_ingestion_payload_with_empty_text() -> anyhow::Result<()> {
|
||||||
|
let text = ""; // Empty text
|
||||||
|
let context = "Process this";
|
||||||
|
let category = "notes";
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
|
Some(text.to_string()),
|
||||||
|
context.to_string(),
|
||||||
|
category.to_string(),
|
||||||
|
vec![],
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::NotFound(msg)) => {
|
||||||
|
assert_eq!(msg, "no valid content or files provided");
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("Expected NotFound error"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_ingestion_payload_with_file_and_text() -> anyhow::Result<()> {
|
||||||
|
let text = "plain notes";
|
||||||
|
let context = "ctx";
|
||||||
|
let category = "cat";
|
||||||
|
let user_id = "user123";
|
||||||
|
let file_info: FileInfo = MockFileInfo {
|
||||||
|
id: "file1".to_string(),
|
||||||
|
}
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
|
Some(text.to_string()),
|
||||||
|
context.to_string(),
|
||||||
|
category.to_string(),
|
||||||
|
vec![file_info],
|
||||||
|
user_id.to_string(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 2);
|
||||||
|
let first = result.first().expect("expected first payload");
|
||||||
|
let second = result.get(1).expect("expected second payload");
|
||||||
|
match (first, second) {
|
||||||
|
(
|
||||||
|
IngestionPayload::File {
|
||||||
|
file_info: payload_file,
|
||||||
|
context: file_context,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
IngestionPayload::Text {
|
||||||
|
text: payload_text,
|
||||||
|
context: text_context,
|
||||||
|
category: text_category,
|
||||||
|
user_id: text_user_id,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
assert_eq!(payload_file.id, "file1");
|
||||||
|
assert_eq!(file_context, context);
|
||||||
|
assert_eq!(payload_text, text);
|
||||||
|
assert_eq!(text_context, context);
|
||||||
|
assert_eq!(text_category, category);
|
||||||
|
assert_eq!(text_user_id, user_id);
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("expected File then Text"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_ingestion_payload_short_content_with_file_only_yields_file() -> anyhow::Result<()>
|
||||||
|
{
|
||||||
|
let context = "ctx";
|
||||||
|
let category = "cat";
|
||||||
|
let user_id = "user123";
|
||||||
|
let file_info: FileInfo = MockFileInfo {
|
||||||
|
id: "file1".to_string(),
|
||||||
|
}
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
|
Some("ab".to_string()),
|
||||||
|
context.to_string(),
|
||||||
|
category.to_string(),
|
||||||
|
vec![file_info],
|
||||||
|
user_id.to_string(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
match result.first().context("expected one file payload")? {
|
||||||
|
IngestionPayload::File {
|
||||||
|
file_info,
|
||||||
|
context: payload_context,
|
||||||
|
category: payload_category,
|
||||||
|
user_id: payload_user_id,
|
||||||
|
} => {
|
||||||
|
assert_eq!(file_info.id, "file1");
|
||||||
|
assert_eq!(payload_context, context);
|
||||||
|
assert_eq!(payload_category, category);
|
||||||
|
assert_eq!(payload_user_id, user_id);
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("expected File variant only"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_ingestion_payload_two_files_without_content() -> anyhow::Result<()> {
|
||||||
|
let context = "ctx";
|
||||||
|
let category = "cat";
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let files = vec![
|
||||||
|
MockFileInfo {
|
||||||
|
id: "file1".to_string(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
MockFileInfo {
|
||||||
|
id: "file2".to_string(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
];
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
None,
|
None,
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
files,
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
);
|
)?;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert_eq!(result.len(), 2);
|
||||||
match result {
|
assert!(matches!(
|
||||||
Err(AppError::NotFound(msg)) => {
|
result.first(),
|
||||||
assert_eq!(msg, "No valid content or files provided");
|
Some(IngestionPayload::File { .. })
|
||||||
}
|
));
|
||||||
_ => panic!("Expected NotFound error"),
|
assert!(matches!(result.get(1), Some(IngestionPayload::File { .. })));
|
||||||
}
|
Ok(())
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_create_ingestion_payload_with_empty_text() {
|
|
||||||
let text = ""; // Empty text
|
|
||||||
let context = "Process this";
|
|
||||||
let category = "notes";
|
|
||||||
let user_id = "user123";
|
|
||||||
let files = vec![];
|
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
|
||||||
Some(text.to_string()),
|
|
||||||
context.to_string(),
|
|
||||||
category.to_string(),
|
|
||||||
files,
|
|
||||||
user_id,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result {
|
|
||||||
Err(AppError::NotFound(msg)) => {
|
|
||||||
assert_eq!(msg, "No valid content or files provided");
|
|
||||||
}
|
|
||||||
_ => panic!("Expected NotFound error"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,478 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use surrealdb::RecordId;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
error::AppError,
|
||||||
|
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||||
|
stored_object,
|
||||||
|
};
|
||||||
|
|
||||||
|
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
||||||
|
entity_id: RecordId,
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
/// Denormalized source id for bulk deletes
|
||||||
|
source_id: String,
|
||||||
|
/// Denormalized user id for query scoping
|
||||||
|
user_id: String
|
||||||
|
});
|
||||||
|
|
||||||
|
impl KnowledgeEntityEmbedding {
|
||||||
|
/// Recreate the HNSW index with a new embedding dimension.
|
||||||
|
pub async fn redefine_hnsw_index(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
dimension: usize,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let query = hnsw_index_redefine_transaction_sql(
|
||||||
|
"idx_embedding_knowledge_entity_embedding",
|
||||||
|
Self::table_name(),
|
||||||
|
dimension,
|
||||||
|
);
|
||||||
|
|
||||||
|
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||||
|
res.check().map_err(AppError::from)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||||
|
#[allow(clippy::result_large_err)]
|
||||||
|
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||||
|
if embedding.len() != expected {
|
||||||
|
return Err(AppError::Validation(format!(
|
||||||
|
"embedding dimension mismatch: got {}, expected {expected}",
|
||||||
|
embedding.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new knowledge entity embedding.
|
||||||
|
///
|
||||||
|
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||||
|
let now = Utc::now();
|
||||||
|
Self {
|
||||||
|
id: entity_id.to_owned(),
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
||||||
|
embedding,
|
||||||
|
source_id,
|
||||||
|
user_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get embedding by entity ID
|
||||||
|
pub async fn get_by_entity_id(
|
||||||
|
entity_id: &RecordId,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<Option<Self>, AppError> {
|
||||||
|
let query = format!(
|
||||||
|
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
|
||||||
|
Self::table_name()
|
||||||
|
);
|
||||||
|
let mut result = db
|
||||||
|
.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("entity_id", entity_id.clone()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||||
|
Ok(embeddings.into_iter().next())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get embeddings for multiple entities in batch
|
||||||
|
pub async fn get_by_entity_ids(
|
||||||
|
entity_ids: &[RecordId],
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<HashMap<String, Vec<f32>>, AppError> {
|
||||||
|
if entity_ids.is_empty() {
|
||||||
|
return Ok(HashMap::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let query = format!(
|
||||||
|
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
|
||||||
|
Self::table_name()
|
||||||
|
);
|
||||||
|
let mut result = db
|
||||||
|
.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("entity_ids", entity_ids.to_vec()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
|
Ok(embeddings
|
||||||
|
.into_iter()
|
||||||
|
.map(|e| (e.entity_id.key().to_string(), e.embedding))
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete embedding by entity ID
|
||||||
|
pub async fn delete_by_entity_id(
|
||||||
|
entity_id: &RecordId,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let query = format!(
|
||||||
|
"DELETE FROM {} WHERE entity_id = $entity_id",
|
||||||
|
Self::table_name()
|
||||||
|
);
|
||||||
|
db.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("entity_id", entity_id.clone()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete all embeddings with the given denormalized `source_id`.
|
||||||
|
pub async fn delete_by_source_id(
|
||||||
|
source_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let query = format!(
|
||||||
|
"DELETE FROM {} WHERE source_id = $source_id",
|
||||||
|
Self::table_name()
|
||||||
|
);
|
||||||
|
db.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("source_id", source_id.to_owned()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use super::*;
|
||||||
|
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||||
|
use crate::test_utils::{prepare_knowledge_entity_test_db, setup_test_db};
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
use chrono::Utc;
|
||||||
|
use surrealdb::Value as SurrealValue;
|
||||||
|
|
||||||
|
fn build_knowledge_entity_with_id(
|
||||||
|
key: &str,
|
||||||
|
source_id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
) -> KnowledgeEntity {
|
||||||
|
KnowledgeEntity {
|
||||||
|
id: key.to_owned(),
|
||||||
|
created_at: Utc::now(),
|
||||||
|
updated_at: Utc::now(),
|
||||||
|
source_id: source_id.to_owned(),
|
||||||
|
name: "Test entity".to_owned(),
|
||||||
|
description: "Desc".to_owned(),
|
||||||
|
entity_type: KnowledgeEntityType::Document,
|
||||||
|
metadata: None,
|
||||||
|
user_id: user_id.to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_uses_entity_id_as_record_id() {
|
||||||
|
let emb = KnowledgeEntityEmbedding::new(
|
||||||
|
"entity-abc",
|
||||||
|
"source-1".to_owned(),
|
||||||
|
vec![0.1, 0.2],
|
||||||
|
"user-1".to_owned(),
|
||||||
|
);
|
||||||
|
assert_eq!(emb.id, "entity-abc");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_dimension_rejects_mismatch() {
|
||||||
|
let err = KnowledgeEntityEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2)
|
||||||
|
.expect_err("expected dimension mismatch");
|
||||||
|
assert!(matches!(err, AppError::Validation(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
let user_id = "user_ke";
|
||||||
|
let entity_key = "entity-1";
|
||||||
|
let source_id = "source-ke";
|
||||||
|
|
||||||
|
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
|
||||||
|
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||||
|
|
||||||
|
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
|
||||||
|
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||||
|
|
||||||
|
assert_eq!(fetched.id, entity_key);
|
||||||
|
assert_eq!(fetched.user_id, user_id);
|
||||||
|
assert_eq!(fetched.source_id, source_id);
|
||||||
|
assert_eq!(fetched.entity_id, entity_rid);
|
||||||
|
assert_eq!(fetched.embedding, embedding_vec);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
let user_id = "user_ke";
|
||||||
|
let entity_key = "entity-delete";
|
||||||
|
let source_id = "source-del";
|
||||||
|
|
||||||
|
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||||
|
|
||||||
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
|
||||||
|
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||||
|
assert!(existing.is_some());
|
||||||
|
|
||||||
|
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
||||||
|
|
||||||
|
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||||
|
assert!(after.is_none());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
let user_id = "user_store";
|
||||||
|
let source_id = "source_store";
|
||||||
|
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||||
|
|
||||||
|
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||||
|
|
||||||
|
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
|
let stored_entity: Option<KnowledgeEntity> = db
|
||||||
|
.get_item(&entity.id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get entity".to_string())?;
|
||||||
|
assert!(stored_entity.is_some());
|
||||||
|
|
||||||
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||||
|
let stored_embedding =
|
||||||
|
stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||||
|
assert_eq!(stored_embedding.id, entity.id);
|
||||||
|
assert_eq!(stored_embedding.user_id, user_id);
|
||||||
|
assert_eq!(stored_embedding.source_id, source_id);
|
||||||
|
assert_eq!(stored_embedding.entity_id, entity_rid);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
|
||||||
|
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||||
|
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
|
||||||
|
|
||||||
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
let user_id = "user_ke";
|
||||||
|
let source_id = "shared-ke";
|
||||||
|
let other_source = "other-ke";
|
||||||
|
|
||||||
|
let entity1 = build_knowledge_entity_with_id("entity-s1", source_id, user_id);
|
||||||
|
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
|
||||||
|
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
|
||||||
|
|
||||||
|
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
|
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||||
|
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||||
|
let other_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity_other.id);
|
||||||
|
|
||||||
|
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get entity1 embedding after delete".to_string())?
|
||||||
|
.is_none()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get entity2 embedding after delete".to_string())?
|
||||||
|
.is_none()
|
||||||
|
);
|
||||||
|
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get other embedding after delete".to_string())?
|
||||||
|
.is_some());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
|
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
|
||||||
|
.await
|
||||||
|
.with_context(|| "failed to redefine index".to_string())?;
|
||||||
|
|
||||||
|
let mut info_res = db
|
||||||
|
.client
|
||||||
|
.query("INFO FOR TABLE knowledge_entity_embedding;")
|
||||||
|
.await
|
||||||
|
.with_context(|| "info query failed".to_string())?;
|
||||||
|
let info: SurrealValue = info_res
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "failed to take info result".to_string())?;
|
||||||
|
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||||
|
.with_context(|| "failed to convert info to json".to_string())?;
|
||||||
|
let idx_sql = info_json
|
||||||
|
.get("Object")
|
||||||
|
.and_then(|v| v.get("indexes"))
|
||||||
|
.and_then(|v| v.get("Object"))
|
||||||
|
.and_then(|v| v.get("idx_embedding_knowledge_entity_embedding"))
|
||||||
|
.and_then(|v| v.get("Strand"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
idx_sql.contains("DIMENSION 16"),
|
||||||
|
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
idx_sql.contains("DIST COSINE"),
|
||||||
|
"expected index definition to use cosine distance, got: {idx_sql}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fetch_entity_via_record_id() -> anyhow::Result<()> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Row {
|
||||||
|
entity_id: KnowledgeEntity,
|
||||||
|
}
|
||||||
|
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
let user_id = "user_ke";
|
||||||
|
let entity_key = "entity-fetch";
|
||||||
|
let source_id = "source-fetch";
|
||||||
|
|
||||||
|
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||||
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
|
||||||
|
let mut res = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"SELECT entity_id FROM knowledge_entity_embedding WHERE entity_id = $id FETCH entity_id;",
|
||||||
|
)
|
||||||
|
.bind(("id", entity_rid.clone()))
|
||||||
|
.await
|
||||||
|
.with_context(|| "failed to fetch embedding with FETCH".to_string())?;
|
||||||
|
let rows: Vec<Row> = res
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "failed to deserialize fetch rows".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(rows.len(), 1);
|
||||||
|
let fetched_entity = &rows
|
||||||
|
.first()
|
||||||
|
.context("Expected at least one result")?
|
||||||
|
.entity_id;
|
||||||
|
assert_eq!(fetched_entity.id, entity_key);
|
||||||
|
assert_eq!(fetched_entity.name, "Test entity");
|
||||||
|
assert_eq!(fetched_entity.user_id, user_id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
|
||||||
|
let user_id = "user-upsert";
|
||||||
|
let source_id = "source-upsert";
|
||||||
|
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
|
||||||
|
|
||||||
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "initial store".to_string())?;
|
||||||
|
|
||||||
|
let replacement = KnowledgeEntityEmbedding::new(
|
||||||
|
&entity.id,
|
||||||
|
source_id.to_owned(),
|
||||||
|
vec![0.0, 1.0, 0.0],
|
||||||
|
user_id.to_owned(),
|
||||||
|
);
|
||||||
|
db.upsert_item(replacement)
|
||||||
|
.await
|
||||||
|
.with_context(|| "upsert replacement embedding".to_string())?;
|
||||||
|
|
||||||
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
let rows: Vec<KnowledgeEntityEmbedding> = db
|
||||||
|
.client
|
||||||
|
.query(format!(
|
||||||
|
"SELECT * FROM {} WHERE entity_id = $entity_id",
|
||||||
|
KnowledgeEntityEmbedding::table_name()
|
||||||
|
))
|
||||||
|
.bind(("entity_id", entity_rid))
|
||||||
|
.await
|
||||||
|
.with_context(|| "count embeddings".to_string())?
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "take embeddings".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(rows.len(), 1);
|
||||||
|
let row = rows.first().expect("expected one embedding row");
|
||||||
|
assert_eq!(row.id, entity.id);
|
||||||
|
assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::storage::types::file_info::deserialize_flexible_id;
|
use crate::storage::types::user::User;
|
||||||
|
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -21,6 +22,7 @@ pub struct KnowledgeRelationship {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl KnowledgeRelationship {
|
impl KnowledgeRelationship {
|
||||||
|
#[must_use]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
in_: String,
|
in_: String,
|
||||||
out: String,
|
out: String,
|
||||||
@@ -39,64 +41,143 @@ impl KnowledgeRelationship {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
|
||||||
let query = format!(
|
|
||||||
r#"RELATE knowledge_entity:`{}`->relates_to:`{}`->knowledge_entity:`{}`
|
|
||||||
SET
|
|
||||||
metadata.user_id = '{}',
|
|
||||||
metadata.source_id = '{}',
|
|
||||||
metadata.relationship_type = '{}'"#,
|
|
||||||
self.in_,
|
|
||||||
self.id,
|
|
||||||
self.out,
|
|
||||||
self.metadata.user_id,
|
|
||||||
self.metadata.source_id,
|
|
||||||
self.metadata.relationship_type
|
|
||||||
);
|
|
||||||
|
|
||||||
db_client.query(query).await?;
|
pub async fn store_relationship(self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||||
|
User::get_and_validate_knowledge_entity(&self.in_, &self.metadata.user_id, db_client)
|
||||||
|
.await?;
|
||||||
|
User::get_and_validate_knowledge_entity(&self.out, &self.metadata.user_id, db_client)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let Self {
|
||||||
|
id,
|
||||||
|
in_,
|
||||||
|
out,
|
||||||
|
metadata:
|
||||||
|
RelationshipMetadata {
|
||||||
|
user_id,
|
||||||
|
source_id,
|
||||||
|
relationship_type,
|
||||||
|
},
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
db_client
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
r#"BEGIN TRANSACTION;
|
||||||
|
LET $in_entity = type::thing('knowledge_entity', $in_id);
|
||||||
|
LET $out_entity = type::thing('knowledge_entity', $out_id);
|
||||||
|
LET $relation = type::thing('relates_to', $rel_id);
|
||||||
|
DELETE type::thing('relates_to', $rel_id);
|
||||||
|
RELATE $in_entity->$relation->$out_entity SET
|
||||||
|
metadata.user_id = $user_id,
|
||||||
|
metadata.source_id = $source_id,
|
||||||
|
metadata.relationship_type = $relationship_type;
|
||||||
|
COMMIT TRANSACTION;"#,
|
||||||
|
)
|
||||||
|
.bind(("rel_id", id))
|
||||||
|
.bind(("in_id", in_))
|
||||||
|
.bind(("out_id", out))
|
||||||
|
.bind(("user_id", user_id))
|
||||||
|
.bind(("source_id", source_id))
|
||||||
|
.bind(("relationship_type", relationship_type))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn delete_relationships_by_source_id(
|
pub async fn delete_relationships_by_source_id(
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
|
user_id: &str,
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let query = format!(
|
db_client
|
||||||
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'",
|
.client
|
||||||
source_id
|
.query(
|
||||||
);
|
"DELETE FROM relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id",
|
||||||
|
)
|
||||||
db_client.query(query).await?;
|
.bind(("source_id", source_id.to_owned()))
|
||||||
|
.bind(("user_id", user_id.to_owned()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn delete_relationship_by_id(
|
pub async fn delete_relationship_by_id(
|
||||||
id: &str,
|
id: &str,
|
||||||
|
user_id: &str,
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let query = format!("DELETE relates_to:`{}`", id);
|
let mut delete_result = db_client
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"DELETE type::thing('relates_to', $id) WHERE metadata.user_id = $user_id RETURN BEFORE;",
|
||||||
|
)
|
||||||
|
.bind(("id", id.to_owned()))
|
||||||
|
.bind(("user_id", user_id.to_owned()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
let deleted: Vec<KnowledgeRelationship> = delete_result.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
db_client.query(query).await?;
|
if !deleted.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
let mut exists_result = db_client
|
||||||
|
.client
|
||||||
|
.query("SELECT * FROM type::thing('relates_to', $id)")
|
||||||
|
.bind(("id", id.to_owned()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
let existing: Option<KnowledgeRelationship> =
|
||||||
|
exists_result.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
|
if existing.is_some() {
|
||||||
|
Err(AppError::Auth(
|
||||||
|
"Not authorized to delete relationship".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Err(AppError::NotFound(format!("Relationship {id} not found")))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
// Helper function to create a test knowledge entity for the relationship tests
|
use crate::test_utils::setup_test_db;
|
||||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
|
|
||||||
|
async fn get_relationship_by_id(
|
||||||
|
relationship_id: &str,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
) -> Option<KnowledgeRelationship> {
|
||||||
|
let mut result = db_client
|
||||||
|
.client
|
||||||
|
.query("SELECT * FROM type::thing('relates_to', $id)")
|
||||||
|
.bind(("id", relationship_id.to_owned()))
|
||||||
|
.await
|
||||||
|
.expect("relationship query by id failed");
|
||||||
|
|
||||||
|
result.take(0).expect("failed to take relationship by id")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_test_entity(
|
||||||
|
name: &str,
|
||||||
|
user_id: &str,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let description = format!("Description for {}", name);
|
let description = format!("Description for {name}");
|
||||||
let entity_type = KnowledgeEntityType::Document;
|
let entity_type = KnowledgeEntityType::Document;
|
||||||
let embedding = vec![0.1, 0.2, 0.3];
|
|
||||||
let user_id = "user123".to_string();
|
|
||||||
|
|
||||||
let entity = KnowledgeEntity::new(
|
let entity = KnowledgeEntity::new(
|
||||||
source_id,
|
source_id,
|
||||||
@@ -104,19 +185,20 @@ mod tests {
|
|||||||
description,
|
description,
|
||||||
entity_type,
|
entity_type,
|
||||||
None,
|
None,
|
||||||
embedding,
|
user_id.to_string(),
|
||||||
user_id,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let stored: Option<KnowledgeEntity> = db_client
|
let stored: Option<KnowledgeEntity> = db_client
|
||||||
.store_item(entity)
|
.store_item(entity)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity");
|
.with_context(|| "Failed to store entity".to_string())?;
|
||||||
stored.unwrap().id
|
stored
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Expected stored entity to return Some"))
|
||||||
|
.map(|e| e.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_relationship_creation() {
|
async fn test_relationship_creation() -> anyhow::Result<()> {
|
||||||
let in_id = "entity1".to_string();
|
let in_id = "entity1".to_string();
|
||||||
let out_id = "entity2".to_string();
|
let out_id = "entity2".to_string();
|
||||||
let user_id = "user123".to_string();
|
let user_id = "user123".to_string();
|
||||||
@@ -131,133 +213,261 @@ mod tests {
|
|||||||
relationship_type.clone(),
|
relationship_type.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify fields are correctly set
|
|
||||||
assert_eq!(relationship.in_, in_id);
|
assert_eq!(relationship.in_, in_id);
|
||||||
assert_eq!(relationship.out, out_id);
|
assert_eq!(relationship.out, out_id);
|
||||||
assert_eq!(relationship.metadata.user_id, user_id);
|
assert_eq!(relationship.metadata.user_id, user_id);
|
||||||
assert_eq!(relationship.metadata.source_id, source_id);
|
assert_eq!(relationship.metadata.source_id, source_id);
|
||||||
assert_eq!(relationship.metadata.relationship_type, relationship_type);
|
assert_eq!(relationship.metadata.relationship_type, relationship_type);
|
||||||
assert!(!relationship.id.is_empty());
|
assert!(!relationship.id.is_empty());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_relationship() {
|
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let namespace = "test_ns";
|
let user_id = "user123";
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Create two entities to relate
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
|
||||||
|
|
||||||
// Create relationship
|
|
||||||
let user_id = "user123".to_string();
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let relationship_type = "references".to_string();
|
let relationship_type = "references".to_string();
|
||||||
|
|
||||||
let relationship = KnowledgeRelationship::new(
|
let relationship = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
relationship_type,
|
relationship_type,
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
// Store the relationship
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship");
|
.with_context(|| "Failed to store relationship".to_string())?;
|
||||||
|
|
||||||
// Query to verify the relationship exists by checking for relationships with our source_id
|
let persisted = get_relationship_by_id(&relationship_id, &db)
|
||||||
// This approach is more reliable than trying to look up by ID
|
.await
|
||||||
let check_query = format!(
|
.expect("Relationship should be retrievable by id");
|
||||||
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
|
assert_eq!(persisted.in_, entity1_id);
|
||||||
source_id
|
assert_eq!(persisted.out, entity2_id);
|
||||||
);
|
assert_eq!(persisted.metadata.user_id, user_id);
|
||||||
let mut check_result = db.query(check_query).await.expect("Check query failed");
|
assert_eq!(persisted.metadata.source_id, source_id);
|
||||||
|
|
||||||
|
let mut check_result = db
|
||||||
|
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||||
|
.bind(("source_id", source_id.clone()))
|
||||||
|
.await
|
||||||
|
.expect("Check query failed");
|
||||||
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
||||||
|
|
||||||
// Just verify that a relationship was created
|
assert_eq!(
|
||||||
assert!(
|
check_results.len(),
|
||||||
!check_results.is_empty(),
|
1,
|
||||||
"Relationship should exist in the database"
|
"Expected one relationship for source_id"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_relationship_by_id() {
|
async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
|
||||||
|
|
||||||
|
let relationship = KnowledgeRelationship::new(
|
||||||
|
owner_entity,
|
||||||
|
other_entity,
|
||||||
|
"owner-user".to_string(),
|
||||||
|
"source123".to_string(),
|
||||||
|
"references".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = relationship.store_relationship(&db).await;
|
||||||
|
assert!(matches!(result, Err(AppError::Auth(_))));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
|
|
||||||
|
let relationship = KnowledgeRelationship::new(
|
||||||
|
entity1_id,
|
||||||
|
entity2_id,
|
||||||
|
user_id.to_string(),
|
||||||
|
"source123'; DELETE FROM relates_to; --".to_string(),
|
||||||
|
"references'; UPDATE user SET admin = true; --".to_string(),
|
||||||
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
|
relationship
|
||||||
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.expect("store relationship should safely handle quote-containing values");
|
||||||
|
|
||||||
// Create two entities to relate
|
let mut res = db
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
.client
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
|
||||||
|
.bind(("id", relationship_id))
|
||||||
|
.await
|
||||||
|
.expect("query relationship by id failed");
|
||||||
|
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
|
||||||
|
|
||||||
|
assert_eq!(rows.len(), 1);
|
||||||
|
let row = rows.first().expect("expected 1 row");
|
||||||
|
assert_eq!(
|
||||||
|
row.metadata.source_id,
|
||||||
|
"source123'; DELETE FROM relates_to; --"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
|
|
||||||
// Create relationship
|
|
||||||
let user_id = "user123".to_string();
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let relationship_type = "references".to_string();
|
let relationship_type = "references".to_string();
|
||||||
|
|
||||||
let relationship = KnowledgeRelationship::new(
|
let relationship = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
relationship_type,
|
relationship_type,
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
// Store the relationship
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship");
|
.with_context(|| "Failed to store relationship".to_string())?;
|
||||||
|
|
||||||
// Delete the relationship by ID
|
let mut existing_before_delete = db
|
||||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &db)
|
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.bind(("source_id", source_id.clone()))
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete relationship by ID");
|
.with_context(|| "Query failed".to_string())?;
|
||||||
|
let before_results: Vec<KnowledgeRelationship> =
|
||||||
|
existing_before_delete.take(0).unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
!before_results.is_empty(),
|
||||||
|
"Relationship should exist before deletion"
|
||||||
|
);
|
||||||
|
|
||||||
// Query to verify the relationship was deleted
|
KnowledgeRelationship::delete_relationship_by_id(&relationship_id, user_id, &db)
|
||||||
let query = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship.id);
|
.await
|
||||||
let mut result = db.query(query).await.expect("Query failed");
|
.with_context(|| "Failed to delete relationship by ID".to_string())?;
|
||||||
|
|
||||||
|
let mut result = db
|
||||||
|
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.bind(("source_id", source_id))
|
||||||
|
.await
|
||||||
|
.with_context(|| "Query failed".to_string())?;
|
||||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||||
|
|
||||||
// Verify the relationship no longer exists
|
|
||||||
assert!(results.is_empty(), "Relationship should be deleted");
|
assert!(results.is_empty(), "Relationship should be deleted");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_relationships_by_source_id() {
|
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let namespace = "test_ns";
|
let owner_user_id = "owner-user";
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
|
||||||
|
let entity2_id = create_test_entity("Entity 2", owner_user_id, &db).await?;
|
||||||
|
|
||||||
|
let source_id = "source123".to_string();
|
||||||
|
|
||||||
|
let relationship = KnowledgeRelationship::new(
|
||||||
|
entity1_id.clone(),
|
||||||
|
entity2_id.clone(),
|
||||||
|
owner_user_id.to_string(),
|
||||||
|
source_id,
|
||||||
|
"references".to_string(),
|
||||||
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
|
relationship
|
||||||
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to store relationship".to_string())?;
|
||||||
|
|
||||||
// Create entities to relate
|
let mut before_attempt = db
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
.bind(("user_id", owner_user_id.to_string()))
|
||||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
.await
|
||||||
|
.with_context(|| "Query failed".to_string())?;
|
||||||
|
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
!before_results.is_empty(),
|
||||||
|
"Relationship should exist before unauthorized delete attempt"
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = KnowledgeRelationship::delete_relationship_by_id(
|
||||||
|
&relationship_id,
|
||||||
|
"different-user",
|
||||||
|
&db,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Err(AppError::Auth(_)) => {}
|
||||||
|
_ => anyhow::bail!(
|
||||||
|
"Expected authorization error when deleting someone else's relationship"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut after_attempt = db
|
||||||
|
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
|
||||||
|
.bind(("user_id", owner_user_id.to_string()))
|
||||||
|
.await
|
||||||
|
.with_context(|| "Query failed".to_string())?;
|
||||||
|
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!results.is_empty(),
|
||||||
|
"Relationship should still exist after unauthorized delete attempt"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_store_relationship_exists() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
|
let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
|
||||||
|
|
||||||
// Create relationships with the same source_id
|
|
||||||
let user_id = "user123".to_string();
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let different_source_id = "different_source".to_string();
|
let different_source_id = "different_source".to_string();
|
||||||
|
|
||||||
// Create two relationships with the same source_id
|
|
||||||
let relationship1 = KnowledgeRelationship::new(
|
let relationship1 = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
user_id.clone(),
|
user_id.to_string(),
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
"references".to_string(),
|
"references".to_string(),
|
||||||
);
|
);
|
||||||
@@ -265,77 +475,170 @@ mod tests {
|
|||||||
let relationship2 = KnowledgeRelationship::new(
|
let relationship2 = KnowledgeRelationship::new(
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
entity3_id.clone(),
|
entity3_id.clone(),
|
||||||
user_id.clone(),
|
user_id.to_string(),
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
"contains".to_string(),
|
"contains".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create a relationship with a different source_id
|
|
||||||
let different_relationship = KnowledgeRelationship::new(
|
let different_relationship = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity3_id.clone(),
|
entity3_id.clone(),
|
||||||
user_id.clone(),
|
user_id.to_string(),
|
||||||
different_source_id.clone(),
|
different_source_id.clone(),
|
||||||
"mentions".to_string(),
|
"mentions".to_string(),
|
||||||
);
|
);
|
||||||
|
let relationship1_id = relationship1.id.clone();
|
||||||
|
let relationship2_id = relationship2.id.clone();
|
||||||
|
let different_relationship_id = different_relationship.id.clone();
|
||||||
|
|
||||||
// Store all relationships
|
|
||||||
relationship1
|
relationship1
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship 1");
|
.with_context(|| "Failed to store relationship 1".to_string())?;
|
||||||
relationship2
|
relationship2
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship 2");
|
.with_context(|| "Failed to store relationship 2".to_string())?;
|
||||||
different_relationship
|
different_relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store different relationship");
|
.with_context(|| "Failed to store different relationship".to_string())?;
|
||||||
|
|
||||||
// Delete relationships by source_id
|
let mut before_delete = db
|
||||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
|
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||||
|
.bind(("source_id", source_id.clone()))
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete relationships by source_id");
|
.expect("before delete query failed");
|
||||||
|
let before_delete_rows: Vec<KnowledgeRelationship> =
|
||||||
|
before_delete.take(0).unwrap_or_default();
|
||||||
|
assert_eq!(before_delete_rows.len(), 2);
|
||||||
|
|
||||||
// Query to verify the relationships with source_id were deleted
|
let mut before_delete_different = db
|
||||||
let query1 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship1.id);
|
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||||
let query2 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship2.id);
|
.bind(("source_id", different_source_id.clone()))
|
||||||
let different_query = format!(
|
.await
|
||||||
"SELECT * FROM relates_to WHERE id = '{}'",
|
.expect("before delete different query failed");
|
||||||
different_relationship.id
|
let before_delete_different_rows: Vec<KnowledgeRelationship> =
|
||||||
|
before_delete_different.take(0).unwrap_or_default();
|
||||||
|
assert_eq!(before_delete_different_rows.len(), 1);
|
||||||
|
|
||||||
|
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, user_id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
|
||||||
|
|
||||||
|
let result1 = get_relationship_by_id(&relationship1_id, &db).await;
|
||||||
|
let result2 = get_relationship_by_id(&relationship2_id, &db).await;
|
||||||
|
let different_result = get_relationship_by_id(&different_relationship_id, &db).await;
|
||||||
|
|
||||||
|
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
||||||
|
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
||||||
|
let remaining =
|
||||||
|
different_result.expect("Relationship with different source_id should remain");
|
||||||
|
assert_eq!(remaining.metadata.source_id, different_source_id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
|
let user_a = "user-a";
|
||||||
|
let user_b = "user-b";
|
||||||
|
let shared_source = "shared-source";
|
||||||
|
|
||||||
|
let a1 = create_test_entity("A1", user_a, &db).await?;
|
||||||
|
let a2 = create_test_entity("A2", user_a, &db).await?;
|
||||||
|
let b1 = create_test_entity("B1", user_b, &db).await?;
|
||||||
|
let b2 = create_test_entity("B2", user_b, &db).await?;
|
||||||
|
|
||||||
|
let rel_a = KnowledgeRelationship::new(
|
||||||
|
a1,
|
||||||
|
a2,
|
||||||
|
user_a.to_string(),
|
||||||
|
shared_source.to_string(),
|
||||||
|
"references".to_string(),
|
||||||
|
);
|
||||||
|
let rel_b = KnowledgeRelationship::new(
|
||||||
|
b1,
|
||||||
|
b2,
|
||||||
|
user_b.to_string(),
|
||||||
|
shared_source.to_string(),
|
||||||
|
"references".to_string(),
|
||||||
|
);
|
||||||
|
let owner_relationship_id = rel_a.id.clone();
|
||||||
|
let other_relationship_id = rel_b.id.clone();
|
||||||
|
|
||||||
|
rel_a.store_relationship(&db).await?;
|
||||||
|
rel_b.store_relationship(&db).await?;
|
||||||
|
|
||||||
|
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert!(get_relationship_by_id(&owner_relationship_id, &db)
|
||||||
|
.await
|
||||||
|
.is_none());
|
||||||
|
assert!(get_relationship_by_id(&other_relationship_id, &db)
|
||||||
|
.await
|
||||||
|
.is_some());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
|
||||||
|
{
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
|
let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
|
||||||
|
|
||||||
|
let safe_relationship = KnowledgeRelationship::new(
|
||||||
|
entity1_id.clone(),
|
||||||
|
entity2_id.clone(),
|
||||||
|
user_id.to_string(),
|
||||||
|
"safe_source".to_string(),
|
||||||
|
"references".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut result1 = db.query(query1).await.expect("Query 1 failed");
|
let other_relationship = KnowledgeRelationship::new(
|
||||||
let results1: Vec<KnowledgeRelationship> = result1.take(0).unwrap_or_default();
|
entity2_id,
|
||||||
|
entity3_id,
|
||||||
let mut result2 = db.query(query2).await.expect("Query 2 failed");
|
user_id.to_string(),
|
||||||
let results2: Vec<KnowledgeRelationship> = result2.take(0).unwrap_or_default();
|
"other_source".to_string(),
|
||||||
|
"contains".to_string(),
|
||||||
let mut different_result = db
|
|
||||||
.query(different_query)
|
|
||||||
.await
|
|
||||||
.expect("Different query failed");
|
|
||||||
let _different_results: Vec<KnowledgeRelationship> =
|
|
||||||
different_result.take(0).unwrap_or_default();
|
|
||||||
|
|
||||||
// Verify relationships with the source_id are deleted
|
|
||||||
assert!(results1.is_empty(), "Relationship 1 should be deleted");
|
|
||||||
assert!(results2.is_empty(), "Relationship 2 should be deleted");
|
|
||||||
|
|
||||||
// For the relationship with different source ID, we need to check differently
|
|
||||||
// Let's just verify we have a relationship where the source_id matches different_source_id
|
|
||||||
let check_query = format!(
|
|
||||||
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
|
|
||||||
different_source_id
|
|
||||||
);
|
);
|
||||||
let mut check_result = db.query(check_query).await.expect("Check query failed");
|
let safe_relationship_id = safe_relationship.id.clone();
|
||||||
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
let other_relationship_id = other_relationship.id.clone();
|
||||||
|
|
||||||
// Verify the relationship with a different source_id still exists
|
safe_relationship
|
||||||
|
.store_relationship(&db)
|
||||||
|
.await
|
||||||
|
.expect("store safe relationship");
|
||||||
|
other_relationship
|
||||||
|
.store_relationship(&db)
|
||||||
|
.await
|
||||||
|
.expect("store other relationship");
|
||||||
|
|
||||||
|
KnowledgeRelationship::delete_relationships_by_source_id(
|
||||||
|
"safe_source' OR 1=1 --",
|
||||||
|
user_id,
|
||||||
|
&db,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("delete call should succeed");
|
||||||
|
|
||||||
|
let remaining_safe = get_relationship_by_id(&safe_relationship_id, &db).await;
|
||||||
|
let remaining_other = get_relationship_by_id(&other_relationship_id, &db).await;
|
||||||
|
|
||||||
|
assert!(remaining_safe.is_some(), "Safe relationship should remain");
|
||||||
assert!(
|
assert!(
|
||||||
!check_results.is_empty(),
|
remaining_other.is_some(),
|
||||||
"Relationship with different source_id should still exist"
|
"Other relationship should remain"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
|
#![allow(clippy::module_name_repetitions)]
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone, Serialize, PartialEq)]
|
#[derive(Deserialize, Debug, Clone, Copy, Serialize, PartialEq)]
|
||||||
pub enum MessageRole {
|
pub enum MessageRole {
|
||||||
User,
|
User,
|
||||||
AI,
|
AI,
|
||||||
@@ -17,6 +21,7 @@ stored_object!(Message, "message", {
|
|||||||
});
|
});
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
|
#[must_use]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
conversation_id: String,
|
conversation_id: String,
|
||||||
role: MessageRole,
|
role: MessageRole,
|
||||||
@@ -53,22 +58,31 @@ impl fmt::Display for Message {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// helper function to format a vector of messages
|
// helper function to format a vector of messages
|
||||||
|
#[must_use]
|
||||||
pub fn format_history(history: &[Message]) -> String {
|
pub fn format_history(history: &[Message]) -> String {
|
||||||
history
|
let estimated: usize = history
|
||||||
.iter()
|
.iter()
|
||||||
.map(|msg| format!("{}", msg))
|
.map(|m| m.content.len().saturating_add(10))
|
||||||
.collect::<Vec<String>>()
|
.sum();
|
||||||
.join("\n")
|
let mut out = String::with_capacity(estimated);
|
||||||
|
for (i, msg) in history.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
out.push('\n');
|
||||||
|
}
|
||||||
|
let _ = write!(out, "{msg}");
|
||||||
|
}
|
||||||
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::db::SurrealDbClient;
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_creation() {
|
async fn test_message_creation() -> anyhow::Result<()> {
|
||||||
// Test basic message creation
|
|
||||||
let conversation_id = "test_conversation";
|
let conversation_id = "test_conversation";
|
||||||
let content = "This is a test message";
|
let content = "This is a test message";
|
||||||
let role = MessageRole::User;
|
let role = MessageRole::User;
|
||||||
@@ -76,29 +90,28 @@ mod tests {
|
|||||||
|
|
||||||
let message = Message::new(
|
let message = Message::new(
|
||||||
conversation_id.to_string(),
|
conversation_id.to_string(),
|
||||||
role.clone(),
|
role,
|
||||||
content.to_string(),
|
content.to_string(),
|
||||||
references.clone(),
|
references.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify message properties
|
|
||||||
assert_eq!(message.conversation_id, conversation_id);
|
assert_eq!(message.conversation_id, conversation_id);
|
||||||
assert_eq!(message.content, content);
|
assert_eq!(message.content, content);
|
||||||
assert_eq!(message.role, role);
|
assert_eq!(message.role, role);
|
||||||
assert_eq!(message.references, references);
|
assert_eq!(message.references, references);
|
||||||
assert!(!message.id.is_empty());
|
assert!(!message.id.is_empty());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_persistence() {
|
async fn test_message_persistence() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &uuid::Uuid::new_v4().to_string();
|
let database = &uuid::Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create and store a message
|
|
||||||
let conversation_id = "test_conversation";
|
let conversation_id = "test_conversation";
|
||||||
let message = Message::new(
|
let message = Message::new(
|
||||||
conversation_id.to_string(),
|
conversation_id.to_string(),
|
||||||
@@ -108,39 +121,37 @@ mod tests {
|
|||||||
);
|
);
|
||||||
let message_id = message.id.clone();
|
let message_id = message.id.clone();
|
||||||
|
|
||||||
// Store the message
|
|
||||||
db.store_item(message.clone())
|
db.store_item(message.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store message");
|
.with_context(|| "Failed to store message".to_string())?;
|
||||||
|
|
||||||
// Retrieve the message
|
|
||||||
let retrieved: Option<Message> = db
|
let retrieved: Option<Message> = db
|
||||||
.get_item(&message_id)
|
.get_item(&message_id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve message");
|
.with_context(|| "Failed to retrieve message".to_string())?;
|
||||||
|
|
||||||
assert!(retrieved.is_some());
|
let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected message to exist"))?;
|
||||||
let retrieved = retrieved.unwrap();
|
|
||||||
|
|
||||||
// Verify retrieved properties match original
|
|
||||||
assert_eq!(retrieved.id, message.id);
|
assert_eq!(retrieved.id, message.id);
|
||||||
assert_eq!(retrieved.conversation_id, message.conversation_id);
|
assert_eq!(retrieved.conversation_id, message.conversation_id);
|
||||||
assert_eq!(retrieved.role, message.role);
|
assert_eq!(retrieved.role, message.role);
|
||||||
assert_eq!(retrieved.content, message.content);
|
assert_eq!(retrieved.content, message.content);
|
||||||
assert_eq!(retrieved.references, message.references);
|
assert_eq!(retrieved.references, message.references);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_role_display() {
|
async fn test_message_role_display() -> anyhow::Result<()> {
|
||||||
// Test the Display implementation for MessageRole
|
|
||||||
assert_eq!(format!("{}", MessageRole::User), "User");
|
assert_eq!(format!("{}", MessageRole::User), "User");
|
||||||
assert_eq!(format!("{}", MessageRole::AI), "AI");
|
assert_eq!(format!("{}", MessageRole::AI), "AI");
|
||||||
assert_eq!(format!("{}", MessageRole::System), "System");
|
assert_eq!(format!("{}", MessageRole::System), "System");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_display() {
|
async fn test_message_display() -> anyhow::Result<()> {
|
||||||
// Test the Display implementation for Message
|
|
||||||
let message = Message {
|
let message = Message {
|
||||||
id: "test_id".to_string(),
|
id: "test_id".to_string(),
|
||||||
created_at: Utc::now(),
|
created_at: Utc::now(),
|
||||||
@@ -151,12 +162,13 @@ mod tests {
|
|||||||
references: None,
|
references: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(format!("{}", message), "User: Hello world");
|
assert_eq!(format!("{message}"), "User: Hello world");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_format_history() {
|
async fn test_format_history() -> anyhow::Result<()> {
|
||||||
// Create a vector of messages
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message {
|
||||||
id: "1".to_string(),
|
id: "1".to_string(),
|
||||||
@@ -178,10 +190,10 @@ mod tests {
|
|||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
// Format the history
|
|
||||||
let formatted = format_history(&messages);
|
let formatted = format_history(&messages);
|
||||||
|
|
||||||
// Verify the formatting
|
|
||||||
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
|
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::unsafe_derive_deserialize)]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
pub mod analytics;
|
pub mod analytics;
|
||||||
pub mod conversation;
|
pub mod conversation;
|
||||||
@@ -5,85 +6,35 @@ pub mod file_info;
|
|||||||
pub mod ingestion_payload;
|
pub mod ingestion_payload;
|
||||||
pub mod ingestion_task;
|
pub mod ingestion_task;
|
||||||
pub mod knowledge_entity;
|
pub mod knowledge_entity;
|
||||||
|
pub mod knowledge_entity_embedding;
|
||||||
pub mod knowledge_relationship;
|
pub mod knowledge_relationship;
|
||||||
pub mod message;
|
pub mod message;
|
||||||
|
pub mod scratchpad;
|
||||||
pub mod system_prompts;
|
pub mod system_prompts;
|
||||||
pub mod system_settings;
|
pub mod system_settings;
|
||||||
pub mod text_chunk;
|
pub mod text_chunk;
|
||||||
|
pub mod text_chunk_embedding;
|
||||||
pub mod text_content;
|
pub mod text_content;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
|
|
||||||
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
||||||
fn table_name() -> &'static str;
|
fn table_name() -> &'static str;
|
||||||
fn get_id(&self) -> &str;
|
fn id(&self) -> &str;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! stored_object {
|
macro_rules! stored_object {
|
||||||
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
|
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||||
use serde::{Deserialize, Deserializer, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use surrealdb::sql::Thing;
|
|
||||||
use $crate::storage::types::StoredObject;
|
use $crate::storage::types::StoredObject;
|
||||||
use serde::de::{self, Visitor};
|
#[allow(unused_imports)]
|
||||||
use std::fmt;
|
use $crate::utils::serde_helpers::{
|
||||||
|
deserialize_flexible_id, serialize_datetime, deserialize_datetime,
|
||||||
|
serialize_option_datetime, deserialize_option_datetime,
|
||||||
|
};
|
||||||
use chrono::{DateTime, Utc };
|
use chrono::{DateTime, Utc };
|
||||||
|
|
||||||
struct FlexibleIdVisitor;
|
$(#[$struct_attr])*
|
||||||
|
|
||||||
impl<'de> Visitor<'de> for FlexibleIdVisitor {
|
|
||||||
type Value = String;
|
|
||||||
|
|
||||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
formatter.write_str("a string or a Thing")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: de::Error,
|
|
||||||
{
|
|
||||||
Ok(value.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: de::Error,
|
|
||||||
{
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
|
|
||||||
where
|
|
||||||
A: de::MapAccess<'de>,
|
|
||||||
{
|
|
||||||
// Try to deserialize as Thing
|
|
||||||
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
|
|
||||||
Ok(thing.id.to_raw())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
deserializer.deserialize_any(FlexibleIdVisitor)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn serialize_datetime<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: serde::Serializer,
|
|
||||||
{
|
|
||||||
Into::<surrealdb::sql::Datetime>::into(*date).serialize(serializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
|
|
||||||
where
|
|
||||||
D: serde::Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let dt = surrealdb::sql::Datetime::deserialize(deserializer)?;
|
|
||||||
Ok(DateTime::<Utc>::from(dt))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct $name {
|
pub struct $name {
|
||||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||||
@@ -92,7 +43,7 @@ macro_rules! stored_object {
|
|||||||
pub created_at: DateTime<Utc>,
|
pub created_at: DateTime<Utc>,
|
||||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
|
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
|
||||||
pub updated_at: DateTime<Utc>,
|
pub updated_at: DateTime<Utc>,
|
||||||
$(pub $field: $ty),*
|
$( $(#[$field_attr])* pub $field: $ty),*
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StoredObject for $name {
|
impl StoredObject for $name {
|
||||||
@@ -100,7 +51,7 @@ macro_rules! stored_object {
|
|||||||
$table
|
$table
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_id(&self) -> &str {
|
fn id(&self) -> &str {
|
||||||
&self.id
|
&self.id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,547 @@
|
|||||||
|
use chrono::Utc as ChronoUtc;
|
||||||
|
use surrealdb::opt::PatchOp;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||||
|
|
||||||
|
stored_object!(Scratchpad, "scratchpad", {
|
||||||
|
user_id: String,
|
||||||
|
title: String,
|
||||||
|
content: String,
|
||||||
|
#[serde(serialize_with = "serialize_datetime", deserialize_with="deserialize_datetime")]
|
||||||
|
last_saved_at: DateTime<Utc>,
|
||||||
|
is_dirty: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
is_archived: bool,
|
||||||
|
#[serde(
|
||||||
|
serialize_with = "serialize_option_datetime",
|
||||||
|
deserialize_with = "deserialize_option_datetime",
|
||||||
|
default
|
||||||
|
)]
|
||||||
|
archived_at: Option<DateTime<Utc>>,
|
||||||
|
#[serde(
|
||||||
|
serialize_with = "serialize_option_datetime",
|
||||||
|
deserialize_with = "deserialize_option_datetime",
|
||||||
|
default
|
||||||
|
)]
|
||||||
|
ingested_at: Option<DateTime<Utc>>
|
||||||
|
});
|
||||||
|
|
||||||
|
impl Scratchpad {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(user_id: String, title: String) -> Self {
|
||||||
|
let now = ChronoUtc::now();
|
||||||
|
Self {
|
||||||
|
id: Uuid::new_v4().to_string(),
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
user_id,
|
||||||
|
title,
|
||||||
|
content: String::new(),
|
||||||
|
last_saved_at: now,
|
||||||
|
is_dirty: false,
|
||||||
|
is_archived: false,
|
||||||
|
archived_at: None,
|
||||||
|
ingested_at: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_by_user(user_id: &str, db: &SurrealDbClient) -> Result<Vec<Self>, AppError> {
|
||||||
|
let scratchpads: Vec<Scratchpad> = db.client
|
||||||
|
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id AND (is_archived = false OR is_archived IS NONE) ORDER BY updated_at DESC")
|
||||||
|
.bind(("table_name", Self::table_name()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.await?
|
||||||
|
.take(0)?;
|
||||||
|
|
||||||
|
Ok(scratchpads)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_archived_by_user(
|
||||||
|
user_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<Vec<Self>, AppError> {
|
||||||
|
let scratchpads: Vec<Scratchpad> = db.client
|
||||||
|
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id AND is_archived = true ORDER BY archived_at DESC, updated_at DESC")
|
||||||
|
.bind(("table_name", Self::table_name()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.await?
|
||||||
|
.take(0)?;
|
||||||
|
|
||||||
|
Ok(scratchpads)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_by_id(
|
||||||
|
id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<Self, AppError> {
|
||||||
|
let scratchpad: Option<Scratchpad> = db.get_item(id).await?;
|
||||||
|
|
||||||
|
let scratchpad =
|
||||||
|
scratchpad.ok_or_else(|| AppError::NotFound("scratchpad not found".to_string()))?;
|
||||||
|
|
||||||
|
if scratchpad.user_id != user_id {
|
||||||
|
return Err(AppError::Auth(
|
||||||
|
"You don't have access to this scratchpad".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(scratchpad)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update_content(
|
||||||
|
id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
new_content: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<Self, AppError> {
|
||||||
|
// First verify ownership
|
||||||
|
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||||
|
|
||||||
|
if scratchpad.is_archived {
|
||||||
|
return Ok(scratchpad);
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = ChronoUtc::now();
|
||||||
|
let _updated: Option<Self> = db
|
||||||
|
.update((Self::table_name(), id))
|
||||||
|
.patch(PatchOp::replace("/content", new_content.to_string()))
|
||||||
|
.patch(PatchOp::replace(
|
||||||
|
"/updated_at",
|
||||||
|
surrealdb::Datetime::from(now),
|
||||||
|
))
|
||||||
|
.patch(PatchOp::replace(
|
||||||
|
"/last_saved_at",
|
||||||
|
surrealdb::Datetime::from(now),
|
||||||
|
))
|
||||||
|
.patch(PatchOp::replace("/is_dirty", false))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Return the updated scratchpad
|
||||||
|
Self::get_by_id(id, user_id, db).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update_title(
|
||||||
|
id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
new_title: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
// First verify ownership
|
||||||
|
let _scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||||
|
|
||||||
|
let _updated: Option<Self> = db
|
||||||
|
.update((Self::table_name(), id))
|
||||||
|
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||||
|
.patch(PatchOp::replace(
|
||||||
|
"/updated_at",
|
||||||
|
surrealdb::Datetime::from(ChronoUtc::now()),
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete(id: &str, user_id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
|
||||||
|
// First verify ownership
|
||||||
|
let _scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||||
|
|
||||||
|
let _: Option<Self> = db.client.delete((Self::table_name(), id)).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn archive(
|
||||||
|
id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
mark_ingested: bool,
|
||||||
|
) -> Result<Self, AppError> {
|
||||||
|
// Verify ownership
|
||||||
|
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||||
|
|
||||||
|
if scratchpad.is_archived {
|
||||||
|
if mark_ingested && scratchpad.ingested_at.is_none() {
|
||||||
|
// Ensure ingested_at is set if required
|
||||||
|
let surreal_now = surrealdb::Datetime::from(ChronoUtc::now());
|
||||||
|
let _updated: Option<Self> = db
|
||||||
|
.update((Self::table_name(), id))
|
||||||
|
.patch(PatchOp::replace("/ingested_at", surreal_now))
|
||||||
|
.await?;
|
||||||
|
return Self::get_by_id(id, user_id, db).await;
|
||||||
|
}
|
||||||
|
return Ok(scratchpad);
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = ChronoUtc::now();
|
||||||
|
let surreal_now = surrealdb::Datetime::from(now);
|
||||||
|
let mut update = db
|
||||||
|
.update((Self::table_name(), id))
|
||||||
|
.patch(PatchOp::replace("/is_archived", true))
|
||||||
|
.patch(PatchOp::replace("/archived_at", surreal_now.clone()))
|
||||||
|
.patch(PatchOp::replace("/updated_at", surreal_now.clone()));
|
||||||
|
|
||||||
|
update = if mark_ingested {
|
||||||
|
update.patch(PatchOp::replace("/ingested_at", surreal_now))
|
||||||
|
} else {
|
||||||
|
update.patch(PatchOp::remove("/ingested_at"))
|
||||||
|
};
|
||||||
|
|
||||||
|
let _updated: Option<Self> = update.await?;
|
||||||
|
|
||||||
|
Self::get_by_id(id, user_id, db).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn restore(id: &str, user_id: &str, db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
|
// Verify ownership
|
||||||
|
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||||
|
|
||||||
|
if !scratchpad.is_archived {
|
||||||
|
return Ok(scratchpad);
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = ChronoUtc::now();
|
||||||
|
let surreal_now = surrealdb::Datetime::from(now);
|
||||||
|
let _updated: Option<Self> = db
|
||||||
|
.update((Self::table_name(), id))
|
||||||
|
.patch(PatchOp::replace("/is_archived", false))
|
||||||
|
.patch(PatchOp::remove("/archived_at"))
|
||||||
|
.patch(PatchOp::remove("/ingested_at"))
|
||||||
|
.patch(PatchOp::replace("/updated_at", surreal_now))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Self::get_by_id(id, user_id, db).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_scratchpad() -> anyhow::Result<()> {
|
||||||
|
// Setup in-memory database for testing
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
// Create a new scratchpad
|
||||||
|
let user_id = "test_user";
|
||||||
|
let title = "Test Scratchpad";
|
||||||
|
let scratchpad = Scratchpad::new(user_id.to_string(), title.to_string());
|
||||||
|
|
||||||
|
// Verify scratchpad properties
|
||||||
|
assert_eq!(scratchpad.user_id, user_id);
|
||||||
|
assert_eq!(scratchpad.title, title);
|
||||||
|
assert_eq!(scratchpad.content, "");
|
||||||
|
assert!(!scratchpad.is_dirty);
|
||||||
|
assert!(!scratchpad.is_archived);
|
||||||
|
assert!(scratchpad.archived_at.is_none());
|
||||||
|
assert!(scratchpad.ingested_at.is_none());
|
||||||
|
assert!(!scratchpad.id.is_empty());
|
||||||
|
|
||||||
|
// Store the scratchpad
|
||||||
|
let result = db.store_item(scratchpad.clone()).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify it can be retrieved
|
||||||
|
let retrieved: Option<Scratchpad> = db
|
||||||
|
.get_item(&scratchpad.id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to retrieve scratchpad".to_string())?;
|
||||||
|
let retrieved = retrieved.with_context(|| "expected scratchpad to exist".to_string())?;
|
||||||
|
assert_eq!(retrieved.id, scratchpad.id);
|
||||||
|
assert_eq!(retrieved.user_id, user_id);
|
||||||
|
assert_eq!(retrieved.title, title);
|
||||||
|
assert!(!retrieved.is_archived);
|
||||||
|
assert!(retrieved.archived_at.is_none());
|
||||||
|
assert!(retrieved.ingested_at.is_none());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_by_user() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let user_id = "test_user";
|
||||||
|
|
||||||
|
// Create multiple scratchpads
|
||||||
|
let scratchpad1 = Scratchpad::new(user_id.to_string(), "First".to_string());
|
||||||
|
let scratchpad2 = Scratchpad::new(user_id.to_string(), "Second".to_string());
|
||||||
|
let scratchpad3 = Scratchpad::new("other_user".to_string(), "Other".to_string());
|
||||||
|
|
||||||
|
// Store them
|
||||||
|
let scratchpad1_id = scratchpad1.id.clone();
|
||||||
|
let scratchpad2_id = scratchpad2.id.clone();
|
||||||
|
db.store_item(scratchpad1)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad1".to_string())?;
|
||||||
|
db.store_item(scratchpad2)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad2".to_string())?;
|
||||||
|
db.store_item(scratchpad3)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad3".to_string())?;
|
||||||
|
|
||||||
|
// Archive one of the user's scratchpads
|
||||||
|
Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
|
||||||
|
.await
|
||||||
|
.with_context(|| "archive".to_string())?;
|
||||||
|
|
||||||
|
// Get scratchpads for user_id
|
||||||
|
let user_scratchpads = Scratchpad::get_by_user(user_id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get_by_user".to_string())?;
|
||||||
|
assert_eq!(user_scratchpads.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
user_scratchpads.first().map(|s| &s.id),
|
||||||
|
Some(&scratchpad1_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify they belong to the user
|
||||||
|
for scratchpad in &user_scratchpads {
|
||||||
|
assert_eq!(scratchpad.user_id, user_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
let archived = Scratchpad::get_archived_by_user(user_id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get_archived_by_user".to_string())?;
|
||||||
|
assert_eq!(archived.len(), 1);
|
||||||
|
assert_eq!(archived.first().map(|s| &s.id), Some(&scratchpad2_id));
|
||||||
|
assert!(archived.first().is_some_and(|s| s.is_archived));
|
||||||
|
assert!(archived.first().is_some_and(|s| s.ingested_at.is_none()));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_archive_and_restore() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let user_id = "test_user";
|
||||||
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
|
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to archive".to_string())?;
|
||||||
|
assert!(archived.is_archived);
|
||||||
|
assert!(archived.archived_at.is_some());
|
||||||
|
assert!(archived.ingested_at.is_some());
|
||||||
|
|
||||||
|
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to restore".to_string())?;
|
||||||
|
assert!(!restored.is_archived);
|
||||||
|
assert!(restored.archived_at.is_none());
|
||||||
|
assert!(restored.ingested_at.is_none());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_content() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let user_id = "test_user";
|
||||||
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
|
let new_content = "Updated content";
|
||||||
|
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "update_content".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(updated.content, new_content);
|
||||||
|
assert!(!updated.is_dirty);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let owner_id = "owner";
|
||||||
|
let other_user = "other_user";
|
||||||
|
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||||
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
|
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::Auth(_)) => {}
|
||||||
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_scratchpad() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let user_id = "test_user";
|
||||||
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
|
// Delete should succeed
|
||||||
|
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify it's gone
|
||||||
|
let retrieved: Option<Scratchpad> = db
|
||||||
|
.get_item(&scratchpad_id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get_item".to_string())?;
|
||||||
|
assert!(retrieved.is_none());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_unauthorized() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let owner_id = "owner";
|
||||||
|
let other_user = "other_user";
|
||||||
|
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||||
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
|
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::Auth(_)) => {}
|
||||||
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it still exists
|
||||||
|
let retrieved: Option<Scratchpad> = db
|
||||||
|
.get_item(&scratchpad_id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get_item".to_string())?;
|
||||||
|
assert!(retrieved.is_some());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to create test database".to_string())?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let user_id = "test_user_123";
|
||||||
|
let scratchpad =
|
||||||
|
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
|
||||||
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
|
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get_by_id".to_string())?;
|
||||||
|
|
||||||
|
// Test that datetime fields are preserved and can be used for timezone formatting
|
||||||
|
assert!(retrieved.created_at.timestamp() > 0);
|
||||||
|
assert!(retrieved.updated_at.timestamp() > 0);
|
||||||
|
assert!(retrieved.last_saved_at.timestamp() > 0);
|
||||||
|
|
||||||
|
// Test that optional datetime fields work correctly
|
||||||
|
assert!(retrieved.archived_at.is_none());
|
||||||
|
assert!(retrieved.ingested_at.is_none());
|
||||||
|
|
||||||
|
// Archive the scratchpad to test optional datetime handling
|
||||||
|
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
|
||||||
|
.await
|
||||||
|
.with_context(|| "archive".to_string())?;
|
||||||
|
|
||||||
|
assert!(archived.archived_at.is_some());
|
||||||
|
assert!(
|
||||||
|
archived
|
||||||
|
.archived_at
|
||||||
|
.with_context(|| "expected archived_at".to_string())?
|
||||||
|
.timestamp()
|
||||||
|
> 0
|
||||||
|
);
|
||||||
|
assert!(archived.ingested_at.is_none());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
pub static DEFAULT_QUERY_SYSTEM_PROMPT: &str = r#"You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
|
pub const DEFAULT_QUERY_SYSTEM_PROMPT: &str = r#"You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
|
||||||
|
|
||||||
Your task is to:
|
Your task is to:
|
||||||
1. Carefully analyze the provided knowledge entities in the context
|
1. Carefully analyze the provided knowledge entities in the context
|
||||||
@@ -20,7 +20,7 @@ Example response formats:
|
|||||||
"I found relevant information in multiple entries: [explanation...]"
|
"I found relevant information in multiple entries: [explanation...]"
|
||||||
"I apologize, but the provided context doesn't contain information about [topic]""#;
|
"I apologize, but the provided context doesn't contain information about [topic]""#;
|
||||||
|
|
||||||
pub static DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT: &str = r#"You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
|
pub const DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT: &str = r#"You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
|
||||||
|
|
||||||
The JSON should have the following structure:
|
The JSON should have the following structure:
|
||||||
|
|
||||||
@@ -49,13 +49,13 @@ Guidelines:
|
|||||||
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
|
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
|
||||||
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
|
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
|
||||||
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
|
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
|
||||||
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
|
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity.
|
||||||
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
|
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
|
||||||
7. Only create relationships between existing KnowledgeEntities.
|
7. Only create relationships between existing KnowledgeEntities.
|
||||||
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
|
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
|
||||||
9. A new relationship MUST include a newly created KnowledgeEntity."#;
|
9. A new relationship MUST include a newly created KnowledgeEntity."#;
|
||||||
|
|
||||||
pub static DEFAULT_IMAGE_PROCESSING_PROMPT: &str = r#"Analyze this image and respond based on its primary content:
|
pub const DEFAULT_IMAGE_PROCESSING_PROMPT: &str = r#"Analyze this image and respond based on its primary content:
|
||||||
- If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.
|
- If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.
|
||||||
- If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.
|
- If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.
|
||||||
- For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a "Text:" heading.
|
- For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a "Text:" heading.
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::storage::types::file_info::deserialize_flexible_id;
|
use crate::utils::config::EmbeddingBackend;
|
||||||
|
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient, storage::types::StoredObject};
|
use crate::{error::AppError, storage::db::SurrealDbClient, storage::types::StoredObject};
|
||||||
@@ -13,6 +14,9 @@ pub struct SystemSettings {
|
|||||||
pub processing_model: String,
|
pub processing_model: String,
|
||||||
pub embedding_model: String,
|
pub embedding_model: String,
|
||||||
pub embedding_dimensions: u32,
|
pub embedding_dimensions: u32,
|
||||||
|
/// Active embedding backend. Read-only for admin updates; synced from config at startup.
|
||||||
|
#[serde(default)]
|
||||||
|
pub embedding_backend: Option<EmbeddingBackend>,
|
||||||
pub query_system_prompt: String,
|
pub query_system_prompt: String,
|
||||||
pub ingestion_system_prompt: String,
|
pub ingestion_system_prompt: String,
|
||||||
pub image_processing_model: String,
|
pub image_processing_model: String,
|
||||||
@@ -20,85 +24,334 @@ pub struct SystemSettings {
|
|||||||
pub voice_processing_model: String,
|
pub voice_processing_model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Partial update for singleton system settings without cloning unchanged fields.
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
|
pub struct SystemSettingsPatch {
|
||||||
|
pub registrations_enabled: Option<bool>,
|
||||||
|
pub require_email_verification: Option<bool>,
|
||||||
|
pub query_model: Option<String>,
|
||||||
|
pub processing_model: Option<String>,
|
||||||
|
pub embedding_model: Option<String>,
|
||||||
|
pub embedding_dimensions: Option<u32>,
|
||||||
|
pub query_system_prompt: Option<String>,
|
||||||
|
pub ingestion_system_prompt: Option<String>,
|
||||||
|
pub image_processing_model: Option<String>,
|
||||||
|
pub image_processing_prompt: Option<String>,
|
||||||
|
pub voice_processing_model: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum UpdateMode {
|
||||||
|
User,
|
||||||
|
EmbeddingSync,
|
||||||
|
}
|
||||||
|
|
||||||
impl StoredObject for SystemSettings {
|
impl StoredObject for SystemSettings {
|
||||||
fn table_name() -> &'static str {
|
fn table_name() -> &'static str {
|
||||||
"system_settings"
|
"system_settings"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_id(&self) -> &str {
|
fn id(&self) -> &str {
|
||||||
&self.id
|
&self.id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl SystemSettingsPatch {
|
||||||
|
pub fn apply_to(self, settings: &mut SystemSettings) {
|
||||||
|
if let Some(value) = self.registrations_enabled {
|
||||||
|
settings.registrations_enabled = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.require_email_verification {
|
||||||
|
settings.require_email_verification = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.query_model {
|
||||||
|
settings.query_model = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.processing_model {
|
||||||
|
settings.processing_model = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.embedding_model {
|
||||||
|
settings.embedding_model = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.embedding_dimensions {
|
||||||
|
settings.embedding_dimensions = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.query_system_prompt {
|
||||||
|
settings.query_system_prompt = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.ingestion_system_prompt {
|
||||||
|
settings.ingestion_system_prompt = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.image_processing_model {
|
||||||
|
settings.image_processing_model = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.image_processing_prompt {
|
||||||
|
settings.image_processing_prompt = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.voice_processing_model {
|
||||||
|
settings.voice_processing_model = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn apply(self, db: &SurrealDbClient) -> Result<SystemSettings, AppError> {
|
||||||
|
let mut current = SystemSettings::get_current(db).await?;
|
||||||
|
self.apply_to(&mut current);
|
||||||
|
SystemSettings::update(db, current).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl SystemSettings {
|
impl SystemSettings {
|
||||||
|
pub const RECORD_ID: &'static str = "current";
|
||||||
|
|
||||||
|
#[allow(clippy::result_large_err)]
|
||||||
|
fn validate(&self) -> Result<(), AppError> {
|
||||||
|
if self.embedding_dimensions == 0 {
|
||||||
|
return Err(AppError::Validation(
|
||||||
|
"embedding_dimensions must be greater than 0".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let model_fields = [
|
||||||
|
("query_model", &self.query_model),
|
||||||
|
("processing_model", &self.processing_model),
|
||||||
|
("embedding_model", &self.embedding_model),
|
||||||
|
("image_processing_model", &self.image_processing_model),
|
||||||
|
("voice_processing_model", &self.voice_processing_model),
|
||||||
|
];
|
||||||
|
for (name, value) in model_fields {
|
||||||
|
if value.trim().is_empty() {
|
||||||
|
return Err(AppError::Validation(format!("{name} must not be empty")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let prompt_fields = [
|
||||||
|
("query_system_prompt", &self.query_system_prompt),
|
||||||
|
("ingestion_system_prompt", &self.ingestion_system_prompt),
|
||||||
|
("image_processing_prompt", &self.image_processing_prompt),
|
||||||
|
];
|
||||||
|
for (name, value) in prompt_fields {
|
||||||
|
if value.trim().is_empty() {
|
||||||
|
return Err(AppError::Validation(format!("{name} must not be empty")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let settings: Option<Self> = db.get_item("current").await?;
|
let settings: Option<Self> = db.get_item(Self::RECORD_ID).await?;
|
||||||
settings.ok_or(AppError::NotFound("System settings not found".into()))
|
settings.ok_or(AppError::NotFound("system settings not found".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
|
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
|
||||||
// We need to use a direct query for the update with MERGE
|
Self::update_with_mode(db, changes, UpdateMode::User).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_with_mode(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
mut changes: Self,
|
||||||
|
mode: UpdateMode,
|
||||||
|
) -> Result<Self, AppError> {
|
||||||
|
let current = Self::get_current(db).await?;
|
||||||
|
if matches!(mode, UpdateMode::User) {
|
||||||
|
changes.embedding_backend = current.embedding_backend;
|
||||||
|
}
|
||||||
|
changes.id = Self::RECORD_ID.to_string();
|
||||||
|
changes.validate()?;
|
||||||
|
|
||||||
let updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.client
|
.client
|
||||||
.query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER")
|
.query("UPDATE type::thing('system_settings', $id) MERGE $changes RETURN AFTER")
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
.bind(("changes", changes))
|
.bind(("changes", changes))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
updated.ok_or(AppError::Validation(
|
updated.ok_or(AppError::NotFound(
|
||||||
"Something went wrong updating the settings".into(),
|
"system settings record missing after update".into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Syncs SystemSettings with the active embedding provider's properties.
|
||||||
|
/// Updates embedding_backend, embedding_model, and embedding_dimensions if they differ.
|
||||||
|
/// Returns true if any settings were changed.
|
||||||
|
pub async fn sync_from_embedding_provider(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
provider: &crate::utils::embedding::EmbeddingProvider,
|
||||||
|
) -> Result<(Self, bool), AppError> {
|
||||||
|
let mut settings = Self::get_current(db).await?;
|
||||||
|
let mut needs_update = false;
|
||||||
|
|
||||||
|
let provider_backend = provider
|
||||||
|
.backend_label()
|
||||||
|
.parse::<EmbeddingBackend>()
|
||||||
|
.map_err(|e| AppError::Validation(e.to_string()))?;
|
||||||
|
let provider_dimensions = u32::try_from(provider.dimension()).map_err(|_| {
|
||||||
|
AppError::Validation(format!(
|
||||||
|
"embedding provider dimension {} exceeds u32::MAX",
|
||||||
|
provider.dimension()
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
let provider_model = provider.model_code();
|
||||||
|
|
||||||
|
if settings.embedding_backend != Some(provider_backend) {
|
||||||
|
settings.embedding_backend = Some(provider_backend);
|
||||||
|
needs_update = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.embedding_dimensions != provider_dimensions {
|
||||||
|
tracing::info!(
|
||||||
|
old_dimensions = settings.embedding_dimensions,
|
||||||
|
new_dimensions = provider_dimensions,
|
||||||
|
"Embedding dimensions changed, updating SystemSettings"
|
||||||
|
);
|
||||||
|
settings.embedding_dimensions = provider_dimensions;
|
||||||
|
needs_update = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(model) = provider_model {
|
||||||
|
if settings.embedding_model != model {
|
||||||
|
tracing::info!(
|
||||||
|
old_model = %settings.embedding_model,
|
||||||
|
new_model = %model,
|
||||||
|
"Embedding model changed, updating SystemSettings"
|
||||||
|
);
|
||||||
|
settings.embedding_model = model;
|
||||||
|
needs_update = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needs_update {
|
||||||
|
settings = Self::update_with_mode(db, settings, UpdateMode::EmbeddingSync).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((settings, needs_update))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::storage::types::text_chunk::TextChunk;
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use crate::storage::indexes::ensure_runtime;
|
||||||
|
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
async fn get_hnsw_index_dimension(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
table_name: &str,
|
||||||
|
index_name: &str,
|
||||||
|
) -> anyhow::Result<u32> {
|
||||||
|
let query = format!("INFO FOR TABLE {table_name};");
|
||||||
|
let mut response = db
|
||||||
|
.client
|
||||||
|
.query(query)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to fetch table info".to_string())?;
|
||||||
|
|
||||||
|
let info: surrealdb::Value = response
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "Failed to extract table info response".to_string())?;
|
||||||
|
|
||||||
|
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||||
|
.with_context(|| "Failed to convert info to json".to_string())?;
|
||||||
|
|
||||||
|
let indexes = info_json
|
||||||
|
.get("Object")
|
||||||
|
.and_then(|v| v.get("indexes"))
|
||||||
|
.and_then(|v| v.get("Object"))
|
||||||
|
.and_then(|v| v.as_object())
|
||||||
|
.with_context(|| format!("Indexes collection missing in table info: {info_json:#?}"))?;
|
||||||
|
|
||||||
|
let definition = indexes
|
||||||
|
.get(index_name)
|
||||||
|
.and_then(|definition| definition.get("Strand"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.with_context(|| format!("Index definition not found in table info: {info_json:#?}"))?;
|
||||||
|
|
||||||
|
let dimension_part = definition
|
||||||
|
.split("DIMENSION")
|
||||||
|
.nth(1)
|
||||||
|
.with_context(|| "Index definition missing DIMENSION clause".to_string())?;
|
||||||
|
|
||||||
|
let dimension_token = dimension_part
|
||||||
|
.split_whitespace()
|
||||||
|
.next()
|
||||||
|
.with_context(|| "Dimension value missing in definition".to_string())?
|
||||||
|
.trim_end_matches(';');
|
||||||
|
|
||||||
|
dimension_token
|
||||||
|
.parse::<u32>()
|
||||||
|
.with_context(|| "Dimension value is not a valid number".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn simulate_reembedding(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
target_dimension: usize,
|
||||||
|
initial_chunk: TextChunk,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
db.query(
|
||||||
|
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.with_context(|| "remove index".to_string())?;
|
||||||
|
let define_index_query = format!(
|
||||||
|
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {target_dimension};"
|
||||||
|
);
|
||||||
|
db.query(define_index_query)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Re-defining index should succeed".to_string())?;
|
||||||
|
|
||||||
|
let new_embedding = vec![0.5; target_dimension];
|
||||||
|
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
||||||
|
|
||||||
|
db.client
|
||||||
|
.query(sql)
|
||||||
|
.bind(("id", initial_chunk.id.clone()))
|
||||||
|
.bind(("user_id", initial_chunk.user_id.clone()))
|
||||||
|
.bind(("embedding", new_embedding))
|
||||||
|
.await
|
||||||
|
.with_context(|| "upsert embedding".to_string())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_settings_initialization() {
|
async fn test_settings_initialization() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Test initialization of system settings
|
// Test initialization of system settings
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
let settings = SystemSettings::get_current(&db)
|
let settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get system settings");
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
|
||||||
// Verify initial state after initialization
|
// Verify initial state after initialization
|
||||||
assert_eq!(settings.id, "current");
|
assert_eq!(settings.id, "current");
|
||||||
assert_eq!(settings.registrations_enabled, true);
|
assert!(settings.registrations_enabled);
|
||||||
assert_eq!(settings.require_email_verification, false);
|
assert!(!settings.require_email_verification);
|
||||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||||
assert_eq!(settings.image_processing_model, "gpt-4o-mini");
|
assert_eq!(settings.image_processing_model, "gpt-4o-mini");
|
||||||
// Dont test these for now, having a hard time getting the formatting exactly the same
|
assert!(!settings.ingestion_system_prompt.contains("entity\"\n6."));
|
||||||
// assert_eq!(
|
assert!(settings.ingestion_system_prompt.contains("related entity."));
|
||||||
// settings.query_system_prompt,
|
|
||||||
// crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
|
||||||
// );
|
|
||||||
// assert_eq!(
|
|
||||||
// settings.ingestion_system_prompt,
|
|
||||||
// crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
|
||||||
// );
|
|
||||||
|
|
||||||
// Test idempotency - ensure calling it again doesn't change anything
|
// Test idempotency - ensure calling it again doesn't change anything
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
let settings_again = SystemSettings::get_current(&db)
|
let settings_again = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get settings after initialization");
|
.with_context(|| "Failed to get settings after initialization".to_string())?;
|
||||||
|
|
||||||
assert_eq!(settings.id, settings_again.id);
|
assert_eq!(settings.id, settings_again.id);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -109,49 +362,52 @@ mod tests {
|
|||||||
settings.require_email_verification,
|
settings.require_email_verification,
|
||||||
settings_again.require_email_verification
|
settings_again.require_email_verification
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_settings() {
|
async fn test_get_current_settings() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Initialize settings
|
// Initialize settings
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
// Test get_current method
|
// Test get_current method
|
||||||
let settings = SystemSettings::get_current(&db)
|
let settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get current settings");
|
.with_context(|| "Failed to get current settings".to_string())?;
|
||||||
|
|
||||||
assert_eq!(settings.id, "current");
|
assert_eq!(settings.id, "current");
|
||||||
assert_eq!(settings.registrations_enabled, true);
|
assert!(settings.registrations_enabled);
|
||||||
assert_eq!(settings.require_email_verification, false);
|
assert!(!settings.require_email_verification);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_settings() {
|
async fn test_update_settings() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Initialize settings
|
// Initialize settings
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
// Create updated settings
|
// Create updated settings
|
||||||
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
|
let mut updated_settings = SystemSettings::get_current(&db)
|
||||||
updated_settings.id = "current".to_string();
|
.await
|
||||||
|
.with_context(|| "get_current".to_string())?;
|
||||||
updated_settings.registrations_enabled = false;
|
updated_settings.registrations_enabled = false;
|
||||||
updated_settings.require_email_verification = true;
|
updated_settings.require_email_verification = true;
|
||||||
updated_settings.query_model = "gpt-4".to_string();
|
updated_settings.query_model = "gpt-4".to_string();
|
||||||
@@ -159,31 +415,32 @@ mod tests {
|
|||||||
// Test update method
|
// Test update method
|
||||||
let result = SystemSettings::update(&db, updated_settings)
|
let result = SystemSettings::update(&db, updated_settings)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to update settings");
|
.with_context(|| "Failed to update settings".to_string())?;
|
||||||
|
|
||||||
assert_eq!(result.id, "current");
|
assert_eq!(result.id, "current");
|
||||||
assert_eq!(result.registrations_enabled, false);
|
assert!(!result.registrations_enabled);
|
||||||
assert_eq!(result.require_email_verification, true);
|
assert!(result.require_email_verification);
|
||||||
assert_eq!(result.query_model, "gpt-4");
|
assert_eq!(result.query_model, "gpt-4");
|
||||||
|
|
||||||
// Verify changes persisted by getting current settings
|
// Verify changes persisted by getting current settings
|
||||||
let current = SystemSettings::get_current(&db)
|
let current = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get current settings after update");
|
.with_context(|| "Failed to get current settings after update".to_string())?;
|
||||||
|
|
||||||
assert_eq!(current.registrations_enabled, false);
|
assert!(!current.registrations_enabled);
|
||||||
assert_eq!(current.require_email_verification, true);
|
assert!(current.require_email_verification);
|
||||||
assert_eq!(current.query_model, "gpt-4");
|
assert_eq!(current.query_model, "gpt-4");
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_nonexistent() {
|
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Don't initialize settings and try to get them
|
// Don't initialize settings and try to get them
|
||||||
let result = SystemSettings::get_current(&db).await;
|
let result = SystemSettings::get_current(&db).await;
|
||||||
@@ -193,66 +450,356 @@ mod tests {
|
|||||||
Err(AppError::NotFound(_)) => {
|
Err(AppError::NotFound(_)) => {
|
||||||
// Expected error
|
// Expected error
|
||||||
}
|
}
|
||||||
Err(e) => panic!("Expected NotFound error, got: {:?}", e),
|
Err(e) => anyhow::bail!("Expected NotFound error, got: {e:?}"),
|
||||||
Ok(_) => panic!("Expected error but got Ok"),
|
Ok(_) => anyhow::bail!("Expected error but got Ok"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_migration_after_changing_embedding_length() {
|
async fn test_update_rejects_zero_embedding_dimensions() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
invalid_settings.embedding_dimensions = 0;
|
||||||
|
|
||||||
|
let result = SystemSettings::update(&db, invalid_settings).await;
|
||||||
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_patch_updates_without_cloning_full_settings() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let updated = SystemSettingsPatch {
|
||||||
|
registrations_enabled: Some(false),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.apply(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to patch settings".to_string())?;
|
||||||
|
|
||||||
|
assert!(!updated.registrations_enabled);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_patch_leaves_unmentioned_fields_unchanged() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let original = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
let sentinel = "custom-query-prompt-sentinel".to_string();
|
||||||
|
|
||||||
|
let patched = SystemSettingsPatch {
|
||||||
|
query_system_prompt: Some(sentinel.clone()),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.apply(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to patch query prompt".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(patched.query_system_prompt, sentinel);
|
||||||
|
assert_eq!(
|
||||||
|
patched.ingestion_system_prompt,
|
||||||
|
original.ingestion_system_prompt
|
||||||
|
);
|
||||||
|
assert_eq!(patched.query_model, original.query_model);
|
||||||
|
assert_eq!(
|
||||||
|
patched.registrations_enabled,
|
||||||
|
original.registrations_enabled
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_rejects_empty_model_name() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
invalid_settings.query_model = " ".to_string();
|
||||||
|
|
||||||
|
let result = SystemSettings::update(&db, invalid_settings).await;
|
||||||
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_normalizes_record_id() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let mut settings = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
settings.id = "wrong-id".to_string();
|
||||||
|
|
||||||
|
let updated = SystemSettings::update(&db, settings)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to update settings".to_string())?;
|
||||||
|
assert_eq!(updated.id, SystemSettings::RECORD_ID);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_preserves_embedding_backend() -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
|
SystemSettings::sync_from_embedding_provider(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to sync embedding provider".to_string())?;
|
||||||
|
|
||||||
|
let synced = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get synced settings".to_string())?;
|
||||||
|
assert_eq!(synced.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
|
||||||
|
let mut tampered = synced;
|
||||||
|
tampered.embedding_backend = Some(EmbeddingBackend::OpenAI);
|
||||||
|
let updated = SystemSettings::update(&db, tampered)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to update settings".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(updated.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sync_from_embedding_provider_updates_mismatched_settings() -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
|
let (settings, changed) = SystemSettings::sync_from_embedding_provider(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to sync embedding provider".to_string())?;
|
||||||
|
|
||||||
|
assert!(changed);
|
||||||
|
assert_eq!(settings.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
assert_eq!(settings.embedding_dimensions, 384);
|
||||||
|
|
||||||
|
let persisted = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to reload synced settings".to_string())?;
|
||||||
|
assert_eq!(persisted.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
assert_eq!(persisted.embedding_dimensions, 384);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sync_from_embedding_provider_is_noop_when_already_synced() -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
|
SystemSettings::sync_from_embedding_provider(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to initial sync".to_string())?;
|
||||||
|
|
||||||
|
let (_, changed) = SystemSettings::sync_from_embedding_provider(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to repeat sync".to_string())?;
|
||||||
|
assert!(!changed);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sync_rejects_provider_dimension_above_u32_max() -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed((u32::MAX as usize) + 1)
|
||||||
|
.with_context(|| "Failed to create oversized hashed provider".to_string())?;
|
||||||
|
let result = SystemSettings::sync_from_embedding_provider(&db, &provider).await;
|
||||||
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start DB");
|
.with_context(|| "Failed to start DB".to_string())?;
|
||||||
|
|
||||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Initial migration failed");
|
.with_context(|| "Initial migration failed".to_string())?;
|
||||||
|
|
||||||
let initial_chunk = TextChunk::new(
|
let initial_chunk = TextChunk::new(
|
||||||
"source1".into(),
|
"source1".into(),
|
||||||
"This chunk has the original dimension".into(),
|
"This chunk has the original dimension".into(),
|
||||||
vec![0.1; 1536],
|
|
||||||
"user1".into(),
|
"user1".into(),
|
||||||
);
|
);
|
||||||
|
|
||||||
db.store_item(initial_chunk.clone())
|
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store initial chunk");
|
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
|
||||||
|
|
||||||
async fn simulate_reembedding(
|
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
|
||||||
db: &SurrealDbClient,
|
let target_dimension = 1536usize;
|
||||||
target_dimension: usize,
|
simulate_reembedding(&db, target_dimension, initial_chunk).await?;
|
||||||
initial_chunk: TextChunk,
|
|
||||||
) {
|
|
||||||
db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;")
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let define_index_query = format!(
|
|
||||||
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
|
|
||||||
target_dimension
|
|
||||||
);
|
|
||||||
db.query(define_index_query)
|
|
||||||
.await
|
|
||||||
.expect("Re-defining index should succeed");
|
|
||||||
|
|
||||||
let new_embedding = vec![0.5; target_dimension];
|
|
||||||
let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;";
|
|
||||||
|
|
||||||
let update_result = db
|
|
||||||
.client
|
|
||||||
.query(sql)
|
|
||||||
.bind(("id", initial_chunk.id.clone()))
|
|
||||||
.bind(("embedding", new_embedding))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
assert!(update_result.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
simulate_reembedding(&db, 768, initial_chunk).await;
|
|
||||||
|
|
||||||
let migration_result = db.apply_migrations().await;
|
let migration_result = db.apply_migrations().await;
|
||||||
|
|
||||||
assert!(migration_result.is_ok(), "Migrations should not fail");
|
assert!(
|
||||||
|
migration_result.is_ok(),
|
||||||
|
"Migrations should not fail: {:?}",
|
||||||
|
migration_result.err()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_should_change_embedding_length_on_indexes_when_switching_length(
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start DB".to_string())?;
|
||||||
|
|
||||||
|
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Initial migration failed".to_string())?;
|
||||||
|
|
||||||
|
let mut current_settings = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to load current settings".to_string())?;
|
||||||
|
|
||||||
|
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
|
||||||
|
ensure_runtime(&db, current_settings.embedding_dimensions as usize)
|
||||||
|
.await
|
||||||
|
.with_context(|| "failed to build runtime indexes".to_string())?;
|
||||||
|
|
||||||
|
let initial_chunk_dimension = get_hnsw_index_dimension(
|
||||||
|
&db,
|
||||||
|
"text_chunk_embedding",
|
||||||
|
"idx_embedding_text_chunk_embedding",
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
initial_chunk_dimension, current_settings.embedding_dimensions,
|
||||||
|
"embedding size should match initial system settings"
|
||||||
|
);
|
||||||
|
|
||||||
|
let new_dimension = 768;
|
||||||
|
let new_model = "new-test-embedding-model".to_string();
|
||||||
|
|
||||||
|
current_settings.embedding_dimensions = new_dimension;
|
||||||
|
current_settings.embedding_model = new_model.clone();
|
||||||
|
|
||||||
|
let updated_settings = SystemSettings::update(&db, current_settings)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to update settings".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
updated_settings.embedding_dimensions, new_dimension,
|
||||||
|
"Settings should reflect the new embedding dimension"
|
||||||
|
);
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed(new_dimension as usize)
|
||||||
|
.map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||||
|
|
||||||
|
TextChunk::update_all_embeddings(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?;
|
||||||
|
KnowledgeEntity::update_all_embeddings(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| {
|
||||||
|
"KnowledgeEntity re-embedding should succeed on fresh DB".to_string()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let text_chunk_dimension = get_hnsw_index_dimension(
|
||||||
|
&db,
|
||||||
|
"text_chunk_embedding",
|
||||||
|
"idx_embedding_text_chunk_embedding",
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let knowledge_dimension = get_hnsw_index_dimension(
|
||||||
|
&db,
|
||||||
|
"knowledge_entity_embedding",
|
||||||
|
"idx_embedding_knowledge_entity_embedding",
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
text_chunk_dimension, new_dimension,
|
||||||
|
"text_chunk index dimension should update"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
knowledge_dimension, new_dimension,
|
||||||
|
"knowledge_entity index dimension should update"
|
||||||
|
);
|
||||||
|
|
||||||
|
let persisted_settings = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to reload updated settings".to_string())?;
|
||||||
|
assert_eq!(
|
||||||
|
persisted_settings.embedding_dimensions, new_dimension,
|
||||||
|
"Settings should persist new embedding dimension"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,431 @@
|
|||||||
|
use surrealdb::RecordId;
|
||||||
|
|
||||||
|
use crate::storage::types::text_chunk::TextChunk;
|
||||||
|
use crate::{
|
||||||
|
error::AppError,
|
||||||
|
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||||
|
stored_object,
|
||||||
|
};
|
||||||
|
|
||||||
|
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||||
|
/// Record link to the owning text_chunk
|
||||||
|
chunk_id: RecordId,
|
||||||
|
/// Denormalized source id for bulk deletes
|
||||||
|
source_id: String,
|
||||||
|
/// Embedding vector
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
/// Denormalized user id (for scoping + permissions)
|
||||||
|
user_id: String
|
||||||
|
});
|
||||||
|
|
||||||
|
impl TextChunkEmbedding {
|
||||||
|
/// Recreate the HNSW index with a new embedding dimension.
|
||||||
|
///
|
||||||
|
/// This is useful when the embedding length changes; Surreal requires the
|
||||||
|
/// index definition to be recreated with the updated dimension.
|
||||||
|
pub async fn redefine_hnsw_index(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
dimension: usize,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let query = hnsw_index_redefine_transaction_sql(
|
||||||
|
"idx_embedding_text_chunk_embedding",
|
||||||
|
Self::table_name(),
|
||||||
|
dimension,
|
||||||
|
);
|
||||||
|
|
||||||
|
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||||
|
res.check().map_err(AppError::from)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||||
|
#[allow(clippy::result_large_err)]
|
||||||
|
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||||
|
if embedding.len() != expected {
|
||||||
|
return Err(AppError::Validation(format!(
|
||||||
|
"embedding dimension mismatch: got {}, expected {expected}",
|
||||||
|
embedding.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new text chunk embedding.
|
||||||
|
///
|
||||||
|
/// The embedding record id equals `chunk_id` so each chunk has at most one embedding row.
|
||||||
|
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid".
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||||
|
let now = Utc::now();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
id: chunk_id.to_owned(),
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
||||||
|
source_id,
|
||||||
|
embedding,
|
||||||
|
user_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a single embedding by its chunk RecordId
|
||||||
|
pub async fn get_by_chunk_id(
|
||||||
|
chunk_id: &RecordId,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<Option<Self>, AppError> {
|
||||||
|
let query = format!(
|
||||||
|
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
|
||||||
|
Self::table_name()
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut result = db
|
||||||
|
.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("chunk_id", chunk_id.clone()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
|
Ok(embeddings.into_iter().next())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete embeddings for a given chunk RecordId
|
||||||
|
pub async fn delete_by_chunk_id(
|
||||||
|
chunk_id: &RecordId,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let query = format!(
|
||||||
|
"DELETE FROM {} WHERE chunk_id = $chunk_id",
|
||||||
|
Self::table_name()
|
||||||
|
);
|
||||||
|
|
||||||
|
db.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("chunk_id", chunk_id.clone()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete all embeddings that belong to chunks with a given `source_id`
|
||||||
|
///
|
||||||
|
/// This uses the denormalized `source_id` on the embedding table.
|
||||||
|
pub async fn delete_by_source_id(
|
||||||
|
source_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<(), AppError> {
|
||||||
|
let query = format!(
|
||||||
|
"DELETE FROM {} WHERE source_id = $source_id",
|
||||||
|
Self::table_name()
|
||||||
|
);
|
||||||
|
|
||||||
|
db.client
|
||||||
|
.query(query)
|
||||||
|
.bind(("source_id", source_id.to_owned()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
||||||
|
use surrealdb::Value as SurrealValue;
|
||||||
|
|
||||||
|
async fn create_text_chunk_with_id(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
key: &str,
|
||||||
|
source_id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
) -> anyhow::Result<RecordId> {
|
||||||
|
let chunk = TextChunk {
|
||||||
|
id: key.to_owned(),
|
||||||
|
created_at: Utc::now(),
|
||||||
|
updated_at: Utc::now(),
|
||||||
|
source_id: source_id.to_owned(),
|
||||||
|
chunk: "Some test chunk text".to_owned(),
|
||||||
|
user_id: user_id.to_owned(),
|
||||||
|
};
|
||||||
|
|
||||||
|
db.store_item(chunk)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to create text_chunk".to_string())?;
|
||||||
|
|
||||||
|
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||||
|
let mut info_res = db
|
||||||
|
.client
|
||||||
|
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||||
|
.await
|
||||||
|
.with_context(|| "info query failed".to_string())?;
|
||||||
|
let info: SurrealValue = info_res
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "failed to take info result".to_string())?;
|
||||||
|
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||||
|
.with_context(|| "failed to convert info to json".to_string())?;
|
||||||
|
let idx_sql = info_json
|
||||||
|
.get("Object")
|
||||||
|
.and_then(|v| v.get("indexes"))
|
||||||
|
.and_then(|v| v.get("Object"))
|
||||||
|
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||||
|
.and_then(|v| v.get("Strand"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or_default()
|
||||||
|
.to_string();
|
||||||
|
Ok(idx_sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_uses_chunk_id_as_record_id() {
|
||||||
|
let emb = TextChunkEmbedding::new(
|
||||||
|
"chunk-abc",
|
||||||
|
"source-1".to_owned(),
|
||||||
|
vec![0.1, 0.2],
|
||||||
|
"user-1".to_owned(),
|
||||||
|
);
|
||||||
|
assert_eq!(emb.id, "chunk-abc");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_dimension_rejects_mismatch() {
|
||||||
|
let err = TextChunkEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2)
|
||||||
|
.expect_err("expected dimension mismatch");
|
||||||
|
assert!(matches!(err, AppError::Validation(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
|
let user_id = "user_a";
|
||||||
|
let chunk_key = "chunk-123";
|
||||||
|
let source_id = "source-1";
|
||||||
|
|
||||||
|
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||||
|
|
||||||
|
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
|
||||||
|
let emb = TextChunkEmbedding::new(
|
||||||
|
chunk_key,
|
||||||
|
source_id.to_string(),
|
||||||
|
embedding_vec.clone(),
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
db.upsert_item(emb)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store embedding".to_string())?;
|
||||||
|
|
||||||
|
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
||||||
|
.with_context(|| "Expected an embedding to be found".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(fetched.id, chunk_key);
|
||||||
|
assert_eq!(fetched.user_id, user_id);
|
||||||
|
assert_eq!(fetched.chunk_id, chunk_rid);
|
||||||
|
assert_eq!(fetched.embedding, embedding_vec);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
|
let user_id = "user_b";
|
||||||
|
let chunk_key = "chunk-delete";
|
||||||
|
let source_id = "source-del";
|
||||||
|
|
||||||
|
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||||
|
|
||||||
|
let emb = TextChunkEmbedding::new(
|
||||||
|
chunk_key,
|
||||||
|
source_id.to_string(),
|
||||||
|
vec![0.4_f32, 0.5, 0.6],
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
db.upsert_item(emb)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store embedding".to_string())?;
|
||||||
|
|
||||||
|
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||||
|
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||||
|
|
||||||
|
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
||||||
|
|
||||||
|
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||||
|
assert!(after.is_none(), "Embedding should have been deleted");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_text_chunk_test_db(1).await?;
|
||||||
|
|
||||||
|
let user_id = "user_c";
|
||||||
|
let source_id = "shared-source";
|
||||||
|
let other_source = "other-source";
|
||||||
|
|
||||||
|
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await?;
|
||||||
|
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await?;
|
||||||
|
let chunk_other_rid =
|
||||||
|
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await?;
|
||||||
|
|
||||||
|
for (key, src, vec) in [
|
||||||
|
("chunk-s1", source_id, vec![0.1]),
|
||||||
|
("chunk-s2", source_id, vec![0.2]),
|
||||||
|
("chunk-other", other_source, vec![0.3]),
|
||||||
|
] {
|
||||||
|
let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
|
||||||
|
db.upsert_item(emb)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("store embedding for {key}"))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get chunk1".to_string())?
|
||||||
|
.is_some());
|
||||||
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get chunk2".to_string())?
|
||||||
|
.is_some());
|
||||||
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get chunk_other".to_string())?
|
||||||
|
.is_some());
|
||||||
|
|
||||||
|
TextChunkEmbedding::delete_by_source_id(source_id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||||
|
|
||||||
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "check chunk1".to_string())?
|
||||||
|
.is_none());
|
||||||
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "check chunk2".to_string())?
|
||||||
|
.is_none());
|
||||||
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "check chunk_other".to_string())?
|
||||||
|
.is_some());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
|
let user_id = "user-upsert";
|
||||||
|
let source_id = "source-upsert";
|
||||||
|
let chunk_key = "chunk-upsert";
|
||||||
|
|
||||||
|
create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||||
|
|
||||||
|
let initial = TextChunkEmbedding::new(
|
||||||
|
chunk_key,
|
||||||
|
source_id.to_owned(),
|
||||||
|
vec![1.0_f32, 0.0, 0.0],
|
||||||
|
user_id.to_owned(),
|
||||||
|
);
|
||||||
|
db.upsert_item(initial)
|
||||||
|
.await
|
||||||
|
.with_context(|| "initial upsert".to_string())?;
|
||||||
|
|
||||||
|
let replacement = TextChunkEmbedding::new(
|
||||||
|
chunk_key,
|
||||||
|
source_id.to_owned(),
|
||||||
|
vec![0.0, 1.0, 0.0],
|
||||||
|
user_id.to_owned(),
|
||||||
|
);
|
||||||
|
db.upsert_item(replacement)
|
||||||
|
.await
|
||||||
|
.with_context(|| "upsert replacement embedding".to_string())?;
|
||||||
|
|
||||||
|
let chunk_rid = RecordId::from_table_key(TextChunk::table_name(), chunk_key);
|
||||||
|
let rows: Vec<TextChunkEmbedding> = db
|
||||||
|
.client
|
||||||
|
.query(format!(
|
||||||
|
"SELECT * FROM {} WHERE chunk_id = $chunk_id",
|
||||||
|
TextChunkEmbedding::table_name()
|
||||||
|
))
|
||||||
|
.bind(("chunk_id", chunk_rid))
|
||||||
|
.await
|
||||||
|
.with_context(|| "count embeddings".to_string())?
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "take embeddings".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(rows.len(), 1);
|
||||||
|
let row = rows.first().expect("expected one embedding row");
|
||||||
|
assert_eq!(row.id, chunk_key);
|
||||||
|
assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
|
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
|
||||||
|
.await
|
||||||
|
.with_context(|| "failed to redefine index".to_string())?;
|
||||||
|
|
||||||
|
let idx_sql = get_idx_sql(&db).await?;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
idx_sql.contains("DIMENSION 8"),
|
||||||
|
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
idx_sql.contains("DIST COSINE"),
|
||||||
|
"expected index definition to use cosine distance, got: {idx_sql}"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_redefine_hnsw_index_is_idempotent() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
|
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||||
|
.await
|
||||||
|
.with_context(|| "first redefine failed".to_string())?;
|
||||||
|
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||||
|
.await
|
||||||
|
.with_context(|| "second redefine failed".to_string())?;
|
||||||
|
|
||||||
|
let idx_sql = get_idx_sql(&db).await?;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
idx_sql.contains("DIMENSION 4"),
|
||||||
|
"expected index definition to retain dimension 4, got: {idx_sql}"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,10 +1,15 @@
|
|||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
use surrealdb::opt::PatchOp;
|
use surrealdb::opt::PatchOp;
|
||||||
|
use surrealdb::RecordId;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||||
|
|
||||||
use super::file_info::FileInfo;
|
use super::file_info::FileInfo;
|
||||||
|
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
pub struct TextContentSearchResult {
|
pub struct TextContentSearchResult {
|
||||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||||
@@ -50,8 +55,11 @@ pub struct TextContentSearchResult {
|
|||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||||
pub struct UrlInfo {
|
pub struct UrlInfo {
|
||||||
|
#[serde(default)]
|
||||||
pub url: String,
|
pub url: String,
|
||||||
|
#[serde(default)]
|
||||||
pub title: String,
|
pub title: String,
|
||||||
|
#[serde(default)]
|
||||||
pub image_id: String,
|
pub image_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,6 +73,7 @@ stored_object!(TextContent, "text_content", {
|
|||||||
});
|
});
|
||||||
|
|
||||||
impl TextContent {
|
impl TextContent {
|
||||||
|
#[must_use]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
text: String,
|
text: String,
|
||||||
context: Option<String>,
|
context: Option<String>,
|
||||||
@@ -96,7 +105,7 @@ impl TextContent {
|
|||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
|
||||||
let _res: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.update((Self::table_name(), id))
|
.update((Self::table_name(), id))
|
||||||
.patch(PatchOp::replace("/context", context))
|
.patch(PatchOp::replace("/context", context))
|
||||||
.patch(PatchOp::replace("/category", category))
|
.patch(PatchOp::replace("/category", category))
|
||||||
@@ -105,18 +114,45 @@ impl TextContent {
|
|||||||
"/updated_at",
|
"/updated_at",
|
||||||
surrealdb::Datetime::from(now),
|
surrealdb::Datetime::from(now),
|
||||||
))
|
))
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
if updated.is_none() {
|
||||||
|
return Err(AppError::NotFound(format!("text content {id} not found")));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn has_other_with_file(
|
||||||
|
file_id: &str,
|
||||||
|
exclude_id: &str,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<bool, AppError> {
|
||||||
|
let mut response = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"SELECT VALUE id FROM type::table($table_name) WHERE file_info.id = $file_id AND id != type::thing($table_name, $exclude_id) LIMIT 1",
|
||||||
|
)
|
||||||
|
.bind(("table_name", TextContent::table_name()))
|
||||||
|
.bind(("file_id", file_id.to_owned()))
|
||||||
|
.bind(("exclude_id", exclude_id.to_owned()))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
let existing: Option<surrealdb::sql::Thing> = response.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
|
Ok(existing.is_some())
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn search(
|
pub async fn search(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
search_terms: &str,
|
search_terms: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
) -> Result<Vec<TextContentSearchResult>, AppError> {
|
) -> Result<Vec<TextContentSearchResult>, AppError> {
|
||||||
let sql = r#"
|
let sql = format!(
|
||||||
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
*,
|
*,
|
||||||
search::highlight('<b>', '</b>', 0) AS highlighted_text,
|
search::highlight('<b>', '</b>', 0) AS highlighted_text,
|
||||||
@@ -126,14 +162,14 @@ impl TextContent {
|
|||||||
search::highlight('<b>', '</b>', 4) AS highlighted_url,
|
search::highlight('<b>', '</b>', 4) AS highlighted_url,
|
||||||
search::highlight('<b>', '</b>', 5) AS highlighted_url_title,
|
search::highlight('<b>', '</b>', 5) AS highlighted_url_title,
|
||||||
(
|
(
|
||||||
search::score(0) +
|
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END +
|
||||||
search::score(1) +
|
IF search::score(1) != NONE THEN search::score(1) ELSE 0 END +
|
||||||
search::score(2) +
|
IF search::score(2) != NONE THEN search::score(2) ELSE 0 END +
|
||||||
search::score(3) +
|
IF search::score(3) != NONE THEN search::score(3) ELSE 0 END +
|
||||||
search::score(4) +
|
IF search::score(4) != NONE THEN search::score(4) ELSE 0 END +
|
||||||
search::score(5)
|
IF search::score(5) != NONE THEN search::score(5) ELSE 0 END
|
||||||
) AS score
|
) AS score
|
||||||
FROM text_content
|
FROM {table}
|
||||||
WHERE
|
WHERE
|
||||||
(
|
(
|
||||||
text @0@ $terms OR
|
text @0@ $terms OR
|
||||||
@@ -146,25 +182,192 @@ impl TextContent {
|
|||||||
AND user_id = $user_id
|
AND user_id = $user_id
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit;
|
LIMIT $limit;
|
||||||
"#;
|
"#,
|
||||||
|
table = Self::table_name(),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(db
|
db.client
|
||||||
.client
|
|
||||||
.query(sql)
|
.query(sql)
|
||||||
.bind(("terms", search_terms.to_owned()))
|
.bind(("terms", search_terms.to_owned()))
|
||||||
.bind(("user_id", user_id.to_owned()))
|
.bind(("user_id", user_id.to_owned()))
|
||||||
.bind(("limit", limit))
|
.bind(("limit", limit))
|
||||||
.await?
|
.await
|
||||||
.take(0)?)
|
.map_err(AppError::from)?
|
||||||
|
.take(0)
|
||||||
|
.map_err(AppError::from)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Builds a fallback display label for a source id when no matching content row exists.
|
||||||
|
#[must_use]
|
||||||
|
pub fn fallback_source_label(source_id: &str) -> String {
|
||||||
|
format!("Text snippet: {}", source_id_suffix(source_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolves human-readable labels for the given source ids owned by `user_id`.
|
||||||
|
pub async fn resolve_source_labels(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
user_id: &str,
|
||||||
|
source_ids: impl IntoIterator<Item = impl AsRef<str>>,
|
||||||
|
) -> Result<HashMap<String, String>, AppError> {
|
||||||
|
let source_ids: HashSet<String> = source_ids
|
||||||
|
.into_iter()
|
||||||
|
.map(|id| id.as_ref().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if source_ids.is_empty() {
|
||||||
|
return Ok(HashMap::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let record_ids: Vec<RecordId> = source_ids
|
||||||
|
.iter()
|
||||||
|
.filter_map(|id| {
|
||||||
|
if id.contains(':') {
|
||||||
|
RecordId::from_str(id).ok()
|
||||||
|
} else {
|
||||||
|
Some(RecordId::from_table_key(Self::table_name(), id))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut response = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
|
||||||
|
)
|
||||||
|
.bind(("table_name", Self::table_name()))
|
||||||
|
.bind(("user_id", user_id.to_owned()))
|
||||||
|
.bind(("record_ids", record_ids))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
let contents: Vec<SourceLabelRow> = response.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
source_id_count = source_ids.len(),
|
||||||
|
label_row_count = contents.len(),
|
||||||
|
"resolved source labels"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut labels = HashMap::new();
|
||||||
|
for content in contents {
|
||||||
|
let label = build_source_label(&content);
|
||||||
|
labels.insert(content.id.clone(), label.clone());
|
||||||
|
labels.insert(format!("{}:{}", Self::table_name(), content.id), label);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(labels)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const SOURCE_LABEL_MAX_CHARS: usize = 80;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct SourceLabelRow {
|
||||||
|
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||||
|
id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
url_info: Option<UrlInfo>,
|
||||||
|
#[serde(default)]
|
||||||
|
file_info: Option<FileInfo>,
|
||||||
|
#[serde(default)]
|
||||||
|
context: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
category: String,
|
||||||
|
#[serde(default)]
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn source_id_suffix(source_id: &str) -> String {
|
||||||
|
let start = source_id.len().saturating_sub(8);
|
||||||
|
source_id[start..].to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_with_ellipsis(value: &str, max_chars: usize) -> String {
|
||||||
|
const ELLIPSIS: &str = "…";
|
||||||
|
|
||||||
|
if max_chars == 0 {
|
||||||
|
return if value.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
ELLIPSIS.to_string()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut end_byte = value.len();
|
||||||
|
for (count, (idx, _)) in value.char_indices().enumerate() {
|
||||||
|
if count == max_chars {
|
||||||
|
end_byte = idx;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if end_byte == value.len() {
|
||||||
|
return value.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
format!("{}{}", &value[..end_byte], ELLIPSIS)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn first_non_empty_line(text: &str, max_chars: usize) -> Option<String> {
|
||||||
|
text.lines().find_map(|line| {
|
||||||
|
let trimmed = line.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(truncate_with_ellipsis(trimmed, max_chars))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_source_label(row: &SourceLabelRow) -> String {
|
||||||
|
if let Some(url_info) = row.url_info.as_ref() {
|
||||||
|
let title = url_info.title.trim();
|
||||||
|
if !title.is_empty() {
|
||||||
|
return title.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = url_info.url.trim();
|
||||||
|
if !url.is_empty() {
|
||||||
|
return url.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(file_info) = row.file_info.as_ref() {
|
||||||
|
let name = file_info.file_name.trim();
|
||||||
|
if !name.is_empty() {
|
||||||
|
return name.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(context) = row.context.as_ref() {
|
||||||
|
let trimmed = context.trim();
|
||||||
|
if !trimmed.is_empty() {
|
||||||
|
return truncate_with_ellipsis(trimmed, SOURCE_LABEL_MAX_CHARS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(text_label) = first_non_empty_line(&row.text, SOURCE_LABEL_MAX_CHARS) {
|
||||||
|
return text_label;
|
||||||
|
}
|
||||||
|
|
||||||
|
let category = row.category.trim();
|
||||||
|
if !category.is_empty() {
|
||||||
|
return truncate_with_ellipsis(category, SOURCE_LABEL_MAX_CHARS);
|
||||||
|
}
|
||||||
|
|
||||||
|
TextContent::fallback_source_label(&row.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::test_utils::setup_test_db_with_runtime_indexes;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_creation() {
|
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||||
// Test basic object creation
|
// Test basic object creation
|
||||||
let text = "Test content text".to_string();
|
let text = "Test content text".to_string();
|
||||||
let context = "Test context".to_string();
|
let context = "Test context".to_string();
|
||||||
@@ -188,10 +391,11 @@ mod tests {
|
|||||||
assert!(text_content.file_info.is_none());
|
assert!(text_content.file_info.is_none());
|
||||||
assert!(text_content.url_info.is_none());
|
assert!(text_content.url_info.is_none());
|
||||||
assert!(!text_content.id.is_empty());
|
assert!(!text_content.id.is_empty());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_with_url() {
|
async fn test_text_content_with_url() -> anyhow::Result<()> {
|
||||||
// Test creating with URL
|
// Test creating with URL
|
||||||
let text = "Content with URL".to_string();
|
let text = "Content with URL".to_string();
|
||||||
let context = "URL context".to_string();
|
let context = "URL context".to_string();
|
||||||
@@ -208,26 +412,27 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let text_content = TextContent::new(
|
let text_content = TextContent::new(
|
||||||
text.clone(),
|
text,
|
||||||
Some(context.clone()),
|
Some(context),
|
||||||
category.clone(),
|
category,
|
||||||
None,
|
None,
|
||||||
url_info.clone(),
|
url_info.clone(),
|
||||||
user_id.clone(),
|
user_id,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check URL field is set
|
// Check URL field is set
|
||||||
assert_eq!(text_content.url_info, url_info);
|
assert_eq!(text_content.url_info, url_info);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_patch() {
|
async fn test_text_content_patch() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create initial text content
|
// Create initial text content
|
||||||
let initial_text = "Initial text".to_string();
|
let initial_text = "Initial text".to_string();
|
||||||
@@ -248,7 +453,7 @@ mod tests {
|
|||||||
let stored: Option<TextContent> = db
|
let stored: Option<TextContent> = db
|
||||||
.store_item(text_content.clone())
|
.store_item(text_content.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store text content");
|
.with_context(|| "Failed to store text content".to_string())?;
|
||||||
assert!(stored.is_some());
|
assert!(stored.is_some());
|
||||||
|
|
||||||
// New values for patch
|
// New values for patch
|
||||||
@@ -259,21 +464,178 @@ mod tests {
|
|||||||
// Apply the patch
|
// Apply the patch
|
||||||
TextContent::patch(&text_content.id, new_context, new_category, new_text, &db)
|
TextContent::patch(&text_content.id, new_context, new_category, new_text, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to patch text content");
|
.with_context(|| "Failed to patch text content".to_string())?;
|
||||||
|
|
||||||
// Retrieve the updated content
|
// Retrieve the updated content
|
||||||
let updated: Option<TextContent> = db
|
let updated: Option<TextContent> = db
|
||||||
.get_item(&text_content.id)
|
.get_item(&text_content.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get updated text content");
|
.with_context(|| "Failed to get updated text content".to_string())?;
|
||||||
assert!(updated.is_some());
|
let updated_content = updated.with_context(|| "expected updated content".to_string())?;
|
||||||
|
|
||||||
let updated_content = updated.unwrap();
|
|
||||||
|
|
||||||
// Verify the updates
|
// Verify the updates
|
||||||
assert_eq!(updated_content.context, Some(new_context.to_string()));
|
assert_eq!(updated_content.context, Some(new_context.to_string()));
|
||||||
assert_eq!(updated_content.category, new_category);
|
assert_eq!(updated_content.category, new_category);
|
||||||
assert_eq!(updated_content.text, new_text);
|
assert_eq!(updated_content.text, new_text);
|
||||||
assert!(updated_content.updated_at > text_content.updated_at);
|
assert!(updated_content.updated_at > text_content.updated_at);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_text_content_patch_not_found() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
|
||||||
|
let err = TextContent::patch("missing-id", "ctx", "cat", "text", &db)
|
||||||
|
.await
|
||||||
|
.expect_err("expected not found");
|
||||||
|
|
||||||
|
assert!(matches!(err, AppError::NotFound(_)));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_has_other_with_file_detects_shared_usage() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
|
let user_id = "user123".to_string();
|
||||||
|
let file_info = FileInfo {
|
||||||
|
id: "file-1".to_string(),
|
||||||
|
created_at: chrono::Utc::now(),
|
||||||
|
updated_at: chrono::Utc::now(),
|
||||||
|
sha256: "sha-test".to_string(),
|
||||||
|
path: "user123/file-1/test.txt".to_string(),
|
||||||
|
file_name: "test.txt".to_string(),
|
||||||
|
mime_type: "text/plain".to_string(),
|
||||||
|
user_id: user_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let content_a = TextContent::new(
|
||||||
|
"First".to_string(),
|
||||||
|
Some("ctx-a".to_string()),
|
||||||
|
"category".to_string(),
|
||||||
|
Some(file_info.clone()),
|
||||||
|
None,
|
||||||
|
user_id.clone(),
|
||||||
|
);
|
||||||
|
let content_b = TextContent::new(
|
||||||
|
"Second".to_string(),
|
||||||
|
Some("ctx-b".to_string()),
|
||||||
|
"category".to_string(),
|
||||||
|
Some(file_info.clone()),
|
||||||
|
None,
|
||||||
|
user_id.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
db.store_item(content_a.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store first content".to_string())?;
|
||||||
|
db.store_item(content_b.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store second content".to_string())?;
|
||||||
|
|
||||||
|
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to check for shared file usage".to_string())?;
|
||||||
|
assert!(has_other);
|
||||||
|
|
||||||
|
let _removed: Option<TextContent> = db
|
||||||
|
.delete_item(&content_b.id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to delete second content".to_string())?;
|
||||||
|
|
||||||
|
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to check shared usage after delete".to_string())?;
|
||||||
|
assert!(!has_other_after);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_search_returns_empty_when_no_content() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
|
||||||
|
let results = TextContent::search(&db, "hello", "user", 5)
|
||||||
|
.await
|
||||||
|
.with_context(|| "search".to_string())?;
|
||||||
|
|
||||||
|
assert!(results.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_search_finds_matching_text_and_filters_user() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
let user_id = "search_user";
|
||||||
|
|
||||||
|
let matching = TextContent::new(
|
||||||
|
"rust programming language".to_string(),
|
||||||
|
Some("context".to_string()),
|
||||||
|
"notes".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
let other_user = TextContent::new(
|
||||||
|
"rust programming language".to_string(),
|
||||||
|
None,
|
||||||
|
"notes".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
"other_user".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
db.store_item(matching.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "store matching".to_string())?;
|
||||||
|
db.store_item(other_user)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store other user".to_string())?;
|
||||||
|
|
||||||
|
let results = TextContent::search(&db, "rust", user_id, 5)
|
||||||
|
.await
|
||||||
|
.with_context(|| "search".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
let row = results.first().context("expected one result")?;
|
||||||
|
assert_eq!(row.id, matching.id);
|
||||||
|
assert_eq!(row.user_id, user_id);
|
||||||
|
assert!(row.score.is_finite());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_resolve_source_labels_uses_url_title() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
let user_id = "label_user";
|
||||||
|
|
||||||
|
let content = TextContent::new(
|
||||||
|
"body".to_string(),
|
||||||
|
None,
|
||||||
|
"notes".to_string(),
|
||||||
|
None,
|
||||||
|
Some(UrlInfo {
|
||||||
|
url: "https://example.com/doc".to_string(),
|
||||||
|
title: "Example Document".to_string(),
|
||||||
|
image_id: String::new(),
|
||||||
|
}),
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
db.store_item(content.clone()).await?;
|
||||||
|
|
||||||
|
let labels = TextContent::resolve_source_labels(&db, user_id, [content.id.clone()]).await?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
labels.get(&content.id),
|
||||||
|
Some(&"Example Document".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
labels.get(&format!("text_content:{}", content.id)),
|
||||||
|
Some(&"Example Document".to_string())
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+408
-141
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,93 @@
|
|||||||
|
//! Shared helpers for in-memory SurrealDB tests.
|
||||||
|
#![cfg(any(test, feature = "test-utils"))]
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::storage::{
|
||||||
|
db::SurrealDbClient,
|
||||||
|
indexes::{ensure_runtime, rebuild},
|
||||||
|
types::{
|
||||||
|
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
|
||||||
|
text_chunk_embedding::TextChunkEmbedding,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const TEST_NAMESPACE: &str = "test_ns";
|
||||||
|
|
||||||
|
/// Starts an in-memory database, applies migrations, and returns a client.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the database cannot be started or migrations fail.
|
||||||
|
pub async fn setup_test_db() -> Result<SurrealDbClient> {
|
||||||
|
let database = Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(TEST_NAMESPACE, &database)
|
||||||
|
.await
|
||||||
|
.context("start in-memory surrealdb")?;
|
||||||
|
|
||||||
|
db.apply_migrations().await.context("apply migrations")?;
|
||||||
|
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Updates singleton [`SystemSettings`] embedding dimensions for tests.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if settings cannot be loaded or updated.
|
||||||
|
pub async fn configure_embedding_dimension(db: &SurrealDbClient, dimension: u32) -> Result<()> {
|
||||||
|
let mut settings = SystemSettings::get_current(db).await?;
|
||||||
|
settings.embedding_dimensions = dimension;
|
||||||
|
SystemSettings::update(db, settings).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts a test database and sets the embedding dimension in system settings.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup or settings update fails.
|
||||||
|
pub async fn setup_test_db_with_embedding_dimension(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
configure_embedding_dimension(&db, dimension).await?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prepares a database for text-chunk embedding tests at the given dimension.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, settings update, or index redefinition fails.
|
||||||
|
pub async fn prepare_text_chunk_test_db(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||||
|
TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("set text chunk index dimension to {dimension}"))?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prepares a database for knowledge-entity embedding tests at the given dimension.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, settings update, or index redefinition fails.
|
||||||
|
pub async fn prepare_knowledge_entity_test_db(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||||
|
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("set knowledge entity index dimension to {dimension}"))?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts a test database and ensures runtime FTS/HNSW indexes are ready.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, index creation, or rebuild fails.
|
||||||
|
pub async fn setup_test_db_with_runtime_indexes() -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
ensure_runtime(&db, 1536).await?;
|
||||||
|
rebuild(&db).await?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
+236
-2
@@ -1,16 +1,89 @@
|
|||||||
use config::{Config, ConfigError, Environment, File};
|
use config::{Config, ConfigError, Environment, File};
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{env, str::FromStr, sync::Once};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Debug)]
|
/// Error returned when parsing an embedding backend name.
|
||||||
|
#[derive(Debug, Error, PartialEq, Eq)]
|
||||||
|
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
|
||||||
|
pub struct ParseEmbeddingBackendError {
|
||||||
|
/// The unrecognized input string.
|
||||||
|
pub input: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Selects the embedding backend for vector generation.
|
||||||
|
#[derive(Clone, Copy, Deserialize, Serialize, Debug, Default, PartialEq, Eq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum EmbeddingBackend {
|
||||||
|
/// Use OpenAI-compatible API for embeddings.
|
||||||
|
OpenAI,
|
||||||
|
/// Use FastEmbed local embeddings (default).
|
||||||
|
#[default]
|
||||||
|
FastEmbed,
|
||||||
|
/// Use deterministic hashed embeddings (for testing).
|
||||||
|
Hashed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingBackend {
|
||||||
|
#[must_use]
|
||||||
|
pub fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::OpenAI => "openai",
|
||||||
|
Self::FastEmbed => "fastembed",
|
||||||
|
Self::Hashed => "hashed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for EmbeddingBackend {
|
||||||
|
type Err = ParseEmbeddingBackendError;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
match s.to_ascii_lowercase().as_str() {
|
||||||
|
"openai" => Ok(Self::OpenAI),
|
||||||
|
"hashed" => Ok(Self::Hashed),
|
||||||
|
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||||
|
other => Err(ParseEmbeddingBackendError {
|
||||||
|
input: other.to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Deserialize, Debug, PartialEq)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum StorageKind {
|
pub enum StorageKind {
|
||||||
Local,
|
Local,
|
||||||
|
Memory,
|
||||||
|
S3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Default storage backend when none is configured.
|
||||||
fn default_storage_kind() -> StorageKind {
|
fn default_storage_kind() -> StorageKind {
|
||||||
StorageKind::Local
|
StorageKind::Local
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_s3_region() -> String {
|
||||||
|
"us-east-1".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Selects the strategy used for PDF ingestion.
|
||||||
|
#[derive(Clone, Copy, Deserialize, Debug)]
|
||||||
|
#[serde(rename_all = "kebab-case")]
|
||||||
|
pub enum PdfIngestMode {
|
||||||
|
/// Only rely on classic text extraction (no LLM fallbacks).
|
||||||
|
Classic,
|
||||||
|
/// Prefer fast text extraction, but fall back to the LLM rendering path when needed.
|
||||||
|
LlmFirst,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default PDF ingestion mode when unset.
|
||||||
|
fn default_pdf_ingest_mode() -> PdfIngestMode {
|
||||||
|
PdfIngestMode::LlmFirst
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Application configuration loaded from files and environment variables.
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
#[derive(Clone, Deserialize, Debug)]
|
#[derive(Clone, Deserialize, Debug)]
|
||||||
pub struct AppConfig {
|
pub struct AppConfig {
|
||||||
pub openai_api_key: String,
|
pub openai_api_key: String,
|
||||||
@@ -26,17 +99,154 @@ pub struct AppConfig {
|
|||||||
pub openai_base_url: String,
|
pub openai_base_url: String,
|
||||||
#[serde(default = "default_storage_kind")]
|
#[serde(default = "default_storage_kind")]
|
||||||
pub storage: StorageKind,
|
pub storage: StorageKind,
|
||||||
|
#[serde(default)]
|
||||||
|
pub s3_bucket: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub s3_endpoint: Option<String>,
|
||||||
|
#[serde(default = "default_s3_region")]
|
||||||
|
pub s3_region: String,
|
||||||
|
#[serde(default = "default_pdf_ingest_mode")]
|
||||||
|
pub pdf_ingest_mode: PdfIngestMode,
|
||||||
|
#[serde(default = "default_reranking_enabled")]
|
||||||
|
pub reranking_enabled: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub reranking_pool_size: Option<usize>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub fastembed_cache_dir: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub fastembed_show_download_progress: Option<bool>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub fastembed_max_length: Option<usize>,
|
||||||
|
/// HuggingFace-style FastEmbed `model_code` (e.g. `Xenova/bge-small-en-v1.5`). Overrides
|
||||||
|
/// `system_settings.embedding_model` when `embedding_backend` is `fastembed`.
|
||||||
|
#[serde(default)]
|
||||||
|
pub fastembed_model: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub embedding_backend: EmbeddingBackend,
|
||||||
|
#[serde(default)]
|
||||||
|
pub embedding_pool_size: Option<usize>,
|
||||||
|
#[serde(default = "default_ingest_max_body_bytes")]
|
||||||
|
pub ingest_max_body_bytes: usize,
|
||||||
|
#[serde(default = "default_ingest_max_files")]
|
||||||
|
pub ingest_max_files: usize,
|
||||||
|
#[serde(default = "default_ingest_max_content_bytes")]
|
||||||
|
pub ingest_max_content_bytes: usize,
|
||||||
|
#[serde(default = "default_ingest_max_context_bytes")]
|
||||||
|
pub ingest_max_context_bytes: usize,
|
||||||
|
#[serde(default = "default_ingest_max_category_bytes")]
|
||||||
|
pub ingest_max_category_bytes: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Default data directory for persisted assets.
|
||||||
fn default_data_dir() -> String {
|
fn default_data_dir() -> String {
|
||||||
"./data".to_string()
|
"./data".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Default base URL used for OpenAI-compatible APIs.
|
||||||
fn default_base_url() -> String {
|
fn default_base_url() -> String {
|
||||||
"https://api.openai.com/v1".to_string()
|
"https://api.openai.com/v1".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether reranking is enabled by default.
|
||||||
|
fn default_reranking_enabled() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ingest_max_body_bytes() -> usize {
|
||||||
|
20_000_000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ingest_max_files() -> usize {
|
||||||
|
5
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ingest_max_content_bytes() -> usize {
|
||||||
|
262_144
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ingest_max_context_bytes() -> usize {
|
||||||
|
16_384
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ingest_max_category_bytes() -> usize {
|
||||||
|
128
|
||||||
|
}
|
||||||
|
|
||||||
|
static ORT_PATH_INIT: Once = Once::new();
|
||||||
|
|
||||||
|
/// Sets `ORT_DYLIB_PATH` once per process when a bundled ONNX runtime library is found.
|
||||||
|
pub fn ensure_ort_path() {
|
||||||
|
ORT_PATH_INIT.call_once(|| {
|
||||||
|
if env::var_os("ORT_DYLIB_PATH").is_some() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let Ok(mut exe) = env::current_exe() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
exe.pop();
|
||||||
|
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
for p in [
|
||||||
|
exe.join("onnxruntime.dll"),
|
||||||
|
exe.join("lib").join("onnxruntime.dll"),
|
||||||
|
] {
|
||||||
|
if p.exists() {
|
||||||
|
env::set_var("ORT_DYLIB_PATH", p);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let name = if cfg!(target_os = "macos") {
|
||||||
|
"libonnxruntime.dylib"
|
||||||
|
} else {
|
||||||
|
"libonnxruntime.so"
|
||||||
|
};
|
||||||
|
let p = exe.join("lib").join(name);
|
||||||
|
if p.exists() {
|
||||||
|
env::set_var("ORT_DYLIB_PATH", p);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
openai_api_key: String::new(),
|
||||||
|
surrealdb_address: String::new(),
|
||||||
|
surrealdb_username: String::new(),
|
||||||
|
surrealdb_password: String::new(),
|
||||||
|
surrealdb_namespace: String::new(),
|
||||||
|
surrealdb_database: String::new(),
|
||||||
|
data_dir: default_data_dir(),
|
||||||
|
http_port: 0,
|
||||||
|
openai_base_url: default_base_url(),
|
||||||
|
storage: default_storage_kind(),
|
||||||
|
s3_bucket: None,
|
||||||
|
s3_endpoint: None,
|
||||||
|
s3_region: default_s3_region(),
|
||||||
|
pdf_ingest_mode: default_pdf_ingest_mode(),
|
||||||
|
reranking_enabled: default_reranking_enabled(),
|
||||||
|
reranking_pool_size: None,
|
||||||
|
fastembed_cache_dir: None,
|
||||||
|
fastembed_show_download_progress: None,
|
||||||
|
fastembed_max_length: None,
|
||||||
|
fastembed_model: None,
|
||||||
|
embedding_backend: EmbeddingBackend::default(),
|
||||||
|
embedding_pool_size: None,
|
||||||
|
ingest_max_body_bytes: default_ingest_max_body_bytes(),
|
||||||
|
ingest_max_files: default_ingest_max_files(),
|
||||||
|
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
||||||
|
ingest_max_context_bytes: default_ingest_max_context_bytes(),
|
||||||
|
ingest_max_category_bytes: default_ingest_max_category_bytes(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Loads the application configuration from the environment and optional config file.
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
pub fn get_config() -> Result<AppConfig, ConfigError> {
|
pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||||
|
ensure_ort_path();
|
||||||
|
|
||||||
let config = Config::builder()
|
let config = Config::builder()
|
||||||
.add_source(File::with_name("config").required(false))
|
.add_source(File::with_name("config").required(false))
|
||||||
.add_source(Environment::default())
|
.add_source(Environment::default())
|
||||||
@@ -44,3 +254,27 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
|
|||||||
|
|
||||||
config.try_deserialize()
|
config.try_deserialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used)]
|
||||||
|
|
||||||
|
use super::EmbeddingBackend;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embedding_backend_defaults_to_fastembed() {
|
||||||
|
assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embedding_backend_parses_aliases() {
|
||||||
|
assert_eq!(
|
||||||
|
"openai".parse::<EmbeddingBackend>().expect("openai"),
|
||||||
|
EmbeddingBackend::OpenAI
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
"fast".parse::<EmbeddingBackend>().expect("fast"),
|
||||||
|
EmbeddingBackend::FastEmbed
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user