mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-12 17:24:26 +02:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 28e8ede478 | |||
| 00453fdcbe | |||
| c53ec8c0a1 | |||
| 60cf63292a | |||
| ac0d34bfbd | |||
| 4e20da538d | |||
| 15c9f18f6e | |||
| 7b850769c9 | |||
| 2a28243213 | |||
| b22c351785 | |||
| 3897345ab3 | |||
| 5c2d2e24d3 | |||
| c70141de35 | |||
| 2aa92b6ad7 | |||
| d3443d4153 | |||
| e3bb2935d0 | |||
| 93d11b66eb | |||
| 125b856c49 | |||
| bc41a619ce | |||
| ba8c36da1e | |||
| 5724f11dc1 | |||
| 189adb1a5f | |||
| 97beb91710 | |||
| 85336d77a3 | |||
| 9d5e7cd794 | |||
| 30bb59f243 | |||
| 224a7db451 | |||
| 4579725130 | |||
| 0b08801c90 | |||
| 45d13230a6 | |||
| 0acdba4f54 | |||
| 9609880cff | |||
| 31d585b59f | |||
| 890a4b381d | |||
| 2d630e2af9 | |||
| 9ec11e1f79 | |||
| c60db0fb56 | |||
| f5f0454904 | |||
| 18aadab8ee | |||
| 414d2f5b34 | |||
| 293440b0ee | |||
| 041d9bd81f | |||
| b4383bb227 | |||
| 6c7b586fc5 | |||
| 1927149ce9 | |||
| a52dc802de | |||
| 000852c94c | |||
| 6a5d631287 | |||
| b965c5a2e6 | |||
| 79e46e9c09 | |||
| f22a1e5ba4 |
@@ -1,49 +0,0 @@
|
|||||||
- name: Prepare lib dir
|
|
||||||
run: mkdir -p lib
|
|
||||||
|
|
||||||
# Linux
|
|
||||||
- name: Fetch ONNX Runtime (Linux)
|
|
||||||
if: runner.os == 'Linux'
|
|
||||||
env:
|
|
||||||
ORT_VER: 1.22.0
|
|
||||||
run: |
|
|
||||||
set -euo pipefail
|
|
||||||
ARCH="$(uname -m)"
|
|
||||||
case "$ARCH" in
|
|
||||||
x86_64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-x64-${ORT_VER}.tgz" ;;
|
|
||||||
aarch64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-aarch64-${ORT_VER}.tgz" ;;
|
|
||||||
*) echo "Unsupported arch $ARCH"; exit 1 ;;
|
|
||||||
esac
|
|
||||||
curl -fsSL -o ort.tgz "$URL"
|
|
||||||
tar -xzf ort.tgz
|
|
||||||
cp -v onnxruntime-*/lib/libonnxruntime.so* lib/
|
|
||||||
|
|
||||||
# macOS
|
|
||||||
- name: Fetch ONNX Runtime (macOS)
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
env:
|
|
||||||
ORT_VER: 1.22.0
|
|
||||||
run: |
|
|
||||||
set -euo pipefail
|
|
||||||
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
|
||||||
tar -xzf ort.tgz
|
|
||||||
# copy the main dylib; rename to stable name if needed
|
|
||||||
cp -v onnxruntime-*/lib/libonnxruntime*.dylib lib/
|
|
||||||
# optional: ensure a stable name
|
|
||||||
if [ ! -f lib/libonnxruntime.dylib ]; then
|
|
||||||
cp -v lib/libonnxruntime*.dylib lib/libonnxruntime.dylib
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Windows
|
|
||||||
- name: Fetch ONNX Runtime (Windows)
|
|
||||||
if: runner.os == 'Windows'
|
|
||||||
shell: pwsh
|
|
||||||
env:
|
|
||||||
ORT_VER: 1.22.0
|
|
||||||
run: |
|
|
||||||
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
|
||||||
Invoke-WebRequest $url -OutFile ort.zip
|
|
||||||
Expand-Archive ort.zip -DestinationPath ort
|
|
||||||
$dll = Get-ChildItem -Recurse -Path ort -Filter onnxruntime.dll | Select-Object -First 1
|
|
||||||
Copy-Item $dll.FullName lib\onnxruntime.dll
|
|
||||||
|
|
||||||
@@ -24,9 +24,18 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install Nix
|
||||||
|
uses: cachix/install-nix-action@v27
|
||||||
|
with:
|
||||||
|
extra_nix_config: |
|
||||||
|
experimental-features = nix-command flakes
|
||||||
|
|
||||||
|
- name: Verify ort-version matches nixpkgs onnxruntime
|
||||||
|
run: nix flake check --system x86_64-linux -L
|
||||||
|
|
||||||
- name: Install dist
|
- name: Install dist
|
||||||
shell: bash
|
shell: bash
|
||||||
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.30.0/cargo-dist-installer.sh | sh"
|
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.30.3/cargo-dist-installer.sh | sh"
|
||||||
|
|
||||||
- name: Cache dist
|
- name: Cache dist
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
@@ -67,6 +76,10 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Load ONNX Runtime version
|
||||||
|
shell: bash
|
||||||
|
run: echo "ORT_VER=$(tr -d '[:space:]' < ort-version)" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
- name: Install Rust non-interactively if not already installed
|
- name: Install Rust non-interactively if not already installed
|
||||||
if: ${{ matrix.container }}
|
if: ${{ matrix.container }}
|
||||||
run: |
|
run: |
|
||||||
@@ -107,8 +120,6 @@ jobs:
|
|||||||
|
|
||||||
- name: Fetch ONNX Runtime (Linux)
|
- name: Fetch ONNX Runtime (Linux)
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
env:
|
|
||||||
ORT_VER: 1.22.0
|
|
||||||
run: |
|
run: |
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
ARCH="$(uname -m)"
|
ARCH="$(uname -m)"
|
||||||
@@ -125,8 +136,6 @@ jobs:
|
|||||||
|
|
||||||
- name: Fetch ONNX Runtime (macOS)
|
- name: Fetch ONNX Runtime (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
env:
|
|
||||||
ORT_VER: 1.22.0
|
|
||||||
run: |
|
run: |
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
||||||
@@ -137,8 +146,6 @@ jobs:
|
|||||||
- name: Fetch ONNX Runtime (Windows)
|
- name: Fetch ONNX Runtime (Windows)
|
||||||
if: runner.os == 'Windows'
|
if: runner.os == 'Windows'
|
||||||
shell: pwsh
|
shell: pwsh
|
||||||
env:
|
|
||||||
ORT_VER: 1.22.0
|
|
||||||
run: |
|
run: |
|
||||||
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
||||||
Invoke-WebRequest $url -OutFile ort.zip
|
Invoke-WebRequest $url -OutFile ort.zip
|
||||||
|
|||||||
@@ -25,3 +25,7 @@ devenv.local.nix
|
|||||||
# html-router/assets/style.css
|
# html-router/assets/style.css
|
||||||
html-router/node_modules
|
html-router/node_modules
|
||||||
.fastembed_cache/
|
.fastembed_cache/
|
||||||
|
|
||||||
|
# insta: pending (unreviewed) snapshots; accepted *.snap files are committed
|
||||||
|
*.snap.new
|
||||||
|
.insta.bak
|
||||||
|
|||||||
@@ -1,6 +1,15 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
|
## 1.0.3 (2026-06-12)
|
||||||
|
- Search: filter results by type — knowledge entities, ingested content, or both
|
||||||
|
- Admin: choose the local FastEmbed model from the admin UI; changes save immediately and apply after restart (re-embeds when the vector dimension changes)
|
||||||
|
- Performance: pooled FastEmbed workers and batched embedding generation for faster ingestion and search
|
||||||
|
- Performance: lower search and chat latency from backend allocation and retrieval optimizations
|
||||||
|
- Fix: modal dialogs (scratchpad editor, admin prompts, entity creation) open and close more reliably
|
||||||
|
- Fix: improved knowledge-entity relationship suggestions when creating entities manually
|
||||||
|
- Fix: API key revocation now correctly clears the stored key
|
||||||
|
|
||||||
## 1.0.2 (2026-02-15)
|
## 1.0.2 (2026-02-15)
|
||||||
- Fix: edge case where navigation back to a chat page could trigger a new response generation
|
- Fix: edge case where navigation back to a chat page could trigger a new response generation
|
||||||
- Fix: chat references now validate and render more reliably
|
- Fix: chat references now validate and render more reliably
|
||||||
|
|||||||
Generated
+48
-18
@@ -247,7 +247,9 @@ dependencies = [
|
|||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tower",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1466,6 +1468,17 @@ dependencies = [
|
|||||||
"windows-sys 0.59.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "console"
|
||||||
|
version = "0.16.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d64e8af5551369d19cf50138de61f1c42074ab970f74e99be916646777f8fc87"
|
||||||
|
dependencies = [
|
||||||
|
"encode_unicode",
|
||||||
|
"libc",
|
||||||
|
"windows-sys 0.61.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "const-random"
|
name = "const-random"
|
||||||
version = "0.1.18"
|
version = "0.1.18"
|
||||||
@@ -2965,6 +2978,7 @@ dependencies = [
|
|||||||
"common",
|
"common",
|
||||||
"futures",
|
"futures",
|
||||||
"include_dir",
|
"include_dir",
|
||||||
|
"insta",
|
||||||
"json-stream-parser",
|
"json-stream-parser",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
"minijinja-autoreload",
|
"minijinja-autoreload",
|
||||||
@@ -2978,6 +2992,7 @@ dependencies = [
|
|||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tower-serve-static",
|
"tower-serve-static",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -3339,7 +3354,7 @@ version = "0.17.11"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
|
checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"console",
|
"console 0.15.11",
|
||||||
"number_prefix",
|
"number_prefix",
|
||||||
"portable-atomic",
|
"portable-atomic",
|
||||||
"unicode-width 0.2.2",
|
"unicode-width 0.2.2",
|
||||||
@@ -3409,6 +3424,19 @@ dependencies = [
|
|||||||
"generic-array 0.14.7",
|
"generic-array 0.14.7",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "insta"
|
||||||
|
version = "1.47.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b4a6248eb93a4401ed2f37dfe8ea592d3cf05b7cf4f8efa867b6895af7e094e"
|
||||||
|
dependencies = [
|
||||||
|
"console 0.16.3",
|
||||||
|
"once_cell",
|
||||||
|
"regex",
|
||||||
|
"similar",
|
||||||
|
"tempfile",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "instant"
|
name = "instant"
|
||||||
version = "0.1.13"
|
version = "0.1.13"
|
||||||
@@ -3533,6 +3561,7 @@ name = "json-stream-parser"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"thiserror 1.0.69",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3803,12 +3832,13 @@ checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "main"
|
name = "main"
|
||||||
version = "1.0.2"
|
version = "1.0.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"api-router",
|
"api-router",
|
||||||
"async-openai",
|
"async-openai",
|
||||||
"axum",
|
"axum",
|
||||||
|
"chrono",
|
||||||
"common",
|
"common",
|
||||||
"futures",
|
"futures",
|
||||||
"html-router",
|
"html-router",
|
||||||
@@ -5070,9 +5100,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quinn-proto"
|
name = "quinn-proto"
|
||||||
version = "0.11.13"
|
version = "0.11.14"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
|
checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"getrandom 0.3.4",
|
"getrandom 0.3.4",
|
||||||
@@ -5421,15 +5451,10 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"async-openai",
|
"async-openai",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum",
|
|
||||||
"clap",
|
|
||||||
"common",
|
"common",
|
||||||
"fastembed",
|
"fastembed",
|
||||||
"futures",
|
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"surrealdb",
|
|
||||||
"thiserror 1.0.69",
|
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"uuid",
|
"uuid",
|
||||||
@@ -5762,9 +5787,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-webpki"
|
name = "rustls-webpki"
|
||||||
version = "0.103.9"
|
version = "0.103.13"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53"
|
checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ring",
|
"ring",
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
@@ -6164,6 +6189,12 @@ version = "0.1.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
|
checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "similar"
|
||||||
|
version = "2.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "simple_asn1"
|
name = "simple_asn1"
|
||||||
version = "0.6.4"
|
version = "0.6.4"
|
||||||
@@ -6305,9 +6336,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "state-machines"
|
name = "state-machines"
|
||||||
version = "0.2.0"
|
version = "0.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "806ba0bf43ae158b229036d8a84601649a58d9761e718b5e0e07c2953803f4c1"
|
checksum = "e6a3c439e93b084079d81f1ccec41c64edf5a7348484db5228344372e634b92f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"state-machines-core",
|
"state-machines-core",
|
||||||
"state-machines-macro",
|
"state-machines-macro",
|
||||||
@@ -6315,19 +6346,18 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "state-machines-core"
|
name = "state-machines-core"
|
||||||
version = "0.2.0"
|
version = "0.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "949cc50e84bed6234117f28a0ba2980dc35e9c17984ffe4e0a3364fba3e77540"
|
checksum = "b53079921cf97a990334cd0296c1efa4f16631ed3f30a3010bb2f2d5c76cb37b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "state-machines-macro"
|
name = "state-machines-macro"
|
||||||
version = "0.2.0"
|
version = "0.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8322f5aa92d31b3c05faa1ec3231b82da479a20706836867d67ae89ce74927bd"
|
checksum = "c7158fc1607004ff2bfba01ef8ca59d0446c374cd52a25a8726bba0cbb0d5c74"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"state-machines-core",
|
|
||||||
"syn 2.0.115",
|
"syn 2.0.115",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
+8
-3
@@ -58,13 +58,18 @@ tokio-retry = "0.3.0"
|
|||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
object_store = { version = "0.11.2", features = ["aws"] }
|
object_store = { version = "0.11.2", features = ["aws"] }
|
||||||
bytes = "1.7.1"
|
bytes = "1.7.1"
|
||||||
state-machines = "0.2.0"
|
state-machines = "0.9"
|
||||||
|
pdf-extract = "0.9"
|
||||||
|
lopdf = "0.32"
|
||||||
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
|
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
|
||||||
|
|
||||||
[profile.dist]
|
[profile.dist]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
lto = "thin"
|
lto = "thin"
|
||||||
|
|
||||||
|
[workspace.lints.rust]
|
||||||
|
unexpected_cfgs = { level = "warn", check-cfg = ["cfg(feature, values(\"inspect\"))"] }
|
||||||
|
|
||||||
[workspace.lints.clippy]
|
[workspace.lints.clippy]
|
||||||
# Performance-focused lints
|
# Performance-focused lints
|
||||||
perf = { level = "warn", priority = -1 }
|
perf = { level = "warn", priority = -1 }
|
||||||
@@ -106,11 +111,11 @@ missing_errors_doc = "allow"
|
|||||||
missing_panics_doc = "warn"
|
missing_panics_doc = "warn"
|
||||||
module_name_repetitions = "warn"
|
module_name_repetitions = "warn"
|
||||||
wildcard_dependencies = "warn"
|
wildcard_dependencies = "warn"
|
||||||
missing_docs_in_private_items = "warn"
|
missing_docs_in_private_items = "allow"
|
||||||
|
|
||||||
# Allow noisy lints that don't add value for this project
|
# Allow noisy lints that don't add value for this project
|
||||||
needless_raw_string_hashes = "allow"
|
needless_raw_string_hashes = "allow"
|
||||||
multiple_bound_locations = "allow"
|
multiple_bound_locations = "allow"
|
||||||
cargo_common_metadata = "allow"
|
cargo_common_metadata = "allow"
|
||||||
multiple-crate-versions = "allow"
|
multiple-crate-versions = "allow"
|
||||||
module_name_repetition = "allow"
|
|
||||||
|
|||||||
+6
-4
@@ -1,5 +1,5 @@
|
|||||||
# === Builder ===
|
# === Builder ===
|
||||||
FROM rust:1.89-bookworm AS builder
|
FROM rust:1.91.1-bookworm AS builder
|
||||||
WORKDIR /usr/src/minne
|
WORKDIR /usr/src/minne
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
||||||
@@ -30,9 +30,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
libgomp1 libstdc++6 curl \
|
libgomp1 libstdc++6 curl \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# ONNX Runtime (CPU). Keep in sync with ort crate requirements.
|
# ONNX Runtime (CPU). Version is read from ort-version (override with --build-arg ORT_VERSION=...).
|
||||||
ARG ORT_VERSION=1.23.2
|
COPY ort-version /tmp/ort-version
|
||||||
RUN mkdir -p /opt/onnxruntime && \
|
ARG ORT_VERSION
|
||||||
|
RUN ORT_VERSION="${ORT_VERSION:-$(tr -d '[:space:]' < /tmp/ort-version)}" && \
|
||||||
|
mkdir -p /opt/onnxruntime && \
|
||||||
curl -fsSL -o /tmp/ort.tgz \
|
curl -fsSL -o /tmp/ort.tgz \
|
||||||
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
||||||
tar -xzf /tmp/ort.tgz -C /opt/onnxruntime --strip-components=1 && rm /tmp/ort.tgz
|
tar -xzf /tmp/ort.tgz -C /opt/onnxruntime --strip-components=1 && rm /tmp/ort.tgz
|
||||||
|
|||||||
@@ -20,3 +20,8 @@ futures = { workspace = true }
|
|||||||
axum_typed_multipart = { workspace = true}
|
axum_typed_multipart = { workspace = true}
|
||||||
|
|
||||||
common = { path = "../common" }
|
common = { path = "../common" }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
common = { path = "../common", features = ["test-utils"] }
|
||||||
|
tower = "0.5"
|
||||||
|
uuid = { workspace = true }
|
||||||
|
|||||||
@@ -11,31 +11,3 @@ pub struct ApiState {
|
|||||||
pub config: AppConfig,
|
pub config: AppConfig,
|
||||||
pub storage: StorageManager,
|
pub storage: StorageManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ApiState {
|
|
||||||
pub async fn new(
|
|
||||||
config: &AppConfig,
|
|
||||||
storage: StorageManager,
|
|
||||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
|
||||||
let surreal_db_client = Arc::new(
|
|
||||||
SurrealDbClient::new(
|
|
||||||
&config.surrealdb_address,
|
|
||||||
&config.surrealdb_username,
|
|
||||||
&config.surrealdb_password,
|
|
||||||
&config.surrealdb_namespace,
|
|
||||||
&config.surrealdb_database,
|
|
||||||
)
|
|
||||||
.await?,
|
|
||||||
);
|
|
||||||
|
|
||||||
surreal_db_client.apply_migrations().await?;
|
|
||||||
|
|
||||||
let app_state = Self {
|
|
||||||
db: surreal_db_client.clone(),
|
|
||||||
config: config.clone(),
|
|
||||||
storage,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(app_state)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
+47
-36
@@ -7,39 +7,38 @@ use common::error::AppError;
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Error, Debug, Serialize, Clone)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ApiError {
|
pub enum ApiErr {
|
||||||
#[error("Internal server error")]
|
#[error("internal server error")]
|
||||||
InternalError(String),
|
InternalError(String),
|
||||||
|
|
||||||
#[error("Validation error: {0}")]
|
#[error("validation error: {0}")]
|
||||||
ValidationError(String),
|
ValidationError(String),
|
||||||
|
|
||||||
#[error("Not found: {0}")]
|
#[error("not found: {0}")]
|
||||||
NotFound(String),
|
NotFound(String),
|
||||||
|
|
||||||
#[error("Unauthorized: {0}")]
|
#[error("unauthorized: {0}")]
|
||||||
Unauthorized(String),
|
Unauthorized(String),
|
||||||
|
|
||||||
#[error("Payload too large: {0}")]
|
#[error("payload too large: {0}")]
|
||||||
PayloadTooLarge(String),
|
PayloadTooLarge(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<AppError> for ApiError {
|
impl From<AppError> for ApiErr {
|
||||||
fn from(err: AppError) -> Self {
|
fn from(err: AppError) -> Self {
|
||||||
match err {
|
match err {
|
||||||
AppError::Database(_) | AppError::OpenAI(_) => {
|
|
||||||
tracing::error!("Internal error: {:?}", err);
|
|
||||||
Self::InternalError("Internal server error".to_string())
|
|
||||||
}
|
|
||||||
AppError::NotFound(msg) => Self::NotFound(msg),
|
AppError::NotFound(msg) => Self::NotFound(msg),
|
||||||
AppError::Validation(msg) => Self::ValidationError(msg),
|
AppError::Validation(msg) => Self::ValidationError(msg),
|
||||||
AppError::Auth(msg) => Self::Unauthorized(msg),
|
AppError::Auth(msg) => Self::Unauthorized(msg),
|
||||||
_ => Self::InternalError("Internal server error".to_string()),
|
other => {
|
||||||
|
tracing::error!("internal API error: {other:?}");
|
||||||
|
Self::InternalError("Internal server error".to_string())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl IntoResponse for ApiError {
|
impl IntoResponse for ApiErr {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
let (status, error_response) = match self {
|
let (status, error_response) = match self {
|
||||||
Self::InternalError(message) => (
|
Self::InternalError(message) => (
|
||||||
@@ -94,6 +93,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use common::error::AppError;
|
use common::error::AppError;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
use std::io;
|
||||||
|
|
||||||
// Helper to check status code
|
// Helper to check status code
|
||||||
fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) {
|
fn assert_status_code<T: IntoResponse + Debug>(response: T, expected_status: StatusCode) {
|
||||||
@@ -105,46 +105,57 @@ mod tests {
|
|||||||
fn test_app_error_to_api_error_conversion() {
|
fn test_app_error_to_api_error_conversion() {
|
||||||
// Test NotFound error conversion
|
// Test NotFound error conversion
|
||||||
let not_found = AppError::NotFound("resource not found".to_string());
|
let not_found = AppError::NotFound("resource not found".to_string());
|
||||||
let api_error = ApiError::from(not_found);
|
let api_error = ApiErr::from(not_found);
|
||||||
assert!(matches!(api_error, ApiError::NotFound(msg) if msg == "resource not found"));
|
assert!(matches!(api_error, ApiErr::NotFound(msg) if msg == "resource not found"));
|
||||||
|
|
||||||
// Test Validation error conversion
|
// Test Validation error conversion
|
||||||
let validation = AppError::Validation("invalid input".to_string());
|
let validation = AppError::Validation("invalid input".to_string());
|
||||||
let api_error = ApiError::from(validation);
|
let api_error = ApiErr::from(validation);
|
||||||
assert!(matches!(api_error, ApiError::ValidationError(msg) if msg == "invalid input"));
|
assert!(matches!(api_error, ApiErr::ValidationError(msg) if msg == "invalid input"));
|
||||||
|
|
||||||
// Test Auth error conversion
|
// Test Auth error conversion
|
||||||
let auth = AppError::Auth("unauthorized".to_string());
|
let auth = AppError::Auth("unauthorized".to_string());
|
||||||
let api_error = ApiError::from(auth);
|
let api_error = ApiErr::from(auth);
|
||||||
assert!(matches!(api_error, ApiError::Unauthorized(msg) if msg == "unauthorized"));
|
assert!(matches!(api_error, ApiErr::Unauthorized(msg) if msg == "unauthorized"));
|
||||||
|
|
||||||
// Test for internal errors - create a mock error that doesn't require surrealdb
|
// Test for internal errors - create a mock error that doesn't require surrealdb
|
||||||
let internal_error =
|
let internal_error = AppError::Io(io::Error::other("io error"));
|
||||||
AppError::Io(std::io::Error::new(std::io::ErrorKind::Other, "io error"));
|
let api_error = ApiErr::from(internal_error);
|
||||||
let api_error = ApiError::from(internal_error);
|
assert!(matches!(
|
||||||
assert!(matches!(api_error, ApiError::InternalError(_)));
|
api_error,
|
||||||
|
ApiErr::InternalError(msg) if msg == "Internal server error"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_app_error_internal_error_is_sanitized() {
|
||||||
|
let api_error = ApiErr::from(AppError::internal("db password incorrect"));
|
||||||
|
assert!(matches!(
|
||||||
|
api_error,
|
||||||
|
ApiErr::InternalError(msg) if msg == "Internal server error"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_api_error_response_status_codes() {
|
fn test_api_error_response_status_codes() {
|
||||||
// Test internal error status
|
// Test internal error status
|
||||||
let error = ApiError::InternalError("server error".to_string());
|
let error = ApiErr::InternalError("server error".to_string());
|
||||||
assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR);
|
assert_status_code(error, StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
// Test not found status
|
// Test not found status
|
||||||
let error = ApiError::NotFound("not found".to_string());
|
let error = ApiErr::NotFound("not found".to_string());
|
||||||
assert_status_code(error, StatusCode::NOT_FOUND);
|
assert_status_code(error, StatusCode::NOT_FOUND);
|
||||||
|
|
||||||
// Test validation error status
|
// Test validation error status
|
||||||
let error = ApiError::ValidationError("invalid input".to_string());
|
let error = ApiErr::ValidationError("invalid input".to_string());
|
||||||
assert_status_code(error, StatusCode::BAD_REQUEST);
|
assert_status_code(error, StatusCode::BAD_REQUEST);
|
||||||
|
|
||||||
// Test unauthorized status
|
// Test unauthorized status
|
||||||
let error = ApiError::Unauthorized("not allowed".to_string());
|
let error = ApiErr::Unauthorized("not allowed".to_string());
|
||||||
assert_status_code(error, StatusCode::UNAUTHORIZED);
|
assert_status_code(error, StatusCode::UNAUTHORIZED);
|
||||||
|
|
||||||
// Test payload too large status
|
// Test payload too large status
|
||||||
let error = ApiError::PayloadTooLarge("too big".to_string());
|
let error = ApiErr::PayloadTooLarge("too big".to_string());
|
||||||
assert_status_code(error, StatusCode::PAYLOAD_TOO_LARGE);
|
assert_status_code(error, StatusCode::PAYLOAD_TOO_LARGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,15 +164,15 @@ mod tests {
|
|||||||
fn test_error_messages() {
|
fn test_error_messages() {
|
||||||
// For validation errors
|
// For validation errors
|
||||||
let message = "invalid data format";
|
let message = "invalid data format";
|
||||||
let error = ApiError::ValidationError(message.to_string());
|
let error = ApiErr::ValidationError(message.to_string());
|
||||||
|
|
||||||
// Check that the error itself contains the message
|
// Check that the error itself contains the message
|
||||||
assert_eq!(error.to_string(), format!("Validation error: {}", message));
|
assert_eq!(error.to_string(), format!("validation error: {message}"));
|
||||||
|
|
||||||
// For not found errors
|
// For not found errors
|
||||||
let message = "user not found";
|
let message = "user not found";
|
||||||
let error = ApiError::NotFound(message.to_string());
|
let error = ApiErr::NotFound(message.to_string());
|
||||||
assert_eq!(error.to_string(), format!("Not found: {}", message));
|
assert_eq!(error.to_string(), format!("not found: {message}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Alternative approach for internal error test
|
// Alternative approach for internal error test
|
||||||
@@ -170,11 +181,11 @@ mod tests {
|
|||||||
// Create a sensitive error message
|
// Create a sensitive error message
|
||||||
let sensitive_info = "db password incorrect";
|
let sensitive_info = "db password incorrect";
|
||||||
|
|
||||||
// Create ApiError with sensitive info
|
// Create ApiErr with sensitive info
|
||||||
let api_error = ApiError::InternalError(sensitive_info.to_string());
|
let api_error = ApiErr::InternalError(sensitive_info.to_string());
|
||||||
|
|
||||||
// Check the error message is correctly set
|
// Check the error message is correctly set
|
||||||
assert_eq!(api_error.to_string(), "Internal server error");
|
assert_eq!(api_error.to_string(), "internal server error");
|
||||||
|
|
||||||
// Also verify correct status code
|
// Also verify correct status code
|
||||||
assert_status_code(api_error, StatusCode::INTERNAL_SERVER_ERROR);
|
assert_status_code(api_error, StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use axum::{
|
|||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use middleware_api_auth::api_auth;
|
use middleware_api_auth::api_auth;
|
||||||
use routes::{categories::get_categories, ingest::ingest_data, liveness::live, readiness::ready};
|
use routes::{categories::list, ingest::handle, liveness::live, readiness::ready};
|
||||||
|
|
||||||
pub mod api_state;
|
pub mod api_state;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
@@ -28,11 +28,11 @@ where
|
|||||||
let protected = Router::new()
|
let protected = Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/ingest",
|
"/ingest",
|
||||||
post(ingest_data).layer(DefaultBodyLimit::max(
|
post(handle).layer(DefaultBodyLimit::max(
|
||||||
app_state.config.ingest_max_body_bytes,
|
app_state.config.ingest_max_body_bytes,
|
||||||
)),
|
)),
|
||||||
)
|
)
|
||||||
.route("/categories", get(get_categories))
|
.route("/categories", get(list))
|
||||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth));
|
.route_layer(from_fn_with_state(app_state.clone(), api_auth));
|
||||||
|
|
||||||
public.merge(protected)
|
public.merge(protected)
|
||||||
|
|||||||
@@ -6,26 +6,26 @@ use axum::{
|
|||||||
|
|
||||||
use common::storage::types::user::User;
|
use common::storage::types::user::User;
|
||||||
|
|
||||||
use crate::{api_state::ApiState, error::ApiError};
|
use crate::{api_state::ApiState, error::ApiErr};
|
||||||
|
|
||||||
pub async fn api_auth(
|
pub async fn api_auth(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
mut request: Request,
|
mut request: Request,
|
||||||
next: Next,
|
next: Next,
|
||||||
) -> Result<Response, ApiError> {
|
) -> Result<Response, ApiErr> {
|
||||||
let api_key = extract_api_key(&request)
|
let api_key = extract_api_key(&request)
|
||||||
.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?;
|
.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
|
||||||
|
|
||||||
let user = User::find_by_api_key(&api_key, &state.db).await?;
|
let user = User::find_by_api_key(api_key, &state.db).await?;
|
||||||
let user =
|
let user =
|
||||||
user.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?;
|
user.ok_or_else(|| ApiErr::Unauthorized("You have to be authenticated".to_string()))?;
|
||||||
|
|
||||||
request.extensions_mut().insert(user);
|
request.extensions_mut().insert(user);
|
||||||
|
|
||||||
Ok(next.run(request).await)
|
Ok(next.run(request).await)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_api_key(request: &Request) -> Option<String> {
|
fn extract_api_key(request: &Request) -> Option<&str> {
|
||||||
request
|
request
|
||||||
.headers()
|
.headers()
|
||||||
.get("X-API-Key")
|
.get("X-API-Key")
|
||||||
@@ -35,7 +35,67 @@ fn extract_api_key(request: &Request) -> Option<String> {
|
|||||||
.headers()
|
.headers()
|
||||||
.get("Authorization")
|
.get("Authorization")
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim))
|
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||||
|
.map(str::trim)
|
||||||
})
|
})
|
||||||
.map(String::from)
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
|
mod tests {
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::http::{HeaderValue, Request};
|
||||||
|
|
||||||
|
use super::extract_api_key;
|
||||||
|
|
||||||
|
fn request_with_headers(headers: &[(&str, &str)]) -> Request<Body> {
|
||||||
|
let mut builder = Request::builder().method("GET").uri("/");
|
||||||
|
for (name, value) in headers {
|
||||||
|
builder = builder.header(*name, *value);
|
||||||
|
}
|
||||||
|
builder.body(Body::empty()).expect("test request")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_from_x_api_key_header() {
|
||||||
|
let request = request_with_headers(&[("X-API-Key", "sk_test_key")]);
|
||||||
|
assert_eq!(extract_api_key(&request), Some("sk_test_key"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_from_bearer_authorization() {
|
||||||
|
let request = request_with_headers(&[("Authorization", "Bearer sk_bearer_key")]);
|
||||||
|
assert_eq!(extract_api_key(&request), Some("sk_bearer_key"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_prefers_x_api_key_over_authorization() {
|
||||||
|
let request = request_with_headers(&[
|
||||||
|
("X-API-Key", "sk_header"),
|
||||||
|
("Authorization", "Bearer sk_bearer"),
|
||||||
|
]);
|
||||||
|
assert_eq!(extract_api_key(&request), Some("sk_header"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_returns_none_when_missing() {
|
||||||
|
let request = request_with_headers(&[]);
|
||||||
|
assert_eq!(extract_api_key(&request), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_rejects_non_bearer_authorization() {
|
||||||
|
let request = request_with_headers(&[("Authorization", "Basic abc")]);
|
||||||
|
assert_eq!(extract_api_key(&request), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_api_key_rejects_invalid_header_values() {
|
||||||
|
let mut request = request_with_headers(&[]);
|
||||||
|
request.headers_mut().insert(
|
||||||
|
"X-API-Key",
|
||||||
|
HeaderValue::from_bytes(&[0xFF]).expect("invalid header"),
|
||||||
|
);
|
||||||
|
assert_eq!(extract_api_key(&request), None);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
use axum::{extract::State, response::IntoResponse, Extension, Json};
|
use axum::{extract::State, response::IntoResponse, Extension, Json};
|
||||||
use common::storage::types::user::User;
|
use common::storage::types::user::User;
|
||||||
|
|
||||||
use crate::{api_state::ApiState, error::ApiError};
|
use crate::{api_state::ApiState, error::ApiErr};
|
||||||
|
|
||||||
pub async fn get_categories(
|
pub async fn list(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
Extension(user): Extension<User>,
|
Extension(user): Extension<User>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiErr> {
|
||||||
let categories = User::get_user_categories(&user.id, &state.db).await?;
|
let categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||||
|
|
||||||
Ok(Json(categories))
|
Ok(Json(categories))
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ use serde_json::json;
|
|||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{api_state::ApiState, error::ApiError};
|
use crate::{api_state::ApiState, error::ApiErr};
|
||||||
|
|
||||||
#[derive(Debug, TryFromMultipart)]
|
#[derive(Debug, TryFromMultipart)]
|
||||||
pub struct IngestParams {
|
pub struct Params {
|
||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
pub context: String,
|
pub context: String,
|
||||||
pub category: String,
|
pub category: String,
|
||||||
@@ -25,41 +25,37 @@ pub struct IngestParams {
|
|||||||
pub files: Vec<FieldData<NamedTempFile>>,
|
pub files: Vec<FieldData<NamedTempFile>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn ingest_data(
|
pub async fn handle(
|
||||||
State(state): State<ApiState>,
|
State(state): State<ApiState>,
|
||||||
Extension(user): Extension<User>,
|
Extension(user): Extension<User>,
|
||||||
TypedMultipart(input): TypedMultipart<IngestParams>,
|
TypedMultipart(input): TypedMultipart<Params>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiErr> {
|
||||||
let user_id = user.id;
|
let user_id = user.id;
|
||||||
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
|
|
||||||
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
|
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
|
||||||
let context_bytes = input.context.len();
|
|
||||||
let category_bytes = input.category.len();
|
|
||||||
let file_count = input.files.len();
|
|
||||||
|
|
||||||
match validate_ingest_input(
|
match validate_ingest_input(
|
||||||
&state.config,
|
&state.config,
|
||||||
input.content.as_deref(),
|
input.content.as_deref(),
|
||||||
&input.context,
|
&input.context,
|
||||||
&input.category,
|
&input.category,
|
||||||
file_count,
|
input.files.len(),
|
||||||
) {
|
) {
|
||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
Err(IngestValidationError::PayloadTooLarge(message)) => {
|
Err(IngestValidationError::PayloadTooLarge(message)) => {
|
||||||
return Err(ApiError::PayloadTooLarge(message));
|
return Err(ApiErr::PayloadTooLarge(message));
|
||||||
}
|
}
|
||||||
Err(IngestValidationError::BadRequest(message)) => {
|
Err(IngestValidationError::BadRequest(message)) => {
|
||||||
return Err(ApiError::ValidationError(message));
|
return Err(ApiErr::ValidationError(message));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
user_id = %user_id,
|
user_id = %user_id,
|
||||||
has_content,
|
has_content,
|
||||||
content_bytes,
|
content_len = input.content.as_ref().map_or(0, String::len),
|
||||||
context_bytes,
|
context_len = input.context.len(),
|
||||||
category_bytes,
|
category_len = input.category.len(),
|
||||||
file_count,
|
file_count = input.files.len(),
|
||||||
"Received ingest request"
|
"Received ingest request"
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -74,15 +70,10 @@ pub async fn ingest_data(
|
|||||||
input.context,
|
input.context,
|
||||||
input.category,
|
input.category,
|
||||||
file_infos,
|
file_infos,
|
||||||
&user_id,
|
user_id.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let futures: Vec<_> = payloads
|
IngestionTask::create_all_and_add_to_db(payloads, &user_id, &state.db).await?;
|
||||||
.into_iter()
|
|
||||||
.map(|object| IngestionTask::create_and_add_to_db(object, user_id.clone(), &state.db))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
try_join_all(futures).await?;
|
|
||||||
|
|
||||||
Ok((StatusCode::OK, Json(json!({ "status": "success" }))))
|
Ok((StatusCode::OK, Json(json!({ "status": "success" }))))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use tracing::error;
|
||||||
|
|
||||||
use crate::api_state::ApiState;
|
use crate::api_state::ApiState;
|
||||||
|
|
||||||
@@ -13,13 +14,15 @@ pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
|
|||||||
"checks": { "db": "ok" }
|
"checks": { "db": "ok" }
|
||||||
})),
|
})),
|
||||||
),
|
),
|
||||||
Err(e) => (
|
Err(e) => {
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
error!("readiness check failed: {e:?}");
|
||||||
Json(json!({
|
(
|
||||||
"status": "error",
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
"checks": { "db": "fail" },
|
Json(json!({
|
||||||
"reason": e.to_string()
|
"status": "error",
|
||||||
})),
|
"checks": { "db": "fail" }
|
||||||
),
|
})),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
#![allow(clippy::expect_used)]
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use api_router::{api_routes_v1, api_state::ApiState};
|
||||||
|
use axum::{
|
||||||
|
body::{to_bytes, Body},
|
||||||
|
http::{Request, StatusCode},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use common::{
|
||||||
|
storage::{db::SurrealDbClient, store::StorageManager, types::user::User},
|
||||||
|
utils::config::{AppConfig, StorageKind},
|
||||||
|
};
|
||||||
|
use tower::ServiceExt;
|
||||||
|
|
||||||
|
async fn build_test_app() -> (Router, Arc<SurrealDbClient>) {
|
||||||
|
let namespace = "api_router_test";
|
||||||
|
let database = uuid::Uuid::new_v4().to_string();
|
||||||
|
let db = Arc::new(
|
||||||
|
SurrealDbClient::memory(namespace, &database)
|
||||||
|
.await
|
||||||
|
.expect("in-memory db"),
|
||||||
|
);
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.expect("migrations should apply");
|
||||||
|
|
||||||
|
let config = AppConfig {
|
||||||
|
storage: StorageKind::Memory,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let storage = StorageManager::new(&config).await.expect("storage manager");
|
||||||
|
|
||||||
|
let state = ApiState {
|
||||||
|
db: Arc::clone(&db),
|
||||||
|
config,
|
||||||
|
storage,
|
||||||
|
};
|
||||||
|
|
||||||
|
let router = api_routes_v1(&state).with_state(state);
|
||||||
|
|
||||||
|
(router, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn response_body(response: axum::response::Response) -> String {
|
||||||
|
let body = to_bytes(response.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.expect("response body");
|
||||||
|
String::from_utf8(body.to_vec()).expect("utf-8 body")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn live_probe_is_public() {
|
||||||
|
let (app, _db) = build_test_app().await;
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/live")
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("live request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("live response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
assert!(response_body(response).await.contains("\"status\":\"ok\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn ready_probe_is_public_and_reports_db_ok() {
|
||||||
|
let (app, _db) = build_test_app().await;
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/ready")
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("ready request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("ready response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
let body = response_body(response).await;
|
||||||
|
assert!(body.contains("\"checks\":{\"db\":\"ok\"}") || body.contains("\"db\":\"ok\""));
|
||||||
|
assert!(!body.contains("reason"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn protected_route_requires_api_key() {
|
||||||
|
let (app, _db) = build_test_app().await;
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/categories")
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("categories request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("categories response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn protected_route_rejects_invalid_api_key() {
|
||||||
|
let (app, _db) = build_test_app().await;
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/categories")
|
||||||
|
.header("X-API-Key", "sk_invalid")
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("categories request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("categories response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn authenticated_user_can_list_categories() {
|
||||||
|
let (app, db) = build_test_app().await;
|
||||||
|
|
||||||
|
let user = User::create_new(
|
||||||
|
"api_router_test@example.com".to_string(),
|
||||||
|
"test_password".to_string(),
|
||||||
|
&db,
|
||||||
|
"UTC".to_string(),
|
||||||
|
"system".to_string(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("test user");
|
||||||
|
|
||||||
|
let api_key = User::set_api_key(&user.id, &db).await.expect("api key");
|
||||||
|
|
||||||
|
let response = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(
|
||||||
|
Request::builder()
|
||||||
|
.uri("/categories")
|
||||||
|
.header("X-API-Key", api_key)
|
||||||
|
.body(Body::empty())
|
||||||
|
.expect("categories request"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("categories response");
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
}
|
||||||
+1
-1
@@ -19,6 +19,6 @@ CREATE system_settings:current CONTENT {
|
|||||||
image_processing_prompt: "Analyze this image and respond based on its primary content:\n - If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.\n - If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.\n - For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a Text: heading.\n\n Respond directly with the analysis.",
|
image_processing_prompt: "Analyze this image and respond based on its primary content:\n - If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.\n - If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.\n - For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a Text: heading.\n\n Respond directly with the analysis.",
|
||||||
embedding_dimensions: 1536,
|
embedding_dimensions: 1536,
|
||||||
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
|
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
|
||||||
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity\"\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."
|
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity.\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."
|
||||||
};
|
};
|
||||||
END;
|
END;
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
-- Per-user deduplication: same SHA256 may exist for different users.
|
||||||
|
REMOVE INDEX IF EXISTS file_sha256_idx ON file;
|
||||||
|
DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
-- Harden knowledge entity embeddings and graph storage invariants.
|
||||||
|
|
||||||
|
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
|
||||||
|
|
||||||
|
-- Backfill denormalized source_id from the linked entity.
|
||||||
|
FOR $emb IN (SELECT * FROM knowledge_entity_embedding WHERE source_id = NONE OR source_id = '') {
|
||||||
|
LET $entity = (SELECT source_id FROM $emb.entity_id)[0];
|
||||||
|
IF $entity != NONE {
|
||||||
|
UPDATE $emb.id SET source_id = $entity.source_id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
-- Re-key embeddings so record id matches entity id (stable 1:1 identity).
|
||||||
|
FOR $emb IN (SELECT * FROM knowledge_entity_embedding) {
|
||||||
|
LET $entity_key = record::id($emb.entity_id);
|
||||||
|
LET $canonical = type::thing('knowledge_entity_embedding', $entity_key);
|
||||||
|
IF $emb.id != $canonical {
|
||||||
|
UPSERT $canonical CONTENT {
|
||||||
|
entity_id: $emb.entity_id,
|
||||||
|
embedding: $emb.embedding,
|
||||||
|
user_id: $emb.user_id,
|
||||||
|
source_id: $emb.source_id,
|
||||||
|
created_at: $emb.created_at,
|
||||||
|
updated_at: $emb.updated_at
|
||||||
|
};
|
||||||
|
DELETE $emb.id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REMOVE INDEX IF EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
-- Harden text chunk embeddings storage invariants.
|
||||||
|
|
||||||
|
-- Re-key embeddings so record id matches chunk id (stable 1:1 identity).
|
||||||
|
FOR $emb IN (SELECT * FROM text_chunk_embedding) {
|
||||||
|
LET $chunk_key = record::id($emb.chunk_id);
|
||||||
|
LET $canonical = type::thing('text_chunk_embedding', $chunk_key);
|
||||||
|
IF $emb.id != $canonical {
|
||||||
|
UPSERT $canonical CONTENT {
|
||||||
|
chunk_id: $emb.chunk_id,
|
||||||
|
embedding: $emb.embedding,
|
||||||
|
user_id: $emb.user_id,
|
||||||
|
source_id: $emb.source_id,
|
||||||
|
created_at: $emb.created_at,
|
||||||
|
updated_at: $emb.updated_at
|
||||||
|
};
|
||||||
|
DELETE $emb.id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REMOVE INDEX IF EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding;
|
||||||
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE;
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
-- Align persisted embedding settings when FastEmbed is the recorded backend but the model
|
||||||
|
-- name is still the OpenAI migration default (invalid for FastEmbed `from_str`).
|
||||||
|
|
||||||
|
UPDATE system_settings:current SET
|
||||||
|
embedding_model = 'Xenova/bge-small-en-v1.5',
|
||||||
|
embedding_dimensions = 384
|
||||||
|
WHERE embedding_backend = 'fastembed'
|
||||||
|
AND embedding_model = 'text-embedding-3-small';
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -45,9 +45,8 @@\n DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;\n\n-# Indexes based on usage (get_by_sha, potentially user lookups)\n-# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates\n-DEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;\n+# Indexes based on usage (get_by_sha scoped by user_id, user lookups)\n+DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;\n DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;\n\n # Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n","events":null}
|
||||||
+1
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -68,7 +68,7 @@\n\n # Defines the schema for the 'knowledge_entity' table.\n\n-DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n@@ -90,6 +90,7 @@\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n@@ -102,6 +103,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;\n DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;\n+DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;\n\n -- Custom fields\n DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;\n@@ -109,8 +111,9 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;\n\n # Defines the schema for the 'message' table.\n\n@@ -135,19 +138,17 @@\n # Defines the 'relates_to' edge table for KnowledgeRelationships.\n # Edges connect nodes, in this case knowledge_entity records.\n\n-# Define the edge table itself, enforcing connections between knowledge_entity records\n-# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary\n-DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;\n+\n+DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;\n+DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;\n\n-# Define the metadata field within the edge\n # RelationshipMetadata is a struct, store as object\n DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;\n+DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n+DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n\n-# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)\n-# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;\n-# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;\n-\n # Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n","events":null}
|
||||||
+1
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -237,7 +237,7 @@\n\n -- Indexes\n -- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;\n-DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;\n+DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;\n\n","events":null}
|
||||||
@@ -13,7 +13,6 @@ DEFINE FIELD IF NOT EXISTS file_name ON file TYPE string;
|
|||||||
DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;
|
DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;
|
||||||
DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;
|
DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;
|
||||||
|
|
||||||
# Indexes based on usage (get_by_sha, potentially user lookups)
|
# Indexes based on usage (get_by_sha scoped by user_id, user lookups)
|
||||||
# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates
|
DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;
|
||||||
DEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;
|
|
||||||
DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# Defines the schema for the 'knowledge_entity' table.
|
# Defines the schema for the 'knowledge_entity' table.
|
||||||
|
|
||||||
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;
|
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMAFULL;
|
||||||
|
|
||||||
# Standard fields
|
# Standard fields
|
||||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
|
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
|
||||||
@@ -22,5 +22,6 @@ DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
|
|||||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_source_idx ON knowledge_entity FIELDS user_id, source_id;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;
|
||||||
+3
-1
@@ -7,6 +7,7 @@ DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
|
|||||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
|
||||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
|
||||||
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity_embedding TYPE string;
|
||||||
|
|
||||||
-- Custom fields
|
-- Custom fields
|
||||||
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
|
||||||
@@ -14,5 +15,6 @@ DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<fl
|
|||||||
|
|
||||||
-- Indexes
|
-- Indexes
|
||||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
|
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id UNIQUE;
|
||||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_source_id_idx ON knowledge_entity_embedding FIELDS source_id;
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Defines the 'relates_to' edge table for KnowledgeRelationships.
|
||||||
|
# Edges connect nodes, in this case knowledge_entity records.
|
||||||
|
|
||||||
|
DEFINE TABLE IF NOT EXISTS relates_to SCHEMAFULL TYPE RELATION FROM knowledge_entity TO knowledge_entity;
|
||||||
|
|
||||||
|
DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
|
||||||
|
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
|
||||||
|
|
||||||
|
# RelationshipMetadata is a struct, store as object
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
|
||||||
|
DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
|
||||||
|
|
||||||
|
# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)
|
||||||
|
DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;
|
||||||
|
DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;
|
||||||
+1
-1
@@ -15,6 +15,6 @@ DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
|
|||||||
|
|
||||||
-- Indexes
|
-- Indexes
|
||||||
-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;
|
-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;
|
||||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id UNIQUE;
|
||||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
|
||||||
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
|
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# Defines the 'relates_to' edge table for KnowledgeRelationships.
|
|
||||||
# Edges connect nodes, in this case knowledge_entity records.
|
|
||||||
|
|
||||||
# Define the edge table itself, enforcing connections between knowledge_entity records
|
|
||||||
# SCHEMAFULL requires all fields to be defined, maybe start with SCHEMALESS if metadata might vary
|
|
||||||
DEFINE TABLE IF NOT EXISTS relates_to SCHEMALESS TYPE RELATION FROM knowledge_entity TO knowledge_entity;
|
|
||||||
|
|
||||||
# Define the metadata field within the edge
|
|
||||||
# RelationshipMetadata is a struct, store as object
|
|
||||||
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
|
|
||||||
|
|
||||||
# Optionally, define fields within the metadata object for stricter schema (requires SCHEMAFULL on table)
|
|
||||||
# DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
|
|
||||||
# DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
|
|
||||||
# DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
|
|
||||||
|
|
||||||
# Add indexes based on query patterns (delete_relationships_by_source_id, get_knowledge_relationships)
|
|
||||||
DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;
|
|
||||||
DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;
|
|
||||||
+110
-20
@@ -4,38 +4,128 @@ use tokio::task::JoinError;
|
|||||||
|
|
||||||
use crate::storage::types::file_info::FileError;
|
use crate::storage::types::file_info::FileError;
|
||||||
|
|
||||||
|
/// Errors from embedding provider operations.
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum EmbeddingError {
|
||||||
|
#[error("openai error: {0}")]
|
||||||
|
OpenAI(Box<OpenAIError>),
|
||||||
|
#[error("fastembed error: {0}")]
|
||||||
|
FastEmbed(String),
|
||||||
|
#[error("task join error: {0}")]
|
||||||
|
Join(#[from] JoinError),
|
||||||
|
#[error("fastembed model mutex poisoned: {0}")]
|
||||||
|
MutexPoisoned(String),
|
||||||
|
#[error("no embedding data received")]
|
||||||
|
NoData,
|
||||||
|
#[error("embedding configuration error: {0}")]
|
||||||
|
Config(String),
|
||||||
|
#[error("unknown fastembed model: {0}")]
|
||||||
|
UnknownModel(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OpenAIError> for EmbeddingError {
|
||||||
|
fn from(err: OpenAIError) -> Self {
|
||||||
|
Self::OpenAI(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingError {
|
||||||
|
pub(crate) fn fastembed(err: impl std::fmt::Display) -> Self {
|
||||||
|
Self::FastEmbed(err.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn mutex_poisoned(err: impl std::fmt::Display) -> Self {
|
||||||
|
Self::MutexPoisoned(err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Core internal errors
|
// Core internal errors
|
||||||
#[allow(clippy::module_name_repetitions)]
|
#[allow(clippy::module_name_repetitions)]
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum AppError {
|
pub enum AppError {
|
||||||
#[error("Database error: {0}")]
|
#[error("database error: {0}")]
|
||||||
Database(#[from] surrealdb::Error),
|
Database(Box<surrealdb::Error>),
|
||||||
#[error("OpenAI error: {0}")]
|
#[error("openai error: {0}")]
|
||||||
OpenAI(#[from] OpenAIError),
|
OpenAI(Box<OpenAIError>),
|
||||||
#[error("File error: {0}")]
|
#[error("embedding error: {0}")]
|
||||||
|
Embedding(#[from] EmbeddingError),
|
||||||
|
#[error("file error: {0}")]
|
||||||
File(#[from] FileError),
|
File(#[from] FileError),
|
||||||
#[error("Not found: {0}")]
|
#[error("not found: {0}")]
|
||||||
NotFound(String),
|
NotFound(String),
|
||||||
#[error("Validation error: {0}")]
|
#[error("validation error: {0}")]
|
||||||
Validation(String),
|
Validation(String),
|
||||||
#[error("Authorization error: {0}")]
|
#[error("authorization error: {0}")]
|
||||||
Auth(String),
|
Auth(String),
|
||||||
#[error("LLM parsing error: {0}")]
|
#[error("llm parsing error: {0}")]
|
||||||
LLMParsing(String),
|
LLMParsing(String),
|
||||||
#[error("Task join error: {0}")]
|
#[error("task join error: {0}")]
|
||||||
Join(#[from] JoinError),
|
Join(#[from] JoinError),
|
||||||
#[error("Graph mapper error: {0}")]
|
#[error("graph mapper error: {0}")]
|
||||||
GraphMapper(String),
|
GraphMapper(String),
|
||||||
#[error("IoError: {0}")]
|
#[error("io error: {0}")]
|
||||||
Io(#[from] std::io::Error),
|
Io(#[from] std::io::Error),
|
||||||
#[error("Reqwest error: {0}")]
|
#[error("reqwest error: {0}")]
|
||||||
Reqwest(#[from] reqwest::Error),
|
Reqwest(Box<reqwest::Error>),
|
||||||
#[error("Anyhow error: {0}")]
|
#[error("storage error: {0}")]
|
||||||
Anyhow(#[from] anyhow::Error),
|
Storage(Box<object_store::Error>),
|
||||||
#[error("Ingestion Processing error: {0}")]
|
#[error("ingestion processing error: {0}")]
|
||||||
Processing(String),
|
Processing(String),
|
||||||
#[error("DOM smoothie error: {0}")]
|
#[error("dom smoothie error: {0}")]
|
||||||
DomSmoothie(#[from] dom_smoothie::ReadabilityError),
|
DomSmoothie(Box<dom_smoothie::ReadabilityError>),
|
||||||
#[error("Internal service error: {0}")]
|
#[error("internal service error: {0}")]
|
||||||
InternalError(String),
|
InternalError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<surrealdb::Error> for AppError {
|
||||||
|
fn from(err: surrealdb::Error) -> Self {
|
||||||
|
Self::Database(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OpenAIError> for AppError {
|
||||||
|
fn from(err: OpenAIError) -> Self {
|
||||||
|
Self::OpenAI(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<reqwest::Error> for AppError {
|
||||||
|
fn from(err: reqwest::Error) -> Self {
|
||||||
|
Self::Reqwest(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<object_store::Error> for AppError {
|
||||||
|
fn from(err: object_store::Error) -> Self {
|
||||||
|
Self::Storage(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<dom_smoothie::ReadabilityError> for AppError {
|
||||||
|
fn from(err: dom_smoothie::ReadabilityError) -> Self {
|
||||||
|
Self::DomSmoothie(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppError {
|
||||||
|
/// Builds an [`AppError::InternalError`] from a displayable message.
|
||||||
|
#[must_use]
|
||||||
|
pub fn internal(msg: impl std::fmt::Display) -> Self {
|
||||||
|
Self::InternalError(msg.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::AppError;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn app_error_is_reasonably_sized() {
|
||||||
|
assert!(
|
||||||
|
std::mem::size_of::<AppError>() <= 64,
|
||||||
|
"AppError is {} bytes",
|
||||||
|
std::mem::size_of::<AppError>()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,3 +3,6 @@
|
|||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-utils"))]
|
||||||
|
pub mod test_utils;
|
||||||
|
|||||||
+115
-59
@@ -13,8 +13,8 @@ use surrealdb::{
|
|||||||
use surrealdb_migrations::MigrationRunner;
|
use surrealdb_migrations::MigrationRunner;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
/// Embedded SurrealDB migration directory packaged with the crate.
|
/// Embedded SurrealDB project root (`migrations/`, `schemas/`, `.surrealdb`).
|
||||||
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
|
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/db");
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SurrealDbClient {
|
pub struct SurrealDbClient {
|
||||||
@@ -26,12 +26,20 @@ pub trait ProvidesDb {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SurrealDbClient {
|
impl SurrealDbClient {
|
||||||
/// # Initialize a new datbase client
|
/// Initialize a new database client.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// * `address` — Database connection string (e.g. `ws://localhost:8000` or `mem://`).
|
||||||
/// * `SurrealDbClient` initialized
|
/// * `username` — Root username for authentication.
|
||||||
|
/// * `password` — Root password for authentication.
|
||||||
|
/// * `namespace` — SurrealDB namespace to use.
|
||||||
|
/// * `database` — SurrealDB database to use.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the connection, authentication, or namespace/database selection fails.
|
||||||
|
/// In-memory (`mem://`) connections skip authentication.
|
||||||
pub async fn new(
|
pub async fn new(
|
||||||
address: &str,
|
address: &str,
|
||||||
username: &str,
|
username: &str,
|
||||||
@@ -41,8 +49,10 @@ impl SurrealDbClient {
|
|||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
let db = connect(address).await?;
|
let db = connect(address).await?;
|
||||||
|
|
||||||
// Sign in to database
|
// Skip sign-in for in-memory engine (no auth support)
|
||||||
db.signin(Root { username, password }).await?;
|
if !address.starts_with("mem://") {
|
||||||
|
db.signin(Root { username, password }).await?;
|
||||||
|
}
|
||||||
|
|
||||||
// Set namespace
|
// Set namespace
|
||||||
db.use_ns(namespace).use_db(database).await?;
|
db.use_ns(namespace).use_db(database).await?;
|
||||||
@@ -50,6 +60,19 @@ impl SurrealDbClient {
|
|||||||
Ok(SurrealDbClient { client: db })
|
Ok(SurrealDbClient { client: db })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initialize a new database client using namespace-level authentication.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `address` — Database connection string.
|
||||||
|
/// * `namespace` — SurrealDB namespace to use (also used for auth).
|
||||||
|
/// * `username` — Namespace username for authentication.
|
||||||
|
/// * `password` — Namespace password for authentication.
|
||||||
|
/// * `database` — SurrealDB database to use.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the connection, namespace authentication, or namespace/database selection fails.
|
||||||
pub async fn new_with_namespace_user(
|
pub async fn new_with_namespace_user(
|
||||||
address: &str,
|
address: &str,
|
||||||
namespace: &str,
|
namespace: &str,
|
||||||
@@ -68,6 +91,11 @@ impl SurrealDbClient {
|
|||||||
Ok(SurrealDbClient { client: db })
|
Ok(SurrealDbClient { client: db })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an Axum session store backed by SurrealDB.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `SessionError` if the session store configuration or table creation fails.
|
||||||
pub async fn create_session_store(
|
pub async fn create_session_store(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
|
||||||
@@ -75,7 +103,7 @@ impl SurrealDbClient {
|
|||||||
SessionStore::new(
|
SessionStore::new(
|
||||||
Some(self.client.clone().into()),
|
Some(self.client.clone().into()),
|
||||||
SessionConfig::default()
|
SessionConfig::default()
|
||||||
.with_table_name("test_session_table")
|
.with_table_name("session")
|
||||||
.with_secure(true),
|
.with_secure(true),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -86,51 +114,63 @@ impl SurrealDbClient {
|
|||||||
/// This function should be called during application startup, after connecting to
|
/// This function should be called during application startup, after connecting to
|
||||||
/// the database and selecting the appropriate namespace and database, but before
|
/// the database and selecting the appropriate namespace and database, but before
|
||||||
/// the application starts performing operations that rely on the schema.
|
/// the application starts performing operations that rely on the schema.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::InternalError` if the migration runner fails to apply any migration.
|
||||||
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
pub async fn apply_migrations(&self) -> Result<(), AppError> {
|
||||||
debug!("Applying migrations");
|
debug!("Applying migrations");
|
||||||
MigrationRunner::new(&self.client)
|
MigrationRunner::new(&self.client)
|
||||||
.load_files(&MIGRATIONS_DIR)
|
.load_files(&MIGRATIONS_DIR)
|
||||||
.up()
|
.up()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
.map_err(AppError::internal)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
|
/// Store an object in SurrealDB.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `item` - The item to be stored
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// * `item` — The item to store. Must implement `StoredObject`.
|
||||||
/// * `Result` - Item or Error
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database create operation fails.
|
||||||
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
pub async fn store_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: StoredObject + Send + Sync + 'static,
|
T: StoredObject + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
self.client
|
self.client
|
||||||
.create((T::table_name(), item.get_id()))
|
.create((T::table_name(), item.id()))
|
||||||
.content(item)
|
.content(item)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to upsert an object in SurrealDB, replacing any existing record
|
/// Upsert an object in SurrealDB, replacing any existing record with the same ID.
|
||||||
/// with the same ID. Useful for idempotent ingestion flows.
|
///
|
||||||
|
/// Useful for idempotent ingestion flows.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database upsert operation fails.
|
||||||
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: StoredObject + Send + Sync + 'static,
|
T: StoredObject + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
let id = item.get_id().to_string();
|
let id = item.id().to_string();
|
||||||
self.client
|
self.client
|
||||||
.upsert((T::table_name(), id))
|
.upsert((T::table_name(), id))
|
||||||
.content(item)
|
.content(item)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject
|
/// Retrieve all objects from a table.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Errors
|
||||||
/// * `Result` - Vec<T> or Error
|
///
|
||||||
|
/// Returns `Err` if the database select operation fails.
|
||||||
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
pub async fn get_all_stored_items<T>(&self) -> Result<Vec<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -138,13 +178,16 @@ impl SurrealDbClient {
|
|||||||
self.client.select(T::table_name()).await
|
self.client.select(T::table_name()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to retrieve a single object by its ID, requires the struct to implement StoredObject
|
/// Retrieve a single object by its ID.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `id` - The ID of the item to retrieve
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// * `id` — The ID of the item to retrieve.
|
||||||
/// * `Result<Option<T>, Error>` - The found item or Error
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database select operation fails.
|
||||||
|
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||||
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
pub async fn get_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -152,13 +195,16 @@ impl SurrealDbClient {
|
|||||||
self.client.select((T::table_name(), id)).await
|
self.client.select((T::table_name(), id)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to delete a single object by its ID, requires the struct to implement StoredObject
|
/// Delete a single object by its ID.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `id` - The ID of the item to delete
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// * `id` — The ID of the item to delete.
|
||||||
/// * `Result<Option<T>, Error>` - The deleted item or Error
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the database delete operation fails.
|
||||||
|
/// Returns `Ok(None)` if no record with the given ID exists.
|
||||||
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
pub async fn delete_item<T>(&self, id: &str) -> Result<Option<T>, Error>
|
||||||
where
|
where
|
||||||
T: for<'de> StoredObject,
|
T: for<'de> StoredObject,
|
||||||
@@ -166,10 +212,11 @@ impl SurrealDbClient {
|
|||||||
self.client.delete((T::table_name(), id)).await
|
self.client.delete((T::table_name(), id)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Operation to listen to a table for updates, requires the struct to implement StoredObject
|
/// Listen to a table for real-time updates via a live query stream.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Errors
|
||||||
/// * `Result<Option<T>, Error>` - The deleted item or Error
|
///
|
||||||
|
/// Returns `Err` if the database live query subscription fails.
|
||||||
pub async fn listen<T>(
|
pub async fn listen<T>(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
) -> Result<impl Stream<Item = Result<Notification<T>, Error>>, Error>
|
||||||
@@ -202,7 +249,9 @@ impl SurrealDbClient {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -212,19 +261,17 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_initialization_and_crud() {
|
async fn test_initialization_and_crud() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string(); // ensures isolation per test run
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Call your initialization
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to initialize schema");
|
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||||
|
|
||||||
// Test basic CRUD
|
|
||||||
let dummy = Dummy {
|
let dummy = Dummy {
|
||||||
id: "abc".to_string(),
|
id: "abc".to_string(),
|
||||||
name: "first".to_string(),
|
name: "first".to_string(),
|
||||||
@@ -232,50 +279,50 @@ mod tests {
|
|||||||
updated_at: Utc::now(),
|
updated_at: Utc::now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Store
|
let stored = db
|
||||||
let stored = db.store_item(dummy.clone()).await.expect("Failed to store");
|
.store_item(dummy.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to store".to_string())?;
|
||||||
assert!(stored.is_some());
|
assert!(stored.is_some());
|
||||||
|
|
||||||
// Read
|
|
||||||
let fetched = db
|
let fetched = db
|
||||||
.get_item::<Dummy>(&dummy.id)
|
.get_item::<Dummy>(&dummy.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to fetch");
|
.with_context(|| "Failed to fetch".to_string())?;
|
||||||
assert_eq!(fetched, Some(dummy.clone()));
|
assert_eq!(fetched, Some(dummy.clone()));
|
||||||
|
|
||||||
// Read all
|
|
||||||
let all = db
|
let all = db
|
||||||
.get_all_stored_items::<Dummy>()
|
.get_all_stored_items::<Dummy>()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to fetch all");
|
.with_context(|| "Failed to fetch all".to_string())?;
|
||||||
assert!(all.contains(&dummy));
|
assert!(all.contains(&dummy));
|
||||||
|
|
||||||
// Delete
|
|
||||||
let deleted = db
|
let deleted = db
|
||||||
.delete_item::<Dummy>(&dummy.id)
|
.delete_item::<Dummy>(&dummy.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete");
|
.with_context(|| "Failed to delete".to_string())?;
|
||||||
assert_eq!(deleted, Some(dummy));
|
assert_eq!(deleted, Some(dummy));
|
||||||
|
|
||||||
// After delete, should not be present
|
|
||||||
let fetch_post = db
|
let fetch_post = db
|
||||||
.get_item::<Dummy>("abc")
|
.get_item::<Dummy>("abc")
|
||||||
.await
|
.await
|
||||||
.expect("Failed fetch post delete");
|
.with_context(|| "Failed fetch post delete".to_string())?;
|
||||||
assert!(fetch_post.is_none());
|
assert!(fetch_post.is_none());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn upsert_item_overwrites_existing_records() {
|
async fn upsert_item_overwrites_existing_records() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to initialize schema");
|
.with_context(|| "Failed to initialize schema".to_string())?;
|
||||||
|
|
||||||
let mut dummy = Dummy {
|
let mut dummy = Dummy {
|
||||||
id: "abc".to_string(),
|
id: "abc".to_string(),
|
||||||
@@ -286,17 +333,22 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(dummy.clone())
|
db.store_item(dummy.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store initial record");
|
.with_context(|| "Failed to store initial record".to_string())?;
|
||||||
|
|
||||||
dummy.name = "updated".to_string();
|
dummy.name = "updated".to_string();
|
||||||
let upserted = db
|
let upserted = db
|
||||||
.upsert_item(dummy.clone())
|
.upsert_item(dummy.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to upsert record");
|
.with_context(|| "Failed to upsert record".to_string())?;
|
||||||
assert!(upserted.is_some());
|
assert!(upserted.is_some());
|
||||||
|
|
||||||
let fetched: Option<Dummy> = db.get_item(&dummy.id).await.expect("fetch after upsert");
|
let fetched: Option<Dummy> = db
|
||||||
assert_eq!(fetched.unwrap().name, "updated");
|
.get_item(&dummy.id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "fetch after upsert".to_string())?;
|
||||||
|
let fetched =
|
||||||
|
fetched.ok_or_else(|| anyhow::anyhow!("Expected record to exist after upsert"))?;
|
||||||
|
assert_eq!(fetched.name, "updated");
|
||||||
|
|
||||||
let new_record = Dummy {
|
let new_record = Dummy {
|
||||||
id: "def".to_string(),
|
id: "def".to_string(),
|
||||||
@@ -306,25 +358,29 @@ mod tests {
|
|||||||
};
|
};
|
||||||
db.upsert_item(new_record.clone())
|
db.upsert_item(new_record.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to upsert new record");
|
.with_context(|| "Failed to upsert new record".to_string())?;
|
||||||
|
|
||||||
let fetched_new: Option<Dummy> = db
|
let fetched_new: Option<Dummy> = db
|
||||||
.get_item(&new_record.id)
|
.get_item(&new_record.id)
|
||||||
.await
|
.await
|
||||||
.expect("fetch inserted via upsert");
|
.with_context(|| "fetch inserted via upsert".to_string())?;
|
||||||
assert_eq!(fetched_new, Some(new_record));
|
assert_eq!(fetched_new, Some(new_record));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_applying_migrations() {
|
async fn test_applying_migrations() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to build indexes");
|
.with_context(|| "Failed to build indexes".to_string())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+307
-124
@@ -9,8 +9,38 @@ use tracing::{debug, info, warn};
|
|||||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||||
|
|
||||||
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
|
||||||
|
const INDEX_BUILD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
||||||
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
|
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
|
||||||
|
|
||||||
|
/// HNSW index options used by runtime index creation (includes CONCURRENTLY).
|
||||||
|
pub const HNSW_INDEX_OPTIONS: &str = "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY";
|
||||||
|
/// HNSW index options for use inside transactions (CONCURRENTLY not supported).
|
||||||
|
pub const HNSW_INDEX_OPTIONS_SYNC: &str = "DIST COSINE TYPE F32 EFC 100 M 8";
|
||||||
|
|
||||||
|
/// Builds a `DEFINE INDEX OVERWRITE ... HNSW` statement matching runtime index options.
|
||||||
|
#[must_use]
|
||||||
|
pub fn hnsw_index_overwrite_sql(index_name: &str, table: &str, dimension: usize) -> String {
|
||||||
|
format!(
|
||||||
|
"DEFINE INDEX OVERWRITE {index_name} ON TABLE {table} \
|
||||||
|
FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS};"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recreates an HNSW index inside a transaction (for tests and dimension migrations).
|
||||||
|
#[must_use]
|
||||||
|
pub fn hnsw_index_redefine_transaction_sql(
|
||||||
|
index_name: &str,
|
||||||
|
table: &str,
|
||||||
|
dimension: usize,
|
||||||
|
) -> String {
|
||||||
|
format!(
|
||||||
|
"BEGIN TRANSACTION;
|
||||||
|
REMOVE INDEX IF EXISTS {index_name} ON TABLE {table};
|
||||||
|
DEFINE INDEX {index_name} ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension} {HNSW_INDEX_OPTIONS_SYNC};
|
||||||
|
COMMIT TRANSACTION;"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
struct HnswIndexSpec {
|
struct HnswIndexSpec {
|
||||||
index_name: &'static str,
|
index_name: &'static str,
|
||||||
@@ -23,12 +53,12 @@ const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
|
|||||||
HnswIndexSpec {
|
HnswIndexSpec {
|
||||||
index_name: "idx_embedding_text_chunk_embedding",
|
index_name: "idx_embedding_text_chunk_embedding",
|
||||||
table: "text_chunk_embedding",
|
table: "text_chunk_embedding",
|
||||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
options: HNSW_INDEX_OPTIONS,
|
||||||
},
|
},
|
||||||
HnswIndexSpec {
|
HnswIndexSpec {
|
||||||
index_name: "idx_embedding_knowledge_entity_embedding",
|
index_name: "idx_embedding_knowledge_entity_embedding",
|
||||||
table: "knowledge_entity_embedding",
|
table: "knowledge_entity_embedding",
|
||||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
options: HNSW_INDEX_OPTIONS,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -159,26 +189,50 @@ impl FtsIndexSpec {
|
|||||||
|
|
||||||
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
|
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
|
||||||
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
|
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
|
||||||
pub async fn ensure_runtime_indexes(
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::InternalError` if any index definition or polling step fails.
|
||||||
|
pub async fn ensure_runtime(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
embedding_dimension: usize,
|
embedding_dimension: usize,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
ensure_runtime_indexes_inner(db, embedding_dimension)
|
ensure_runtime_inner(db, embedding_dimension)
|
||||||
.await
|
.await
|
||||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
.map_err(AppError::internal)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||||
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
|
///
|
||||||
rebuild_indexes_inner(db)
|
/// # Errors
|
||||||
.await
|
///
|
||||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
/// Returns `AppError::InternalError` if any index rebuild operation fails.
|
||||||
|
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||||
|
rebuild_inner(db).await.map_err(AppError::internal)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ensure_runtime_indexes_inner(
|
/// Returns the dimension of the currently defined chunk-embedding HNSW index, if any.
|
||||||
db: &SurrealDbClient,
|
///
|
||||||
embedding_dimension: usize,
|
/// Stored embeddings always share this index's dimension because re-embedding rewrites the
|
||||||
) -> Result<()> {
|
/// vectors and the index together, so it acts as a persisted marker of the embedding space
|
||||||
|
/// actually present in the database. Returns `Ok(None)` when the index has not been created yet
|
||||||
|
/// (for example on a fresh database with no ingested data).
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::InternalError` if the index metadata cannot be read.
|
||||||
|
pub async fn embedding_index_dimension(db: &SurrealDbClient) -> Result<Option<usize>, AppError> {
|
||||||
|
let spec = HnswIndexSpec {
|
||||||
|
index_name: "idx_embedding_text_chunk_embedding",
|
||||||
|
table: "text_chunk_embedding",
|
||||||
|
options: HNSW_INDEX_OPTIONS,
|
||||||
|
};
|
||||||
|
existing_hnsw_dimension(db, &spec)
|
||||||
|
.await
|
||||||
|
.map_err(AppError::internal)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
|
||||||
create_fts_analyzer(db).await?;
|
create_fts_analyzer(db).await?;
|
||||||
|
|
||||||
for spec in fts_index_specs() {
|
for spec in fts_index_specs() {
|
||||||
@@ -262,22 +316,17 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
|
|||||||
.context("checking index status")?;
|
.context("checking index status")?;
|
||||||
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
|
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
|
||||||
|
|
||||||
let info = match info {
|
let Some(info) = info else {
|
||||||
Some(i) => i,
|
return Ok("unknown".to_string());
|
||||||
None => return Ok("unknown".to_string()),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let building = info.get("building");
|
let parsed: IndexInfoForIndex =
|
||||||
let status = building
|
serde_json::from_value(info).context("deserializing INFO FOR INDEX response")?;
|
||||||
.and_then(|b| b.get("status"))
|
|
||||||
.and_then(|s| s.as_str())
|
|
||||||
.unwrap_or("ready")
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
Ok(status)
|
Ok(parsed.building_status())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
|
async fn rebuild_inner(db: &SurrealDbClient) -> Result<()> {
|
||||||
debug!("Rebuilding indexes with concurrent definitions");
|
debug!("Rebuilding indexes with concurrent definitions");
|
||||||
create_fts_analyzer(db).await?;
|
create_fts_analyzer(db).await?;
|
||||||
|
|
||||||
@@ -385,10 +434,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
|||||||
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
|
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
|
||||||
// an existing analyzer definition.
|
// an existing analyzer definition.
|
||||||
let snowball_query = format!(
|
let snowball_query = format!(
|
||||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
|
||||||
TOKENIZERS class
|
TOKENIZERS class
|
||||||
FILTERS lowercase, ascii, snowball(english);",
|
FILTERS lowercase, ascii, snowball(english);"
|
||||||
analyzer = FTS_ANALYZER_NAME
|
|
||||||
);
|
);
|
||||||
|
|
||||||
match db.client.query(snowball_query).await {
|
match db.client.query(snowball_query).await {
|
||||||
@@ -410,10 +458,9 @@ async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let fallback_query = format!(
|
let fallback_query = format!(
|
||||||
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
|
"DEFINE ANALYZER IF NOT EXISTS {FTS_ANALYZER_NAME}
|
||||||
TOKENIZERS class
|
TOKENIZERS class
|
||||||
FILTERS lowercase, ascii;",
|
FILTERS lowercase, ascii;"
|
||||||
analyzer = FTS_ANALYZER_NAME
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = db
|
let res = db
|
||||||
@@ -446,6 +493,7 @@ async fn create_index_with_polling(
|
|||||||
table: &str,
|
table: &str,
|
||||||
progress_table: Option<&str>,
|
progress_table: Option<&str>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
const MAX_ATTEMPTS: usize = 3;
|
||||||
let expected_total = match progress_table {
|
let expected_total = match progress_table {
|
||||||
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
|
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
|
||||||
format!("counting rows in {table} for index {index_name} progress")
|
format!("counting rows in {table} for index {index_name} progress")
|
||||||
@@ -453,10 +501,9 @@ async fn create_index_with_polling(
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut attempts = 0;
|
let mut attempts: usize = 0;
|
||||||
const MAX_ATTEMPTS: usize = 3;
|
|
||||||
loop {
|
loop {
|
||||||
attempts += 1;
|
attempts = attempts.saturating_add(1);
|
||||||
let res = db
|
let res = db
|
||||||
.client
|
.client
|
||||||
.query(definition.clone())
|
.query(definition.clone())
|
||||||
@@ -504,8 +551,20 @@ async fn poll_index_build_status(
|
|||||||
poll_every: Duration,
|
poll_every: Duration,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let started_at = std::time::Instant::now();
|
let started_at = std::time::Instant::now();
|
||||||
|
let mut last_snapshot: Option<IndexBuildSnapshot> = None;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
if started_at.elapsed() >= INDEX_BUILD_TIMEOUT {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"index build timed out after {:?} for {index_name} on {table} (last status: {})",
|
||||||
|
INDEX_BUILD_TIMEOUT,
|
||||||
|
last_snapshot
|
||||||
|
.as_ref()
|
||||||
|
.map_or("unknown", |snapshot| snapshot.status.as_str())
|
||||||
|
))
|
||||||
|
.with_context(|| format!("index {index_name} on table {table} did not become ready"));
|
||||||
|
}
|
||||||
|
|
||||||
tokio::time::sleep(poll_every).await;
|
tokio::time::sleep(poll_every).await;
|
||||||
|
|
||||||
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
|
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
|
||||||
@@ -519,16 +578,15 @@ async fn poll_index_build_status(
|
|||||||
.context("failed to deserialize INFO FOR INDEX result")?;
|
.context("failed to deserialize INFO FOR INDEX result")?;
|
||||||
|
|
||||||
let Some(snapshot) = parse_index_build_info(info, total_rows) else {
|
let Some(snapshot) = parse_index_build_info(info, total_rows) else {
|
||||||
warn!(
|
return Err(anyhow::anyhow!(
|
||||||
index = %index_name,
|
"INFO FOR INDEX returned no data for {index_name} on {table}"
|
||||||
table = %table,
|
));
|
||||||
"INFO FOR INDEX returned no data; assuming index definition might be missing"
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
match snapshot.progress_pct {
|
last_snapshot = Some(snapshot.clone());
|
||||||
Some(pct) => debug!(
|
|
||||||
|
if let Some(pct) = snapshot.progress_pct {
|
||||||
|
debug!(
|
||||||
index = %index_name,
|
index = %index_name,
|
||||||
table = %table,
|
table = %table,
|
||||||
status = snapshot.status,
|
status = snapshot.status,
|
||||||
@@ -539,8 +597,9 @@ async fn poll_index_build_status(
|
|||||||
total = snapshot.total_rows,
|
total = snapshot.total_rows,
|
||||||
progress_pct = format_args!("{pct:.1}"),
|
progress_pct = format_args!("{pct:.1}"),
|
||||||
"Index build status"
|
"Index build status"
|
||||||
),
|
);
|
||||||
None => debug!(
|
} else {
|
||||||
|
debug!(
|
||||||
index = %index_name,
|
index = %index_name,
|
||||||
table = %table,
|
table = %table,
|
||||||
status = snapshot.status,
|
status = snapshot.status,
|
||||||
@@ -549,7 +608,7 @@ async fn poll_index_build_status(
|
|||||||
updated = snapshot.updated,
|
updated = snapshot.updated,
|
||||||
processed = snapshot.processed,
|
processed = snapshot.processed,
|
||||||
"Index build status"
|
"Index build status"
|
||||||
),
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if snapshot.is_ready() {
|
if snapshot.is_ready() {
|
||||||
@@ -561,31 +620,101 @@ async fn poll_index_build_status(
|
|||||||
total = snapshot.total_rows,
|
total = snapshot.total_rows,
|
||||||
"Index is ready"
|
"Index is ready"
|
||||||
);
|
);
|
||||||
break;
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
if snapshot.status.eq_ignore_ascii_case("error") {
|
if snapshot.status.eq_ignore_ascii_case("error") {
|
||||||
warn!(
|
return Err(anyhow::anyhow!(
|
||||||
index = %index_name,
|
"index build failed for {index_name} on {table}: status=error, processed={}, total={:?}",
|
||||||
table = %table,
|
snapshot.processed,
|
||||||
status = snapshot.status,
|
snapshot.total_rows
|
||||||
"Index build reported error status; stopping polling"
|
));
|
||||||
);
|
}
|
||||||
break;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `building` block from SurrealDB `INFO FOR INDEX` (concurrent index builds).
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
|
||||||
|
struct IndexBuildingProgress {
|
||||||
|
#[serde(default)]
|
||||||
|
initial: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
pending: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
updated: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
status: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Top-level `INFO FOR INDEX` payload shape (SurrealDB v2.x).
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Default)]
|
||||||
|
struct IndexInfoForIndex {
|
||||||
|
#[serde(default)]
|
||||||
|
building: Option<IndexBuildingProgress>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IndexInfoForIndex {
|
||||||
|
fn building_status(&self) -> String {
|
||||||
|
match &self.building {
|
||||||
|
None => "ready".to_string(),
|
||||||
|
Some(progress) if progress.status.is_empty() => "ready".to_string(),
|
||||||
|
Some(progress) => progress.status.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
fn into_build_snapshot(self, total_rows: Option<u64>) -> IndexBuildSnapshot {
|
||||||
|
let (initial, pending, updated, status) = match self.building {
|
||||||
|
None => (0, 0, 0, "ready".to_string()),
|
||||||
|
Some(progress) => {
|
||||||
|
let status = if progress.status.is_empty() {
|
||||||
|
"ready".to_string()
|
||||||
|
} else {
|
||||||
|
progress.status
|
||||||
|
};
|
||||||
|
(progress.initial, progress.pending, progress.updated, status)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let processed = initial.saturating_add(updated);
|
||||||
|
let progress_pct = total_rows.map(|total| {
|
||||||
|
if total == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
((f64::from(u32::try_from(processed).unwrap_or(u32::MAX))
|
||||||
|
/ f64::from(u32::try_from(total).unwrap_or(1)))
|
||||||
|
.min(1.0))
|
||||||
|
* 100.0
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
IndexBuildSnapshot {
|
||||||
|
status,
|
||||||
|
initial,
|
||||||
|
pending,
|
||||||
|
updated,
|
||||||
|
processed,
|
||||||
|
total_rows,
|
||||||
|
progress_pct,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
/// Snapshot of an index build progress as reported by SurrealDB's `INFO FOR INDEX`.
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
struct IndexBuildSnapshot {
|
struct IndexBuildSnapshot {
|
||||||
|
/// Current build status string (e.g., `"indexing"`, `"ready"`, `"error"`).
|
||||||
status: String,
|
status: String,
|
||||||
|
/// Number of rows present when the build started.
|
||||||
initial: u64,
|
initial: u64,
|
||||||
|
/// Number of rows still pending processing.
|
||||||
pending: u64,
|
pending: u64,
|
||||||
|
/// Number of rows updated since the build started.
|
||||||
updated: u64,
|
updated: u64,
|
||||||
|
/// Total rows processed so far (`initial + updated`).
|
||||||
processed: u64,
|
processed: u64,
|
||||||
|
/// Total rows expected (from `SELECT count()` before the build), if available.
|
||||||
total_rows: Option<u64>,
|
total_rows: Option<u64>,
|
||||||
|
/// Progress as a percentage of `processed / total_rows`, if `total_rows` is known.
|
||||||
progress_pct: Option<f64>,
|
progress_pct: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -600,50 +729,8 @@ fn parse_index_build_info(
|
|||||||
total_rows: Option<u64>,
|
total_rows: Option<u64>,
|
||||||
) -> Option<IndexBuildSnapshot> {
|
) -> Option<IndexBuildSnapshot> {
|
||||||
let info = info?;
|
let info = info?;
|
||||||
let building = info.get("building");
|
let parsed: IndexInfoForIndex = serde_json::from_value(info).ok()?;
|
||||||
|
Some(parsed.into_build_snapshot(total_rows))
|
||||||
let status = building
|
|
||||||
.and_then(|b| b.get("status"))
|
|
||||||
.and_then(|s| s.as_str())
|
|
||||||
// If there's no `building` block at all, treat as "ready" (index not building anymore)
|
|
||||||
.unwrap_or("ready")
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
let initial = building
|
|
||||||
.and_then(|b| b.get("initial"))
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
let pending = building
|
|
||||||
.and_then(|b| b.get("pending"))
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
let updated = building
|
|
||||||
.and_then(|b| b.get("updated"))
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
|
|
||||||
let processed = initial.saturating_add(updated);
|
|
||||||
|
|
||||||
let progress_pct = total_rows.map(|total| {
|
|
||||||
if total == 0 {
|
|
||||||
0.0
|
|
||||||
} else {
|
|
||||||
((processed as f64 / total as f64).min(1.0)) * 100.0
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Some(IndexBuildSnapshot {
|
|
||||||
status,
|
|
||||||
initial,
|
|
||||||
pending,
|
|
||||||
updated,
|
|
||||||
processed,
|
|
||||||
total_rows,
|
|
||||||
progress_pct,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -673,7 +760,7 @@ async fn table_index_definitions(
|
|||||||
.client
|
.client
|
||||||
.query(info_query)
|
.query(info_query)
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("fetching table info for {}", table))?;
|
.with_context(|| format!("fetching table info for {table}"))?;
|
||||||
|
|
||||||
let info: surrealdb::Value = response
|
let info: surrealdb::Value = response
|
||||||
.take(0)
|
.take(0)
|
||||||
@@ -700,12 +787,16 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
use anyhow::{self, Context};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_index_build_info_reports_progress() {
|
fn parse_index_build_info_reports_progress() -> anyhow::Result<()> {
|
||||||
let info = json!({
|
let info = json!({
|
||||||
"building": {
|
"building": {
|
||||||
"initial": 56894,
|
"initial": 56894,
|
||||||
@@ -715,7 +806,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot");
|
let snapshot = parse_index_build_info(Some(info), Some(61081)).context("snapshot")?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
snapshot,
|
snapshot,
|
||||||
IndexBuildSnapshot {
|
IndexBuildSnapshot {
|
||||||
@@ -729,16 +820,84 @@ mod tests {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
assert!(!snapshot.is_ready());
|
assert!(!snapshot.is_ready());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_index_build_info_defaults_to_ready_when_no_building_block() {
|
fn parse_index_build_info_defaults_to_ready_when_no_building_block() -> anyhow::Result<()> {
|
||||||
// Surreal returns `{}` when the index exists but isn't building.
|
// Surreal returns `{}` when the index exists but isn't building.
|
||||||
let info = json!({});
|
let info = json!({});
|
||||||
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot");
|
let snapshot = parse_index_build_info(Some(info), Some(10)).context("snapshot")?;
|
||||||
assert!(snapshot.is_ready());
|
assert!(snapshot.is_ready());
|
||||||
assert_eq!(snapshot.processed, 0);
|
assert_eq!(snapshot.processed, 0);
|
||||||
assert_eq!(snapshot.progress_pct, Some(0.0));
|
assert_eq!(snapshot.progress_pct, Some(0.0));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_info_for_index_deserializes_ready_status_shape() -> anyhow::Result<()> {
|
||||||
|
let info = json!({
|
||||||
|
"building": {
|
||||||
|
"status": "ready"
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let parsed: IndexInfoForIndex =
|
||||||
|
serde_json::from_value(info).context("deserialize ready shape")?;
|
||||||
|
assert_eq!(parsed.building_status(), "ready");
|
||||||
|
|
||||||
|
let snapshot = parse_index_build_info(
|
||||||
|
Some(json!({
|
||||||
|
"building": { "status": "ready" }
|
||||||
|
})),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.context("snapshot")?;
|
||||||
|
assert!(snapshot.is_ready());
|
||||||
|
assert_eq!(snapshot.initial, 0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_info_for_index_deserializes_indexing_shape_from_surreal_docs() -> anyhow::Result<()> {
|
||||||
|
let info = json!({
|
||||||
|
"building": {
|
||||||
|
"initial": 8143,
|
||||||
|
"pending": 19,
|
||||||
|
"status": "indexing",
|
||||||
|
"updated": 80
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let parsed: IndexInfoForIndex =
|
||||||
|
serde_json::from_value(info.clone()).context("deserialize indexing shape")?;
|
||||||
|
assert_eq!(parsed.building_status(), "indexing");
|
||||||
|
|
||||||
|
let snapshot = parse_index_build_info(Some(info), None).context("snapshot")?;
|
||||||
|
assert_eq!(snapshot.status, "indexing");
|
||||||
|
assert_eq!(snapshot.initial, 8143);
|
||||||
|
assert_eq!(snapshot.pending, 19);
|
||||||
|
assert_eq!(snapshot.updated, 80);
|
||||||
|
assert_eq!(snapshot.processed, 8223);
|
||||||
|
assert!(!snapshot.is_ready());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_index_build_info_reports_error_status() -> anyhow::Result<()> {
|
||||||
|
let info = json!({
|
||||||
|
"building": {
|
||||||
|
"initial": 100,
|
||||||
|
"pending": 5,
|
||||||
|
"status": "error",
|
||||||
|
"updated": 10
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let snapshot = parse_index_build_info(Some(info), Some(200)).context("snapshot")?;
|
||||||
|
assert_eq!(snapshot.status, "error");
|
||||||
|
assert!(!snapshot.is_ready());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -748,48 +907,72 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn ensure_runtime_indexes_is_idempotent() {
|
async fn ensure_runtime_is_idempotent() -> anyhow::Result<()> {
|
||||||
let namespace = "indexes_ns";
|
let namespace = "indexes_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("in-memory db");
|
.context("in-memory db")?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("migrations should succeed");
|
.context("migrations should succeed")?;
|
||||||
|
|
||||||
// First run creates everything
|
ensure_runtime(&db, 1536)
|
||||||
ensure_runtime_indexes(&db, 1536)
|
|
||||||
.await
|
.await
|
||||||
.expect("initial index creation");
|
.context("first call should succeed")?;
|
||||||
|
ensure_runtime(&db, 1536)
|
||||||
// Second run should be a no-op and still succeed
|
|
||||||
ensure_runtime_indexes(&db, 1536)
|
|
||||||
.await
|
.await
|
||||||
.expect("second index creation");
|
.context("second index creation")?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn ensure_hnsw_index_overwrites_dimension() {
|
async fn embedding_index_dimension_reflects_runtime_state() -> anyhow::Result<()> {
|
||||||
|
let namespace = "indexes_marker";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.context("in-memory db")?;
|
||||||
|
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.context("migrations should succeed")?;
|
||||||
|
|
||||||
|
// Before any index exists, there is no stored embedding dimension to detect.
|
||||||
|
assert_eq!(embedding_index_dimension(&db).await?, None);
|
||||||
|
|
||||||
|
ensure_runtime(&db, 1536)
|
||||||
|
.await
|
||||||
|
.context("initial index creation")?;
|
||||||
|
assert_eq!(embedding_index_dimension(&db).await?, Some(1536));
|
||||||
|
|
||||||
|
// After a dimension change the marker tracks the new index dimension.
|
||||||
|
ensure_runtime(&db, 256)
|
||||||
|
.await
|
||||||
|
.context("overwritten index creation")?;
|
||||||
|
assert_eq!(embedding_index_dimension(&db).await?, Some(256));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ensure_hnsw_index_overwrites_dimension() -> anyhow::Result<()> {
|
||||||
let namespace = "indexes_dim";
|
let namespace = "indexes_dim";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("in-memory db");
|
.context("in-memory db")?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("migrations should succeed");
|
.context("migrations should succeed")?;
|
||||||
|
|
||||||
// Create initial index with default dimension
|
ensure_runtime(&db, 1536)
|
||||||
ensure_runtime_indexes(&db, 1536)
|
|
||||||
.await
|
.await
|
||||||
.expect("initial index creation");
|
.context("initial index creation")?;
|
||||||
|
ensure_runtime(&db, 128)
|
||||||
// Change dimension and ensure overwrite path is exercised
|
|
||||||
ensure_runtime_indexes(&db, 128)
|
|
||||||
.await
|
.await
|
||||||
.expect("overwritten index creation");
|
.context("overwritten index creation")?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+286
-125
@@ -2,7 +2,7 @@ use std::io::ErrorKind;
|
|||||||
use std::path::{Component, Path, PathBuf};
|
use std::path::{Component, Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result as AnyResult};
|
use anyhow::{anyhow, Context, Result as AnyResult};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
use futures::{StreamExt, TryStreamExt};
|
use futures::{StreamExt, TryStreamExt};
|
||||||
@@ -13,13 +13,13 @@ use object_store::{path::Path as ObjPath, ObjectStore};
|
|||||||
|
|
||||||
use crate::utils::config::{AppConfig, StorageKind};
|
use crate::utils::config::{AppConfig, StorageKind};
|
||||||
|
|
||||||
pub type DynStore = Arc<dyn ObjectStore>;
|
pub type DynStorage = Arc<dyn ObjectStore>;
|
||||||
|
|
||||||
/// Storage manager with persistent state and proper lifecycle management.
|
/// Storage manager with persistent state and proper lifecycle management.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct StorageManager {
|
pub struct StorageManager {
|
||||||
// Store from objectstore wrapped as dyn
|
// Store from objectstore wrapped as dyn
|
||||||
store: DynStore,
|
store: DynStorage,
|
||||||
// Simple enum to track which kind
|
// Simple enum to track which kind
|
||||||
backend_kind: StorageKind,
|
backend_kind: StorageKind,
|
||||||
// Where on disk
|
// Where on disk
|
||||||
@@ -31,8 +31,13 @@ impl StorageManager {
|
|||||||
///
|
///
|
||||||
/// This method validates the configuration and creates the appropriate
|
/// This method validates the configuration and creates the appropriate
|
||||||
/// storage backend with proper initialization.
|
/// storage backend with proper initialization.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the storage backend cannot be created or initialised
|
||||||
|
/// (e.g. missing S3 bucket, local filesystem permission error).
|
||||||
pub async fn new(cfg: &AppConfig) -> object_store::Result<Self> {
|
pub async fn new(cfg: &AppConfig) -> object_store::Result<Self> {
|
||||||
let backend_kind = cfg.storage.clone();
|
let backend_kind = cfg.storage;
|
||||||
let (store, local_base) = create_storage_backend(cfg).await?;
|
let (store, local_base) = create_storage_backend(cfg).await?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@@ -46,7 +51,8 @@ impl StorageManager {
|
|||||||
///
|
///
|
||||||
/// This method is useful for testing scenarios where you want to inject
|
/// This method is useful for testing scenarios where you want to inject
|
||||||
/// a specific storage backend.
|
/// a specific storage backend.
|
||||||
pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self {
|
#[must_use]
|
||||||
|
pub fn with_backend(store: DynStorage, backend_kind: StorageKind) -> Self {
|
||||||
Self {
|
Self {
|
||||||
store,
|
store,
|
||||||
backend_kind,
|
backend_kind,
|
||||||
@@ -55,11 +61,13 @@ impl StorageManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the storage backend kind.
|
/// Get the storage backend kind.
|
||||||
|
#[must_use]
|
||||||
pub fn backend_kind(&self) -> &StorageKind {
|
pub fn backend_kind(&self) -> &StorageKind {
|
||||||
&self.backend_kind
|
&self.backend_kind
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Access the resolved local base directory when using the local backend.
|
/// Access the resolved local base directory when using the local backend.
|
||||||
|
#[must_use]
|
||||||
pub fn local_base_path(&self) -> Option<&Path> {
|
pub fn local_base_path(&self) -> Option<&Path> {
|
||||||
self.local_base.as_deref()
|
self.local_base.as_deref()
|
||||||
}
|
}
|
||||||
@@ -68,6 +76,7 @@ impl StorageManager {
|
|||||||
///
|
///
|
||||||
/// Returns `None` when the backend is not local or when the provided location includes
|
/// Returns `None` when the backend is not local or when the provided location includes
|
||||||
/// unsupported components (absolute paths or parent traversals).
|
/// unsupported components (absolute paths or parent traversals).
|
||||||
|
#[must_use]
|
||||||
pub fn resolve_local_path(&self, location: &str) -> Option<PathBuf> {
|
pub fn resolve_local_path(&self, location: &str) -> Option<PathBuf> {
|
||||||
let base = self.local_base_path()?;
|
let base = self.local_base_path()?;
|
||||||
let relative = Path::new(location);
|
let relative = Path::new(location);
|
||||||
@@ -86,6 +95,10 @@ impl StorageManager {
|
|||||||
///
|
///
|
||||||
/// This operation persists data using the underlying storage backend.
|
/// This operation persists data using the underlying storage backend.
|
||||||
/// For memory backends, data persists for the lifetime of the StorageManager.
|
/// For memory backends, data persists for the lifetime of the StorageManager.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the underlying storage backend fails to persist the data.
|
||||||
pub async fn put(&self, location: &str, data: Bytes) -> object_store::Result<()> {
|
pub async fn put(&self, location: &str, data: Bytes) -> object_store::Result<()> {
|
||||||
let path = ObjPath::from(location);
|
let path = ObjPath::from(location);
|
||||||
let payload = object_store::PutPayload::from_bytes(data);
|
let payload = object_store::PutPayload::from_bytes(data);
|
||||||
@@ -94,16 +107,27 @@ impl StorageManager {
|
|||||||
|
|
||||||
/// Retrieve bytes from the specified location.
|
/// Retrieve bytes from the specified location.
|
||||||
///
|
///
|
||||||
/// Returns the full contents buffered in memory.
|
/// Reads via [`Self::get_stream`] and buffers the full object in memory.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the location does not exist or the underlying backend fails.
|
||||||
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
|
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
|
||||||
let path = ObjPath::from(location);
|
let mut stream = self.get_stream(location).await?;
|
||||||
let result = self.store.get(&path).await?;
|
let mut collected = Vec::new();
|
||||||
result.bytes().await
|
while let Some(chunk) = stream.next().await {
|
||||||
|
collected.extend_from_slice(&chunk?);
|
||||||
|
}
|
||||||
|
Ok(Bytes::from(collected))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a streaming handle for large objects.
|
/// Get a streaming handle for large objects.
|
||||||
///
|
///
|
||||||
/// Returns a fallible stream of Bytes chunks suitable for large file processing.
|
/// Returns a fallible stream of Bytes chunks suitable for large file processing.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the location does not exist or the underlying backend fails.
|
||||||
pub async fn get_stream(
|
pub async fn get_stream(
|
||||||
&self,
|
&self,
|
||||||
location: &str,
|
location: &str,
|
||||||
@@ -116,6 +140,10 @@ impl StorageManager {
|
|||||||
/// Delete all objects below the specified prefix.
|
/// Delete all objects below the specified prefix.
|
||||||
///
|
///
|
||||||
/// For local filesystem backends, this also attempts to clean up empty directories.
|
/// For local filesystem backends, this also attempts to clean up empty directories.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the underlying backend fails during deletion.
|
||||||
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
|
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
|
||||||
let prefix_path = ObjPath::from(prefix);
|
let prefix_path = ObjPath::from(prefix);
|
||||||
let locations = self
|
let locations = self
|
||||||
@@ -137,6 +165,10 @@ impl StorageManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// List all objects below the specified prefix.
|
/// List all objects below the specified prefix.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the underlying backend fails to list objects.
|
||||||
pub async fn list(
|
pub async fn list(
|
||||||
&self,
|
&self,
|
||||||
prefix: Option<&str>,
|
prefix: Option<&str>,
|
||||||
@@ -146,6 +178,10 @@ impl StorageManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Check if an object exists at the specified location.
|
/// Check if an object exists at the specified location.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` if the underlying backend returns a non-NotFound error.
|
||||||
pub async fn exists(&self, location: &str) -> object_store::Result<bool> {
|
pub async fn exists(&self, location: &str) -> object_store::Result<bool> {
|
||||||
let path = ObjPath::from(location);
|
let path = ObjPath::from(location);
|
||||||
self.store
|
self.store
|
||||||
@@ -216,10 +252,13 @@ impl StorageManager {
|
|||||||
/// storage backends with proper error handling and validation.
|
/// storage backends with proper error handling and validation.
|
||||||
async fn create_storage_backend(
|
async fn create_storage_backend(
|
||||||
cfg: &AppConfig,
|
cfg: &AppConfig,
|
||||||
) -> object_store::Result<(DynStore, Option<PathBuf>)> {
|
) -> object_store::Result<(DynStorage, Option<PathBuf>)> {
|
||||||
match cfg.storage {
|
match cfg.storage {
|
||||||
StorageKind::Local => {
|
StorageKind::Local => {
|
||||||
let base = resolve_base_dir(cfg);
|
let base = resolve_base_dir(cfg).map_err(|err| object_store::Error::Generic {
|
||||||
|
store: "LocalFileSystem",
|
||||||
|
source: err.into(),
|
||||||
|
})?;
|
||||||
if !base.exists() {
|
if !base.exists() {
|
||||||
tokio::fs::create_dir_all(&base).await.map_err(|e| {
|
tokio::fs::create_dir_all(&base).await.map_err(|e| {
|
||||||
object_store::Error::Generic {
|
object_store::Error::Generic {
|
||||||
@@ -261,9 +300,7 @@ async fn create_storage_backend(
|
|||||||
builder = builder.with_endpoint(endpoint);
|
builder = builder.with_endpoint(endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(region) = &cfg.s3_region {
|
builder = builder.with_region(&cfg.s3_region);
|
||||||
builder = builder.with_region(region);
|
|
||||||
}
|
|
||||||
|
|
||||||
let store = builder.build()?;
|
let store = builder.build()?;
|
||||||
Ok((Arc::new(store), None))
|
Ok((Arc::new(store), None))
|
||||||
@@ -277,6 +314,7 @@ async fn create_storage_backend(
|
|||||||
/// automatic memory backend setup and proper test isolation.
|
/// automatic memory backend setup and proper test isolation.
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub mod testing {
|
pub mod testing {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::utils::config::{AppConfig, PdfIngestMode};
|
use crate::utils::config::{AppConfig, PdfIngestMode};
|
||||||
use uuid;
|
use uuid;
|
||||||
@@ -342,7 +380,7 @@ pub mod testing {
|
|||||||
surrealdb_password: "test".into(),
|
surrealdb_password: "test".into(),
|
||||||
surrealdb_namespace: "test".into(),
|
surrealdb_namespace: "test".into(),
|
||||||
surrealdb_database: "test".into(),
|
surrealdb_database: "test".into(),
|
||||||
data_dir: base.into(),
|
data_dir: base,
|
||||||
http_port: 0,
|
http_port: 0,
|
||||||
openai_base_url: "..".into(),
|
openai_base_url: "..".into(),
|
||||||
storage: StorageKind::Local,
|
storage: StorageKind::Local,
|
||||||
@@ -369,7 +407,7 @@ pub mod testing {
|
|||||||
storage: StorageKind::S3,
|
storage: StorageKind::S3,
|
||||||
s3_bucket: Some(configured_test_s3_bucket()),
|
s3_bucket: Some(configured_test_s3_bucket()),
|
||||||
s3_endpoint: Some(configured_test_s3_endpoint()),
|
s3_endpoint: Some(configured_test_s3_endpoint()),
|
||||||
s3_region: Some("us-east-1".into()),
|
s3_region: "us-east-1".into(),
|
||||||
pdf_ingest_mode: PdfIngestMode::LlmFirst,
|
pdf_ingest_mode: PdfIngestMode::LlmFirst,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
@@ -382,7 +420,7 @@ pub mod testing {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct TestStorageManager {
|
pub struct TestStorageManager {
|
||||||
storage: StorageManager,
|
storage: StorageManager,
|
||||||
_temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
|
temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestStorageManager {
|
impl TestStorageManager {
|
||||||
@@ -396,7 +434,7 @@ pub mod testing {
|
|||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
storage,
|
storage,
|
||||||
_temp_dir: None,
|
temp_dir: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -413,7 +451,7 @@ pub mod testing {
|
|||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
storage,
|
storage,
|
||||||
_temp_dir: resolved,
|
temp_dir: resolved,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -437,7 +475,7 @@ pub mod testing {
|
|||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
storage,
|
storage,
|
||||||
_temp_dir: None,
|
temp_dir: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -452,10 +490,7 @@ pub mod testing {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self { storage, temp_dir })
|
||||||
storage,
|
|
||||||
_temp_dir: temp_dir,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a reference to the underlying StorageManager.
|
/// Get a reference to the underlying StorageManager.
|
||||||
@@ -508,7 +543,7 @@ pub mod testing {
|
|||||||
impl Drop for TestStorageManager {
|
impl Drop for TestStorageManager {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
// Clean up temporary directories for local storage
|
// Clean up temporary directories for local storage
|
||||||
if let Some((_, path)) = &self._temp_dir {
|
if let Some((_, path)) = &self.temp_dir {
|
||||||
if path.exists() {
|
if path.exists() {
|
||||||
let _ = std::fs::remove_dir_all(path);
|
let _ = std::fs::remove_dir_all(path);
|
||||||
}
|
}
|
||||||
@@ -547,14 +582,22 @@ pub mod testing {
|
|||||||
|
|
||||||
/// Resolve the absolute base directory used for local storage from config.
|
/// Resolve the absolute base directory used for local storage from config.
|
||||||
///
|
///
|
||||||
/// If `data_dir` is relative, it is resolved against the current working directory.
|
/// If `data_dir` is relative, it is resolved against the process current working directory.
|
||||||
pub fn resolve_base_dir(cfg: &AppConfig) -> PathBuf {
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Err` when `data_dir` is relative and the current working directory cannot be read.
|
||||||
|
pub fn resolve_base_dir(cfg: &AppConfig) -> AnyResult<PathBuf> {
|
||||||
if cfg.data_dir.starts_with('/') {
|
if cfg.data_dir.starts_with('/') {
|
||||||
PathBuf::from(&cfg.data_dir)
|
Ok(PathBuf::from(&cfg.data_dir))
|
||||||
} else {
|
} else {
|
||||||
std::env::current_dir()
|
let cwd = std::env::current_dir().with_context(|| {
|
||||||
.unwrap_or_else(|_| PathBuf::from("."))
|
format!(
|
||||||
.join(&cfg.data_dir)
|
"failed to resolve relative data_dir '{}' against the current working directory",
|
||||||
|
cfg.data_dir
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
Ok(cwd.join(&cfg.data_dir))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -583,8 +626,10 @@ pub fn split_object_path(path: &str) -> AnyResult<(String, String)> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
|
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
|
||||||
|
use anyhow::Context;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -623,11 +668,11 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_memory_basic_operations() {
|
async fn test_storage_manager_memory_basic_operations() -> anyhow::Result<()> {
|
||||||
let cfg = test_config_memory();
|
let cfg = test_config_memory();
|
||||||
let storage = StorageManager::new(&cfg)
|
let storage = StorageManager::new(&cfg)
|
||||||
.await
|
.await
|
||||||
.expect("create storage manager");
|
.with_context(|| "create storage manager".to_string())?;
|
||||||
assert!(storage.local_base_path().is_none());
|
assert!(storage.local_base_path().is_none());
|
||||||
|
|
||||||
let location = "test/data/file.txt";
|
let location = "test/data/file.txt";
|
||||||
@@ -637,31 +682,42 @@ mod tests {
|
|||||||
storage
|
storage
|
||||||
.put(location, Bytes::from(data.to_vec()))
|
.put(location, Bytes::from(data.to_vec()))
|
||||||
.await
|
.await
|
||||||
.expect("put");
|
.with_context(|| "put".to_string())?;
|
||||||
let retrieved = storage.get(location).await.expect("get");
|
let retrieved = storage
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get".to_string())?;
|
||||||
assert_eq!(retrieved.as_ref(), data);
|
assert_eq!(retrieved.as_ref(), data);
|
||||||
|
|
||||||
// Test exists
|
// Test exists
|
||||||
assert!(storage.exists(location).await.expect("exists check"));
|
assert!(storage
|
||||||
|
.exists(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "exists check".to_string())?);
|
||||||
|
|
||||||
// Test delete
|
// Test delete
|
||||||
storage.delete_prefix("test/data/").await.expect("delete");
|
storage
|
||||||
|
.delete_prefix("test/data/")
|
||||||
|
.await
|
||||||
|
.with_context(|| "delete".to_string())?;
|
||||||
assert!(!storage
|
assert!(!storage
|
||||||
.exists(location)
|
.exists(location)
|
||||||
.await
|
.await
|
||||||
.expect("exists check after delete"));
|
.with_context(|| "exists check after delete".to_string())?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_local_basic_operations() {
|
async fn test_storage_manager_local_basic_operations() -> anyhow::Result<()> {
|
||||||
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
|
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
|
||||||
let cfg = test_config(&base);
|
let cfg = test_config(&base);
|
||||||
let storage = StorageManager::new(&cfg)
|
let storage = StorageManager::new(&cfg)
|
||||||
.await
|
.await
|
||||||
.expect("create storage manager");
|
.with_context(|| "create storage manager".to_string())?;
|
||||||
let resolved_base = storage
|
let resolved_base = storage
|
||||||
.local_base_path()
|
.local_base_path()
|
||||||
.expect("resolved base dir")
|
.with_context(|| "resolved base dir".to_string())?
|
||||||
.to_path_buf();
|
.to_path_buf();
|
||||||
assert_eq!(resolved_base, PathBuf::from(&base));
|
assert_eq!(resolved_base, PathBuf::from(&base));
|
||||||
|
|
||||||
@@ -672,42 +728,53 @@ mod tests {
|
|||||||
storage
|
storage
|
||||||
.put(location, Bytes::from(data.to_vec()))
|
.put(location, Bytes::from(data.to_vec()))
|
||||||
.await
|
.await
|
||||||
.expect("put");
|
.with_context(|| "put".to_string())?;
|
||||||
let retrieved = storage.get(location).await.expect("get");
|
let retrieved = storage
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get".to_string())?;
|
||||||
assert_eq!(retrieved.as_ref(), data);
|
assert_eq!(retrieved.as_ref(), data);
|
||||||
|
|
||||||
let object_dir = resolved_base.join("test/data");
|
let object_dir = resolved_base.join("test/data");
|
||||||
tokio::fs::metadata(&object_dir)
|
tokio::fs::metadata(&object_dir)
|
||||||
.await
|
.await
|
||||||
.expect("object directory exists after write");
|
.with_context(|| "object directory exists after write".to_string())?;
|
||||||
|
|
||||||
// Test exists
|
// Test exists
|
||||||
assert!(storage.exists(location).await.expect("exists check"));
|
assert!(storage
|
||||||
|
.exists(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "exists check".to_string())?);
|
||||||
|
|
||||||
// Test delete
|
// Test delete
|
||||||
storage.delete_prefix("test/data/").await.expect("delete");
|
storage
|
||||||
|
.delete_prefix("test/data/")
|
||||||
|
.await
|
||||||
|
.with_context(|| "delete".to_string())?;
|
||||||
assert!(!storage
|
assert!(!storage
|
||||||
.exists(location)
|
.exists(location)
|
||||||
.await
|
.await
|
||||||
.expect("exists check after delete"));
|
.with_context(|| "exists check after delete".to_string())?);
|
||||||
assert!(
|
assert!(
|
||||||
tokio::fs::metadata(&object_dir).await.is_err(),
|
tokio::fs::metadata(&object_dir).await.is_err(),
|
||||||
"object directory should be removed"
|
"object directory should be removed"
|
||||||
);
|
);
|
||||||
tokio::fs::metadata(&resolved_base)
|
tokio::fs::metadata(&resolved_base)
|
||||||
.await
|
.await
|
||||||
.expect("base directory remains intact");
|
.with_context(|| "base directory remains intact".to_string())?;
|
||||||
|
|
||||||
// Clean up
|
// Clean up
|
||||||
let _ = tokio::fs::remove_dir_all(&base).await;
|
let _ = tokio::fs::remove_dir_all(&base).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_memory_persistence() {
|
async fn test_storage_manager_memory_persistence() -> anyhow::Result<()> {
|
||||||
let cfg = test_config_memory();
|
let cfg = test_config_memory();
|
||||||
let storage = StorageManager::new(&cfg)
|
let storage = StorageManager::new(&cfg)
|
||||||
.await
|
.await
|
||||||
.expect("create storage manager");
|
.with_context(|| "create storage manager".to_string())?;
|
||||||
|
|
||||||
let location = "persistence/test.txt";
|
let location = "persistence/test.txt";
|
||||||
let data1 = b"first data";
|
let data1 = b"first data";
|
||||||
@@ -717,32 +784,40 @@ mod tests {
|
|||||||
storage
|
storage
|
||||||
.put(location, Bytes::from(data1.to_vec()))
|
.put(location, Bytes::from(data1.to_vec()))
|
||||||
.await
|
.await
|
||||||
.expect("put first");
|
.with_context(|| "put first".to_string())?;
|
||||||
|
|
||||||
// Retrieve and verify first data
|
// Retrieve and verify first data
|
||||||
let retrieved1 = storage.get(location).await.expect("get first");
|
let retrieved1 = storage
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get first".to_string())?;
|
||||||
assert_eq!(retrieved1.as_ref(), data1);
|
assert_eq!(retrieved1.as_ref(), data1);
|
||||||
|
|
||||||
// Overwrite with second data
|
// Overwrite with second data
|
||||||
storage
|
storage
|
||||||
.put(location, Bytes::from(data2.to_vec()))
|
.put(location, Bytes::from(data2.to_vec()))
|
||||||
.await
|
.await
|
||||||
.expect("put second");
|
.with_context(|| "put second".to_string())?;
|
||||||
|
|
||||||
// Retrieve and verify second data
|
// Retrieve and verify second data
|
||||||
let retrieved2 = storage.get(location).await.expect("get second");
|
let retrieved2 = storage
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get second".to_string())?;
|
||||||
assert_eq!(retrieved2.as_ref(), data2);
|
assert_eq!(retrieved2.as_ref(), data2);
|
||||||
|
|
||||||
// Data persists across multiple operations using the same StorageManager
|
// Data persists across multiple operations using the same StorageManager
|
||||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_list_operations() {
|
async fn test_storage_manager_list_operations() -> anyhow::Result<()> {
|
||||||
let cfg = test_config_memory();
|
let cfg = test_config_memory();
|
||||||
let storage = StorageManager::new(&cfg)
|
let storage = StorageManager::new(&cfg)
|
||||||
.await
|
.await
|
||||||
.expect("create storage manager");
|
.with_context(|| "create storage manager".to_string())?;
|
||||||
|
|
||||||
// Create multiple files
|
// Create multiple files
|
||||||
let files = vec![
|
let files = vec![
|
||||||
@@ -755,15 +830,21 @@ mod tests {
|
|||||||
storage
|
storage
|
||||||
.put(location, Bytes::from(data.to_vec()))
|
.put(location, Bytes::from(data.to_vec()))
|
||||||
.await
|
.await
|
||||||
.expect("put");
|
.with_context(|| "put".to_string())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test listing without prefix
|
// Test listing without prefix
|
||||||
let all_files = storage.list(None).await.expect("list all");
|
let all_files = storage
|
||||||
|
.list(None)
|
||||||
|
.await
|
||||||
|
.with_context(|| "list all".to_string())?;
|
||||||
assert_eq!(all_files.len(), 3);
|
assert_eq!(all_files.len(), 3);
|
||||||
|
|
||||||
// Test listing with prefix
|
// Test listing with prefix
|
||||||
let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1");
|
let dir1_files = storage
|
||||||
|
.list(Some("dir1/"))
|
||||||
|
.await
|
||||||
|
.with_context(|| "list dir1".to_string())?;
|
||||||
assert_eq!(dir1_files.len(), 2);
|
assert_eq!(dir1_files.len(), 2);
|
||||||
assert!(dir1_files
|
assert!(dir1_files
|
||||||
.iter()
|
.iter()
|
||||||
@@ -776,16 +857,18 @@ mod tests {
|
|||||||
let empty_files = storage
|
let empty_files = storage
|
||||||
.list(Some("nonexistent/"))
|
.list(Some("nonexistent/"))
|
||||||
.await
|
.await
|
||||||
.expect("list nonexistent");
|
.with_context(|| "list nonexistent".to_string())?;
|
||||||
assert_eq!(empty_files.len(), 0);
|
assert_eq!(empty_files.len(), 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_stream_operations() {
|
async fn test_storage_manager_stream_operations() -> anyhow::Result<()> {
|
||||||
let cfg = test_config_memory();
|
let cfg = test_config_memory();
|
||||||
let storage = StorageManager::new(&cfg)
|
let storage = StorageManager::new(&cfg)
|
||||||
.await
|
.await
|
||||||
.expect("create storage manager");
|
.with_context(|| "create storage manager".to_string())?;
|
||||||
|
|
||||||
let location = "stream/test.bin";
|
let location = "stream/test.bin";
|
||||||
let content = vec![42u8; 1024 * 64]; // 64KB of data
|
let content = vec![42u8; 1024 * 64]; // 64KB of data
|
||||||
@@ -794,22 +877,27 @@ mod tests {
|
|||||||
storage
|
storage
|
||||||
.put(location, Bytes::from(content.clone()))
|
.put(location, Bytes::from(content.clone()))
|
||||||
.await
|
.await
|
||||||
.expect("put large data");
|
.with_context(|| "put large data".to_string())?;
|
||||||
|
|
||||||
// Get as stream
|
// Get as stream
|
||||||
let mut stream = storage.get_stream(location).await.expect("get stream");
|
let mut stream = storage
|
||||||
|
.get_stream(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get stream".to_string())?;
|
||||||
let mut collected = Vec::new();
|
let mut collected = Vec::new();
|
||||||
|
|
||||||
while let Some(chunk) = stream.next().await {
|
while let Some(chunk) = stream.next().await {
|
||||||
let chunk = chunk.expect("stream chunk");
|
let chunk = chunk.with_context(|| "stream chunk".to_string())?;
|
||||||
collected.extend_from_slice(&chunk);
|
collected.extend_from_slice(&chunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(collected, content);
|
assert_eq!(collected, content);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_with_custom_backend() {
|
async fn test_storage_manager_with_custom_backend() -> anyhow::Result<()> {
|
||||||
use object_store::memory::InMemory;
|
use object_store::memory::InMemory;
|
||||||
|
|
||||||
// Create custom memory backend
|
// Create custom memory backend
|
||||||
@@ -823,20 +911,28 @@ mod tests {
|
|||||||
storage
|
storage
|
||||||
.put(location, Bytes::from(data.to_vec()))
|
.put(location, Bytes::from(data.to_vec()))
|
||||||
.await
|
.await
|
||||||
.expect("put");
|
.with_context(|| "put".to_string())?;
|
||||||
let retrieved = storage.get(location).await.expect("get");
|
let retrieved = storage
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get".to_string())?;
|
||||||
assert_eq!(retrieved.as_ref(), data);
|
assert_eq!(retrieved.as_ref(), data);
|
||||||
|
|
||||||
assert!(storage.exists(location).await.expect("exists"));
|
assert!(storage
|
||||||
|
.exists(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "exists".to_string())?);
|
||||||
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
|
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_error_handling() {
|
async fn test_storage_manager_error_handling() -> anyhow::Result<()> {
|
||||||
let cfg = test_config_memory();
|
let cfg = test_config_memory();
|
||||||
let storage = StorageManager::new(&cfg)
|
let storage = StorageManager::new(&cfg)
|
||||||
.await
|
.await
|
||||||
.expect("create storage manager");
|
.with_context(|| "create storage manager".to_string())?;
|
||||||
|
|
||||||
// Test getting non-existent file
|
// Test getting non-existent file
|
||||||
let result = storage.get("nonexistent.txt").await;
|
let result = storage.get("nonexistent.txt").await;
|
||||||
@@ -846,124 +942,163 @@ mod tests {
|
|||||||
let exists = storage
|
let exists = storage
|
||||||
.exists("nonexistent.txt")
|
.exists("nonexistent.txt")
|
||||||
.await
|
.await
|
||||||
.expect("exists check");
|
.with_context(|| "exists check".to_string())?;
|
||||||
assert!(!exists);
|
assert!(!exists);
|
||||||
|
|
||||||
// Test listing with invalid location (should not panic)
|
// Test listing with invalid location (should not panic)
|
||||||
let _result = storage.get("").await;
|
let _result = storage.get("").await;
|
||||||
// This may or may not error depending on the backend implementation
|
// This may or may not error depending on the backend implementation
|
||||||
// The important thing is that it doesn't panic
|
// The important thing is that it doesn't panic
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestStorageManager tests
|
// TestStorageManager tests
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_test_storage_manager_memory() {
|
async fn test_test_storage_manager_memory() -> anyhow::Result<()> {
|
||||||
let test_storage = testing::TestStorageManager::new_memory()
|
let test_storage = testing::TestStorageManager::new_memory()
|
||||||
.await
|
.await
|
||||||
.expect("create test storage");
|
.with_context(|| "create test storage".to_string())?;
|
||||||
|
|
||||||
let location = "test/storage/file.txt";
|
let location = "test/storage/file.txt";
|
||||||
let data = b"test data with TestStorageManager";
|
let data = b"test data with TestStorageManager";
|
||||||
|
|
||||||
// Test put and get
|
// Test put and get
|
||||||
test_storage.put(location, data).await.expect("put");
|
test_storage
|
||||||
let retrieved = test_storage.get(location).await.expect("get");
|
.put(location, data)
|
||||||
|
.await
|
||||||
|
.with_context(|| "put".to_string())?;
|
||||||
|
let retrieved = test_storage
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get".to_string())?;
|
||||||
assert_eq!(retrieved.as_ref(), data);
|
assert_eq!(retrieved.as_ref(), data);
|
||||||
|
|
||||||
// Test existence check
|
// Test existence check
|
||||||
assert!(test_storage.exists(location).await.expect("exists"));
|
assert!(test_storage
|
||||||
|
.exists(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "exists".to_string())?);
|
||||||
|
|
||||||
// Test list
|
// Test list
|
||||||
let files = test_storage
|
let files = test_storage
|
||||||
.list(Some("test/storage/"))
|
.list(Some("test/storage/"))
|
||||||
.await
|
.await
|
||||||
.expect("list");
|
.with_context(|| "list".to_string())?;
|
||||||
assert_eq!(files.len(), 1);
|
assert_eq!(files.len(), 1);
|
||||||
|
|
||||||
// Test delete
|
// Test delete
|
||||||
test_storage
|
test_storage
|
||||||
.delete_prefix("test/storage/")
|
.delete_prefix("test/storage/")
|
||||||
.await
|
.await
|
||||||
.expect("delete");
|
.with_context(|| "delete".to_string())?;
|
||||||
assert!(!test_storage
|
assert!(!test_storage
|
||||||
.exists(location)
|
.exists(location)
|
||||||
.await
|
.await
|
||||||
.expect("exists after delete"));
|
.with_context(|| "exists after delete".to_string())?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_test_storage_manager_local() {
|
async fn test_test_storage_manager_local() -> anyhow::Result<()> {
|
||||||
let test_storage = testing::TestStorageManager::new_local()
|
let test_storage = testing::TestStorageManager::new_local()
|
||||||
.await
|
.await
|
||||||
.expect("create test storage");
|
.with_context(|| "create test storage".to_string())?;
|
||||||
|
|
||||||
let location = "test/local/file.txt";
|
let location = "test/local/file.txt";
|
||||||
let data = b"test data with local TestStorageManager";
|
let data = b"test data with local TestStorageManager";
|
||||||
|
|
||||||
// Test put and get
|
test_storage
|
||||||
test_storage.put(location, data).await.expect("put");
|
.put(location, data)
|
||||||
let retrieved = test_storage.get(location).await.expect("get");
|
.await
|
||||||
|
.with_context(|| "put".to_string())?;
|
||||||
|
let retrieved = test_storage
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get".to_string())?;
|
||||||
assert_eq!(retrieved.as_ref(), data);
|
assert_eq!(retrieved.as_ref(), data);
|
||||||
|
|
||||||
// Test existence check
|
assert!(test_storage
|
||||||
assert!(test_storage.exists(location).await.expect("exists"));
|
.exists(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "exists".to_string())?);
|
||||||
|
|
||||||
// The storage should be automatically cleaned up when test_storage is dropped
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_test_storage_manager_isolation() {
|
async fn test_test_storage_manager_isolation() -> anyhow::Result<()> {
|
||||||
let storage1 = testing::TestStorageManager::new_memory()
|
let storage1 = testing::TestStorageManager::new_memory()
|
||||||
.await
|
.await
|
||||||
.expect("create test storage 1");
|
.with_context(|| "create test storage 1".to_string())?;
|
||||||
let storage2 = testing::TestStorageManager::new_memory()
|
let storage2 = testing::TestStorageManager::new_memory()
|
||||||
.await
|
.await
|
||||||
.expect("create test storage 2");
|
.with_context(|| "create test storage 2".to_string())?;
|
||||||
|
|
||||||
let location = "isolation/test.txt";
|
let location = "isolation/test.txt";
|
||||||
let data1 = b"storage 1 data";
|
let data1 = b"storage 1 data";
|
||||||
let data2 = b"storage 2 data";
|
let data2 = b"storage 2 data";
|
||||||
|
|
||||||
// Put different data in each storage
|
storage1
|
||||||
storage1.put(location, data1).await.expect("put storage 1");
|
.put(location, data1)
|
||||||
storage2.put(location, data2).await.expect("put storage 2");
|
.await
|
||||||
|
.with_context(|| "put storage 1".to_string())?;
|
||||||
|
storage2
|
||||||
|
.put(location, data2)
|
||||||
|
.await
|
||||||
|
.with_context(|| "put storage 2".to_string())?;
|
||||||
|
|
||||||
// Verify isolation
|
let retrieved1 = storage1
|
||||||
let retrieved1 = storage1.get(location).await.expect("get storage 1");
|
.get(location)
|
||||||
let retrieved2 = storage2.get(location).await.expect("get storage 2");
|
.await
|
||||||
|
.with_context(|| "get storage 1".to_string())?;
|
||||||
|
let retrieved2 = storage2
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get storage 2".to_string())?;
|
||||||
|
|
||||||
assert_eq!(retrieved1.as_ref(), data1);
|
assert_eq!(retrieved1.as_ref(), data1);
|
||||||
assert_eq!(retrieved2.as_ref(), data2);
|
assert_eq!(retrieved2.as_ref(), data2);
|
||||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_test_storage_manager_config() {
|
async fn test_test_storage_manager_config() -> anyhow::Result<()> {
|
||||||
let cfg = testing::test_config_memory();
|
let cfg = testing::test_config_memory();
|
||||||
let test_storage = testing::TestStorageManager::with_config(&cfg)
|
let test_storage = testing::TestStorageManager::with_config(&cfg)
|
||||||
.await
|
.await
|
||||||
.expect("create test storage with config");
|
.with_context(|| "create test storage with config".to_string())?;
|
||||||
|
|
||||||
let location = "config/test.txt";
|
let location = "config/test.txt";
|
||||||
let data = b"test data with custom config";
|
let data = b"test data with custom config";
|
||||||
|
|
||||||
test_storage.put(location, data).await.expect("put");
|
test_storage
|
||||||
let retrieved = test_storage.get(location).await.expect("get");
|
.put(location, data)
|
||||||
|
.await
|
||||||
|
.with_context(|| "put".to_string())?;
|
||||||
|
let retrieved = test_storage
|
||||||
|
.get(location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get".to_string())?;
|
||||||
assert_eq!(retrieved.as_ref(), data);
|
assert_eq!(retrieved.as_ref(), data);
|
||||||
|
|
||||||
// Verify it's using memory backend
|
|
||||||
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
|
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// S3 Tests - Require a reachable MinIO endpoint and test bucket.
|
// S3 Tests - Require a reachable MinIO endpoint and test bucket.
|
||||||
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
|
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_s3_basic_operations() {
|
async fn test_storage_manager_s3_basic_operations() -> anyhow::Result<()> {
|
||||||
// Skip if S3 connection fails (e.g. no MinIO)
|
// Skip if S3 connection fails (e.g. no MinIO)
|
||||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||||
eprintln!("Skipping S3 test (setup failed)");
|
eprintln!("Skipping S3 test (setup failed)");
|
||||||
return;
|
return Ok(());
|
||||||
};
|
};
|
||||||
|
|
||||||
let prefix = format!("test-basic-{}", Uuid::new_v4());
|
let prefix = format!("test-basic-{}", Uuid::new_v4());
|
||||||
@@ -973,31 +1108,39 @@ mod tests {
|
|||||||
// Test put
|
// Test put
|
||||||
if let Err(e) = storage.put(&location, data).await {
|
if let Err(e) = storage.put(&location, data).await {
|
||||||
eprintln!("Skipping S3 test (put failed - bucket missing?): {e}");
|
eprintln!("Skipping S3 test (put failed - bucket missing?): {e}");
|
||||||
return;
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test get
|
// Test get
|
||||||
let retrieved = storage.get(&location).await.expect("get");
|
let retrieved = storage
|
||||||
|
.get(&location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get".to_string())?;
|
||||||
assert_eq!(retrieved.as_ref(), data);
|
assert_eq!(retrieved.as_ref(), data);
|
||||||
|
|
||||||
// Test exists
|
// Test exists
|
||||||
assert!(storage.exists(&location).await.expect("exists"));
|
assert!(storage
|
||||||
|
.exists(&location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "exists".to_string())?);
|
||||||
|
|
||||||
// Test delete
|
// Test delete
|
||||||
storage
|
storage
|
||||||
.delete_prefix(&format!("{prefix}/"))
|
.delete_prefix(&format!("{prefix}/"))
|
||||||
.await
|
.await
|
||||||
.expect("delete");
|
.with_context(|| "delete".to_string())?;
|
||||||
assert!(!storage
|
assert!(!storage
|
||||||
.exists(&location)
|
.exists(&location)
|
||||||
.await
|
.await
|
||||||
.expect("exists after delete"));
|
.with_context(|| "exists after delete".to_string())?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_s3_list_operations() {
|
async fn test_storage_manager_s3_list_operations() -> anyhow::Result<()> {
|
||||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||||
return;
|
return Ok(());
|
||||||
};
|
};
|
||||||
|
|
||||||
let prefix = format!("test-list-{}", Uuid::new_v4());
|
let prefix = format!("test-list-{}", Uuid::new_v4());
|
||||||
@@ -1009,23 +1152,31 @@ mod tests {
|
|||||||
|
|
||||||
for (loc, data) in &files {
|
for (loc, data) in &files {
|
||||||
if storage.put(loc, *data).await.is_err() {
|
if storage.put(loc, *data).await.is_err() {
|
||||||
return; // Abort if put fails
|
return Ok(()); // Abort if put fails
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// List with prefix
|
// List with prefix
|
||||||
let list_prefix = format!("{prefix}/");
|
let list_prefix = format!("{prefix}/");
|
||||||
let items = storage.list(Some(&list_prefix)).await.expect("list");
|
let items = storage
|
||||||
|
.list(Some(&list_prefix))
|
||||||
|
.await
|
||||||
|
.with_context(|| "list".to_string())?;
|
||||||
assert_eq!(items.len(), 3);
|
assert_eq!(items.len(), 3);
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
storage.delete_prefix(&list_prefix).await.expect("cleanup");
|
storage
|
||||||
|
.delete_prefix(&list_prefix)
|
||||||
|
.await
|
||||||
|
.with_context(|| "cleanup".to_string())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_s3_stream_operations() {
|
async fn test_storage_manager_s3_stream_operations() -> anyhow::Result<()> {
|
||||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||||
return;
|
return Ok(());
|
||||||
};
|
};
|
||||||
|
|
||||||
let prefix = format!("test-stream-{}", Uuid::new_v4());
|
let prefix = format!("test-stream-{}", Uuid::new_v4());
|
||||||
@@ -1033,38 +1184,48 @@ mod tests {
|
|||||||
let content = vec![42u8; 1024 * 10]; // 10KB
|
let content = vec![42u8; 1024 * 10]; // 10KB
|
||||||
|
|
||||||
if storage.put(&location, &content).await.is_err() {
|
if storage.put(&location, &content).await.is_err() {
|
||||||
return;
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut stream = storage.get_stream(&location).await.expect("get stream");
|
let mut stream = storage
|
||||||
|
.get_stream(&location)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get stream".to_string())?;
|
||||||
let mut collected = Vec::new();
|
let mut collected = Vec::new();
|
||||||
while let Some(chunk) = stream.next().await {
|
while let Some(chunk) = stream.next().await {
|
||||||
collected.extend_from_slice(&chunk.expect("chunk"));
|
collected.extend_from_slice(&chunk.with_context(|| "chunk".to_string())?);
|
||||||
}
|
}
|
||||||
assert_eq!(collected, content);
|
assert_eq!(collected, content);
|
||||||
|
|
||||||
storage
|
storage
|
||||||
.delete_prefix(&format!("{prefix}/"))
|
.delete_prefix(&format!("{prefix}/"))
|
||||||
.await
|
.await
|
||||||
.expect("cleanup");
|
.with_context(|| "cleanup".to_string())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_s3_backend_kind() {
|
async fn test_storage_manager_s3_backend_kind() -> anyhow::Result<()> {
|
||||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||||
return;
|
return Ok(());
|
||||||
};
|
};
|
||||||
assert_eq!(*storage.storage().backend_kind(), StorageKind::S3);
|
assert_eq!(*storage.storage().backend_kind(), StorageKind::S3);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_storage_manager_s3_error_handling() {
|
async fn test_storage_manager_s3_error_handling() -> anyhow::Result<()> {
|
||||||
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
|
||||||
return;
|
return Ok(());
|
||||||
};
|
};
|
||||||
|
|
||||||
let location = format!("nonexistent-{}/file.txt", Uuid::new_v4());
|
let location = format!("nonexistent-{}/file.txt", Uuid::new_v4());
|
||||||
assert!(storage.get(&location).await.is_err());
|
assert!(storage.get(&location).await.is_err());
|
||||||
assert!(!storage.exists(&location).await.expect("exists check"));
|
// exists may fail if S3 is unavailable; treat error as false
|
||||||
|
assert!(!storage.exists(&location).await.unwrap_or(false));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::storage::types::{file_info::deserialize_flexible_id, user::User, StoredObject};
|
use crate::storage::types::{user::User, StoredObject};
|
||||||
|
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||||
@@ -16,55 +17,71 @@ impl StoredObject for Analytics {
|
|||||||
"analytics"
|
"analytics"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_id(&self) -> &str {
|
fn id(&self) -> &str {
|
||||||
&self.id
|
&self.id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Analytics {
|
impl Analytics {
|
||||||
|
const RECORD_ID: &'static str = "current";
|
||||||
|
|
||||||
|
/// Ensures the singleton analytics record exists (idempotent).
|
||||||
|
///
|
||||||
|
/// Production databases are also seeded by `20250503_215025_initial_setup.surql`;
|
||||||
|
/// this uses an atomic `UPSERT` for tests and recovery.
|
||||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let analytics = db.get_item::<Self>("current").await?;
|
let analytics: Option<Self> = db
|
||||||
|
.client
|
||||||
if analytics.is_none() {
|
.query(
|
||||||
let created_analytics = Analytics {
|
"UPSERT type::thing('analytics', $id) SET visitors = visitors ?? 0, page_loads = page_loads ?? 0 RETURN AFTER",
|
||||||
id: "current".to_string(),
|
)
|
||||||
visitors: 0,
|
.bind(("id", Self::RECORD_ID))
|
||||||
page_loads: 0,
|
.await?
|
||||||
};
|
.take(0)?;
|
||||||
|
|
||||||
let stored: Option<Self> = db.store_item(created_analytics).await?;
|
|
||||||
return stored.ok_or(AppError::Validation(
|
|
||||||
"Failed to initialize analytics".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
analytics.ok_or(AppError::Validation(
|
analytics.ok_or(AppError::Validation(
|
||||||
"Failed to initialize analytics".into(),
|
"failed to initialize analytics".into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let analytics: Option<Self> = db.get_item("current").await?;
|
let analytics: Option<Self> = db.get_item("current").await?;
|
||||||
analytics.ok_or(AppError::NotFound("Analytics not found".into()))
|
analytics.ok_or(AppError::NotFound("analytics not found".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.client
|
.client
|
||||||
.query("UPDATE type::thing('analytics', 'current') SET visitors += 1 RETURN AFTER")
|
.query(
|
||||||
|
"UPSERT type::thing('analytics', $id) SET visitors = (visitors ?? 0) + 1, page_loads = page_loads ?? 0 RETURN AFTER",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
|
updated.ok_or(AppError::Validation("failed to update analytics".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
|
Self::record_page_view(db, false).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Records a page view, optionally counting the visitor as new.
|
||||||
|
pub async fn record_page_view(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
is_new_visitor: bool,
|
||||||
|
) -> Result<Self, AppError> {
|
||||||
|
let visitor_delta = i64::from(is_new_visitor);
|
||||||
let updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.client
|
.client
|
||||||
.query("UPDATE type::thing('analytics', 'current') SET page_loads += 1 RETURN AFTER")
|
.query(
|
||||||
|
"UPSERT type::thing('analytics', $id) SET page_loads = (page_loads ?? 0) + 1, visitors = (visitors ?? 0) + $visitor_delta RETURN AFTER",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
|
.bind(("visitor_delta", visitor_delta))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
updated.ok_or(AppError::Validation("Failed to update analytics".into()))
|
updated.ok_or(AppError::Validation("failed to update analytics".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
|
pub async fn get_users_amount(db: &SurrealDbClient) -> Result<i64, AppError> {
|
||||||
@@ -88,8 +105,10 @@ impl Analytics {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
|
use anyhow::{self};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
stored_object!(TestUser, "user", {
|
stored_object!(TestUser, "user", {
|
||||||
@@ -99,18 +118,14 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_analytics_initialization() {
|
async fn test_analytics_initialization() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Test initialization of analytics
|
// Test initialization of analytics
|
||||||
let analytics = Analytics::ensure_initialized(&db)
|
let analytics = Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to initialize analytics");
|
|
||||||
|
|
||||||
// Verify initial state after initialization
|
// Verify initial state after initialization
|
||||||
assert_eq!(analytics.id, "current");
|
assert_eq!(analytics.id, "current");
|
||||||
@@ -118,159 +133,198 @@ mod tests {
|
|||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
|
|
||||||
// Test idempotency - ensure calling it again doesn't change anything
|
// Test idempotency - ensure calling it again doesn't change anything
|
||||||
let analytics_again = Analytics::ensure_initialized(&db)
|
let analytics_again = Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to get analytics after initialization");
|
|
||||||
|
|
||||||
assert_eq!(analytics.id, analytics_again.id);
|
assert_eq!(analytics.id, analytics_again.id);
|
||||||
assert_eq!(analytics.page_loads, analytics_again.page_loads);
|
assert_eq!(analytics.page_loads, analytics_again.page_loads);
|
||||||
assert_eq!(analytics.visitors, analytics_again.visitors);
|
assert_eq!(analytics.visitors, analytics_again.visitors);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_analytics() {
|
async fn test_get_current_analytics() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db)
|
Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to initialize analytics");
|
|
||||||
|
|
||||||
// Test get_current method
|
// Test get_current method
|
||||||
let analytics = Analytics::get_current(&db)
|
let analytics = Analytics::get_current(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to get current analytics");
|
|
||||||
|
|
||||||
assert_eq!(analytics.id, "current");
|
assert_eq!(analytics.id, "current");
|
||||||
assert_eq!(analytics.page_loads, 0);
|
assert_eq!(analytics.page_loads, 0);
|
||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_increment_visitors() {
|
async fn test_increment_visitors() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db)
|
Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to initialize analytics");
|
|
||||||
|
|
||||||
// Test increment_visitors method
|
// Test increment_visitors method
|
||||||
let analytics = Analytics::increment_visitors(&db)
|
let analytics = Analytics::increment_visitors(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to increment visitors");
|
|
||||||
|
|
||||||
assert_eq!(analytics.visitors, 1);
|
assert_eq!(analytics.visitors, 1);
|
||||||
assert_eq!(analytics.page_loads, 0);
|
assert_eq!(analytics.page_loads, 0);
|
||||||
|
|
||||||
// Increment again and check
|
// Increment again and check
|
||||||
let analytics = Analytics::increment_visitors(&db)
|
let analytics = Analytics::increment_visitors(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to increment visitors again");
|
|
||||||
|
|
||||||
assert_eq!(analytics.visitors, 2);
|
assert_eq!(analytics.visitors, 2);
|
||||||
assert_eq!(analytics.page_loads, 0);
|
assert_eq!(analytics.page_loads, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_increment_page_loads() {
|
async fn test_increment_page_loads() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Initialize analytics
|
// Initialize analytics
|
||||||
Analytics::ensure_initialized(&db)
|
Analytics::ensure_initialized(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to initialize analytics");
|
|
||||||
|
|
||||||
// Test increment_page_loads method
|
// Test increment_page_loads method
|
||||||
let analytics = Analytics::increment_page_loads(&db)
|
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to increment page loads");
|
|
||||||
|
|
||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
assert_eq!(analytics.page_loads, 1);
|
assert_eq!(analytics.page_loads, 1);
|
||||||
|
|
||||||
// Increment again and check
|
// Increment again and check
|
||||||
let analytics = Analytics::increment_page_loads(&db)
|
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to increment page loads again");
|
|
||||||
|
|
||||||
assert_eq!(analytics.visitors, 0);
|
assert_eq!(analytics.visitors, 0);
|
||||||
assert_eq!(analytics.page_loads, 2);
|
assert_eq!(analytics.page_loads, 2);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_users_amount() {
|
async fn test_get_users_amount() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Test with no users
|
// Test with no users
|
||||||
let count = Analytics::get_users_amount(&db)
|
let count = Analytics::get_users_amount(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to get users amount");
|
|
||||||
assert_eq!(count, 0);
|
assert_eq!(count, 0);
|
||||||
|
|
||||||
// Create a few test users
|
// Create a few test users
|
||||||
for i in 0..3 {
|
for i in 0..3 {
|
||||||
let user = TestUser {
|
let user = TestUser {
|
||||||
id: format!("user{}", i),
|
id: format!("user{i}"),
|
||||||
email: format!("user{}@example.com", i),
|
email: format!("user{i}@example.com"),
|
||||||
password: "password".to_string(),
|
password: "password".to_string(),
|
||||||
user_id: format!("uid{}", i),
|
user_id: format!("uid{i}"),
|
||||||
created_at: Utc::now(),
|
created_at: Utc::now(),
|
||||||
updated_at: Utc::now(),
|
updated_at: Utc::now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
db.store_item(user)
|
db.store_item(user).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to create test user");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test users amount after adding users
|
// Test users amount after adding users
|
||||||
let count = Analytics::get_users_amount(&db)
|
let count = Analytics::get_users_amount(&db).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to get users amount after adding users");
|
|
||||||
assert_eq!(count, 3);
|
assert_eq!(count, 3);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_nonexistent() {
|
async fn test_increment_visitors_without_prior_init() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let analytics = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(analytics.visitors, 1);
|
||||||
|
assert_eq!(analytics.page_loads, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_increment_page_loads_without_prior_init() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||||
|
assert_eq!(analytics.page_loads, 1);
|
||||||
|
assert_eq!(analytics.visitors, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_visitor_and_page_load_increments_are_independent() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let after_visitors = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(after_visitors.visitors, 1);
|
||||||
|
assert_eq!(after_visitors.page_loads, 0);
|
||||||
|
|
||||||
|
let after_page_load = Analytics::increment_page_loads(&db).await?;
|
||||||
|
assert_eq!(after_page_load.visitors, 1);
|
||||||
|
assert_eq!(after_page_load.page_loads, 1);
|
||||||
|
|
||||||
|
let after_second_visitor = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(after_second_visitor.visitors, 2);
|
||||||
|
assert_eq!(after_second_visitor.page_loads, 1);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_record_page_view() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let first_view = Analytics::record_page_view(&db, true).await?;
|
||||||
|
assert_eq!(first_view.visitors, 1);
|
||||||
|
assert_eq!(first_view.page_loads, 1);
|
||||||
|
|
||||||
|
let returning_view = Analytics::record_page_view(&db, false).await?;
|
||||||
|
assert_eq!(returning_view.visitors, 1);
|
||||||
|
assert_eq!(returning_view.page_loads, 2);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
// Don't initialize analytics and try to get it
|
// Don't initialize analytics and try to get it
|
||||||
let result = Analytics::get_current(&db).await;
|
let result = Analytics::get_current(&db).await;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
if let Err(err) = result {
|
match result {
|
||||||
match err {
|
Ok(_) => anyhow::bail!("Expected NotFound error, got success"),
|
||||||
AppError::NotFound(_) => {
|
Err(AppError::NotFound(_)) => {}
|
||||||
// Expected error
|
Err(err) => anyhow::bail!("Expected NotFound error, got: {err:?}"),
|
||||||
}
|
|
||||||
_ => panic!("Expected NotFound error, got: {:?}", err),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ stored_object!(Conversation, "conversation", {
|
|||||||
});
|
});
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
pub struct SidebarConversation {
|
pub struct SidebarConversation {
|
||||||
#[serde(deserialize_with = "deserialize_sidebar_id")]
|
#[serde(deserialize_with = "deserialize_sidebar_id")]
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@@ -59,6 +60,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Conversation {
|
impl Conversation {
|
||||||
|
#[must_use]
|
||||||
pub fn new(user_id: String, title: String) -> Self {
|
pub fn new(user_id: String, title: String) -> Self {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
Self {
|
Self {
|
||||||
@@ -78,7 +80,7 @@ impl Conversation {
|
|||||||
let conversation: Conversation = db
|
let conversation: Conversation = db
|
||||||
.get_item(conversation_id)
|
.get_item(conversation_id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
|
.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
|
||||||
|
|
||||||
if conversation.user_id != user_id {
|
if conversation.user_id != user_id {
|
||||||
return Err(AppError::Auth(
|
return Err(AppError::Auth(
|
||||||
@@ -86,10 +88,15 @@ impl Conversation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let messages:Vec<Message> = db.client.
|
let messages: Vec<Message> = db
|
||||||
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
|
.client
|
||||||
bind(("table_name", Message::table_name())).
|
.query(
|
||||||
bind(("conversation_id", conversation_id.to_string()))
|
"SELECT * FROM type::table($message_table) WHERE conversation_id = $conversation_id AND type::thing($conversation_table, $conversation_id).user_id = $user_id ORDER BY updated_at",
|
||||||
|
)
|
||||||
|
.bind(("message_table", Message::table_name()))
|
||||||
|
.bind(("conversation_table", Self::table_name()))
|
||||||
|
.bind(("conversation_id", conversation_id.to_string()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
@@ -104,7 +111,7 @@ impl Conversation {
|
|||||||
// First verify ownership by getting conversation user_id
|
// First verify ownership by getting conversation user_id
|
||||||
let conversation: Option<Conversation> = db.get_item(id).await?;
|
let conversation: Option<Conversation> = db.get_item(id).await?;
|
||||||
let conversation =
|
let conversation =
|
||||||
conversation.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
|
conversation.ok_or_else(|| AppError::NotFound("conversation not found".to_string()))?;
|
||||||
|
|
||||||
if conversation.user_id != user_id {
|
if conversation.user_id != user_id {
|
||||||
return Err(AppError::Auth(
|
return Err(AppError::Auth(
|
||||||
@@ -112,7 +119,7 @@ impl Conversation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let _updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.update((Self::table_name(), id))
|
.update((Self::table_name(), id))
|
||||||
.patch(PatchOp::replace("/title", new_title.to_string()))
|
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||||
.patch(PatchOp::replace(
|
.patch(PatchOp::replace(
|
||||||
@@ -121,6 +128,10 @@ impl Conversation {
|
|||||||
))
|
))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
if updated.is_none() {
|
||||||
|
return Err(AppError::NotFound("conversation not found".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,76 +155,91 @@ impl Conversation {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use crate::storage::types::message::MessageRole;
|
use crate::storage::types::message::MessageRole;
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
const MESSAGE_QUERY_FOR_OWNER: &str = "SELECT * FROM type::table($message_table) WHERE conversation_id = $conversation_id AND type::thing($conversation_table, $conversation_id).user_id = $user_id ORDER BY updated_at";
|
||||||
|
|
||||||
|
async fn fetch_messages_for_owner(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
conversation_id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<Vec<Message>, AppError> {
|
||||||
|
db.client
|
||||||
|
.query(MESSAGE_QUERY_FOR_OWNER)
|
||||||
|
.bind(("message_table", Message::table_name()))
|
||||||
|
.bind(("conversation_table", Conversation::table_name()))
|
||||||
|
.bind(("conversation_id", conversation_id.to_string()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.await?
|
||||||
|
.take(0)
|
||||||
|
.map_err(AppError::from)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_conversation() {
|
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create a new conversation
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let title = "Test Conversation";
|
let title = "Test Conversation";
|
||||||
let conversation = Conversation::new(user_id.to_string(), title.to_string());
|
let conversation = Conversation::new(user_id.to_string(), title.to_string());
|
||||||
|
|
||||||
// Verify conversation properties
|
|
||||||
assert_eq!(conversation.user_id, user_id);
|
assert_eq!(conversation.user_id, user_id);
|
||||||
assert_eq!(conversation.title, title);
|
assert_eq!(conversation.title, title);
|
||||||
assert!(!conversation.id.is_empty());
|
assert!(!conversation.id.is_empty());
|
||||||
|
|
||||||
// Store the conversation
|
|
||||||
let result = db.store_item(conversation.clone()).await;
|
let result = db.store_item(conversation.clone()).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Verify it can be retrieved
|
|
||||||
let retrieved: Option<Conversation> = db
|
let retrieved: Option<Conversation> = db
|
||||||
.get_item(&conversation.id)
|
.get_item(&conversation.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve conversation");
|
.with_context(|| "Failed to retrieve conversation".to_string())?;
|
||||||
assert!(retrieved.is_some());
|
|
||||||
|
|
||||||
let retrieved = retrieved.unwrap();
|
let retrieved =
|
||||||
|
retrieved.ok_or_else(|| anyhow::anyhow!("Expected conversation to exist"))?;
|
||||||
assert_eq!(retrieved.id, conversation.id);
|
assert_eq!(retrieved.id, conversation.id);
|
||||||
assert_eq!(retrieved.user_id, user_id);
|
assert_eq!(retrieved.user_id, user_id);
|
||||||
assert_eq!(retrieved.title, title);
|
assert_eq!(retrieved.title, title);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_not_found() {
|
async fn test_get_complete_conversation_not_found() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Try to get a conversation that doesn't exist
|
|
||||||
let result =
|
let result =
|
||||||
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
Conversation::get_complete_conversation("nonexistent_id", "test_user", &db).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::NotFound(_)) => { /* expected error */ }
|
Err(AppError::NotFound(_)) => {}
|
||||||
_ => panic!("Expected NotFound error"),
|
_ => anyhow::bail!("Expected NotFound error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_unauthorized() {
|
async fn test_get_complete_conversation_unauthorized() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create and store a conversation for user_id_1
|
|
||||||
let user_id_1 = "user_1";
|
let user_id_1 = "user_1";
|
||||||
let conversation =
|
let conversation =
|
||||||
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
|
Conversation::new(user_id_1.to_string(), "Private Conversation".to_string());
|
||||||
@@ -221,27 +247,28 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(conversation)
|
db.store_item(conversation)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
|
|
||||||
// Try to access with a different user
|
|
||||||
let user_id_2 = "user_2";
|
let user_id_2 = "user_2";
|
||||||
let result =
|
let result =
|
||||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected Auth error"),
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_success() {
|
async fn test_patch_title_success() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
let user_id = "user_1";
|
let user_id = "user_1";
|
||||||
let original_title = "Original Title";
|
let original_title = "Original Title";
|
||||||
@@ -250,49 +277,50 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(conversation)
|
db.store_item(conversation)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
|
|
||||||
let new_title = "Updated Title";
|
let new_title = "Updated Title";
|
||||||
|
|
||||||
// Patch title successfully
|
|
||||||
let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await;
|
let result = Conversation::patch_title(&conversation_id, user_id, new_title, &db).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Retrieve from DB to verify
|
|
||||||
let updated_conversation = db
|
let updated_conversation = db
|
||||||
.get_item::<Conversation>(&conversation_id)
|
.get_item::<Conversation>(&conversation_id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get conversation")
|
.with_context(|| "Failed to get conversation".to_string())?
|
||||||
.expect("Conversation missing");
|
.ok_or_else(|| anyhow::anyhow!("Conversation missing"))?;
|
||||||
assert_eq!(updated_conversation.title, new_title);
|
assert_eq!(updated_conversation.title, new_title);
|
||||||
assert_eq!(updated_conversation.user_id, user_id);
|
assert_eq!(updated_conversation.user_id, user_id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_not_found() {
|
async fn test_patch_title_not_found() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Try to patch non-existing conversation
|
|
||||||
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
let result = Conversation::patch_title("nonexistent", "user_x", "New Title", &db).await;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::NotFound(_)) => {}
|
Err(AppError::NotFound(_)) => {}
|
||||||
_ => panic!("Expected NotFound error"),
|
_ => anyhow::bail!("Expected NotFound error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_patch_title_unauthorized() {
|
async fn test_patch_title_unauthorized() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
let owner_id = "owner";
|
let owner_id = "owner";
|
||||||
let other_user_id = "intruder";
|
let other_user_id = "intruder";
|
||||||
@@ -301,17 +329,18 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(conversation)
|
db.store_item(conversation)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
|
|
||||||
// Attempt patch with unauthorized user
|
|
||||||
let result =
|
let result =
|
||||||
Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await;
|
Conversation::patch_title(&conversation_id, other_user_id, "Hacked Title", &db).await;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::Auth(_)) => {}
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected Auth error"),
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -356,12 +385,15 @@ mod tests {
|
|||||||
.expect("Failed to get sidebar conversations");
|
.expect("Failed to get sidebar conversations");
|
||||||
|
|
||||||
assert_eq!(sidebar_items.len(), 3);
|
assert_eq!(sidebar_items.len(), 3);
|
||||||
assert_eq!(sidebar_items[0].id, newest.id);
|
let s0 = sidebar_items.first().expect("expected 3 items");
|
||||||
assert_eq!(sidebar_items[0].title, "Newest");
|
let s1 = sidebar_items.get(1).expect("expected 3 items");
|
||||||
assert_eq!(sidebar_items[1].id, middle.id);
|
let s2 = sidebar_items.get(2).expect("expected 3 items");
|
||||||
assert_eq!(sidebar_items[1].title, "Middle");
|
assert_eq!(s0.id, newest.id);
|
||||||
assert_eq!(sidebar_items[2].id, oldest.id);
|
assert_eq!(s0.title, "Newest");
|
||||||
assert_eq!(sidebar_items[2].title, "Oldest");
|
assert_eq!(s1.id, middle.id);
|
||||||
|
assert_eq!(s1.title, "Middle");
|
||||||
|
assert_eq!(s2.id, oldest.id);
|
||||||
|
assert_eq!(s2.title, "Oldest");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -391,7 +423,8 @@ mod tests {
|
|||||||
let before_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
|
let before_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get sidebar conversations before patch");
|
.expect("Failed to get sidebar conversations before patch");
|
||||||
assert_eq!(before_patch[0].id, second.id);
|
let before = before_patch.first().expect("expected at least 1 item");
|
||||||
|
assert_eq!(before.id, second.id);
|
||||||
|
|
||||||
Conversation::patch_title(&first.id, user_id, "First (renamed)", &db)
|
Conversation::patch_title(&first.id, user_id, "First (renamed)", &db)
|
||||||
.await
|
.await
|
||||||
@@ -400,29 +433,27 @@ mod tests {
|
|||||||
let after_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
|
let after_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get sidebar conversations after patch");
|
.expect("Failed to get sidebar conversations after patch");
|
||||||
assert_eq!(after_patch[0].id, first.id);
|
let after = after_patch.first().expect("expected at least 1 item");
|
||||||
assert_eq!(after_patch[0].title, "First (renamed)");
|
assert_eq!(after.id, first.id);
|
||||||
|
assert_eq!(after.title, "First (renamed)");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_complete_conversation_with_messages() {
|
async fn test_get_complete_conversation_with_messages() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create and store a conversation for user_id_1
|
|
||||||
let user_id_1 = "user_1";
|
let user_id_1 = "user_1";
|
||||||
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
let conversation = Conversation::new(user_id_1.to_string(), "Conversation".to_string());
|
||||||
let conversation_id = conversation.id.clone();
|
let conversation_id = conversation.id.clone();
|
||||||
|
|
||||||
db.store_item(conversation)
|
db.store_item(conversation)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
|
|
||||||
// Create messages
|
|
||||||
let message1 = Message::new(
|
let message1 = Message::new(
|
||||||
conversation_id.clone(),
|
conversation_id.clone(),
|
||||||
MessageRole::User,
|
MessageRole::User,
|
||||||
@@ -442,46 +473,200 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Store messages
|
|
||||||
db.store_item(message1)
|
db.store_item(message1)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store message1");
|
.with_context(|| "Failed to store message1".to_string())?;
|
||||||
db.store_item(message2)
|
db.store_item(message2)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store message2");
|
.with_context(|| "Failed to store message2".to_string())?;
|
||||||
db.store_item(message3)
|
db.store_item(message3)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store message3");
|
.with_context(|| "Failed to store message3".to_string())?;
|
||||||
|
|
||||||
// Retrieve the complete conversation
|
|
||||||
let result =
|
let result =
|
||||||
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
|
Conversation::get_complete_conversation(&conversation_id, user_id_1, &db).await;
|
||||||
assert!(result.is_ok(), "Failed to retrieve complete conversation");
|
assert!(result.is_ok(), "Failed to retrieve complete conversation");
|
||||||
|
|
||||||
let (retrieved_conversation, messages) = result.unwrap();
|
let (retrieved_conversation, retrieved_messages) =
|
||||||
|
result.with_context(|| "Failed to retrieve complete conversation".to_string())?;
|
||||||
|
|
||||||
// Verify conversation data
|
|
||||||
assert_eq!(retrieved_conversation.id, conversation_id);
|
assert_eq!(retrieved_conversation.id, conversation_id);
|
||||||
assert_eq!(retrieved_conversation.user_id, user_id_1);
|
assert_eq!(retrieved_conversation.user_id, user_id_1);
|
||||||
assert_eq!(retrieved_conversation.title, "Conversation");
|
assert_eq!(retrieved_conversation.title, "Conversation");
|
||||||
|
|
||||||
// Verify messages
|
assert_eq!(retrieved_messages.len(), 3);
|
||||||
assert_eq!(messages.len(), 3);
|
|
||||||
|
|
||||||
// Verify messages are sorted by updated_at
|
let message_contents: Vec<&str> = retrieved_messages
|
||||||
let message_contents: Vec<&str> = messages.iter().map(|m| m.content.as_str()).collect();
|
.iter()
|
||||||
|
.map(|m| m.content.as_str())
|
||||||
|
.collect();
|
||||||
assert!(message_contents.contains(&"Hello, AI!"));
|
assert!(message_contents.contains(&"Hello, AI!"));
|
||||||
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
|
assert!(message_contents.contains(&"Hello, human! How can I help you today?"));
|
||||||
assert!(message_contents.contains(&"Tell me about Rust programming."));
|
assert!(message_contents.contains(&"Tell me about Rust programming."));
|
||||||
|
|
||||||
// Make sure we can't access with different user
|
|
||||||
let user_id_2 = "user_2";
|
let user_id_2 = "user_2";
|
||||||
let unauthorized_result =
|
let unauthorized_result =
|
||||||
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
Conversation::get_complete_conversation(&conversation_id, user_id_2, &db).await;
|
||||||
assert!(unauthorized_result.is_err());
|
assert!(unauthorized_result.is_err());
|
||||||
match unauthorized_result {
|
match unauthorized_result {
|
||||||
Err(AppError::Auth(_)) => { /* expected error */ }
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected Auth error"),
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sidebar_conversation_deserializes_plain_string_id() {
|
||||||
|
let item: SidebarConversation =
|
||||||
|
serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#)
|
||||||
|
.expect("valid sidebar conversation json");
|
||||||
|
assert_eq!(item.id, "conv-plain");
|
||||||
|
assert_eq!(item.title, "My chat");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sidebar_conversation_deserializes_id_from_db_record() {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.expect("Failed to start in-memory surrealdb");
|
||||||
|
|
||||||
|
let owner = "sidebar_owner";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "Sidebar title".to_string());
|
||||||
|
let expected_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation)
|
||||||
|
.await
|
||||||
|
.expect("Failed to store conversation");
|
||||||
|
|
||||||
|
let items = Conversation::get_user_sidebar_conversations(owner, &db)
|
||||||
|
.await
|
||||||
|
.expect("Failed to load sidebar");
|
||||||
|
assert_eq!(items.len(), 1);
|
||||||
|
let item = items.first().expect("expected one sidebar item");
|
||||||
|
assert_eq!(item.id, expected_id);
|
||||||
|
assert_eq!(item.title, "Sidebar title");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_message_query_filters_by_owner_user_id_in_sql() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let owner = "owner_user";
|
||||||
|
let intruder = "intruder_user";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "Private".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
db.store_item(Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::User,
|
||||||
|
"secret message".to_string(),
|
||||||
|
None,
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let owner_messages = fetch_messages_for_owner(&db, &conversation_id, owner).await?;
|
||||||
|
assert_eq!(owner_messages.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
owner_messages
|
||||||
|
.first()
|
||||||
|
.expect("expected owner message")
|
||||||
|
.content,
|
||||||
|
"secret message"
|
||||||
|
);
|
||||||
|
|
||||||
|
let intruder_messages = fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
|
||||||
|
assert!(
|
||||||
|
intruder_messages.is_empty(),
|
||||||
|
"SQL owner filter must not return messages for a non-owner user_id"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_complete_conversation_orders_messages_by_updated_at() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let user_id = "order_user";
|
||||||
|
let conversation = Conversation::new(user_id.to_string(), "Ordered".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
|
||||||
|
let base = Utc::now();
|
||||||
|
let mut first = Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::User,
|
||||||
|
"first".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
first.updated_at = base - chrono::Duration::minutes(20);
|
||||||
|
|
||||||
|
let mut second = Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::AI,
|
||||||
|
"second".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
second.updated_at = base - chrono::Duration::minutes(5);
|
||||||
|
|
||||||
|
db.store_item(first).await?;
|
||||||
|
db.store_item(second).await?;
|
||||||
|
|
||||||
|
let (_, messages) =
|
||||||
|
Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?;
|
||||||
|
|
||||||
|
assert_eq!(messages.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
messages.first().expect("expected first message").content,
|
||||||
|
"first"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
messages.get(1).expect("expected second message").content,
|
||||||
|
"second"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_patch_title_not_found_when_conversation_deleted() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let owner = "owner";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "To delete".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
db.delete_item::<Conversation>(&conversation_id).await?;
|
||||||
|
|
||||||
|
let result = Conversation::patch_title(&conversation_id, owner, "New title", &db).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::NotFound(_)) => {}
|
||||||
|
other => anyhow::bail!("expected NotFound, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_conversation_new_initializes_timestamps_and_id() {
|
||||||
|
let before = Utc::now();
|
||||||
|
let conversation = Conversation::new("user".to_string(), "Title".to_string());
|
||||||
|
let after = Utc::now();
|
||||||
|
|
||||||
|
assert!(!conversation.id.is_empty());
|
||||||
|
assert!(conversation.created_at >= before && conversation.created_at <= after);
|
||||||
|
assert_eq!(conversation.created_at, conversation.updated_at);
|
||||||
|
assert_eq!(conversation.user_id, "user");
|
||||||
|
assert_eq!(conversation.title, "Title");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,4 @@
|
|||||||
#![allow(
|
#![allow(clippy::result_large_err)]
|
||||||
clippy::result_large_err,
|
|
||||||
clippy::needless_pass_by_value,
|
|
||||||
clippy::implicit_clone,
|
|
||||||
clippy::semicolon_if_nothing_returned
|
|
||||||
)]
|
|
||||||
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
use crate::{error::AppError, storage::types::file_info::FileInfo};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
@@ -31,78 +26,150 @@ pub enum IngestionPayload {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for IngestionPayload {
|
||||||
|
/// An empty text payload, used as a cheap placeholder when the real content
|
||||||
|
/// has been moved out of a task (see [`crate::storage::types::ingestion_task::IngestionTask::take_content`]).
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Text {
|
||||||
|
text: String::new(),
|
||||||
|
context: String::new(),
|
||||||
|
category: String::new(),
|
||||||
|
user_id: String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared ingest metadata moved or cloned into each payload variant.
|
||||||
|
struct IngestFields {
|
||||||
|
context: String,
|
||||||
|
category: String,
|
||||||
|
user_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of parsing optional ingest content before file payloads are built.
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum ParsedContent {
|
||||||
|
/// No URL or text payload should be appended.
|
||||||
|
Skip,
|
||||||
|
Url(String),
|
||||||
|
Text(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParsedContent {
|
||||||
|
#[must_use]
|
||||||
|
fn follows(&self) -> bool {
|
||||||
|
!matches!(self, Self::Skip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl IngestionPayload {
|
impl IngestionPayload {
|
||||||
/// Creates ingestion payloads from the provided content, context, and files.
|
/// Creates ingestion payloads from the provided content, context, and files.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// Files are emitted first. When both files and content are present, shared
|
||||||
/// * `content` - Optional textual content to be ingressed
|
/// metadata is cloned per file; otherwise the last file-only payload moves
|
||||||
/// * `context` - context for processing the ingress content
|
/// `context`, `category`, and `user_id` without cloning.
|
||||||
/// * `category` - Category to classify the ingressed content
|
|
||||||
/// * `files` - Vector of `FileInfo` objects containing information about uploaded files
|
|
||||||
/// * `user_id` - Identifier of the user performing the ingress operation
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Errors
|
||||||
/// * `Result<Vec<IngestionPayload>, AppError>` - On success, returns a vector of ingress objects
|
///
|
||||||
/// (one per file/content type). On failure, returns an `AppError`.
|
/// Returns [`AppError::NotFound`] when no valid files or content are provided.
|
||||||
#[allow(clippy::similar_names)]
|
#[allow(clippy::similar_names)]
|
||||||
pub fn create_ingestion_payload(
|
pub fn create_ingestion_payload(
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
context: String,
|
context: String,
|
||||||
category: String,
|
category: String,
|
||||||
files: Vec<FileInfo>,
|
files: Vec<FileInfo>,
|
||||||
user_id: &str,
|
user_id: String,
|
||||||
) -> Result<Vec<IngestionPayload>, AppError> {
|
) -> Result<Vec<IngestionPayload>, AppError> {
|
||||||
// Initialize list
|
let parsed = Self::parse_content(content);
|
||||||
let mut object_list = Vec::new();
|
let content_follows = parsed.follows();
|
||||||
|
let file_count = files.len();
|
||||||
|
#[allow(clippy::arithmetic_side_effects)]
|
||||||
|
let capacity = file_count + usize::from(content_follows);
|
||||||
|
let mut object_list = Vec::with_capacity(capacity);
|
||||||
|
let mut fields = Some(IngestFields {
|
||||||
|
context,
|
||||||
|
category,
|
||||||
|
user_id,
|
||||||
|
});
|
||||||
|
|
||||||
// Create a IngestionPayload from content if it exists, checking for URL or text
|
for (index, file) in files.into_iter().enumerate() {
|
||||||
if let Some(input_content) = content {
|
let is_last_file = index.saturating_add(1) == file_count;
|
||||||
match Url::parse(&input_content) {
|
if content_follows || !is_last_file {
|
||||||
Ok(url) => {
|
let Some(shared) = fields.as_ref() else {
|
||||||
info!("Detected URL: {}", url);
|
return Err(AppError::internal("shared ingest fields consumed early"));
|
||||||
object_list.push(IngestionPayload::Url {
|
};
|
||||||
url: url.to_string(),
|
object_list.push(Self::File {
|
||||||
context: context.clone(),
|
file_info: file,
|
||||||
category: category.clone(),
|
context: shared.context.clone(),
|
||||||
user_id: user_id.into(),
|
category: shared.category.clone(),
|
||||||
});
|
user_id: shared.user_id.clone(),
|
||||||
}
|
});
|
||||||
Err(_) => {
|
} else {
|
||||||
if input_content.len() > 2 {
|
let Some(shared) = fields.take() else {
|
||||||
info!("Treating input as plain text");
|
return Err(AppError::internal("shared ingest fields missing for file"));
|
||||||
object_list.push(IngestionPayload::Text {
|
};
|
||||||
text: input_content.to_string(),
|
object_list.push(Self::File {
|
||||||
context: context.clone(),
|
file_info: file,
|
||||||
category: category.clone(),
|
context: shared.context,
|
||||||
user_id: user_id.into(),
|
category: shared.category,
|
||||||
});
|
user_id: shared.user_id,
|
||||||
}
|
});
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for file in files {
|
if let ParsedContent::Url(url) = parsed {
|
||||||
object_list.push(IngestionPayload::File {
|
info!("Detected URL: {url}");
|
||||||
file_info: file,
|
let Some(shared) = fields.take() else {
|
||||||
context: context.clone(),
|
return Err(AppError::internal("shared ingest fields missing for url"));
|
||||||
category: category.clone(),
|
};
|
||||||
user_id: user_id.into(),
|
object_list.push(Self::Url {
|
||||||
})
|
url,
|
||||||
|
context: shared.context,
|
||||||
|
category: shared.category,
|
||||||
|
user_id: shared.user_id,
|
||||||
|
});
|
||||||
|
} else if let ParsedContent::Text(text) = parsed {
|
||||||
|
info!("Treating input as plain text");
|
||||||
|
let Some(shared) = fields.take() else {
|
||||||
|
return Err(AppError::internal("shared ingest fields missing for text"));
|
||||||
|
};
|
||||||
|
object_list.push(Self::Text {
|
||||||
|
text,
|
||||||
|
context: shared.context,
|
||||||
|
category: shared.category,
|
||||||
|
user_id: shared.user_id,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no objects are constructed, we return Err
|
|
||||||
if object_list.is_empty() {
|
if object_list.is_empty() {
|
||||||
return Err(AppError::NotFound(
|
return Err(AppError::NotFound(
|
||||||
"No valid content or files provided".into(),
|
"no valid content or files provided".into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(object_list)
|
Ok(object_list)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_content(content: Option<String>) -> ParsedContent {
|
||||||
|
let Some(input_content) = content else {
|
||||||
|
return ParsedContent::Skip;
|
||||||
|
};
|
||||||
|
|
||||||
|
if input_content.len() <= 2 {
|
||||||
|
return ParsedContent::Skip;
|
||||||
|
}
|
||||||
|
|
||||||
|
match Url::parse(&input_content) {
|
||||||
|
Ok(url) => ParsedContent::Url(url.to_string()),
|
||||||
|
Err(_) => ParsedContent::Text(input_content),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -131,24 +198,23 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_with_url() {
|
fn test_create_ingestion_payload_with_url() -> anyhow::Result<()> {
|
||||||
let url = "https://example.com";
|
let url = "https://example.com";
|
||||||
let context = "Process this URL";
|
let context = "Process this URL";
|
||||||
let category = "websites";
|
let category = "websites";
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let files = vec![];
|
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
Some(url.to_string()),
|
Some(url.to_string()),
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
vec![],
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
match &result[0] {
|
match result.first().context("expected one result")? {
|
||||||
IngestionPayload::Url {
|
IngestionPayload::Url {
|
||||||
url: payload_url,
|
url: payload_url,
|
||||||
context: payload_context,
|
context: payload_context,
|
||||||
@@ -156,34 +222,34 @@ mod tests {
|
|||||||
user_id: payload_user_id,
|
user_id: payload_user_id,
|
||||||
} => {
|
} => {
|
||||||
// URL parser may normalize the URL by adding a trailing slash
|
// URL parser may normalize the URL by adding a trailing slash
|
||||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||||
assert_eq!(payload_context, &context);
|
assert_eq!(payload_context, &context);
|
||||||
assert_eq!(payload_category, &category);
|
assert_eq!(payload_category, &category);
|
||||||
assert_eq!(payload_user_id, &user_id);
|
assert_eq!(payload_user_id, &user_id);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected Url variant"),
|
_ => anyhow::bail!("Expected Url variant"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_with_text() {
|
fn test_create_ingestion_payload_with_text() -> anyhow::Result<()> {
|
||||||
let text = "This is some text content";
|
let text = "This is some text content";
|
||||||
let context = "Process this text";
|
let context = "Process this text";
|
||||||
let category = "notes";
|
let category = "notes";
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let files = vec![];
|
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
Some(text.to_string()),
|
Some(text.to_string()),
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
vec![],
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.with_context(|| "create_ingestion_payload".to_string())?;
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
match &result[0] {
|
match result.first().context("expected one result")? {
|
||||||
IngestionPayload::Text {
|
IngestionPayload::Text {
|
||||||
text: payload_text,
|
text: payload_text,
|
||||||
context: payload_context,
|
context: payload_context,
|
||||||
@@ -195,12 +261,13 @@ mod tests {
|
|||||||
assert_eq!(payload_category, category);
|
assert_eq!(payload_category, category);
|
||||||
assert_eq!(payload_user_id, user_id);
|
assert_eq!(payload_user_id, user_id);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected Text variant"),
|
_ => anyhow::bail!("Expected Text variant"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_with_file() {
|
fn test_create_ingestion_payload_with_file() -> anyhow::Result<()> {
|
||||||
let context = "Process this file";
|
let context = "Process this file";
|
||||||
let category = "documents";
|
let category = "documents";
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
@@ -211,36 +278,36 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let file_info: FileInfo = mock_file.into();
|
let file_info: FileInfo = mock_file.into();
|
||||||
let files = vec![file_info.clone()];
|
let file_id = file_info.id.clone();
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
None,
|
None,
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
vec![file_info],
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
)
|
)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
match &result[0] {
|
match result.first().context("expected one result")? {
|
||||||
IngestionPayload::File {
|
IngestionPayload::File {
|
||||||
file_info: payload_file_info,
|
file_info: payload_file_info,
|
||||||
context: payload_context,
|
context: payload_context,
|
||||||
category: payload_category,
|
category: payload_category,
|
||||||
user_id: payload_user_id,
|
user_id: payload_user_id,
|
||||||
} => {
|
} => {
|
||||||
assert_eq!(payload_file_info.id, file_info.id);
|
assert_eq!(payload_file_info.id, file_id);
|
||||||
assert_eq!(payload_context, context);
|
assert_eq!(payload_context, context);
|
||||||
assert_eq!(payload_category, category);
|
assert_eq!(payload_category, category);
|
||||||
assert_eq!(payload_user_id, user_id);
|
assert_eq!(payload_user_id, user_id);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected File variant"),
|
_ => anyhow::bail!("Expected File variant"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_with_url_and_file() {
|
fn test_create_ingestion_payload_with_url_and_file() -> anyhow::Result<()> {
|
||||||
let url = "https://example.com";
|
let url = "https://example.com";
|
||||||
let context = "Process this data";
|
let context = "Process this data";
|
||||||
let category = "mixed";
|
let category = "mixed";
|
||||||
@@ -252,88 +319,207 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let file_info: FileInfo = mock_file.into();
|
let file_info: FileInfo = mock_file.into();
|
||||||
let files = vec![file_info.clone()];
|
let file_id = file_info.id.clone();
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
Some(url.to_string()),
|
Some(url.to_string()),
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
vec![file_info],
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
)
|
)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(result.len(), 2);
|
assert_eq!(result.len(), 2);
|
||||||
|
|
||||||
// Check first item is URL
|
// Check first item is File (files processed first to minimize clones)
|
||||||
match &result[0] {
|
match result.first().context("expected first item")? {
|
||||||
IngestionPayload::Url {
|
|
||||||
url: payload_url, ..
|
|
||||||
} => {
|
|
||||||
// URL parser may normalize the URL by adding a trailing slash
|
|
||||||
assert!(payload_url == &url.to_string() || payload_url == &format!("{}/", url));
|
|
||||||
}
|
|
||||||
_ => panic!("Expected first item to be Url variant"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check second item is File
|
|
||||||
match &result[1] {
|
|
||||||
IngestionPayload::File {
|
IngestionPayload::File {
|
||||||
file_info: payload_file_info,
|
file_info: payload_file_info,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
assert_eq!(payload_file_info.id, file_info.id);
|
assert_eq!(payload_file_info.id, file_id);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected second item to be File variant"),
|
_ => anyhow::bail!("Expected first item to be File variant"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check second item is URL
|
||||||
|
match result.get(1).context("expected second item")? {
|
||||||
|
IngestionPayload::Url {
|
||||||
|
url: payload_url, ..
|
||||||
|
} => {
|
||||||
|
// URL parser may normalize the URL by adding a trailing slash
|
||||||
|
assert!(payload_url == &url.to_string() || payload_url == &format!("{url}/"));
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("Expected second item to be Url variant"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_ingestion_payload_empty_input() {
|
fn test_create_ingestion_payload_empty_input() -> anyhow::Result<()> {
|
||||||
let context = "Process something";
|
let context = "Process something";
|
||||||
let category = "empty";
|
let category = "empty";
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let files = vec![];
|
|
||||||
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
|
None,
|
||||||
|
context.to_string(),
|
||||||
|
category.to_string(),
|
||||||
|
vec![],
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::NotFound(msg)) => {
|
||||||
|
assert_eq!(msg, "no valid content or files provided");
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("Expected NotFound error"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_ingestion_payload_with_empty_text() -> anyhow::Result<()> {
|
||||||
|
let text = ""; // Empty text
|
||||||
|
let context = "Process this";
|
||||||
|
let category = "notes";
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
|
Some(text.to_string()),
|
||||||
|
context.to_string(),
|
||||||
|
category.to_string(),
|
||||||
|
vec![],
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::NotFound(msg)) => {
|
||||||
|
assert_eq!(msg, "no valid content or files provided");
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("Expected NotFound error"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_ingestion_payload_with_file_and_text() -> anyhow::Result<()> {
|
||||||
|
let text = "plain notes";
|
||||||
|
let context = "ctx";
|
||||||
|
let category = "cat";
|
||||||
|
let user_id = "user123";
|
||||||
|
let file_info: FileInfo = MockFileInfo {
|
||||||
|
id: "file1".to_string(),
|
||||||
|
}
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
|
Some(text.to_string()),
|
||||||
|
context.to_string(),
|
||||||
|
category.to_string(),
|
||||||
|
vec![file_info],
|
||||||
|
user_id.to_string(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 2);
|
||||||
|
let first = result.first().expect("expected first payload");
|
||||||
|
let second = result.get(1).expect("expected second payload");
|
||||||
|
match (first, second) {
|
||||||
|
(
|
||||||
|
IngestionPayload::File {
|
||||||
|
file_info: payload_file,
|
||||||
|
context: file_context,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
IngestionPayload::Text {
|
||||||
|
text: payload_text,
|
||||||
|
context: text_context,
|
||||||
|
category: text_category,
|
||||||
|
user_id: text_user_id,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
assert_eq!(payload_file.id, "file1");
|
||||||
|
assert_eq!(file_context, context);
|
||||||
|
assert_eq!(payload_text, text);
|
||||||
|
assert_eq!(text_context, context);
|
||||||
|
assert_eq!(text_category, category);
|
||||||
|
assert_eq!(text_user_id, user_id);
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("expected File then Text"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_ingestion_payload_short_content_with_file_only_yields_file() -> anyhow::Result<()>
|
||||||
|
{
|
||||||
|
let context = "ctx";
|
||||||
|
let category = "cat";
|
||||||
|
let user_id = "user123";
|
||||||
|
let file_info: FileInfo = MockFileInfo {
|
||||||
|
id: "file1".to_string(),
|
||||||
|
}
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
|
Some("ab".to_string()),
|
||||||
|
context.to_string(),
|
||||||
|
category.to_string(),
|
||||||
|
vec![file_info],
|
||||||
|
user_id.to_string(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
match result.first().context("expected one file payload")? {
|
||||||
|
IngestionPayload::File {
|
||||||
|
file_info,
|
||||||
|
context: payload_context,
|
||||||
|
category: payload_category,
|
||||||
|
user_id: payload_user_id,
|
||||||
|
} => {
|
||||||
|
assert_eq!(file_info.id, "file1");
|
||||||
|
assert_eq!(payload_context, context);
|
||||||
|
assert_eq!(payload_category, category);
|
||||||
|
assert_eq!(payload_user_id, user_id);
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("expected File variant only"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_ingestion_payload_two_files_without_content() -> anyhow::Result<()> {
|
||||||
|
let context = "ctx";
|
||||||
|
let category = "cat";
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let files = vec![
|
||||||
|
MockFileInfo {
|
||||||
|
id: "file1".to_string(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
MockFileInfo {
|
||||||
|
id: "file2".to_string(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
];
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
let result = IngestionPayload::create_ingestion_payload(
|
||||||
None,
|
None,
|
||||||
context.to_string(),
|
context.to_string(),
|
||||||
category.to_string(),
|
category.to_string(),
|
||||||
files,
|
files,
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
);
|
)?;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert_eq!(result.len(), 2);
|
||||||
match result {
|
assert!(matches!(
|
||||||
Err(AppError::NotFound(msg)) => {
|
result.first(),
|
||||||
assert_eq!(msg, "No valid content or files provided");
|
Some(IngestionPayload::File { .. })
|
||||||
}
|
));
|
||||||
_ => panic!("Expected NotFound error"),
|
assert!(matches!(result.get(1), Some(IngestionPayload::File { .. })));
|
||||||
}
|
Ok(())
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_create_ingestion_payload_with_empty_text() {
|
|
||||||
let text = ""; // Empty text
|
|
||||||
let context = "Process this";
|
|
||||||
let category = "notes";
|
|
||||||
let user_id = "user123";
|
|
||||||
let files = vec![];
|
|
||||||
|
|
||||||
let result = IngestionPayload::create_ingestion_payload(
|
|
||||||
Some(text.to_string()),
|
|
||||||
context.to_string(),
|
|
||||||
category.to_string(),
|
|
||||||
files,
|
|
||||||
user_id,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result {
|
|
||||||
Err(AppError::NotFound(msg)) => {
|
|
||||||
assert_eq!(msg, "No valid content or files provided");
|
|
||||||
}
|
|
||||||
_ => panic!("Expected NotFound error"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,7 @@
|
|||||||
#![allow(
|
use std::{sync::Arc, time::Duration};
|
||||||
clippy::cast_possible_wrap,
|
|
||||||
clippy::items_after_statements,
|
|
||||||
clippy::arithmetic_side_effects,
|
|
||||||
clippy::cast_sign_loss,
|
|
||||||
clippy::missing_docs_in_private_items,
|
|
||||||
clippy::trivially_copy_pass_by_ref,
|
|
||||||
clippy::expect_used
|
|
||||||
)]
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use chrono::Duration as ChronoDuration;
|
use chrono::Duration as ChronoDuration;
|
||||||
|
use futures::future::try_join_all;
|
||||||
use state_machines::state_machine;
|
use state_machines::state_machine;
|
||||||
use surrealdb::sql::Datetime as SurrealDatetime;
|
use surrealdb::sql::Datetime as SurrealDatetime;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -22,7 +14,7 @@ pub const MAX_ATTEMPTS: u32 = 3;
|
|||||||
pub const DEFAULT_LEASE_SECS: i64 = 300;
|
pub const DEFAULT_LEASE_SECS: i64 = 300;
|
||||||
pub const DEFAULT_PRIORITY: i32 = 0;
|
pub const DEFAULT_PRIORITY: i32 = 0;
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Default, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
||||||
pub enum TaskState {
|
pub enum TaskState {
|
||||||
#[serde(rename = "Pending")]
|
#[serde(rename = "Pending")]
|
||||||
#[default]
|
#[default]
|
||||||
@@ -42,6 +34,7 @@ pub enum TaskState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TaskState {
|
impl TaskState {
|
||||||
|
#[must_use]
|
||||||
pub fn as_str(&self) -> &'static str {
|
pub fn as_str(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
TaskState::Pending => "Pending",
|
TaskState::Pending => "Pending",
|
||||||
@@ -54,6 +47,7 @@ impl TaskState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn is_terminal(&self) -> bool {
|
pub fn is_terminal(&self) -> bool {
|
||||||
matches!(
|
matches!(
|
||||||
self,
|
self,
|
||||||
@@ -61,6 +55,7 @@ impl TaskState {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn display_label(&self) -> &'static str {
|
pub fn display_label(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
TaskState::Pending => "Pending",
|
TaskState::Pending => "Pending",
|
||||||
@@ -74,12 +69,16 @@ impl TaskState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Information about an error that occurred during task processing.
|
||||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Default)]
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Default)]
|
||||||
pub struct TaskErrorInfo {
|
pub struct TaskErrorInfo {
|
||||||
|
/// Machine-readable error code (e.g., `"pipeline_error"`).
|
||||||
pub code: Option<String>,
|
pub code: Option<String>,
|
||||||
|
/// Human-readable error description.
|
||||||
pub message: String,
|
pub message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Internal events that drive the task state machine transitions.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
enum TaskTransition {
|
enum TaskTransition {
|
||||||
StartProcessing,
|
StartProcessing,
|
||||||
@@ -91,7 +90,7 @@ enum TaskTransition {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TaskTransition {
|
impl TaskTransition {
|
||||||
fn as_str(&self) -> &'static str {
|
fn as_str(self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
TaskTransition::StartProcessing => "start_processing",
|
TaskTransition::StartProcessing => "start_processing",
|
||||||
TaskTransition::Succeed => "succeed",
|
TaskTransition::Succeed => "succeed",
|
||||||
@@ -141,34 +140,20 @@ mod lifecycle {
|
|||||||
pub(super) fn pending() -> TaskLifecycleMachine<(), Pending> {
|
pub(super) fn pending() -> TaskLifecycleMachine<(), Pending> {
|
||||||
TaskLifecycleMachine::new(())
|
TaskLifecycleMachine::new(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn reserved() -> TaskLifecycleMachine<(), Reserved> {
|
|
||||||
pending()
|
|
||||||
.reserve()
|
|
||||||
.expect("reserve transition from Pending should exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn processing() -> TaskLifecycleMachine<(), Processing> {
|
|
||||||
reserved()
|
|
||||||
.start_processing()
|
|
||||||
.expect("start_processing transition from Reserved should exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn failed() -> TaskLifecycleMachine<(), Failed> {
|
|
||||||
processing()
|
|
||||||
.fail()
|
|
||||||
.expect("fail transition from Processing should exist")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn invalid_transition(state: &TaskState, event: TaskTransition) -> AppError {
|
fn invalid_transition(state: TaskState, event: TaskTransition) -> AppError {
|
||||||
AppError::Validation(format!(
|
AppError::Validation(format!(
|
||||||
"Invalid task transition: {} -> {}",
|
"invalid task transition: {} -> {}",
|
||||||
state.as_str(),
|
state.as_str(),
|
||||||
event.as_str()
|
event.as_str()
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn worker_id_for_bind(worker_id: Option<&String>) -> String {
|
||||||
|
worker_id.cloned().unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
stored_object!(IngestionTask, "ingestion_task", {
|
stored_object!(IngestionTask, "ingestion_task", {
|
||||||
content: IngestionPayload,
|
content: IngestionPayload,
|
||||||
state: TaskState,
|
state: TaskState,
|
||||||
@@ -197,6 +182,7 @@ stored_object!(IngestionTask, "ingestion_task", {
|
|||||||
});
|
});
|
||||||
|
|
||||||
impl IngestionTask {
|
impl IngestionTask {
|
||||||
|
#[must_use]
|
||||||
pub fn new(content: IngestionPayload, user_id: String) -> Self {
|
pub fn new(content: IngestionPayload, user_id: String) -> Self {
|
||||||
let now = chrono::Utc::now();
|
let now = chrono::Utc::now();
|
||||||
|
|
||||||
@@ -220,33 +206,85 @@ impl IngestionTask {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn can_retry(&self) -> bool {
|
pub fn can_retry(&self) -> bool {
|
||||||
self.attempts < self.max_attempts
|
self.attempts < self.max_attempts
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn lease_duration(&self) -> Duration {
|
/// Moves the payload out of the task, leaving an empty placeholder behind.
|
||||||
Duration::from_secs(self.lease_duration_secs.max(0) as u64)
|
///
|
||||||
|
/// The task's `content` is only needed while driving the pipeline; the
|
||||||
|
/// terminal `user_id`, `state`, and bookkeeping fields are stored separately,
|
||||||
|
/// so replacing it with the default placeholder avoids cloning large payloads.
|
||||||
|
#[must_use]
|
||||||
|
pub fn take_content(&mut self) -> IngestionPayload {
|
||||||
|
std::mem::take(&mut self.content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn lease_duration(&self) -> Duration {
|
||||||
|
Duration::from_secs(u64::try_from(self.lease_duration_secs.max(0)).unwrap_or(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new task and immediately persist it to the database.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Database` if the store operation fails.
|
||||||
|
/// Returns `AppError::internal` if the database returns no stored record.
|
||||||
pub async fn create_and_add_to_db(
|
pub async fn create_and_add_to_db(
|
||||||
content: IngestionPayload,
|
content: IngestionPayload,
|
||||||
user_id: String,
|
user_id: impl AsRef<str>,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<IngestionTask, AppError> {
|
) -> Result<IngestionTask, AppError> {
|
||||||
let task = Self::new(content, user_id);
|
let task = Self::new(content, user_id.as_ref().to_string());
|
||||||
db.store_item(task.clone()).await?;
|
db.store_item(task)
|
||||||
Ok(task)
|
.await?
|
||||||
|
.ok_or_else(|| AppError::internal("ingestion task store returned no record"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create and persist multiple tasks concurrently (one `CREATE` per payload).
|
||||||
|
///
|
||||||
|
/// Use this when ingest produces several payloads (files plus URL/text). For a
|
||||||
|
/// single payload, call [`Self::create_and_add_to_db`] instead.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns the first [`AppError`] from any failed store, same as [`try_join_all`].
|
||||||
|
pub async fn create_all_and_add_to_db(
|
||||||
|
contents: Vec<IngestionPayload>,
|
||||||
|
user_id: impl AsRef<str>,
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
) -> Result<Vec<IngestionTask>, AppError> {
|
||||||
|
if contents.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let user_id = Arc::new(user_id.as_ref().to_string());
|
||||||
|
let db = db.clone();
|
||||||
|
|
||||||
|
try_join_all(contents.into_iter().map(|content| {
|
||||||
|
let user_id = Arc::clone(&user_id);
|
||||||
|
let db = db.clone();
|
||||||
|
async move { Self::create_and_add_to_db(content, user_id.as_ref(), &db).await }
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Claim the next ready task for processing.
|
||||||
|
///
|
||||||
|
/// Atomically reserves a task by transitioning it from a candidate state to `Reserved`.
|
||||||
|
/// Returns `Ok(None)` if no task is ready to claim.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Database` if the update query fails.
|
||||||
pub async fn claim_next_ready(
|
pub async fn claim_next_ready(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
worker_id: &str,
|
worker_id: &str,
|
||||||
now: chrono::DateTime<chrono::Utc>,
|
now: chrono::DateTime<chrono::Utc>,
|
||||||
lease_duration: Duration,
|
lease_duration: Duration,
|
||||||
) -> Result<Option<IngestionTask>, AppError> {
|
) -> Result<Option<IngestionTask>, AppError> {
|
||||||
debug_assert!(lifecycle::pending().reserve().is_ok());
|
|
||||||
debug_assert!(lifecycle::failed().reserve().is_ok());
|
|
||||||
|
|
||||||
const CLAIM_QUERY: &str = r#"
|
const CLAIM_QUERY: &str = r#"
|
||||||
UPDATE (
|
UPDATE (
|
||||||
SELECT * FROM type::table($table)
|
SELECT * FROM type::table($table)
|
||||||
@@ -276,6 +314,11 @@ impl IngestionTask {
|
|||||||
RETURN *;
|
RETURN *;
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
|
debug_assert!(lifecycle::pending().reserve().is_ok());
|
||||||
|
debug_assert!(lifecycle::pending().reserve().is_ok_and(|m| m
|
||||||
|
.start_processing()
|
||||||
|
.is_ok_and(|m| m.fail().is_ok_and(|m| m.reserve().is_ok()))));
|
||||||
|
|
||||||
let mut result = db
|
let mut result = db
|
||||||
.client
|
.client
|
||||||
.query(CLAIM_QUERY)
|
.query(CLAIM_QUERY)
|
||||||
@@ -300,13 +343,22 @@ impl IngestionTask {
|
|||||||
.bind(("reserved_state", TaskState::Reserved.as_str()))
|
.bind(("reserved_state", TaskState::Reserved.as_str()))
|
||||||
.bind(("now", SurrealDatetime::from(now)))
|
.bind(("now", SurrealDatetime::from(now)))
|
||||||
.bind(("worker_id", worker_id.to_string()))
|
.bind(("worker_id", worker_id.to_string()))
|
||||||
.bind(("lease_secs", lease_duration.as_secs() as i64))
|
.bind((
|
||||||
|
"lease_secs",
|
||||||
|
i64::try_from(lease_duration.as_secs()).unwrap_or(i64::MAX),
|
||||||
|
))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let task: Option<IngestionTask> = result.take(0)?;
|
let task: Option<IngestionTask> = result.take(0)?;
|
||||||
Ok(task)
|
Ok(task)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Transition this task from `Reserved` to `Processing`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Validation` if the task is not in `Reserved` state
|
||||||
|
/// or belongs to a different worker. Returns `AppError::Database` on DB failure.
|
||||||
pub async fn mark_processing(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
pub async fn mark_processing(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||||
const START_PROCESSING_QUERY: &str = r#"
|
const START_PROCESSING_QUERY: &str = r#"
|
||||||
UPDATE type::thing($table, $id)
|
UPDATE type::thing($table, $id)
|
||||||
@@ -318,6 +370,7 @@ impl IngestionTask {
|
|||||||
"#;
|
"#;
|
||||||
|
|
||||||
let now = chrono::Utc::now();
|
let now = chrono::Utc::now();
|
||||||
|
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
|
||||||
let mut result = db
|
let mut result = db
|
||||||
.client
|
.client
|
||||||
.query(START_PROCESSING_QUERY)
|
.query(START_PROCESSING_QUERY)
|
||||||
@@ -326,13 +379,19 @@ impl IngestionTask {
|
|||||||
.bind(("processing", TaskState::Processing.as_str()))
|
.bind(("processing", TaskState::Processing.as_str()))
|
||||||
.bind(("reserved", TaskState::Reserved.as_str()))
|
.bind(("reserved", TaskState::Reserved.as_str()))
|
||||||
.bind(("now", SurrealDatetime::from(now)))
|
.bind(("now", SurrealDatetime::from(now)))
|
||||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
.bind(("worker_id", worker_id))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let updated: Option<IngestionTask> = result.take(0)?;
|
let updated: Option<IngestionTask> = result.take(0)?;
|
||||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::StartProcessing))
|
updated.ok_or_else(|| invalid_transition(self.state, TaskTransition::StartProcessing))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Transition this task from `Processing` to `Succeeded`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Validation` if the task is not in `Processing` state
|
||||||
|
/// or belongs to a different worker. Returns `AppError::Database` on DB failure.
|
||||||
pub async fn mark_succeeded(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
pub async fn mark_succeeded(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||||
const COMPLETE_QUERY: &str = r#"
|
const COMPLETE_QUERY: &str = r#"
|
||||||
UPDATE type::thing($table, $id)
|
UPDATE type::thing($table, $id)
|
||||||
@@ -349,6 +408,7 @@ impl IngestionTask {
|
|||||||
"#;
|
"#;
|
||||||
|
|
||||||
let now = chrono::Utc::now();
|
let now = chrono::Utc::now();
|
||||||
|
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
|
||||||
let mut result = db
|
let mut result = db
|
||||||
.client
|
.client
|
||||||
.query(COMPLETE_QUERY)
|
.query(COMPLETE_QUERY)
|
||||||
@@ -357,23 +417,27 @@ impl IngestionTask {
|
|||||||
.bind(("succeeded", TaskState::Succeeded.as_str()))
|
.bind(("succeeded", TaskState::Succeeded.as_str()))
|
||||||
.bind(("processing", TaskState::Processing.as_str()))
|
.bind(("processing", TaskState::Processing.as_str()))
|
||||||
.bind(("now", SurrealDatetime::from(now)))
|
.bind(("now", SurrealDatetime::from(now)))
|
||||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
.bind(("worker_id", worker_id))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let updated: Option<IngestionTask> = result.take(0)?;
|
let updated: Option<IngestionTask> = result.take(0)?;
|
||||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Succeed))
|
updated.ok_or_else(|| invalid_transition(self.state, TaskTransition::Succeed))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Transition this task from `Processing` to `Failed`.
|
||||||
|
///
|
||||||
|
/// The task will be rescheduled for retry after `retry_delay`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Validation` if the task is not in `Processing` state
|
||||||
|
/// or belongs to a different worker. Returns `AppError::Database` on DB failure.
|
||||||
pub async fn mark_failed(
|
pub async fn mark_failed(
|
||||||
&self,
|
&self,
|
||||||
error: TaskErrorInfo,
|
error: TaskErrorInfo,
|
||||||
retry_delay: Duration,
|
retry_delay: Duration,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<IngestionTask, AppError> {
|
) -> Result<IngestionTask, AppError> {
|
||||||
let now = chrono::Utc::now();
|
|
||||||
let retry_at = now
|
|
||||||
+ ChronoDuration::from_std(retry_delay).unwrap_or_else(|_| ChronoDuration::seconds(30));
|
|
||||||
|
|
||||||
const FAIL_QUERY: &str = r#"
|
const FAIL_QUERY: &str = r#"
|
||||||
UPDATE type::thing($table, $id)
|
UPDATE type::thing($table, $id)
|
||||||
SET state = $failed,
|
SET state = $failed,
|
||||||
@@ -388,6 +452,15 @@ impl IngestionTask {
|
|||||||
RETURN *;
|
RETURN *;
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
let retry_at = now
|
||||||
|
.checked_add_signed(
|
||||||
|
ChronoDuration::from_std(retry_delay)
|
||||||
|
.unwrap_or_else(|_| ChronoDuration::seconds(30)),
|
||||||
|
)
|
||||||
|
.unwrap_or(now);
|
||||||
|
|
||||||
|
let worker_id = worker_id_for_bind(self.worker_id.as_ref());
|
||||||
let mut result = db
|
let mut result = db
|
||||||
.client
|
.client
|
||||||
.query(FAIL_QUERY)
|
.query(FAIL_QUERY)
|
||||||
@@ -399,13 +472,19 @@ impl IngestionTask {
|
|||||||
.bind(("retry_at", SurrealDatetime::from(retry_at)))
|
.bind(("retry_at", SurrealDatetime::from(retry_at)))
|
||||||
.bind(("error_code", error.code.clone()))
|
.bind(("error_code", error.code.clone()))
|
||||||
.bind(("error_message", error.message.clone()))
|
.bind(("error_message", error.message.clone()))
|
||||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
.bind(("worker_id", worker_id))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let updated: Option<IngestionTask> = result.take(0)?;
|
let updated: Option<IngestionTask> = result.take(0)?;
|
||||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Fail))
|
updated.ok_or_else(|| invalid_transition(self.state, TaskTransition::Fail))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Transition this task from `Failed` to `DeadLetter`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Validation` if the task is not in `Failed` state.
|
||||||
|
/// Returns `AppError::Database` on DB failure.
|
||||||
pub async fn mark_dead_letter(
|
pub async fn mark_dead_letter(
|
||||||
&self,
|
&self,
|
||||||
error: TaskErrorInfo,
|
error: TaskErrorInfo,
|
||||||
@@ -439,9 +518,15 @@ impl IngestionTask {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let updated: Option<IngestionTask> = result.take(0)?;
|
let updated: Option<IngestionTask> = result.take(0)?;
|
||||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::DeadLetter))
|
updated.ok_or_else(|| invalid_transition(self.state, TaskTransition::DeadLetter))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Transition this task to `Cancelled` from any non-terminal state.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Validation` if the task is in a terminal state.
|
||||||
|
/// Returns `AppError::Database` on DB failure.
|
||||||
pub async fn mark_cancelled(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
pub async fn mark_cancelled(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||||
const CANCEL_QUERY: &str = r#"
|
const CANCEL_QUERY: &str = r#"
|
||||||
UPDATE type::thing($table, $id)
|
UPDATE type::thing($table, $id)
|
||||||
@@ -472,9 +557,15 @@ impl IngestionTask {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let updated: Option<IngestionTask> = result.take(0)?;
|
let updated: Option<IngestionTask> = result.take(0)?;
|
||||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Cancel))
|
updated.ok_or_else(|| invalid_transition(self.state, TaskTransition::Cancel))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Release a reserved task back to `Pending` state.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Validation` if the task is not in `Reserved` state.
|
||||||
|
/// Returns `AppError::Database` on DB failure.
|
||||||
pub async fn release(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
pub async fn release(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||||
const RELEASE_QUERY: &str = r#"
|
const RELEASE_QUERY: &str = r#"
|
||||||
UPDATE type::thing($table, $id)
|
UPDATE type::thing($table, $id)
|
||||||
@@ -498,9 +589,14 @@ impl IngestionTask {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let updated: Option<IngestionTask> = result.take(0)?;
|
let updated: Option<IngestionTask> = result.take(0)?;
|
||||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Release))
|
updated.ok_or_else(|| invalid_transition(self.state, TaskTransition::Release))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve all non-terminal tasks across active states.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `AppError::Database` if the query fails.
|
||||||
pub async fn get_unfinished_tasks(
|
pub async fn get_unfinished_tasks(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<Vec<IngestionTask>, AppError> {
|
) -> Result<Vec<IngestionTask>, AppError> {
|
||||||
@@ -529,6 +625,9 @@ impl IngestionTask {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||||
|
|
||||||
@@ -541,16 +640,16 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn memory_db() -> SurrealDbClient {
|
async fn memory_db() -> anyhow::Result<SurrealDbClient> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = Uuid::new_v4().to_string();
|
let database = Uuid::new_v4().to_string();
|
||||||
SurrealDbClient::memory(namespace, &database)
|
SurrealDbClient::memory(namespace, &database)
|
||||||
.await
|
.await
|
||||||
.expect("in-memory surrealdb")
|
.with_context(|| "in-memory surrealdb".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_new_task_defaults() {
|
async fn test_new_task_defaults() -> anyhow::Result<()> {
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let payload = create_payload(user_id);
|
let payload = create_payload(user_id);
|
||||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
@@ -562,73 +661,140 @@ mod tests {
|
|||||||
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
|
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
|
||||||
assert!(task.locked_at.is_none());
|
assert!(task.locked_at.is_none());
|
||||||
assert!(task.worker_id.is_none());
|
assert!(task.worker_id.is_none());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_take_content_moves_payload_and_leaves_default() {
|
||||||
|
let user_id = "user123";
|
||||||
|
let payload = create_payload(user_id);
|
||||||
|
let mut task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
|
|
||||||
|
let taken = task.take_content();
|
||||||
|
|
||||||
|
assert_eq!(taken, payload);
|
||||||
|
assert_eq!(task.content, IngestionPayload::default());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_and_store_task() {
|
async fn test_create_all_and_add_to_db_empty() -> anyhow::Result<()> {
|
||||||
let db = memory_db().await;
|
let db = memory_db().await?;
|
||||||
|
let tasks = IngestionTask::create_all_and_add_to_db(vec![], "user123", &db).await?;
|
||||||
|
assert!(tasks.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_all_and_add_to_db_stores_multiple() -> anyhow::Result<()> {
|
||||||
|
let db = memory_db().await?;
|
||||||
|
let user_id = "user123";
|
||||||
|
let payloads = vec![
|
||||||
|
create_payload(user_id),
|
||||||
|
IngestionPayload::Text {
|
||||||
|
text: "second payload".to_string(),
|
||||||
|
context: "ctx".to_string(),
|
||||||
|
category: "cat".to_string(),
|
||||||
|
user_id: user_id.to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let created = IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?;
|
||||||
|
|
||||||
|
assert_eq!(created.len(), 2);
|
||||||
|
let first = created.first().expect("expected first task");
|
||||||
|
let second = created.get(1).expect("expected second task");
|
||||||
|
assert_ne!(first.id, second.id);
|
||||||
|
|
||||||
|
for task in &created {
|
||||||
|
let stored: Option<IngestionTask> = db.get_item::<IngestionTask>(&task.id).await?;
|
||||||
|
let stored = stored.with_context(|| format!("task {} should exist", task.id))?;
|
||||||
|
assert_eq!(stored.id, task.id);
|
||||||
|
assert_eq!(stored.state, TaskState::Pending);
|
||||||
|
assert_eq!(stored.user_id, user_id);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_and_store_task() -> anyhow::Result<()> {
|
||||||
|
let db = memory_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let payload = create_payload(user_id);
|
let payload = create_payload(user_id);
|
||||||
|
|
||||||
let created =
|
let created =
|
||||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||||
.await
|
.await
|
||||||
.expect("store");
|
.with_context(|| "store".to_string())?;
|
||||||
|
|
||||||
let stored: Option<IngestionTask> = db
|
let stored: Option<IngestionTask> = db
|
||||||
.get_item::<IngestionTask>(&created.id)
|
.get_item::<IngestionTask>(&created.id)
|
||||||
.await
|
.await
|
||||||
.expect("fetch");
|
.with_context(|| "fetch".to_string())?;
|
||||||
|
|
||||||
let stored = stored.expect("task exists");
|
let stored = stored.with_context(|| "task exists".to_string())?;
|
||||||
assert_eq!(stored.id, created.id);
|
assert_eq!(stored.id, created.id);
|
||||||
assert_eq!(stored.state, TaskState::Pending);
|
assert_eq!(stored.state, TaskState::Pending);
|
||||||
assert_eq!(stored.attempts, 0);
|
assert_eq!(stored.attempts, 0);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_claim_and_transition() {
|
async fn test_claim_and_transition() -> anyhow::Result<()> {
|
||||||
let db = memory_db().await;
|
let db = memory_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let payload = create_payload(user_id);
|
let payload = create_payload(user_id);
|
||||||
let task = IngestionTask::new(payload, user_id.to_string());
|
let task = IngestionTask::new(payload, user_id.to_string());
|
||||||
db.store_item(task.clone()).await.expect("store");
|
db.store_item(task.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "store".to_string())?;
|
||||||
|
|
||||||
let worker_id = "worker-1";
|
let worker_id = "worker-1";
|
||||||
let now = chrono::Utc::now();
|
let now = chrono::Utc::now();
|
||||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||||
.await
|
.await
|
||||||
.expect("claim");
|
.with_context(|| "claim".to_string())?
|
||||||
|
.with_context(|| "task claimed".to_string())?;
|
||||||
|
|
||||||
let claimed = claimed.expect("task claimed");
|
|
||||||
assert_eq!(claimed.state, TaskState::Reserved);
|
assert_eq!(claimed.state, TaskState::Reserved);
|
||||||
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
|
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
|
||||||
|
|
||||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
let processing = claimed
|
||||||
|
.mark_processing(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "processing".to_string())?;
|
||||||
assert_eq!(processing.state, TaskState::Processing);
|
assert_eq!(processing.state, TaskState::Processing);
|
||||||
|
|
||||||
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded");
|
let succeeded = processing
|
||||||
|
.mark_succeeded(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "succeeded".to_string())?;
|
||||||
assert_eq!(succeeded.state, TaskState::Succeeded);
|
assert_eq!(succeeded.state, TaskState::Succeeded);
|
||||||
assert!(succeeded.worker_id.is_none());
|
assert!(succeeded.worker_id.is_none());
|
||||||
assert!(succeeded.locked_at.is_none());
|
assert!(succeeded.locked_at.is_none());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fail_and_dead_letter() {
|
async fn test_fail_and_dead_letter() -> anyhow::Result<()> {
|
||||||
let db = memory_db().await;
|
let db = memory_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let payload = create_payload(user_id);
|
let payload = create_payload(user_id);
|
||||||
let task = IngestionTask::new(payload, user_id.to_string());
|
let task = IngestionTask::new(payload, user_id.to_string());
|
||||||
db.store_item(task.clone()).await.expect("store");
|
db.store_item(task.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "store".to_string())?;
|
||||||
|
|
||||||
let worker_id = "worker-dead";
|
let worker_id = "worker-dead";
|
||||||
let now = chrono::Utc::now();
|
let now = chrono::Utc::now();
|
||||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||||
.await
|
.await
|
||||||
.expect("claim")
|
.with_context(|| "claim".to_string())?
|
||||||
.expect("claimed");
|
.with_context(|| "claimed".to_string())?;
|
||||||
|
|
||||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
let processing = claimed
|
||||||
|
.mark_processing(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "processing".to_string())?;
|
||||||
|
|
||||||
let error_info = TaskErrorInfo {
|
let error_info = TaskErrorInfo {
|
||||||
code: Some("pipeline_error".into()),
|
code: Some("pipeline_error".into()),
|
||||||
@@ -638,7 +804,7 @@ mod tests {
|
|||||||
let failed = processing
|
let failed = processing
|
||||||
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
|
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
|
||||||
.await
|
.await
|
||||||
.expect("failed update");
|
.with_context(|| "failed update".to_string())?;
|
||||||
assert_eq!(failed.state, TaskState::Failed);
|
assert_eq!(failed.state, TaskState::Failed);
|
||||||
assert_eq!(failed.error_message.as_deref(), Some("failed"));
|
assert_eq!(failed.error_message.as_deref(), Some("failed"));
|
||||||
assert!(failed.worker_id.is_none());
|
assert!(failed.worker_id.is_none());
|
||||||
@@ -648,24 +814,26 @@ mod tests {
|
|||||||
let dead = failed
|
let dead = failed
|
||||||
.mark_dead_letter(error_info.clone(), &db)
|
.mark_dead_letter(error_info.clone(), &db)
|
||||||
.await
|
.await
|
||||||
.expect("dead letter");
|
.with_context(|| "dead letter".to_string())?;
|
||||||
assert_eq!(dead.state, TaskState::DeadLetter);
|
assert_eq!(dead.state, TaskState::DeadLetter);
|
||||||
assert_eq!(dead.error_message.as_deref(), Some("failed"));
|
assert_eq!(dead.error_message.as_deref(), Some("failed"));
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_mark_processing_requires_reservation() {
|
async fn test_mark_processing_requires_reservation() -> anyhow::Result<()> {
|
||||||
let db = memory_db().await;
|
let db = memory_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let payload = create_payload(user_id);
|
let payload = create_payload(user_id);
|
||||||
|
|
||||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
db.store_item(task.clone()).await.expect("store");
|
db.store_item(task.clone())
|
||||||
|
|
||||||
let err = task
|
|
||||||
.mark_processing(&db)
|
|
||||||
.await
|
.await
|
||||||
.expect_err("processing should fail without reservation");
|
.with_context(|| "store".to_string())?;
|
||||||
|
|
||||||
|
let Err(err) = task.mark_processing(&db).await else {
|
||||||
|
anyhow::bail!("processing should fail without reservation")
|
||||||
|
};
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
AppError::Validation(message) => {
|
AppError::Validation(message) => {
|
||||||
@@ -674,20 +842,23 @@ mod tests {
|
|||||||
"unexpected message: {message}"
|
"unexpected message: {message}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
other => panic!("expected validation error, got {other:?}"),
|
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_mark_failed_requires_processing() {
|
async fn test_mark_failed_requires_processing() -> anyhow::Result<()> {
|
||||||
let db = memory_db().await;
|
let db = memory_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let payload = create_payload(user_id);
|
let payload = create_payload(user_id);
|
||||||
|
|
||||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
db.store_item(task.clone()).await.expect("store");
|
db.store_item(task.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "store".to_string())?;
|
||||||
|
|
||||||
let err = task
|
let Err(err) = task
|
||||||
.mark_failed(
|
.mark_failed(
|
||||||
TaskErrorInfo {
|
TaskErrorInfo {
|
||||||
code: None,
|
code: None,
|
||||||
@@ -697,7 +868,9 @@ mod tests {
|
|||||||
&db,
|
&db,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect_err("failing should require processing state");
|
else {
|
||||||
|
anyhow::bail!("failing should require processing state")
|
||||||
|
};
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
AppError::Validation(message) => {
|
AppError::Validation(message) => {
|
||||||
@@ -706,23 +879,25 @@ mod tests {
|
|||||||
"unexpected message: {message}"
|
"unexpected message: {message}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
other => panic!("expected validation error, got {other:?}"),
|
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_release_requires_reservation() {
|
async fn test_release_requires_reservation() -> anyhow::Result<()> {
|
||||||
let db = memory_db().await;
|
let db = memory_db().await?;
|
||||||
let user_id = "user123";
|
let user_id = "user123";
|
||||||
let payload = create_payload(user_id);
|
let payload = create_payload(user_id);
|
||||||
|
|
||||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
db.store_item(task.clone()).await.expect("store");
|
db.store_item(task.clone())
|
||||||
|
|
||||||
let err = task
|
|
||||||
.release(&db)
|
|
||||||
.await
|
.await
|
||||||
.expect_err("release should require reserved state");
|
.with_context(|| "store".to_string())?;
|
||||||
|
|
||||||
|
let Err(err) = task.release(&db).await else {
|
||||||
|
anyhow::bail!("release should require reserved state")
|
||||||
|
};
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
AppError::Validation(message) => {
|
AppError::Validation(message) => {
|
||||||
@@ -731,7 +906,8 @@ mod tests {
|
|||||||
"unexpected message: {message}"
|
"unexpected message: {message}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
other => panic!("expected validation error, got {other:?}"),
|
other => anyhow::bail!("expected validation error, got {other:?}"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -2,11 +2,17 @@ use std::collections::HashMap;
|
|||||||
|
|
||||||
use surrealdb::RecordId;
|
use surrealdb::RecordId;
|
||||||
|
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
use crate::{
|
||||||
|
error::AppError,
|
||||||
|
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||||
|
stored_object,
|
||||||
|
};
|
||||||
|
|
||||||
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
|
||||||
entity_id: RecordId,
|
entity_id: RecordId,
|
||||||
embedding: Vec<f32>,
|
embedding: Vec<f32>,
|
||||||
|
/// Denormalized source id for bulk deletes
|
||||||
|
source_id: String,
|
||||||
/// Denormalized user id for query scoping
|
/// Denormalized user id for query scoping
|
||||||
user_id: String
|
user_id: String
|
||||||
});
|
});
|
||||||
@@ -17,29 +23,43 @@ impl KnowledgeEntityEmbedding {
|
|||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
dimension: usize,
|
dimension: usize,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let query = format!(
|
let query = hnsw_index_redefine_transaction_sql(
|
||||||
"BEGIN TRANSACTION;
|
"idx_embedding_knowledge_entity_embedding",
|
||||||
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table};
|
Self::table_name(),
|
||||||
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
|
dimension,
|
||||||
COMMIT TRANSACTION;",
|
|
||||||
table = Self::table_name(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = db.client.query(query).await.map_err(AppError::Database)?;
|
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||||
res.check().map_err(AppError::Database)?;
|
res.check().map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new knowledge entity embedding
|
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||||
pub fn new(entity_id: &str, embedding: Vec<f32>, user_id: String) -> Self {
|
#[allow(clippy::result_large_err)]
|
||||||
|
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||||
|
if embedding.len() != expected {
|
||||||
|
return Err(AppError::Validation(format!(
|
||||||
|
"embedding dimension mismatch: got {}, expected {expected}",
|
||||||
|
embedding.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new knowledge entity embedding.
|
||||||
|
///
|
||||||
|
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
Self {
|
Self {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: entity_id.to_owned(),
|
||||||
created_at: now,
|
created_at: now,
|
||||||
updated_at: now,
|
updated_at: now,
|
||||||
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
|
||||||
embedding,
|
embedding,
|
||||||
|
source_id,
|
||||||
user_id,
|
user_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -58,8 +78,8 @@ impl KnowledgeEntityEmbedding {
|
|||||||
.query(query)
|
.query(query)
|
||||||
.bind(("entity_id", entity_id.clone()))
|
.bind(("entity_id", entity_id.clone()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::Database)?;
|
.map_err(AppError::from)?;
|
||||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||||
Ok(embeddings.into_iter().next())
|
Ok(embeddings.into_iter().next())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,8 +92,6 @@ impl KnowledgeEntityEmbedding {
|
|||||||
return Ok(HashMap::new());
|
return Ok(HashMap::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
let ids_list: Vec<RecordId> = entity_ids.to_vec();
|
|
||||||
|
|
||||||
let query = format!(
|
let query = format!(
|
||||||
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
|
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
|
||||||
Self::table_name()
|
Self::table_name()
|
||||||
@@ -81,10 +99,10 @@ impl KnowledgeEntityEmbedding {
|
|||||||
let mut result = db
|
let mut result = db
|
||||||
.client
|
.client
|
||||||
.query(query)
|
.query(query)
|
||||||
.bind(("entity_ids", ids_list))
|
.bind(("entity_ids", entity_ids.to_vec()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::Database)?;
|
.map_err(AppError::from)?;
|
||||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(embeddings
|
Ok(embeddings
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -105,59 +123,41 @@ impl KnowledgeEntityEmbedding {
|
|||||||
.query(query)
|
.query(query)
|
||||||
.bind(("entity_id", entity_id.clone()))
|
.bind(("entity_id", entity_id.clone()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::Database)?;
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Delete embeddings by source_id (via joining to knowledge_entity table)
|
/// Delete all embeddings with the given denormalized `source_id`.
|
||||||
#[allow(clippy::items_after_statements)]
|
|
||||||
pub async fn delete_by_source_id(
|
pub async fn delete_by_source_id(
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
|
let query = format!(
|
||||||
let mut res = db
|
"DELETE FROM {} WHERE source_id = $source_id",
|
||||||
.client
|
Self::table_name()
|
||||||
|
);
|
||||||
|
db.client
|
||||||
.query(query)
|
.query(query)
|
||||||
.bind(("source_id", source_id.to_owned()))
|
.bind(("source_id", source_id.to_owned()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::Database)?;
|
.map_err(AppError::from)?
|
||||||
#[allow(clippy::missing_docs_in_private_items)]
|
.check()
|
||||||
#[derive(Deserialize)]
|
.map_err(AppError::from)?;
|
||||||
struct IdRow {
|
|
||||||
id: RecordId,
|
|
||||||
}
|
|
||||||
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
|
|
||||||
|
|
||||||
for row in ids {
|
|
||||||
Self::delete_by_entity_id(&row.id, db).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::db::SurrealDbClient;
|
|
||||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||||
|
use crate::test_utils::{prepare_knowledge_entity_test_db, setup_test_db};
|
||||||
|
use anyhow::{self, Context};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use surrealdb::Value as SurrealValue;
|
use surrealdb::Value as SurrealValue;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
async fn setup_test_db() -> SurrealDbClient {
|
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, &database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.expect("Failed to apply migrations");
|
|
||||||
|
|
||||||
db
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_knowledge_entity_with_id(
|
fn build_knowledge_entity_with_id(
|
||||||
key: &str,
|
key: &str,
|
||||||
@@ -177,12 +177,27 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_uses_entity_id_as_record_id() {
|
||||||
|
let emb = KnowledgeEntityEmbedding::new(
|
||||||
|
"entity-abc",
|
||||||
|
"source-1".to_owned(),
|
||||||
|
vec![0.1, 0.2],
|
||||||
|
"user-1".to_owned(),
|
||||||
|
);
|
||||||
|
assert_eq!(emb.id, "entity-abc");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_dimension_rejects_mismatch() {
|
||||||
|
let err = KnowledgeEntityEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2)
|
||||||
|
.expect_err("expected dimension mismatch");
|
||||||
|
assert!(matches!(err, AppError::Validation(_)));
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_and_get_by_entity_id() {
|
async fn test_create_and_get_by_entity_id() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.expect("set test index dimension");
|
|
||||||
let user_id = "user_ke";
|
let user_id = "user_ke";
|
||||||
let entity_key = "entity-1";
|
let entity_key = "entity-1";
|
||||||
let source_id = "source-ke";
|
let source_id = "source-ke";
|
||||||
@@ -192,26 +207,27 @@ mod tests {
|
|||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity with embedding");
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
|
||||||
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get embedding by entity_id")
|
.with_context(|| "Failed to get embedding by entity_id".to_string())?
|
||||||
.expect("Expected embedding to exist");
|
.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||||
|
|
||||||
|
assert_eq!(fetched.id, entity_key);
|
||||||
assert_eq!(fetched.user_id, user_id);
|
assert_eq!(fetched.user_id, user_id);
|
||||||
|
assert_eq!(fetched.source_id, source_id);
|
||||||
assert_eq!(fetched.entity_id, entity_rid);
|
assert_eq!(fetched.entity_id, entity_rid);
|
||||||
assert_eq!(fetched.embedding, embedding_vec);
|
assert_eq!(fetched.embedding, embedding_vec);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_entity_id() {
|
async fn test_delete_by_entity_id() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
|
||||||
.expect("set test index dimension");
|
|
||||||
let user_id = "user_ke";
|
let user_id = "user_ke";
|
||||||
let entity_key = "entity-delete";
|
let entity_key = "entity-delete";
|
||||||
let source_id = "source-del";
|
let source_id = "source-del";
|
||||||
@@ -220,61 +236,75 @@ mod tests {
|
|||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity with embedding");
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
|
||||||
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get embedding before delete");
|
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||||
assert!(existing.is_some());
|
assert!(existing.is_some());
|
||||||
|
|
||||||
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete by entity_id");
|
.with_context(|| "Failed to delete by entity_id".to_string())?;
|
||||||
|
|
||||||
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get embedding after delete");
|
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||||
assert!(after.is_none());
|
assert!(after.is_none());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_with_embedding_creates_entity_and_embedding() {
|
async fn test_store_with_embedding_creates_entity_and_embedding() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let user_id = "user_store";
|
let user_id = "user_store";
|
||||||
let source_id = "source_store";
|
let source_id = "source_store";
|
||||||
let embedding = vec![0.2_f32, 0.3, 0.4];
|
let embedding = vec![0.2_f32, 0.3, 0.4];
|
||||||
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
|
|
||||||
.await
|
|
||||||
.expect("set test index dimension");
|
|
||||||
|
|
||||||
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity with embedding");
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
|
let stored_entity: Option<KnowledgeEntity> = db
|
||||||
|
.get_item(&entity.id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get entity".to_string())?;
|
||||||
assert!(stored_entity.is_some());
|
assert!(stored_entity.is_some());
|
||||||
|
|
||||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to fetch embedding");
|
.with_context(|| "Failed to fetch embedding".to_string())?;
|
||||||
assert!(stored_embedding.is_some());
|
let stored_embedding =
|
||||||
let stored_embedding = stored_embedding.unwrap();
|
stored_embedding.ok_or_else(|| anyhow::anyhow!("Expected embedding to exist"))?;
|
||||||
|
assert_eq!(stored_embedding.id, entity.id);
|
||||||
assert_eq!(stored_embedding.user_id, user_id);
|
assert_eq!(stored_embedding.user_id, user_id);
|
||||||
|
assert_eq!(stored_embedding.source_id, source_id);
|
||||||
assert_eq!(stored_embedding.entity_id, entity_rid);
|
assert_eq!(stored_embedding.entity_id, entity_rid);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_source_id() {
|
async fn test_store_with_embedding_rejects_wrong_dimension() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
|
||||||
.await
|
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||||
.expect("set test index dimension");
|
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
|
||||||
|
|
||||||
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let user_id = "user_ke";
|
let user_id = "user_ke";
|
||||||
let source_id = "shared-ke";
|
let source_id = "shared-ke";
|
||||||
let other_source = "other-ke";
|
let other_source = "other-ke";
|
||||||
@@ -285,13 +315,13 @@ mod tests {
|
|||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity with embedding");
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity with embedding");
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity with embedding");
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
|
||||||
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
|
||||||
@@ -299,59 +329,75 @@ mod tests {
|
|||||||
|
|
||||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
|
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete by source_id");
|
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "get entity1 embedding after delete".to_string())?
|
||||||
.is_none()
|
.is_none()
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "get entity2 embedding after delete".to_string())?
|
||||||
.is_none()
|
.is_none()
|
||||||
);
|
);
|
||||||
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "get other embedding after delete".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
|
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
|
||||||
.await
|
.await
|
||||||
.expect("failed to redefine index");
|
.with_context(|| "failed to redefine index".to_string())?;
|
||||||
|
|
||||||
let mut info_res = db
|
let mut info_res = db
|
||||||
.client
|
.client
|
||||||
.query("INFO FOR TABLE knowledge_entity_embedding;")
|
.query("INFO FOR TABLE knowledge_entity_embedding;")
|
||||||
.await
|
.await
|
||||||
.expect("info query failed");
|
.with_context(|| "info query failed".to_string())?;
|
||||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
let info: SurrealValue = info_res
|
||||||
let info_json: serde_json::Value =
|
.take(0)
|
||||||
serde_json::to_value(info).expect("failed to convert info to json");
|
.with_context(|| "failed to take info result".to_string())?;
|
||||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||||
["idx_embedding_knowledge_entity_embedding"]["Strand"]
|
.with_context(|| "failed to convert info to json".to_string())?;
|
||||||
.as_str()
|
let idx_sql = info_json
|
||||||
|
.get("Object")
|
||||||
|
.and_then(|v| v.get("indexes"))
|
||||||
|
.and_then(|v| v.get("Object"))
|
||||||
|
.and_then(|v| v.get("idx_embedding_knowledge_entity_embedding"))
|
||||||
|
.and_then(|v| v.get("Strand"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
idx_sql.contains("DIMENSION 16"),
|
idx_sql.contains("DIMENSION 16"),
|
||||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||||
);
|
);
|
||||||
|
assert!(
|
||||||
|
idx_sql.contains("DIST COSINE"),
|
||||||
|
"expected index definition to use cosine distance, got: {idx_sql}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fetch_entity_via_record_id() {
|
async fn test_fetch_entity_via_record_id() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
#[derive(Deserialize)]
|
||||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
struct Row {
|
||||||
.await
|
entity_id: KnowledgeEntity,
|
||||||
.expect("set test index dimension");
|
}
|
||||||
|
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
let user_id = "user_ke";
|
let user_id = "user_ke";
|
||||||
let entity_key = "entity-fetch";
|
let entity_key = "entity-fetch";
|
||||||
let source_id = "source-fetch";
|
let source_id = "source-fetch";
|
||||||
@@ -359,15 +405,10 @@ mod tests {
|
|||||||
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
|
||||||
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity with embedding");
|
.with_context(|| "Failed to store entity with embedding".to_string())?;
|
||||||
|
|
||||||
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct Row {
|
|
||||||
entity_id: KnowledgeEntity,
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut res = db
|
let mut res = db
|
||||||
.client
|
.client
|
||||||
.query(
|
.query(
|
||||||
@@ -375,13 +416,63 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.bind(("id", entity_rid.clone()))
|
.bind(("id", entity_rid.clone()))
|
||||||
.await
|
.await
|
||||||
.expect("failed to fetch embedding with FETCH");
|
.with_context(|| "failed to fetch embedding with FETCH".to_string())?;
|
||||||
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows");
|
let rows: Vec<Row> = res
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "failed to deserialize fetch rows".to_string())?;
|
||||||
|
|
||||||
assert_eq!(rows.len(), 1);
|
assert_eq!(rows.len(), 1);
|
||||||
let fetched_entity = &rows[0].entity_id;
|
let fetched_entity = &rows
|
||||||
|
.first()
|
||||||
|
.context("Expected at least one result")?
|
||||||
|
.entity_id;
|
||||||
assert_eq!(fetched_entity.id, entity_key);
|
assert_eq!(fetched_entity.id, entity_key);
|
||||||
assert_eq!(fetched_entity.name, "Test entity");
|
assert_eq!(fetched_entity.name, "Test entity");
|
||||||
assert_eq!(fetched_entity.user_id, user_id);
|
assert_eq!(fetched_entity.user_id, user_id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||||
|
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||||
|
|
||||||
|
let user_id = "user-upsert";
|
||||||
|
let source_id = "source-upsert";
|
||||||
|
let entity = build_knowledge_entity_with_id("entity-upsert", source_id, user_id);
|
||||||
|
|
||||||
|
KnowledgeEntity::store_with_embedding(entity.clone(), vec![1.0_f32, 0.0, 0.0], &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "initial store".to_string())?;
|
||||||
|
|
||||||
|
let replacement = KnowledgeEntityEmbedding::new(
|
||||||
|
&entity.id,
|
||||||
|
source_id.to_owned(),
|
||||||
|
vec![0.0, 1.0, 0.0],
|
||||||
|
user_id.to_owned(),
|
||||||
|
);
|
||||||
|
db.upsert_item(replacement)
|
||||||
|
.await
|
||||||
|
.with_context(|| "upsert replacement embedding".to_string())?;
|
||||||
|
|
||||||
|
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
|
||||||
|
let rows: Vec<KnowledgeEntityEmbedding> = db
|
||||||
|
.client
|
||||||
|
.query(format!(
|
||||||
|
"SELECT * FROM {} WHERE entity_id = $entity_id",
|
||||||
|
KnowledgeEntityEmbedding::table_name()
|
||||||
|
))
|
||||||
|
.bind(("entity_id", entity_rid))
|
||||||
|
.await
|
||||||
|
.with_context(|| "count embeddings".to_string())?
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "take embeddings".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(rows.len(), 1);
|
||||||
|
let row = rows.first().expect("expected one embedding row");
|
||||||
|
assert_eq!(row.id, entity.id);
|
||||||
|
assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::storage::types::file_info::deserialize_flexible_id;
|
use crate::storage::types::user::User;
|
||||||
|
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -21,6 +22,7 @@ pub struct KnowledgeRelationship {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl KnowledgeRelationship {
|
impl KnowledgeRelationship {
|
||||||
|
#[must_use]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
in_: String,
|
in_: String,
|
||||||
out: String,
|
out: String,
|
||||||
@@ -39,7 +41,25 @@ impl KnowledgeRelationship {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
|
||||||
|
pub async fn store_relationship(self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||||
|
User::get_and_validate_knowledge_entity(&self.in_, &self.metadata.user_id, db_client)
|
||||||
|
.await?;
|
||||||
|
User::get_and_validate_knowledge_entity(&self.out, &self.metadata.user_id, db_client)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let Self {
|
||||||
|
id,
|
||||||
|
in_,
|
||||||
|
out,
|
||||||
|
metadata:
|
||||||
|
RelationshipMetadata {
|
||||||
|
user_id,
|
||||||
|
source_id,
|
||||||
|
relationship_type,
|
||||||
|
},
|
||||||
|
} = self;
|
||||||
|
|
||||||
db_client
|
db_client
|
||||||
.client
|
.client
|
||||||
.query(
|
.query(
|
||||||
@@ -54,28 +74,36 @@ impl KnowledgeRelationship {
|
|||||||
metadata.relationship_type = $relationship_type;
|
metadata.relationship_type = $relationship_type;
|
||||||
COMMIT TRANSACTION;"#,
|
COMMIT TRANSACTION;"#,
|
||||||
)
|
)
|
||||||
.bind(("rel_id", self.id.clone()))
|
.bind(("rel_id", id))
|
||||||
.bind(("in_id", self.in_.clone()))
|
.bind(("in_id", in_))
|
||||||
.bind(("out_id", self.out.clone()))
|
.bind(("out_id", out))
|
||||||
.bind(("user_id", self.metadata.user_id.clone()))
|
.bind(("user_id", user_id))
|
||||||
.bind(("source_id", self.metadata.source_id.clone()))
|
.bind(("source_id", source_id))
|
||||||
.bind(("relationship_type", self.metadata.relationship_type.clone()))
|
.bind(("relationship_type", relationship_type))
|
||||||
.await?
|
.await
|
||||||
.check()?;
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn delete_relationships_by_source_id(
|
pub async fn delete_relationships_by_source_id(
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
|
user_id: &str,
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
db_client
|
db_client
|
||||||
.client
|
.client
|
||||||
.query("DELETE FROM relates_to WHERE metadata.source_id = $source_id")
|
.query(
|
||||||
|
"DELETE FROM relates_to WHERE metadata.source_id = $source_id AND metadata.user_id = $user_id",
|
||||||
|
)
|
||||||
.bind(("source_id", source_id.to_owned()))
|
.bind(("source_id", source_id.to_owned()))
|
||||||
.await?
|
.bind(("user_id", user_id.to_owned()))
|
||||||
.check()?;
|
.await
|
||||||
|
.map_err(AppError::from)?
|
||||||
|
.check()
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -85,61 +113,48 @@ impl KnowledgeRelationship {
|
|||||||
user_id: &str,
|
user_id: &str,
|
||||||
db_client: &SurrealDbClient,
|
db_client: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let mut authorized_result = db_client
|
let mut delete_result = db_client
|
||||||
.client
|
.client
|
||||||
.query(
|
.query(
|
||||||
"SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id",
|
"DELETE type::thing('relates_to', $id) WHERE metadata.user_id = $user_id RETURN BEFORE;",
|
||||||
)
|
)
|
||||||
.bind(("id", id.to_owned()))
|
.bind(("id", id.to_owned()))
|
||||||
.bind(("user_id", user_id.to_owned()))
|
.bind(("user_id", user_id.to_owned()))
|
||||||
.await?;
|
.await
|
||||||
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
|
.map_err(AppError::from)?;
|
||||||
|
let deleted: Vec<KnowledgeRelationship> = delete_result.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
if authorized.is_empty() {
|
if !deleted.is_empty() {
|
||||||
let mut exists_result = db_client
|
return Ok(());
|
||||||
.client
|
}
|
||||||
.query("SELECT * FROM type::thing('relates_to', $id)")
|
|
||||||
.bind(("id", id.to_owned()))
|
|
||||||
.await?;
|
|
||||||
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
|
|
||||||
|
|
||||||
if existing.is_some() {
|
let mut exists_result = db_client
|
||||||
Err(AppError::Auth(
|
.client
|
||||||
"Not authorized to delete relationship".into(),
|
.query("SELECT * FROM type::thing('relates_to', $id)")
|
||||||
))
|
.bind(("id", id.to_owned()))
|
||||||
} else {
|
.await
|
||||||
Err(AppError::NotFound(format!("Relationship {id} not found")))
|
.map_err(AppError::from)?;
|
||||||
}
|
let existing: Option<KnowledgeRelationship> =
|
||||||
|
exists_result.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
|
if existing.is_some() {
|
||||||
|
Err(AppError::Auth(
|
||||||
|
"Not authorized to delete relationship".into(),
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
db_client
|
Err(AppError::NotFound(format!("Relationship {id} not found")))
|
||||||
.client
|
|
||||||
.query("DELETE type::thing('relates_to', $id)")
|
|
||||||
.bind(("id", id.to_owned()))
|
|
||||||
.await?
|
|
||||||
.check()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
async fn setup_test_db() -> SurrealDbClient {
|
use crate::test_utils::setup_test_db;
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = &Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.expect("Failed to apply migrations");
|
|
||||||
|
|
||||||
db
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_relationship_by_id(
|
async fn get_relationship_by_id(
|
||||||
relationship_id: &str,
|
relationship_id: &str,
|
||||||
@@ -155,12 +170,14 @@ mod tests {
|
|||||||
result.take(0).expect("failed to take relationship by id")
|
result.take(0).expect("failed to take relationship by id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to create a test knowledge entity for the relationship tests
|
async fn create_test_entity(
|
||||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
|
name: &str,
|
||||||
|
user_id: &str,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let description = format!("Description for {}", name);
|
let description = format!("Description for {name}");
|
||||||
let entity_type = KnowledgeEntityType::Document;
|
let entity_type = KnowledgeEntityType::Document;
|
||||||
let user_id = "user123".to_string();
|
|
||||||
|
|
||||||
let entity = KnowledgeEntity::new(
|
let entity = KnowledgeEntity::new(
|
||||||
source_id,
|
source_id,
|
||||||
@@ -168,18 +185,20 @@ mod tests {
|
|||||||
description,
|
description,
|
||||||
entity_type,
|
entity_type,
|
||||||
None,
|
None,
|
||||||
user_id,
|
user_id.to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let stored: Option<KnowledgeEntity> = db_client
|
let stored: Option<KnowledgeEntity> = db_client
|
||||||
.store_item(entity)
|
.store_item(entity)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store entity");
|
.with_context(|| "Failed to store entity".to_string())?;
|
||||||
stored.unwrap().id
|
stored
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Expected stored entity to return Some"))
|
||||||
|
.map(|e| e.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_relationship_creation() {
|
async fn test_relationship_creation() -> anyhow::Result<()> {
|
||||||
let in_id = "entity1".to_string();
|
let in_id = "entity1".to_string();
|
||||||
let out_id = "entity2".to_string();
|
let out_id = "entity2".to_string();
|
||||||
let user_id = "user123".to_string();
|
let user_id = "user123".to_string();
|
||||||
@@ -194,44 +213,42 @@ mod tests {
|
|||||||
relationship_type.clone(),
|
relationship_type.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify fields are correctly set
|
|
||||||
assert_eq!(relationship.in_, in_id);
|
assert_eq!(relationship.in_, in_id);
|
||||||
assert_eq!(relationship.out, out_id);
|
assert_eq!(relationship.out, out_id);
|
||||||
assert_eq!(relationship.metadata.user_id, user_id);
|
assert_eq!(relationship.metadata.user_id, user_id);
|
||||||
assert_eq!(relationship.metadata.source_id, source_id);
|
assert_eq!(relationship.metadata.source_id, source_id);
|
||||||
assert_eq!(relationship.metadata.relationship_type, relationship_type);
|
assert_eq!(relationship.metadata.relationship_type, relationship_type);
|
||||||
assert!(!relationship.id.is_empty());
|
assert!(!relationship.id.is_empty());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_and_verify_by_source_id() {
|
async fn test_store_and_verify_by_source_id() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let db = setup_test_db().await;
|
let user_id = "user123";
|
||||||
|
|
||||||
// Create two entities to relate
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
|
||||||
|
|
||||||
// Create relationship
|
|
||||||
let user_id = "user123".to_string();
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let relationship_type = "references".to_string();
|
let relationship_type = "references".to_string();
|
||||||
|
|
||||||
let relationship = KnowledgeRelationship::new(
|
let relationship = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
user_id.clone(),
|
user_id.to_string(),
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
relationship_type,
|
relationship_type,
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
// Store the relationship
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship");
|
.with_context(|| "Failed to store relationship".to_string())?;
|
||||||
|
|
||||||
let persisted = get_relationship_by_id(&relationship.id, &db)
|
let persisted = get_relationship_by_id(&relationship_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Relationship should be retrievable by id");
|
.expect("Relationship should be retrievable by id");
|
||||||
assert_eq!(persisted.in_, entity1_id);
|
assert_eq!(persisted.in_, entity1_id);
|
||||||
@@ -239,8 +256,6 @@ mod tests {
|
|||||||
assert_eq!(persisted.metadata.user_id, user_id);
|
assert_eq!(persisted.metadata.user_id, user_id);
|
||||||
assert_eq!(persisted.metadata.source_id, source_id);
|
assert_eq!(persisted.metadata.source_id, source_id);
|
||||||
|
|
||||||
// Query to verify the relationship exists by checking for relationships with our source_id
|
|
||||||
// This approach is more reliable than trying to look up by ID
|
|
||||||
let mut check_result = db
|
let mut check_result = db
|
||||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||||
.bind(("source_id", source_id.clone()))
|
.bind(("source_id", source_id.clone()))
|
||||||
@@ -253,22 +268,47 @@ mod tests {
|
|||||||
1,
|
1,
|
||||||
"Expected one relationship for source_id"
|
"Expected one relationship for source_id"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_relationship_resists_query_injection() {
|
async fn test_store_relationship_rejects_foreign_entity() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
let owner_entity = create_test_entity("Owner entity", "owner-user", &db).await?;
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
let other_entity = create_test_entity("Other entity", "other-user", &db).await?;
|
||||||
|
|
||||||
|
let relationship = KnowledgeRelationship::new(
|
||||||
|
owner_entity,
|
||||||
|
other_entity,
|
||||||
|
"owner-user".to_string(),
|
||||||
|
"source123".to_string(),
|
||||||
|
"references".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = relationship.store_relationship(&db).await;
|
||||||
|
assert!(matches!(result, Err(AppError::Auth(_))));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_store_relationship_resists_query_injection() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
|
|
||||||
let relationship = KnowledgeRelationship::new(
|
let relationship = KnowledgeRelationship::new(
|
||||||
entity1_id,
|
entity1_id,
|
||||||
entity2_id,
|
entity2_id,
|
||||||
"user'123".to_string(),
|
user_id.to_string(),
|
||||||
"source123'; DELETE FROM relates_to; --".to_string(),
|
"source123'; DELETE FROM relates_to; --".to_string(),
|
||||||
"references'; UPDATE user SET admin = true; --".to_string(),
|
"references'; UPDATE user SET admin = true; --".to_string(),
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
@@ -278,54 +318,52 @@ mod tests {
|
|||||||
let mut res = db
|
let mut res = db
|
||||||
.client
|
.client
|
||||||
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
|
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
|
||||||
.bind(("id", relationship.id.clone()))
|
.bind(("id", relationship_id))
|
||||||
.await
|
.await
|
||||||
.expect("query relationship by id failed");
|
.expect("query relationship by id failed");
|
||||||
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
|
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
|
||||||
|
|
||||||
assert_eq!(rows.len(), 1);
|
assert_eq!(rows.len(), 1);
|
||||||
|
let row = rows.first().expect("expected 1 row");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
rows[0].metadata.source_id,
|
row.metadata.source_id,
|
||||||
"source123'; DELETE FROM relates_to; --"
|
"source123'; DELETE FROM relates_to; --"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_and_delete_relationship() {
|
async fn test_store_and_delete_relationship() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let db = setup_test_db().await;
|
let user_id = "user123";
|
||||||
|
|
||||||
// Create two entities to relate
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
|
||||||
|
|
||||||
// Create relationship
|
|
||||||
let user_id = "user123".to_string();
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let relationship_type = "references".to_string();
|
let relationship_type = "references".to_string();
|
||||||
|
|
||||||
let relationship = KnowledgeRelationship::new(
|
let relationship = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
user_id.clone(),
|
user_id.to_string(),
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
relationship_type,
|
relationship_type,
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
// Store relationship
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship");
|
.with_context(|| "Failed to store relationship".to_string())?;
|
||||||
|
|
||||||
// Ensure relationship exists before deletion attempt
|
|
||||||
let mut existing_before_delete = db
|
let mut existing_before_delete = db
|
||||||
.query(format!(
|
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
|
||||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
.bind(("user_id", user_id.to_string()))
|
||||||
user_id, source_id
|
.bind(("source_id", source_id.clone()))
|
||||||
))
|
|
||||||
.await
|
.await
|
||||||
.expect("Query failed");
|
.with_context(|| "Query failed".to_string())?;
|
||||||
let before_results: Vec<KnowledgeRelationship> =
|
let before_results: Vec<KnowledgeRelationship> =
|
||||||
existing_before_delete.take(0).unwrap_or_default();
|
existing_before_delete.take(0).unwrap_or_default();
|
||||||
assert!(
|
assert!(
|
||||||
@@ -333,55 +371,52 @@ mod tests {
|
|||||||
"Relationship should exist before deletion"
|
"Relationship should exist before deletion"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Delete relationship by ID
|
KnowledgeRelationship::delete_relationship_by_id(&relationship_id, user_id, &db)
|
||||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
|
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete relationship by ID");
|
.with_context(|| "Failed to delete relationship by ID".to_string())?;
|
||||||
|
|
||||||
// Query to verify relationship was deleted
|
|
||||||
let mut result = db
|
let mut result = db
|
||||||
.query(format!(
|
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id AND metadata.source_id = $source_id")
|
||||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
.bind(("user_id", user_id.to_string()))
|
||||||
user_id, source_id
|
.bind(("source_id", source_id))
|
||||||
))
|
|
||||||
.await
|
.await
|
||||||
.expect("Query failed");
|
.with_context(|| "Query failed".to_string())?;
|
||||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||||
|
|
||||||
// Verify relationship no longer exists
|
|
||||||
assert!(results.is_empty(), "Relationship should be deleted");
|
assert!(results.is_empty(), "Relationship should be deleted");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_relationship_by_id_unauthorized() {
|
async fn test_delete_relationship_by_id_unauthorized() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
let owner_user_id = "owner-user";
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
let entity1_id = create_test_entity("Entity 1", owner_user_id, &db).await?;
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
let entity2_id = create_test_entity("Entity 2", owner_user_id, &db).await?;
|
||||||
|
|
||||||
let owner_user_id = "owner-user".to_string();
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
|
|
||||||
let relationship = KnowledgeRelationship::new(
|
let relationship = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
owner_user_id.clone(),
|
owner_user_id.to_string(),
|
||||||
source_id,
|
source_id,
|
||||||
"references".to_string(),
|
"references".to_string(),
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship");
|
.with_context(|| "Failed to store relationship".to_string())?;
|
||||||
|
|
||||||
let mut before_attempt = db
|
let mut before_attempt = db
|
||||||
.query(format!(
|
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
|
||||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
.bind(("user_id", owner_user_id.to_string()))
|
||||||
owner_user_id
|
|
||||||
))
|
|
||||||
.await
|
.await
|
||||||
.expect("Query failed");
|
.with_context(|| "Query failed".to_string())?;
|
||||||
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
|
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
|
||||||
assert!(
|
assert!(
|
||||||
!before_results.is_empty(),
|
!before_results.is_empty(),
|
||||||
@@ -389,7 +424,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let result = KnowledgeRelationship::delete_relationship_by_id(
|
let result = KnowledgeRelationship::delete_relationship_by_id(
|
||||||
&relationship.id,
|
&relationship_id,
|
||||||
"different-user",
|
"different-user",
|
||||||
&db,
|
&db,
|
||||||
)
|
)
|
||||||
@@ -397,44 +432,42 @@ mod tests {
|
|||||||
|
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::Auth(_)) => {}
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected authorization error when deleting someone else's relationship"),
|
_ => anyhow::bail!(
|
||||||
|
"Expected authorization error when deleting someone else's relationship"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut after_attempt = db
|
let mut after_attempt = db
|
||||||
.query(format!(
|
.query("SELECT * FROM relates_to WHERE metadata.user_id = $user_id")
|
||||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
.bind(("user_id", owner_user_id.to_string()))
|
||||||
owner_user_id
|
|
||||||
))
|
|
||||||
.await
|
.await
|
||||||
.expect("Query failed");
|
.with_context(|| "Query failed".to_string())?;
|
||||||
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
|
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
!results.is_empty(),
|
!results.is_empty(),
|
||||||
"Relationship should still exist after unauthorized delete attempt"
|
"Relationship should still exist after unauthorized delete attempt"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_store_relationship_exists() {
|
async fn test_store_relationship_exists() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
let db = setup_test_db().await?;
|
||||||
let db = setup_test_db().await;
|
let user_id = "user123";
|
||||||
|
|
||||||
// Create entities to relate
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
|
||||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
|
||||||
|
|
||||||
// Create relationships with the same source_id
|
|
||||||
let user_id = "user123".to_string();
|
|
||||||
let source_id = "source123".to_string();
|
let source_id = "source123".to_string();
|
||||||
let different_source_id = "different_source".to_string();
|
let different_source_id = "different_source".to_string();
|
||||||
|
|
||||||
// Create two relationships with the same source_id
|
|
||||||
let relationship1 = KnowledgeRelationship::new(
|
let relationship1 = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
user_id.clone(),
|
user_id.to_string(),
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
"references".to_string(),
|
"references".to_string(),
|
||||||
);
|
);
|
||||||
@@ -442,35 +475,35 @@ mod tests {
|
|||||||
let relationship2 = KnowledgeRelationship::new(
|
let relationship2 = KnowledgeRelationship::new(
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
entity3_id.clone(),
|
entity3_id.clone(),
|
||||||
user_id.clone(),
|
user_id.to_string(),
|
||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
"contains".to_string(),
|
"contains".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create a relationship with a different source_id
|
|
||||||
let different_relationship = KnowledgeRelationship::new(
|
let different_relationship = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity3_id.clone(),
|
entity3_id.clone(),
|
||||||
user_id.clone(),
|
user_id.to_string(),
|
||||||
different_source_id.clone(),
|
different_source_id.clone(),
|
||||||
"mentions".to_string(),
|
"mentions".to_string(),
|
||||||
);
|
);
|
||||||
|
let relationship1_id = relationship1.id.clone();
|
||||||
|
let relationship2_id = relationship2.id.clone();
|
||||||
|
let different_relationship_id = different_relationship.id.clone();
|
||||||
|
|
||||||
// Store all relationships
|
|
||||||
relationship1
|
relationship1
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship 1");
|
.with_context(|| "Failed to store relationship 1".to_string())?;
|
||||||
relationship2
|
relationship2
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store relationship 2");
|
.with_context(|| "Failed to store relationship 2".to_string())?;
|
||||||
different_relationship
|
different_relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store different relationship");
|
.with_context(|| "Failed to store different relationship".to_string())?;
|
||||||
|
|
||||||
// Sanity-check setup: exactly two relationships use source_id and one uses different_source_id.
|
|
||||||
let mut before_delete = db
|
let mut before_delete = db
|
||||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||||
.bind(("source_id", source_id.clone()))
|
.bind(("source_id", source_id.clone()))
|
||||||
@@ -489,36 +522,83 @@ mod tests {
|
|||||||
before_delete_different.take(0).unwrap_or_default();
|
before_delete_different.take(0).unwrap_or_default();
|
||||||
assert_eq!(before_delete_different_rows.len(), 1);
|
assert_eq!(before_delete_different_rows.len(), 1);
|
||||||
|
|
||||||
// Delete relationships by source_id
|
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, user_id, &db)
|
||||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
|
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete relationships by source_id");
|
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
|
||||||
|
|
||||||
// Query to verify the specific relationships with source_id were deleted.
|
let result1 = get_relationship_by_id(&relationship1_id, &db).await;
|
||||||
let result1 = get_relationship_by_id(&relationship1.id, &db).await;
|
let result2 = get_relationship_by_id(&relationship2_id, &db).await;
|
||||||
let result2 = get_relationship_by_id(&relationship2.id, &db).await;
|
let different_result = get_relationship_by_id(&different_relationship_id, &db).await;
|
||||||
let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
|
|
||||||
|
|
||||||
// Verify relationships with the source_id are deleted
|
|
||||||
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
||||||
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
||||||
let remaining =
|
let remaining =
|
||||||
different_result.expect("Relationship with different source_id should remain");
|
different_result.expect("Relationship with different source_id should remain");
|
||||||
assert_eq!(remaining.metadata.source_id, different_source_id);
|
assert_eq!(remaining.metadata.source_id, different_source_id);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_relationships_by_source_id_resists_query_injection() {
|
async fn test_delete_relationships_by_source_id_scoped_to_user() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
let user_a = "user-a";
|
||||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
let user_b = "user-b";
|
||||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
let shared_source = "shared-source";
|
||||||
|
|
||||||
|
let a1 = create_test_entity("A1", user_a, &db).await?;
|
||||||
|
let a2 = create_test_entity("A2", user_a, &db).await?;
|
||||||
|
let b1 = create_test_entity("B1", user_b, &db).await?;
|
||||||
|
let b2 = create_test_entity("B2", user_b, &db).await?;
|
||||||
|
|
||||||
|
let rel_a = KnowledgeRelationship::new(
|
||||||
|
a1,
|
||||||
|
a2,
|
||||||
|
user_a.to_string(),
|
||||||
|
shared_source.to_string(),
|
||||||
|
"references".to_string(),
|
||||||
|
);
|
||||||
|
let rel_b = KnowledgeRelationship::new(
|
||||||
|
b1,
|
||||||
|
b2,
|
||||||
|
user_b.to_string(),
|
||||||
|
shared_source.to_string(),
|
||||||
|
"references".to_string(),
|
||||||
|
);
|
||||||
|
let owner_relationship_id = rel_a.id.clone();
|
||||||
|
let other_relationship_id = rel_b.id.clone();
|
||||||
|
|
||||||
|
rel_a.store_relationship(&db).await?;
|
||||||
|
rel_b.store_relationship(&db).await?;
|
||||||
|
|
||||||
|
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert!(get_relationship_by_id(&owner_relationship_id, &db)
|
||||||
|
.await
|
||||||
|
.is_none());
|
||||||
|
assert!(get_relationship_by_id(&other_relationship_id, &db)
|
||||||
|
.await
|
||||||
|
.is_some());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_relationships_by_source_id_resists_query_injection() -> anyhow::Result<()>
|
||||||
|
{
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
let user_id = "user123";
|
||||||
|
|
||||||
|
let entity1_id = create_test_entity("Entity 1", user_id, &db).await?;
|
||||||
|
let entity2_id = create_test_entity("Entity 2", user_id, &db).await?;
|
||||||
|
let entity3_id = create_test_entity("Entity 3", user_id, &db).await?;
|
||||||
|
|
||||||
let safe_relationship = KnowledgeRelationship::new(
|
let safe_relationship = KnowledgeRelationship::new(
|
||||||
entity1_id.clone(),
|
entity1_id.clone(),
|
||||||
entity2_id.clone(),
|
entity2_id.clone(),
|
||||||
"user123".to_string(),
|
user_id.to_string(),
|
||||||
"safe_source".to_string(),
|
"safe_source".to_string(),
|
||||||
"references".to_string(),
|
"references".to_string(),
|
||||||
);
|
);
|
||||||
@@ -526,10 +606,12 @@ mod tests {
|
|||||||
let other_relationship = KnowledgeRelationship::new(
|
let other_relationship = KnowledgeRelationship::new(
|
||||||
entity2_id,
|
entity2_id,
|
||||||
entity3_id,
|
entity3_id,
|
||||||
"user123".to_string(),
|
user_id.to_string(),
|
||||||
"other_source".to_string(),
|
"other_source".to_string(),
|
||||||
"contains".to_string(),
|
"contains".to_string(),
|
||||||
);
|
);
|
||||||
|
let safe_relationship_id = safe_relationship.id.clone();
|
||||||
|
let other_relationship_id = other_relationship.id.clone();
|
||||||
|
|
||||||
safe_relationship
|
safe_relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
@@ -540,17 +622,23 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("store other relationship");
|
.expect("store other relationship");
|
||||||
|
|
||||||
KnowledgeRelationship::delete_relationships_by_source_id("safe_source' OR 1=1 --", &db)
|
KnowledgeRelationship::delete_relationships_by_source_id(
|
||||||
.await
|
"safe_source' OR 1=1 --",
|
||||||
.expect("delete call should succeed");
|
user_id,
|
||||||
|
&db,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("delete call should succeed");
|
||||||
|
|
||||||
let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await;
|
let remaining_safe = get_relationship_by_id(&safe_relationship_id, &db).await;
|
||||||
let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await;
|
let remaining_other = get_relationship_by_id(&other_relationship_id, &db).await;
|
||||||
|
|
||||||
assert!(remaining_safe.is_some(), "Safe relationship should remain");
|
assert!(remaining_safe.is_some(), "Safe relationship should remain");
|
||||||
assert!(
|
assert!(
|
||||||
remaining_other.is_some(),
|
remaining_other.is_some(),
|
||||||
"Other relationship should remain"
|
"Other relationship should remain"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
#![allow(clippy::module_name_repetitions)]
|
#![allow(clippy::module_name_repetitions)]
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
use crate::stored_object;
|
use crate::stored_object;
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone, Serialize, PartialEq)]
|
#[derive(Deserialize, Debug, Clone, Copy, Serialize, PartialEq)]
|
||||||
pub enum MessageRole {
|
pub enum MessageRole {
|
||||||
User,
|
User,
|
||||||
AI,
|
AI,
|
||||||
@@ -18,6 +21,7 @@ stored_object!(Message, "message", {
|
|||||||
});
|
});
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
|
#[must_use]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
conversation_id: String,
|
conversation_id: String,
|
||||||
role: MessageRole,
|
role: MessageRole,
|
||||||
@@ -54,22 +58,31 @@ impl fmt::Display for Message {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// helper function to format a vector of messages
|
// helper function to format a vector of messages
|
||||||
|
#[must_use]
|
||||||
pub fn format_history(history: &[Message]) -> String {
|
pub fn format_history(history: &[Message]) -> String {
|
||||||
history
|
let estimated: usize = history
|
||||||
.iter()
|
.iter()
|
||||||
.map(|msg| format!("{msg}"))
|
.map(|m| m.content.len().saturating_add(10))
|
||||||
.collect::<Vec<String>>()
|
.sum();
|
||||||
.join("\n")
|
let mut out = String::with_capacity(estimated);
|
||||||
|
for (i, msg) in history.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
out.push('\n');
|
||||||
|
}
|
||||||
|
let _ = write!(out, "{msg}");
|
||||||
|
}
|
||||||
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::db::SurrealDbClient;
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_creation() {
|
async fn test_message_creation() -> anyhow::Result<()> {
|
||||||
// Test basic message creation
|
|
||||||
let conversation_id = "test_conversation";
|
let conversation_id = "test_conversation";
|
||||||
let content = "This is a test message";
|
let content = "This is a test message";
|
||||||
let role = MessageRole::User;
|
let role = MessageRole::User;
|
||||||
@@ -77,29 +90,28 @@ mod tests {
|
|||||||
|
|
||||||
let message = Message::new(
|
let message = Message::new(
|
||||||
conversation_id.to_string(),
|
conversation_id.to_string(),
|
||||||
role.clone(),
|
role,
|
||||||
content.to_string(),
|
content.to_string(),
|
||||||
references.clone(),
|
references.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify message properties
|
|
||||||
assert_eq!(message.conversation_id, conversation_id);
|
assert_eq!(message.conversation_id, conversation_id);
|
||||||
assert_eq!(message.content, content);
|
assert_eq!(message.content, content);
|
||||||
assert_eq!(message.role, role);
|
assert_eq!(message.role, role);
|
||||||
assert_eq!(message.references, references);
|
assert_eq!(message.references, references);
|
||||||
assert!(!message.id.is_empty());
|
assert!(!message.id.is_empty());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_persistence() {
|
async fn test_message_persistence() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &uuid::Uuid::new_v4().to_string();
|
let database = &uuid::Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create and store a message
|
|
||||||
let conversation_id = "test_conversation";
|
let conversation_id = "test_conversation";
|
||||||
let message = Message::new(
|
let message = Message::new(
|
||||||
conversation_id.to_string(),
|
conversation_id.to_string(),
|
||||||
@@ -109,39 +121,37 @@ mod tests {
|
|||||||
);
|
);
|
||||||
let message_id = message.id.clone();
|
let message_id = message.id.clone();
|
||||||
|
|
||||||
// Store the message
|
|
||||||
db.store_item(message.clone())
|
db.store_item(message.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store message");
|
.with_context(|| "Failed to store message".to_string())?;
|
||||||
|
|
||||||
// Retrieve the message
|
|
||||||
let retrieved: Option<Message> = db
|
let retrieved: Option<Message> = db
|
||||||
.get_item(&message_id)
|
.get_item(&message_id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve message");
|
.with_context(|| "Failed to retrieve message".to_string())?;
|
||||||
|
|
||||||
assert!(retrieved.is_some());
|
let retrieved = retrieved.ok_or_else(|| anyhow::anyhow!("Expected message to exist"))?;
|
||||||
let retrieved = retrieved.unwrap();
|
|
||||||
|
|
||||||
// Verify retrieved properties match original
|
|
||||||
assert_eq!(retrieved.id, message.id);
|
assert_eq!(retrieved.id, message.id);
|
||||||
assert_eq!(retrieved.conversation_id, message.conversation_id);
|
assert_eq!(retrieved.conversation_id, message.conversation_id);
|
||||||
assert_eq!(retrieved.role, message.role);
|
assert_eq!(retrieved.role, message.role);
|
||||||
assert_eq!(retrieved.content, message.content);
|
assert_eq!(retrieved.content, message.content);
|
||||||
assert_eq!(retrieved.references, message.references);
|
assert_eq!(retrieved.references, message.references);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_role_display() {
|
async fn test_message_role_display() -> anyhow::Result<()> {
|
||||||
// Test the Display implementation for MessageRole
|
|
||||||
assert_eq!(format!("{}", MessageRole::User), "User");
|
assert_eq!(format!("{}", MessageRole::User), "User");
|
||||||
assert_eq!(format!("{}", MessageRole::AI), "AI");
|
assert_eq!(format!("{}", MessageRole::AI), "AI");
|
||||||
assert_eq!(format!("{}", MessageRole::System), "System");
|
assert_eq!(format!("{}", MessageRole::System), "System");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_message_display() {
|
async fn test_message_display() -> anyhow::Result<()> {
|
||||||
// Test the Display implementation for Message
|
|
||||||
let message = Message {
|
let message = Message {
|
||||||
id: "test_id".to_string(),
|
id: "test_id".to_string(),
|
||||||
created_at: Utc::now(),
|
created_at: Utc::now(),
|
||||||
@@ -152,12 +162,13 @@ mod tests {
|
|||||||
references: None,
|
references: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(format!("{}", message), "User: Hello world");
|
assert_eq!(format!("{message}"), "User: Hello world");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_format_history() {
|
async fn test_format_history() -> anyhow::Result<()> {
|
||||||
// Create a vector of messages
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message {
|
||||||
id: "1".to_string(),
|
id: "1".to_string(),
|
||||||
@@ -179,10 +190,10 @@ mod tests {
|
|||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
// Format the history
|
|
||||||
let formatted = format_history(&messages);
|
let formatted = format_history(&messages);
|
||||||
|
|
||||||
// Verify the formatting
|
|
||||||
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
|
assert_eq!(formatted, "User: Hello\nAI: Hi there!");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,103 +19,21 @@ pub mod user;
|
|||||||
|
|
||||||
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
|
||||||
fn table_name() -> &'static str;
|
fn table_name() -> &'static str;
|
||||||
fn get_id(&self) -> &str;
|
fn id(&self) -> &str;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! stored_object {
|
macro_rules! stored_object {
|
||||||
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
|
||||||
use serde::{Deserialize, Deserializer, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use surrealdb::sql::Thing;
|
|
||||||
use $crate::storage::types::StoredObject;
|
use $crate::storage::types::StoredObject;
|
||||||
use serde::de::{self, Visitor};
|
#[allow(unused_imports)]
|
||||||
use std::fmt;
|
use $crate::utils::serde_helpers::{
|
||||||
|
deserialize_flexible_id, serialize_datetime, deserialize_datetime,
|
||||||
|
serialize_option_datetime, deserialize_option_datetime,
|
||||||
|
};
|
||||||
use chrono::{DateTime, Utc };
|
use chrono::{DateTime, Utc };
|
||||||
|
|
||||||
struct FlexibleIdVisitor;
|
|
||||||
|
|
||||||
impl<'de> Visitor<'de> for FlexibleIdVisitor {
|
|
||||||
type Value = String;
|
|
||||||
|
|
||||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
formatter.write_str("a string or a Thing")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: de::Error,
|
|
||||||
{
|
|
||||||
Ok(value.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: de::Error,
|
|
||||||
{
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
|
|
||||||
where
|
|
||||||
A: de::MapAccess<'de>,
|
|
||||||
{
|
|
||||||
// Try to deserialize as Thing
|
|
||||||
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
|
|
||||||
Ok(thing.id.to_raw())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
deserializer.deserialize_any(FlexibleIdVisitor)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn serialize_datetime<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: serde::Serializer,
|
|
||||||
{
|
|
||||||
Into::<surrealdb::sql::Datetime>::into(*date).serialize(serializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
|
|
||||||
where
|
|
||||||
D: serde::Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let dt = surrealdb::sql::Datetime::deserialize(deserializer)?;
|
|
||||||
Ok(DateTime::<Utc>::from(dt))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[allow(clippy::ref_option)]
|
|
||||||
fn serialize_option_datetime<S>(
|
|
||||||
date: &Option<DateTime<Utc>>,
|
|
||||||
serializer: S,
|
|
||||||
) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: serde::Serializer,
|
|
||||||
{
|
|
||||||
match date {
|
|
||||||
Some(dt) => serializer
|
|
||||||
.serialize_some(&Into::<surrealdb::sql::Datetime>::into(*dt)),
|
|
||||||
None => serializer.serialize_none(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[allow(clippy::ref_option)]
|
|
||||||
fn deserialize_option_datetime<'de, D>(
|
|
||||||
deserializer: D,
|
|
||||||
) -> Result<Option<DateTime<Utc>>, D::Error>
|
|
||||||
where
|
|
||||||
D: serde::Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let value = Option::<surrealdb::sql::Datetime>::deserialize(deserializer)?;
|
|
||||||
Ok(value.map(DateTime::<Utc>::from))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
$(#[$struct_attr])*
|
$(#[$struct_attr])*
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct $name {
|
pub struct $name {
|
||||||
@@ -133,7 +51,7 @@ macro_rules! stored_object {
|
|||||||
$table
|
$table
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_id(&self) -> &str {
|
fn id(&self) -> &str {
|
||||||
&self.id
|
&self.id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ stored_object!(Scratchpad, "scratchpad", {
|
|||||||
});
|
});
|
||||||
|
|
||||||
impl Scratchpad {
|
impl Scratchpad {
|
||||||
|
#[must_use]
|
||||||
pub fn new(user_id: String, title: String) -> Self {
|
pub fn new(user_id: String, title: String) -> Self {
|
||||||
let now = ChronoUtc::now();
|
let now = ChronoUtc::now();
|
||||||
Self {
|
Self {
|
||||||
@@ -78,7 +79,7 @@ impl Scratchpad {
|
|||||||
let scratchpad: Option<Scratchpad> = db.get_item(id).await?;
|
let scratchpad: Option<Scratchpad> = db.get_item(id).await?;
|
||||||
|
|
||||||
let scratchpad =
|
let scratchpad =
|
||||||
scratchpad.ok_or_else(|| AppError::NotFound("Scratchpad not found".to_string()))?;
|
scratchpad.ok_or_else(|| AppError::NotFound("scratchpad not found".to_string()))?;
|
||||||
|
|
||||||
if scratchpad.user_id != user_id {
|
if scratchpad.user_id != user_id {
|
||||||
return Err(AppError::Auth(
|
return Err(AppError::Auth(
|
||||||
@@ -216,20 +217,23 @@ impl Scratchpad {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_scratchpad() {
|
async fn test_create_scratchpad() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
// Create a new scratchpad
|
// Create a new scratchpad
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
@@ -254,29 +258,28 @@ mod tests {
|
|||||||
let retrieved: Option<Scratchpad> = db
|
let retrieved: Option<Scratchpad> = db
|
||||||
.get_item(&scratchpad.id)
|
.get_item(&scratchpad.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve scratchpad");
|
.with_context(|| "Failed to retrieve scratchpad".to_string())?;
|
||||||
assert!(retrieved.is_some());
|
let retrieved = retrieved.with_context(|| "expected scratchpad to exist".to_string())?;
|
||||||
|
|
||||||
let retrieved = retrieved.unwrap();
|
|
||||||
assert_eq!(retrieved.id, scratchpad.id);
|
assert_eq!(retrieved.id, scratchpad.id);
|
||||||
assert_eq!(retrieved.user_id, user_id);
|
assert_eq!(retrieved.user_id, user_id);
|
||||||
assert_eq!(retrieved.title, title);
|
assert_eq!(retrieved.title, title);
|
||||||
assert!(!retrieved.is_archived);
|
assert!(!retrieved.is_archived);
|
||||||
assert!(retrieved.archived_at.is_none());
|
assert!(retrieved.archived_at.is_none());
|
||||||
assert!(retrieved.ingested_at.is_none());
|
assert!(retrieved.ingested_at.is_none());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_by_user() {
|
async fn test_get_by_user() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
|
|
||||||
@@ -288,19 +291,30 @@ mod tests {
|
|||||||
// Store them
|
// Store them
|
||||||
let scratchpad1_id = scratchpad1.id.clone();
|
let scratchpad1_id = scratchpad1.id.clone();
|
||||||
let scratchpad2_id = scratchpad2.id.clone();
|
let scratchpad2_id = scratchpad2.id.clone();
|
||||||
db.store_item(scratchpad1).await.unwrap();
|
db.store_item(scratchpad1)
|
||||||
db.store_item(scratchpad2).await.unwrap();
|
.await
|
||||||
db.store_item(scratchpad3).await.unwrap();
|
.with_context(|| "store scratchpad1".to_string())?;
|
||||||
|
db.store_item(scratchpad2)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad2".to_string())?;
|
||||||
|
db.store_item(scratchpad3)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad3".to_string())?;
|
||||||
|
|
||||||
// Archive one of the user's scratchpads
|
// Archive one of the user's scratchpads
|
||||||
Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
|
Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.with_context(|| "archive".to_string())?;
|
||||||
|
|
||||||
// Get scratchpads for user_id
|
// Get scratchpads for user_id
|
||||||
let user_scratchpads = Scratchpad::get_by_user(user_id, &db).await.unwrap();
|
let user_scratchpads = Scratchpad::get_by_user(user_id, &db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get_by_user".to_string())?;
|
||||||
assert_eq!(user_scratchpads.len(), 1);
|
assert_eq!(user_scratchpads.len(), 1);
|
||||||
assert_eq!(user_scratchpads[0].id, scratchpad1_id);
|
assert_eq!(
|
||||||
|
user_scratchpads.first().map(|s| &s.id),
|
||||||
|
Some(&scratchpad1_id)
|
||||||
|
);
|
||||||
|
|
||||||
// Verify they belong to the user
|
// Verify they belong to the user
|
||||||
for scratchpad in &user_scratchpads {
|
for scratchpad in &user_scratchpads {
|
||||||
@@ -309,177 +323,201 @@ mod tests {
|
|||||||
|
|
||||||
let archived = Scratchpad::get_archived_by_user(user_id, &db)
|
let archived = Scratchpad::get_archived_by_user(user_id, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.with_context(|| "get_archived_by_user".to_string())?;
|
||||||
assert_eq!(archived.len(), 1);
|
assert_eq!(archived.len(), 1);
|
||||||
assert_eq!(archived[0].id, scratchpad2_id);
|
assert_eq!(archived.first().map(|s| &s.id), Some(&scratchpad2_id));
|
||||||
assert!(archived[0].is_archived);
|
assert!(archived.first().is_some_and(|s| s.is_archived));
|
||||||
assert!(archived[0].ingested_at.is_none());
|
assert!(archived.first().is_some_and(|s| s.ingested_at.is_none()));
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_archive_and_restore() {
|
async fn test_archive_and_restore() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
let scratchpad_id = scratchpad.id.clone();
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
db.store_item(scratchpad).await.unwrap();
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
|
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to archive");
|
.with_context(|| "Failed to archive".to_string())?;
|
||||||
assert!(archived.is_archived);
|
assert!(archived.is_archived);
|
||||||
assert!(archived.archived_at.is_some());
|
assert!(archived.archived_at.is_some());
|
||||||
assert!(archived.ingested_at.is_some());
|
assert!(archived.ingested_at.is_some());
|
||||||
|
|
||||||
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
|
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to restore");
|
.with_context(|| "Failed to restore".to_string())?;
|
||||||
assert!(!restored.is_archived);
|
assert!(!restored.is_archived);
|
||||||
assert!(restored.archived_at.is_none());
|
assert!(restored.archived_at.is_none());
|
||||||
assert!(restored.ingested_at.is_none());
|
assert!(restored.ingested_at.is_none());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_content() {
|
async fn test_update_content() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
let scratchpad_id = scratchpad.id.clone();
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
db.store_item(scratchpad).await.unwrap();
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
let new_content = "Updated content";
|
let new_content = "Updated content";
|
||||||
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
|
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.with_context(|| "update_content".to_string())?;
|
||||||
|
|
||||||
assert_eq!(updated.content, new_content);
|
assert_eq!(updated.content, new_content);
|
||||||
assert!(!updated.is_dirty);
|
assert!(!updated.is_dirty);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_content_unauthorized() {
|
async fn test_update_content_unauthorized() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
let owner_id = "owner";
|
let owner_id = "owner";
|
||||||
let other_user = "other_user";
|
let other_user = "other_user";
|
||||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||||
let scratchpad_id = scratchpad.id.clone();
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
db.store_item(scratchpad).await.unwrap();
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
|
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::Auth(_)) => {}
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected Auth error"),
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_scratchpad() {
|
async fn test_delete_scratchpad() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
let user_id = "test_user";
|
let user_id = "test_user";
|
||||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||||
let scratchpad_id = scratchpad.id.clone();
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
db.store_item(scratchpad).await.unwrap();
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
// Delete should succeed
|
// Delete should succeed
|
||||||
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
|
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Verify it's gone
|
// Verify it's gone
|
||||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
let retrieved: Option<Scratchpad> = db
|
||||||
|
.get_item(&scratchpad_id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get_item".to_string())?;
|
||||||
assert!(retrieved.is_none());
|
assert!(retrieved.is_none());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_unauthorized() {
|
async fn test_delete_unauthorized() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
let owner_id = "owner";
|
let owner_id = "owner";
|
||||||
let other_user = "other_user";
|
let other_user = "other_user";
|
||||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||||
let scratchpad_id = scratchpad.id.clone();
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
db.store_item(scratchpad).await.unwrap();
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
|
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
match result {
|
match result {
|
||||||
Err(AppError::Auth(_)) => {}
|
Err(AppError::Auth(_)) => {}
|
||||||
_ => panic!("Expected Auth error"),
|
_ => anyhow::bail!("Expected Auth error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify it still exists
|
// Verify it still exists
|
||||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
let retrieved: Option<Scratchpad> = db
|
||||||
|
.get_item(&scratchpad_id)
|
||||||
|
.await
|
||||||
|
.with_context(|| "get_item".to_string())?;
|
||||||
assert!(retrieved.is_some());
|
assert!(retrieved.is_some());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timezone_aware_scratchpad_conversion() {
|
async fn test_timezone_aware_scratchpad_conversion() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create test database");
|
.with_context(|| "Failed to create test database".to_string())?;
|
||||||
|
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
let user_id = "test_user_123";
|
let user_id = "test_user_123";
|
||||||
let scratchpad =
|
let scratchpad =
|
||||||
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
|
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
|
||||||
let scratchpad_id = scratchpad.id.clone();
|
let scratchpad_id = scratchpad.id.clone();
|
||||||
|
|
||||||
db.store_item(scratchpad).await.unwrap();
|
db.store_item(scratchpad)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store scratchpad".to_string())?;
|
||||||
|
|
||||||
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
|
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.with_context(|| "get_by_id".to_string())?;
|
||||||
|
|
||||||
// Test that datetime fields are preserved and can be used for timezone formatting
|
// Test that datetime fields are preserved and can be used for timezone formatting
|
||||||
assert!(retrieved.created_at.timestamp() > 0);
|
assert!(retrieved.created_at.timestamp() > 0);
|
||||||
@@ -493,10 +531,17 @@ mod tests {
|
|||||||
// Archive the scratchpad to test optional datetime handling
|
// Archive the scratchpad to test optional datetime handling
|
||||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
|
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.with_context(|| "archive".to_string())?;
|
||||||
|
|
||||||
assert!(archived.archived_at.is_some());
|
assert!(archived.archived_at.is_some());
|
||||||
assert!(archived.archived_at.unwrap().timestamp() > 0);
|
assert!(
|
||||||
|
archived
|
||||||
|
.archived_at
|
||||||
|
.with_context(|| "expected archived_at".to_string())?
|
||||||
|
.timestamp()
|
||||||
|
> 0
|
||||||
|
);
|
||||||
assert!(archived.ingested_at.is_none());
|
assert!(archived.ingested_at.is_none());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
pub static DEFAULT_QUERY_SYSTEM_PROMPT: &str = r#"You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
|
pub const DEFAULT_QUERY_SYSTEM_PROMPT: &str = r#"You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.
|
||||||
|
|
||||||
Your task is to:
|
Your task is to:
|
||||||
1. Carefully analyze the provided knowledge entities in the context
|
1. Carefully analyze the provided knowledge entities in the context
|
||||||
@@ -20,7 +20,7 @@ Example response formats:
|
|||||||
"I found relevant information in multiple entries: [explanation...]"
|
"I found relevant information in multiple entries: [explanation...]"
|
||||||
"I apologize, but the provided context doesn't contain information about [topic]""#;
|
"I apologize, but the provided context doesn't contain information about [topic]""#;
|
||||||
|
|
||||||
pub static DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT: &str = r#"You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
|
pub const DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT: &str = r#"You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.
|
||||||
|
|
||||||
The JSON should have the following structure:
|
The JSON should have the following structure:
|
||||||
|
|
||||||
@@ -49,13 +49,13 @@ Guidelines:
|
|||||||
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
|
2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.
|
||||||
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
|
3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.
|
||||||
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
|
4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.
|
||||||
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity"
|
5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity.
|
||||||
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
|
6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.
|
||||||
7. Only create relationships between existing KnowledgeEntities.
|
7. Only create relationships between existing KnowledgeEntities.
|
||||||
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
|
8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.
|
||||||
9. A new relationship MUST include a newly created KnowledgeEntity."#;
|
9. A new relationship MUST include a newly created KnowledgeEntity."#;
|
||||||
|
|
||||||
pub static DEFAULT_IMAGE_PROCESSING_PROMPT: &str = r#"Analyze this image and respond based on its primary content:
|
pub const DEFAULT_IMAGE_PROCESSING_PROMPT: &str = r#"Analyze this image and respond based on its primary content:
|
||||||
- If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.
|
- If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.
|
||||||
- If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.
|
- If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.
|
||||||
- For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a "Text:" heading.
|
- For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a "Text:" heading.
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use crate::storage::types::file_info::deserialize_flexible_id;
|
use crate::utils::config::EmbeddingBackend;
|
||||||
|
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient, storage::types::StoredObject};
|
use crate::{error::AppError, storage::db::SurrealDbClient, storage::types::StoredObject};
|
||||||
@@ -13,9 +14,9 @@ pub struct SystemSettings {
|
|||||||
pub processing_model: String,
|
pub processing_model: String,
|
||||||
pub embedding_model: String,
|
pub embedding_model: String,
|
||||||
pub embedding_dimensions: u32,
|
pub embedding_dimensions: u32,
|
||||||
/// Active embedding backend ("openai", "fastembed", "hashed"). Read-only, synced from config.
|
/// Active embedding backend. Read-only for admin updates; synced from config at startup.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub embedding_backend: Option<String>,
|
pub embedding_backend: Option<EmbeddingBackend>,
|
||||||
pub query_system_prompt: String,
|
pub query_system_prompt: String,
|
||||||
pub ingestion_system_prompt: String,
|
pub ingestion_system_prompt: String,
|
||||||
pub image_processing_model: String,
|
pub image_processing_model: String,
|
||||||
@@ -23,33 +24,151 @@ pub struct SystemSettings {
|
|||||||
pub voice_processing_model: String,
|
pub voice_processing_model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Partial update for singleton system settings without cloning unchanged fields.
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
|
pub struct SystemSettingsPatch {
|
||||||
|
pub registrations_enabled: Option<bool>,
|
||||||
|
pub require_email_verification: Option<bool>,
|
||||||
|
pub query_model: Option<String>,
|
||||||
|
pub processing_model: Option<String>,
|
||||||
|
pub embedding_model: Option<String>,
|
||||||
|
pub embedding_dimensions: Option<u32>,
|
||||||
|
pub query_system_prompt: Option<String>,
|
||||||
|
pub ingestion_system_prompt: Option<String>,
|
||||||
|
pub image_processing_model: Option<String>,
|
||||||
|
pub image_processing_prompt: Option<String>,
|
||||||
|
pub voice_processing_model: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum UpdateMode {
|
||||||
|
User,
|
||||||
|
EmbeddingSync,
|
||||||
|
}
|
||||||
|
|
||||||
impl StoredObject for SystemSettings {
|
impl StoredObject for SystemSettings {
|
||||||
fn table_name() -> &'static str {
|
fn table_name() -> &'static str {
|
||||||
"system_settings"
|
"system_settings"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_id(&self) -> &str {
|
fn id(&self) -> &str {
|
||||||
&self.id
|
&self.id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl SystemSettingsPatch {
|
||||||
|
pub fn apply_to(self, settings: &mut SystemSettings) {
|
||||||
|
if let Some(value) = self.registrations_enabled {
|
||||||
|
settings.registrations_enabled = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.require_email_verification {
|
||||||
|
settings.require_email_verification = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.query_model {
|
||||||
|
settings.query_model = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.processing_model {
|
||||||
|
settings.processing_model = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.embedding_model {
|
||||||
|
settings.embedding_model = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.embedding_dimensions {
|
||||||
|
settings.embedding_dimensions = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.query_system_prompt {
|
||||||
|
settings.query_system_prompt = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.ingestion_system_prompt {
|
||||||
|
settings.ingestion_system_prompt = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.image_processing_model {
|
||||||
|
settings.image_processing_model = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.image_processing_prompt {
|
||||||
|
settings.image_processing_prompt = value;
|
||||||
|
}
|
||||||
|
if let Some(value) = self.voice_processing_model {
|
||||||
|
settings.voice_processing_model = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn apply(self, db: &SurrealDbClient) -> Result<SystemSettings, AppError> {
|
||||||
|
let mut current = SystemSettings::get_current(db).await?;
|
||||||
|
self.apply_to(&mut current);
|
||||||
|
SystemSettings::update(db, current).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl SystemSettings {
|
impl SystemSettings {
|
||||||
|
pub const RECORD_ID: &'static str = "current";
|
||||||
|
|
||||||
|
#[allow(clippy::result_large_err)]
|
||||||
|
fn validate(&self) -> Result<(), AppError> {
|
||||||
|
if self.embedding_dimensions == 0 {
|
||||||
|
return Err(AppError::Validation(
|
||||||
|
"embedding_dimensions must be greater than 0".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let model_fields = [
|
||||||
|
("query_model", &self.query_model),
|
||||||
|
("processing_model", &self.processing_model),
|
||||||
|
("embedding_model", &self.embedding_model),
|
||||||
|
("image_processing_model", &self.image_processing_model),
|
||||||
|
("voice_processing_model", &self.voice_processing_model),
|
||||||
|
];
|
||||||
|
for (name, value) in model_fields {
|
||||||
|
if value.trim().is_empty() {
|
||||||
|
return Err(AppError::Validation(format!("{name} must not be empty")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let prompt_fields = [
|
||||||
|
("query_system_prompt", &self.query_system_prompt),
|
||||||
|
("ingestion_system_prompt", &self.ingestion_system_prompt),
|
||||||
|
("image_processing_prompt", &self.image_processing_prompt),
|
||||||
|
];
|
||||||
|
for (name, value) in prompt_fields {
|
||||||
|
if value.trim().is_empty() {
|
||||||
|
return Err(AppError::Validation(format!("{name} must not be empty")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let settings: Option<Self> = db.get_item("current").await?;
|
let settings: Option<Self> = db.get_item(Self::RECORD_ID).await?;
|
||||||
settings.ok_or(AppError::NotFound("System settings not found".into()))
|
settings.ok_or(AppError::NotFound("system settings not found".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
|
pub async fn update(db: &SurrealDbClient, changes: Self) -> Result<Self, AppError> {
|
||||||
// We need to use a direct query for the update with MERGE
|
Self::update_with_mode(db, changes, UpdateMode::User).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_with_mode(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
mut changes: Self,
|
||||||
|
mode: UpdateMode,
|
||||||
|
) -> Result<Self, AppError> {
|
||||||
|
let current = Self::get_current(db).await?;
|
||||||
|
if matches!(mode, UpdateMode::User) {
|
||||||
|
changes.embedding_backend = current.embedding_backend;
|
||||||
|
}
|
||||||
|
changes.id = Self::RECORD_ID.to_string();
|
||||||
|
changes.validate()?;
|
||||||
|
|
||||||
let updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.client
|
.client
|
||||||
.query("UPDATE type::thing('system_settings', 'current') MERGE $changes RETURN AFTER")
|
.query("UPDATE type::thing('system_settings', $id) MERGE $changes RETURN AFTER")
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
.bind(("changes", changes))
|
.bind(("changes", changes))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
updated.ok_or(AppError::Validation(
|
updated.ok_or(AppError::NotFound(
|
||||||
"Something went wrong updating the settings".into(),
|
"system settings record missing after update".into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,17 +182,23 @@ impl SystemSettings {
|
|||||||
let mut settings = Self::get_current(db).await?;
|
let mut settings = Self::get_current(db).await?;
|
||||||
let mut needs_update = false;
|
let mut needs_update = false;
|
||||||
|
|
||||||
let backend_label = provider.backend_label().to_string();
|
let provider_backend = provider
|
||||||
let provider_dimensions = provider.dimension() as u32;
|
.backend_label()
|
||||||
|
.parse::<EmbeddingBackend>()
|
||||||
|
.map_err(|e| AppError::Validation(e.to_string()))?;
|
||||||
|
let provider_dimensions = u32::try_from(provider.dimension()).map_err(|_| {
|
||||||
|
AppError::Validation(format!(
|
||||||
|
"embedding provider dimension {} exceeds u32::MAX",
|
||||||
|
provider.dimension()
|
||||||
|
))
|
||||||
|
})?;
|
||||||
let provider_model = provider.model_code();
|
let provider_model = provider.model_code();
|
||||||
|
|
||||||
// Sync backend label
|
if settings.embedding_backend != Some(provider_backend) {
|
||||||
if settings.embedding_backend.as_deref() != Some(&backend_label) {
|
settings.embedding_backend = Some(provider_backend);
|
||||||
settings.embedding_backend = Some(backend_label);
|
|
||||||
needs_update = true;
|
needs_update = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync dimensions
|
|
||||||
if settings.embedding_dimensions != provider_dimensions {
|
if settings.embedding_dimensions != provider_dimensions {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
old_dimensions = settings.embedding_dimensions,
|
old_dimensions = settings.embedding_dimensions,
|
||||||
@@ -84,7 +209,6 @@ impl SystemSettings {
|
|||||||
needs_update = true;
|
needs_update = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync model if provider has one
|
|
||||||
if let Some(model) = provider_model {
|
if let Some(model) = provider_model {
|
||||||
if settings.embedding_model != model {
|
if settings.embedding_model != model {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
@@ -98,7 +222,7 @@ impl SystemSettings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if needs_update {
|
if needs_update {
|
||||||
settings = Self::update(db, settings).await?;
|
settings = Self::update_with_mode(db, settings, UpdateMode::EmbeddingSync).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((settings, needs_update))
|
Ok((settings, needs_update))
|
||||||
@@ -107,9 +231,10 @@ impl SystemSettings {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::storage::indexes::ensure_runtime_indexes;
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use crate::storage::indexes::ensure_runtime;
|
||||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||||
use async_openai::Client;
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -118,88 +243,115 @@ mod tests {
|
|||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
table_name: &str,
|
table_name: &str,
|
||||||
index_name: &str,
|
index_name: &str,
|
||||||
) -> u32 {
|
) -> anyhow::Result<u32> {
|
||||||
let query = format!("INFO FOR TABLE {table_name};");
|
let query = format!("INFO FOR TABLE {table_name};");
|
||||||
let mut response = db
|
let mut response = db
|
||||||
.client
|
.client
|
||||||
.query(query)
|
.query(query)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to fetch table info");
|
.with_context(|| "Failed to fetch table info".to_string())?;
|
||||||
|
|
||||||
let info: surrealdb::Value = response
|
let info: surrealdb::Value = response
|
||||||
.take(0)
|
.take(0)
|
||||||
.expect("Failed to extract table info response");
|
.with_context(|| "Failed to extract table info response".to_string())?;
|
||||||
|
|
||||||
let info_json: serde_json::Value =
|
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||||
serde_json::to_value(info).expect("Failed to convert info to json");
|
.with_context(|| "Failed to convert info to json".to_string())?;
|
||||||
|
|
||||||
let indexes = info_json["Object"]["indexes"]["Object"]
|
let indexes = info_json
|
||||||
.as_object()
|
.get("Object")
|
||||||
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}"));
|
.and_then(|v| v.get("indexes"))
|
||||||
|
.and_then(|v| v.get("Object"))
|
||||||
|
.and_then(|v| v.as_object())
|
||||||
|
.with_context(|| format!("Indexes collection missing in table info: {info_json:#?}"))?;
|
||||||
|
|
||||||
let definition = indexes
|
let definition = indexes
|
||||||
.get(index_name)
|
.get(index_name)
|
||||||
.and_then(|definition| definition.get("Strand"))
|
.and_then(|definition| definition.get("Strand"))
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}"));
|
.with_context(|| format!("Index definition not found in table info: {info_json:#?}"))?;
|
||||||
|
|
||||||
let dimension_part = definition
|
let dimension_part = definition
|
||||||
.split("DIMENSION")
|
.split("DIMENSION")
|
||||||
.nth(1)
|
.nth(1)
|
||||||
.expect("Index definition missing DIMENSION clause");
|
.with_context(|| "Index definition missing DIMENSION clause".to_string())?;
|
||||||
|
|
||||||
let dimension_token = dimension_part
|
let dimension_token = dimension_part
|
||||||
.split_whitespace()
|
.split_whitespace()
|
||||||
.next()
|
.next()
|
||||||
.expect("Dimension value missing in definition")
|
.with_context(|| "Dimension value missing in definition".to_string())?
|
||||||
.trim_end_matches(';');
|
.trim_end_matches(';');
|
||||||
|
|
||||||
dimension_token
|
dimension_token
|
||||||
.parse::<u32>()
|
.parse::<u32>()
|
||||||
.expect("Dimension value is not a valid number")
|
.with_context(|| "Dimension value is not a valid number".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn simulate_reembedding(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
target_dimension: usize,
|
||||||
|
initial_chunk: TextChunk,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
db.query(
|
||||||
|
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.with_context(|| "remove index".to_string())?;
|
||||||
|
let define_index_query = format!(
|
||||||
|
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {target_dimension};"
|
||||||
|
);
|
||||||
|
db.query(define_index_query)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Re-defining index should succeed".to_string())?;
|
||||||
|
|
||||||
|
let new_embedding = vec![0.5; target_dimension];
|
||||||
|
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
||||||
|
|
||||||
|
db.client
|
||||||
|
.query(sql)
|
||||||
|
.bind(("id", initial_chunk.id.clone()))
|
||||||
|
.bind(("user_id", initial_chunk.user_id.clone()))
|
||||||
|
.bind(("embedding", new_embedding))
|
||||||
|
.await
|
||||||
|
.with_context(|| "upsert embedding".to_string())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_settings_initialization() {
|
async fn test_settings_initialization() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Test initialization of system settings
|
// Test initialization of system settings
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
let settings = SystemSettings::get_current(&db)
|
let settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get system settings");
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
|
||||||
// Verify initial state after initialization
|
// Verify initial state after initialization
|
||||||
assert_eq!(settings.id, "current");
|
assert_eq!(settings.id, "current");
|
||||||
assert_eq!(settings.registrations_enabled, true);
|
assert!(settings.registrations_enabled);
|
||||||
assert_eq!(settings.require_email_verification, false);
|
assert!(!settings.require_email_verification);
|
||||||
assert_eq!(settings.query_model, "gpt-4o-mini");
|
assert_eq!(settings.query_model, "gpt-4o-mini");
|
||||||
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
assert_eq!(settings.processing_model, "gpt-4o-mini");
|
||||||
assert_eq!(settings.image_processing_model, "gpt-4o-mini");
|
assert_eq!(settings.image_processing_model, "gpt-4o-mini");
|
||||||
// Dont test these for now, having a hard time getting the formatting exactly the same
|
assert!(!settings.ingestion_system_prompt.contains("entity\"\n6."));
|
||||||
// assert_eq!(
|
assert!(settings.ingestion_system_prompt.contains("related entity."));
|
||||||
// settings.query_system_prompt,
|
|
||||||
// crate::storage::types::system_prompts::DEFAULT_QUERY_SYSTEM_PROMPT
|
|
||||||
// );
|
|
||||||
// assert_eq!(
|
|
||||||
// settings.ingestion_system_prompt,
|
|
||||||
// crate::storage::types::system_prompts::DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT
|
|
||||||
// );
|
|
||||||
|
|
||||||
// Test idempotency - ensure calling it again doesn't change anything
|
// Test idempotency - ensure calling it again doesn't change anything
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
let settings_again = SystemSettings::get_current(&db)
|
let settings_again = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get settings after initialization");
|
.with_context(|| "Failed to get settings after initialization".to_string())?;
|
||||||
|
|
||||||
assert_eq!(settings.id, settings_again.id);
|
assert_eq!(settings.id, settings_again.id);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -210,49 +362,52 @@ mod tests {
|
|||||||
settings.require_email_verification,
|
settings.require_email_verification,
|
||||||
settings_again.require_email_verification
|
settings_again.require_email_verification
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_settings() {
|
async fn test_get_current_settings() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Initialize settings
|
// Initialize settings
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
// Test get_current method
|
// Test get_current method
|
||||||
let settings = SystemSettings::get_current(&db)
|
let settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get current settings");
|
.with_context(|| "Failed to get current settings".to_string())?;
|
||||||
|
|
||||||
assert_eq!(settings.id, "current");
|
assert_eq!(settings.id, "current");
|
||||||
assert_eq!(settings.registrations_enabled, true);
|
assert!(settings.registrations_enabled);
|
||||||
assert_eq!(settings.require_email_verification, false);
|
assert!(!settings.require_email_verification);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_settings() {
|
async fn test_update_settings() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Initialize settings
|
// Initialize settings
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Failed to apply migrations");
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
// Create updated settings
|
// Create updated settings
|
||||||
let mut updated_settings = SystemSettings::get_current(&db).await.unwrap();
|
let mut updated_settings = SystemSettings::get_current(&db)
|
||||||
updated_settings.id = "current".to_string();
|
.await
|
||||||
|
.with_context(|| "get_current".to_string())?;
|
||||||
updated_settings.registrations_enabled = false;
|
updated_settings.registrations_enabled = false;
|
||||||
updated_settings.require_email_verification = true;
|
updated_settings.require_email_verification = true;
|
||||||
updated_settings.query_model = "gpt-4".to_string();
|
updated_settings.query_model = "gpt-4".to_string();
|
||||||
@@ -260,31 +415,32 @@ mod tests {
|
|||||||
// Test update method
|
// Test update method
|
||||||
let result = SystemSettings::update(&db, updated_settings)
|
let result = SystemSettings::update(&db, updated_settings)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to update settings");
|
.with_context(|| "Failed to update settings".to_string())?;
|
||||||
|
|
||||||
assert_eq!(result.id, "current");
|
assert_eq!(result.id, "current");
|
||||||
assert_eq!(result.registrations_enabled, false);
|
assert!(!result.registrations_enabled);
|
||||||
assert_eq!(result.require_email_verification, true);
|
assert!(result.require_email_verification);
|
||||||
assert_eq!(result.query_model, "gpt-4");
|
assert_eq!(result.query_model, "gpt-4");
|
||||||
|
|
||||||
// Verify changes persisted by getting current settings
|
// Verify changes persisted by getting current settings
|
||||||
let current = SystemSettings::get_current(&db)
|
let current = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get current settings after update");
|
.with_context(|| "Failed to get current settings after update".to_string())?;
|
||||||
|
|
||||||
assert_eq!(current.registrations_enabled, false);
|
assert!(!current.registrations_enabled);
|
||||||
assert_eq!(current.require_email_verification, true);
|
assert!(current.require_email_verification);
|
||||||
assert_eq!(current.query_model, "gpt-4");
|
assert_eq!(current.query_model, "gpt-4");
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_nonexistent() {
|
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Don't initialize settings and try to get them
|
// Don't initialize settings and try to get them
|
||||||
let result = SystemSettings::get_current(&db).await;
|
let result = SystemSettings::get_current(&db).await;
|
||||||
@@ -294,21 +450,240 @@ mod tests {
|
|||||||
Err(AppError::NotFound(_)) => {
|
Err(AppError::NotFound(_)) => {
|
||||||
// Expected error
|
// Expected error
|
||||||
}
|
}
|
||||||
Err(e) => panic!("Expected NotFound error, got: {:?}", e),
|
Err(e) => anyhow::bail!("Expected NotFound error, got: {e:?}"),
|
||||||
Ok(_) => panic!("Expected error but got Ok"),
|
Ok(_) => anyhow::bail!("Expected error but got Ok"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_migration_after_changing_embedding_length() {
|
async fn test_update_rejects_zero_embedding_dimensions() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
invalid_settings.embedding_dimensions = 0;
|
||||||
|
|
||||||
|
let result = SystemSettings::update(&db, invalid_settings).await;
|
||||||
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_patch_updates_without_cloning_full_settings() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let updated = SystemSettingsPatch {
|
||||||
|
registrations_enabled: Some(false),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.apply(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to patch settings".to_string())?;
|
||||||
|
|
||||||
|
assert!(!updated.registrations_enabled);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_patch_leaves_unmentioned_fields_unchanged() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let original = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
let sentinel = "custom-query-prompt-sentinel".to_string();
|
||||||
|
|
||||||
|
let patched = SystemSettingsPatch {
|
||||||
|
query_system_prompt: Some(sentinel.clone()),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.apply(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to patch query prompt".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(patched.query_system_prompt, sentinel);
|
||||||
|
assert_eq!(
|
||||||
|
patched.ingestion_system_prompt,
|
||||||
|
original.ingestion_system_prompt
|
||||||
|
);
|
||||||
|
assert_eq!(patched.query_model, original.query_model);
|
||||||
|
assert_eq!(
|
||||||
|
patched.registrations_enabled,
|
||||||
|
original.registrations_enabled
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_rejects_empty_model_name() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let mut invalid_settings = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
invalid_settings.query_model = " ".to_string();
|
||||||
|
|
||||||
|
let result = SystemSettings::update(&db, invalid_settings).await;
|
||||||
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_normalizes_record_id() -> anyhow::Result<()> {
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let mut settings = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get system settings".to_string())?;
|
||||||
|
settings.id = "wrong-id".to_string();
|
||||||
|
|
||||||
|
let updated = SystemSettings::update(&db, settings)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to update settings".to_string())?;
|
||||||
|
assert_eq!(updated.id, SystemSettings::RECORD_ID);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_preserves_embedding_backend() -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
|
SystemSettings::sync_from_embedding_provider(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to sync embedding provider".to_string())?;
|
||||||
|
|
||||||
|
let synced = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to get synced settings".to_string())?;
|
||||||
|
assert_eq!(synced.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
|
||||||
|
let mut tampered = synced;
|
||||||
|
tampered.embedding_backend = Some(EmbeddingBackend::OpenAI);
|
||||||
|
let updated = SystemSettings::update(&db, tampered)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to update settings".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(updated.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sync_from_embedding_provider_updates_mismatched_settings() -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
|
let (settings, changed) = SystemSettings::sync_from_embedding_provider(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to sync embedding provider".to_string())?;
|
||||||
|
|
||||||
|
assert!(changed);
|
||||||
|
assert_eq!(settings.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
assert_eq!(settings.embedding_dimensions, 384);
|
||||||
|
|
||||||
|
let persisted = SystemSettings::get_current(&db)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to reload synced settings".to_string())?;
|
||||||
|
assert_eq!(persisted.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
assert_eq!(persisted.embedding_dimensions, 384);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sync_from_embedding_provider_is_noop_when_already_synced() -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed(384)
|
||||||
|
.with_context(|| "Failed to create hashed embedding provider".to_string())?;
|
||||||
|
SystemSettings::sync_from_embedding_provider(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to initial sync".to_string())?;
|
||||||
|
|
||||||
|
let (_, changed) = SystemSettings::sync_from_embedding_provider(&db, &provider)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to repeat sync".to_string())?;
|
||||||
|
assert!(!changed);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sync_rejects_provider_dimension_above_u32_max() -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to apply migrations".to_string())?;
|
||||||
|
|
||||||
|
let provider = EmbeddingProvider::new_hashed((u32::MAX as usize) + 1)
|
||||||
|
.with_context(|| "Failed to create oversized hashed provider".to_string())?;
|
||||||
|
let result = SystemSettings::sync_from_embedding_provider(&db, &provider).await;
|
||||||
|
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_migration_after_changing_embedding_length() -> anyhow::Result<()> {
|
||||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start DB");
|
.with_context(|| "Failed to start DB".to_string())?;
|
||||||
|
|
||||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Initial migration failed");
|
.with_context(|| "Initial migration failed".to_string())?;
|
||||||
|
|
||||||
let initial_chunk = TextChunk::new(
|
let initial_chunk = TextChunk::new(
|
||||||
"source1".into(),
|
"source1".into(),
|
||||||
@@ -318,43 +693,11 @@ mod tests {
|
|||||||
|
|
||||||
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store initial chunk with embedding");
|
.with_context(|| "Failed to store initial chunk with embedding".to_string())?;
|
||||||
|
|
||||||
async fn simulate_reembedding(
|
|
||||||
db: &SurrealDbClient,
|
|
||||||
target_dimension: usize,
|
|
||||||
initial_chunk: TextChunk,
|
|
||||||
) {
|
|
||||||
db.query(
|
|
||||||
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let define_index_query = format!(
|
|
||||||
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
|
|
||||||
target_dimension
|
|
||||||
);
|
|
||||||
db.query(define_index_query)
|
|
||||||
.await
|
|
||||||
.expect("Re-defining index should succeed");
|
|
||||||
|
|
||||||
let new_embedding = vec![0.5; target_dimension];
|
|
||||||
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
|
|
||||||
|
|
||||||
let update_result = db
|
|
||||||
.client
|
|
||||||
.query(sql)
|
|
||||||
.bind(("id", initial_chunk.id.clone()))
|
|
||||||
.bind(("user_id", initial_chunk.user_id.clone()))
|
|
||||||
.bind(("embedding", new_embedding))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
assert!(update_result.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
|
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
|
||||||
let target_dimension = 1536usize;
|
let target_dimension = 1536usize;
|
||||||
simulate_reembedding(&db, target_dimension, initial_chunk).await;
|
simulate_reembedding(&db, target_dimension, initial_chunk).await?;
|
||||||
|
|
||||||
let migration_result = db.apply_migrations().await;
|
let migration_result = db.apply_migrations().await;
|
||||||
|
|
||||||
@@ -363,34 +706,38 @@ mod tests {
|
|||||||
"Migrations should not fail: {:?}",
|
"Migrations should not fail: {:?}",
|
||||||
migration_result.err()
|
migration_result.err()
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length() {
|
async fn test_should_change_embedding_length_on_indexes_when_switching_length(
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
use crate::utils::embedding::EmbeddingProvider;
|
||||||
|
|
||||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start DB");
|
.with_context(|| "Failed to start DB".to_string())?;
|
||||||
|
|
||||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||||
db.apply_migrations()
|
db.apply_migrations()
|
||||||
.await
|
.await
|
||||||
.expect("Initial migration failed");
|
.with_context(|| "Initial migration failed".to_string())?;
|
||||||
|
|
||||||
let mut current_settings = SystemSettings::get_current(&db)
|
let mut current_settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to load current settings");
|
.with_context(|| "Failed to load current settings".to_string())?;
|
||||||
|
|
||||||
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
|
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
|
||||||
ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize)
|
ensure_runtime(&db, current_settings.embedding_dimensions as usize)
|
||||||
.await
|
.await
|
||||||
.expect("failed to build runtime indexes");
|
.with_context(|| "failed to build runtime indexes".to_string())?;
|
||||||
|
|
||||||
let initial_chunk_dimension = get_hnsw_index_dimension(
|
let initial_chunk_dimension = get_hnsw_index_dimension(
|
||||||
&db,
|
&db,
|
||||||
"text_chunk_embedding",
|
"text_chunk_embedding",
|
||||||
"idx_embedding_text_chunk_embedding",
|
"idx_embedding_text_chunk_embedding",
|
||||||
)
|
)
|
||||||
.await;
|
.await?;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
initial_chunk_dimension, current_settings.embedding_dimensions,
|
initial_chunk_dimension, current_settings.embedding_dimensions,
|
||||||
@@ -405,34 +752,37 @@ mod tests {
|
|||||||
|
|
||||||
let updated_settings = SystemSettings::update(&db, current_settings)
|
let updated_settings = SystemSettings::update(&db, current_settings)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to update settings");
|
.with_context(|| "Failed to update settings".to_string())?;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
updated_settings.embedding_dimensions, new_dimension,
|
updated_settings.embedding_dimensions, new_dimension,
|
||||||
"Settings should reflect the new embedding dimension"
|
"Settings should reflect the new embedding dimension"
|
||||||
);
|
);
|
||||||
|
|
||||||
let openai_client = Client::new();
|
let provider = EmbeddingProvider::new_hashed(new_dimension as usize)
|
||||||
|
.map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||||
|
|
||||||
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
TextChunk::update_all_embeddings(&db, &provider)
|
||||||
.await
|
.await
|
||||||
.expect("TextChunk re-embedding should succeed on fresh DB");
|
.with_context(|| "TextChunk re-embedding should succeed on fresh DB".to_string())?;
|
||||||
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
KnowledgeEntity::update_all_embeddings(&db, &provider)
|
||||||
.await
|
.await
|
||||||
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
|
.with_context(|| {
|
||||||
|
"KnowledgeEntity re-embedding should succeed on fresh DB".to_string()
|
||||||
|
})?;
|
||||||
|
|
||||||
let text_chunk_dimension = get_hnsw_index_dimension(
|
let text_chunk_dimension = get_hnsw_index_dimension(
|
||||||
&db,
|
&db,
|
||||||
"text_chunk_embedding",
|
"text_chunk_embedding",
|
||||||
"idx_embedding_text_chunk_embedding",
|
"idx_embedding_text_chunk_embedding",
|
||||||
)
|
)
|
||||||
.await;
|
.await?;
|
||||||
let knowledge_dimension = get_hnsw_index_dimension(
|
let knowledge_dimension = get_hnsw_index_dimension(
|
||||||
&db,
|
&db,
|
||||||
"knowledge_entity_embedding",
|
"knowledge_entity_embedding",
|
||||||
"idx_embedding_knowledge_entity_embedding",
|
"idx_embedding_knowledge_entity_embedding",
|
||||||
)
|
)
|
||||||
.await;
|
.await?;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
text_chunk_dimension, new_dimension,
|
text_chunk_dimension, new_dimension,
|
||||||
@@ -445,10 +795,11 @@ mod tests {
|
|||||||
|
|
||||||
let persisted_settings = SystemSettings::get_current(&db)
|
let persisted_settings = SystemSettings::get_current(&db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to reload updated settings");
|
.with_context(|| "Failed to reload updated settings".to_string())?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
persisted_settings.embedding_dimensions, new_dimension,
|
persisted_settings.embedding_dimensions, new_dimension,
|
||||||
"Settings should persist new embedding dimension"
|
"Settings should persist new embedding dimension"
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,11 @@
|
|||||||
use surrealdb::RecordId;
|
use surrealdb::RecordId;
|
||||||
|
|
||||||
use crate::storage::types::text_chunk::TextChunk;
|
use crate::storage::types::text_chunk::TextChunk;
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
use crate::{
|
||||||
|
error::AppError,
|
||||||
|
storage::{db::SurrealDbClient, indexes::hnsw_index_redefine_transaction_sql},
|
||||||
|
stored_object,
|
||||||
|
};
|
||||||
|
|
||||||
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
|
||||||
/// Record link to the owning text_chunk
|
/// Record link to the owning text_chunk
|
||||||
@@ -23,33 +27,42 @@ impl TextChunkEmbedding {
|
|||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
dimension: usize,
|
dimension: usize,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let query = format!(
|
let query = hnsw_index_redefine_transaction_sql(
|
||||||
"BEGIN TRANSACTION;
|
"idx_embedding_text_chunk_embedding",
|
||||||
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE {table};
|
Self::table_name(),
|
||||||
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
|
dimension,
|
||||||
COMMIT TRANSACTION;",
|
|
||||||
table = Self::table_name(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = db.client.query(query).await.map_err(AppError::Database)?;
|
let res = db.client.query(query).await.map_err(AppError::from)?;
|
||||||
res.check().map_err(AppError::Database)?;
|
res.check().map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new text chunk embedding
|
/// Validates that an embedding vector matches the configured HNSW dimension.
|
||||||
|
#[allow(clippy::result_large_err)]
|
||||||
|
pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<(), AppError> {
|
||||||
|
if embedding.len() != expected {
|
||||||
|
return Err(AppError::Validation(format!(
|
||||||
|
"embedding dimension mismatch: got {}, expected {expected}",
|
||||||
|
embedding.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new text chunk embedding.
|
||||||
///
|
///
|
||||||
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID),
|
/// The embedding record id equals `chunk_id` so each chunk has at most one embedding row.
|
||||||
/// not "text_chunk:uuid".
|
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID), not "text_chunk:uuid".
|
||||||
|
#[must_use]
|
||||||
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
// NOTE: `stored_object!` macro defines `id` as `String`
|
id: chunk_id.to_owned(),
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
|
||||||
created_at: now,
|
created_at: now,
|
||||||
updated_at: now,
|
updated_at: now,
|
||||||
// Create a record<text_chunk> link: text_chunk:<chunk_id>
|
|
||||||
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
|
||||||
source_id,
|
source_id,
|
||||||
embedding,
|
embedding,
|
||||||
@@ -72,9 +85,9 @@ impl TextChunkEmbedding {
|
|||||||
.query(query)
|
.query(query)
|
||||||
.bind(("chunk_id", chunk_id.clone()))
|
.bind(("chunk_id", chunk_id.clone()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::Database)?;
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
|
let embeddings: Vec<Self> = result.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(embeddings.into_iter().next())
|
Ok(embeddings.into_iter().next())
|
||||||
}
|
}
|
||||||
@@ -93,9 +106,9 @@ impl TextChunkEmbedding {
|
|||||||
.query(query)
|
.query(query)
|
||||||
.bind(("chunk_id", chunk_id.clone()))
|
.bind(("chunk_id", chunk_id.clone()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::Database)?
|
.map_err(AppError::from)?
|
||||||
.check()
|
.check()
|
||||||
.map_err(AppError::Database)?;
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -116,9 +129,9 @@ impl TextChunkEmbedding {
|
|||||||
.query(query)
|
.query(query)
|
||||||
.bind(("source_id", source_id.to_owned()))
|
.bind(("source_id", source_id.to_owned()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::Database)?
|
.map_err(AppError::from)?
|
||||||
.check()
|
.check()
|
||||||
.map_err(AppError::Database)?;
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -126,33 +139,20 @@ impl TextChunkEmbedding {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::db::SurrealDbClient;
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
use crate::test_utils::{prepare_text_chunk_test_db, setup_test_db};
|
||||||
use surrealdb::Value as SurrealValue;
|
use surrealdb::Value as SurrealValue;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
/// Helper to create an in-memory DB and apply migrations
|
|
||||||
async fn setup_test_db() -> SurrealDbClient {
|
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, &database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.expect("Failed to apply migrations");
|
|
||||||
|
|
||||||
db
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper: create a text_chunk with a known key, return its RecordId
|
|
||||||
async fn create_text_chunk_with_id(
|
async fn create_text_chunk_with_id(
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
key: &str,
|
key: &str,
|
||||||
source_id: &str,
|
source_id: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> RecordId {
|
) -> anyhow::Result<RecordId> {
|
||||||
let chunk = TextChunk {
|
let chunk = TextChunk {
|
||||||
id: key.to_owned(),
|
id: key.to_owned(),
|
||||||
created_at: Utc::now(),
|
created_at: Utc::now(),
|
||||||
@@ -164,23 +164,62 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(chunk)
|
db.store_item(chunk)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create text_chunk");
|
.with_context(|| "Failed to create text_chunk".to_string())?;
|
||||||
|
|
||||||
RecordId::from_table_key(TextChunk::table_name(), key)
|
Ok(RecordId::from_table_key(TextChunk::table_name(), key))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_idx_sql(db: &SurrealDbClient) -> anyhow::Result<String> {
|
||||||
|
let mut info_res = db
|
||||||
|
.client
|
||||||
|
.query("INFO FOR TABLE text_chunk_embedding;")
|
||||||
|
.await
|
||||||
|
.with_context(|| "info query failed".to_string())?;
|
||||||
|
let info: SurrealValue = info_res
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "failed to take info result".to_string())?;
|
||||||
|
let info_json: serde_json::Value = serde_json::to_value(info)
|
||||||
|
.with_context(|| "failed to convert info to json".to_string())?;
|
||||||
|
let idx_sql = info_json
|
||||||
|
.get("Object")
|
||||||
|
.and_then(|v| v.get("indexes"))
|
||||||
|
.and_then(|v| v.get("Object"))
|
||||||
|
.and_then(|v| v.get("idx_embedding_text_chunk_embedding"))
|
||||||
|
.and_then(|v| v.get("Strand"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or_default()
|
||||||
|
.to_string();
|
||||||
|
Ok(idx_sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_uses_chunk_id_as_record_id() {
|
||||||
|
let emb = TextChunkEmbedding::new(
|
||||||
|
"chunk-abc",
|
||||||
|
"source-1".to_owned(),
|
||||||
|
vec![0.1, 0.2],
|
||||||
|
"user-1".to_owned(),
|
||||||
|
);
|
||||||
|
assert_eq!(emb.id, "chunk-abc");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_dimension_rejects_mismatch() {
|
||||||
|
let err = TextChunkEmbedding::validate_dimension(&[0.1, 0.2, 0.3], 2)
|
||||||
|
.expect_err("expected dimension mismatch");
|
||||||
|
assert!(matches!(err, AppError::Validation(_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_and_get_by_chunk_id() {
|
async fn test_create_and_get_by_chunk_id() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
let user_id = "user_a";
|
let user_id = "user_a";
|
||||||
let chunk_key = "chunk-123";
|
let chunk_key = "chunk-123";
|
||||||
let source_id = "source-1";
|
let source_id = "source-1";
|
||||||
|
|
||||||
// 1) Create a text_chunk with a known key
|
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
|
||||||
|
|
||||||
// 2) Create and store an embedding for that chunk
|
|
||||||
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
|
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
|
||||||
let emb = TextChunkEmbedding::new(
|
let emb = TextChunkEmbedding::new(
|
||||||
chunk_key,
|
chunk_key,
|
||||||
@@ -189,41 +228,31 @@ mod tests {
|
|||||||
user_id.to_string(),
|
user_id.to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
db.upsert_item(emb)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to redefine index length");
|
.with_context(|| "Failed to store embedding".to_string())?;
|
||||||
|
|
||||||
let _: Option<TextChunkEmbedding> = db
|
|
||||||
.client
|
|
||||||
.create(TextChunkEmbedding::table_name())
|
|
||||||
.content(emb)
|
|
||||||
.await
|
|
||||||
.expect("Failed to store embedding")
|
|
||||||
.take()
|
|
||||||
.expect("Failed to deserialize stored embedding");
|
|
||||||
|
|
||||||
// 3) Fetch it via get_by_chunk_id
|
|
||||||
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get embedding by chunk_id");
|
.with_context(|| "Failed to get embedding by chunk_id".to_string())?
|
||||||
|
.with_context(|| "Expected an embedding to be found".to_string())?;
|
||||||
assert!(fetched.is_some(), "Expected an embedding to be found");
|
|
||||||
let fetched = fetched.unwrap();
|
|
||||||
|
|
||||||
|
assert_eq!(fetched.id, chunk_key);
|
||||||
assert_eq!(fetched.user_id, user_id);
|
assert_eq!(fetched.user_id, user_id);
|
||||||
assert_eq!(fetched.chunk_id, chunk_rid);
|
assert_eq!(fetched.chunk_id, chunk_rid);
|
||||||
assert_eq!(fetched.embedding, embedding_vec);
|
assert_eq!(fetched.embedding, embedding_vec);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_chunk_id() {
|
async fn test_delete_by_chunk_id() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
let user_id = "user_b";
|
let user_id = "user_b";
|
||||||
let chunk_key = "chunk-delete";
|
let chunk_key = "chunk-delete";
|
||||||
let source_id = "source-del";
|
let source_id = "source-del";
|
||||||
|
|
||||||
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
|
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||||
|
|
||||||
let emb = TextChunkEmbedding::new(
|
let emb = TextChunkEmbedding::new(
|
||||||
chunk_key,
|
chunk_key,
|
||||||
@@ -232,180 +261,171 @@ mod tests {
|
|||||||
user_id.to_string(),
|
user_id.to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
|
db.upsert_item(emb)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to redefine index length");
|
.with_context(|| "Failed to store embedding".to_string())?;
|
||||||
|
|
||||||
let _: Option<TextChunkEmbedding> = db
|
|
||||||
.client
|
|
||||||
.create(TextChunkEmbedding::table_name())
|
|
||||||
.content(emb)
|
|
||||||
.await
|
|
||||||
.expect("Failed to store embedding")
|
|
||||||
.take()
|
|
||||||
.expect("Failed to deserialize stored embedding");
|
|
||||||
|
|
||||||
// Ensure it exists
|
|
||||||
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get embedding before delete");
|
.with_context(|| "Failed to get embedding before delete".to_string())?;
|
||||||
assert!(existing.is_some(), "Embedding should exist before delete");
|
assert!(existing.is_some(), "Embedding should exist before delete");
|
||||||
|
|
||||||
// Delete by chunk_id
|
|
||||||
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete by chunk_id");
|
.with_context(|| "Failed to delete by chunk_id".to_string())?;
|
||||||
|
|
||||||
// Ensure it no longer exists
|
|
||||||
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get embedding after delete");
|
.with_context(|| "Failed to get embedding after delete".to_string())?;
|
||||||
assert!(after.is_none(), "Embedding should have been deleted");
|
assert!(after.is_none(), "Embedding should have been deleted");
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_delete_by_source_id() {
|
async fn test_delete_by_source_id() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = prepare_text_chunk_test_db(1).await?;
|
||||||
|
|
||||||
let user_id = "user_c";
|
let user_id = "user_c";
|
||||||
let source_id = "shared-source";
|
let source_id = "shared-source";
|
||||||
let other_source = "other-source";
|
let other_source = "other-source";
|
||||||
|
|
||||||
// Two chunks with the same source_id
|
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await?;
|
||||||
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await;
|
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await?;
|
||||||
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await;
|
|
||||||
|
|
||||||
// One chunk with a different source_id
|
|
||||||
let chunk_other_rid =
|
let chunk_other_rid =
|
||||||
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await;
|
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await?;
|
||||||
|
|
||||||
// Create embeddings for all three
|
for (key, src, vec) in [
|
||||||
let emb1 = TextChunkEmbedding::new(
|
("chunk-s1", source_id, vec![0.1]),
|
||||||
"chunk-s1",
|
("chunk-s2", source_id, vec![0.2]),
|
||||||
source_id.to_string(),
|
("chunk-other", other_source, vec![0.3]),
|
||||||
vec![0.1],
|
] {
|
||||||
user_id.to_string(),
|
let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
|
||||||
);
|
db.upsert_item(emb)
|
||||||
let emb2 = TextChunkEmbedding::new(
|
|
||||||
"chunk-s2",
|
|
||||||
source_id.to_string(),
|
|
||||||
vec![0.2],
|
|
||||||
user_id.to_string(),
|
|
||||||
);
|
|
||||||
let emb3 = TextChunkEmbedding::new(
|
|
||||||
"chunk-other",
|
|
||||||
other_source.to_string(),
|
|
||||||
vec![0.3],
|
|
||||||
user_id.to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Update length on index
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
|
|
||||||
.await
|
|
||||||
.expect("Failed to redefine index length");
|
|
||||||
|
|
||||||
for emb in [emb1, emb2, emb3] {
|
|
||||||
let _: Option<TextChunkEmbedding> = db
|
|
||||||
.client
|
|
||||||
.create(TextChunkEmbedding::table_name())
|
|
||||||
.content(emb)
|
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store embedding")
|
.with_context(|| format!("store embedding for {key}"))?;
|
||||||
.take()
|
|
||||||
.expect("Failed to deserialize stored embedding");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sanity check: they all exist
|
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "get chunk1".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "get chunk2".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "get chunk_other".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
|
|
||||||
// Delete embeddings by source_id (shared-source)
|
|
||||||
TextChunkEmbedding::delete_by_source_id(source_id, &db)
|
TextChunkEmbedding::delete_by_source_id(source_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete by source_id");
|
.with_context(|| "Failed to delete by source_id".to_string())?;
|
||||||
|
|
||||||
// Chunks from shared-source should have no embeddings
|
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "check chunk1".to_string())?
|
||||||
.is_none());
|
.is_none());
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "check chunk2".to_string())?
|
||||||
.is_none());
|
.is_none());
|
||||||
|
|
||||||
// The other chunk should still have its embedding
|
|
||||||
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.with_context(|| "check chunk_other".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_redefine_hnsw_index_updates_dimension() {
|
async fn test_upsert_replaces_existing_embedding_row() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = prepare_text_chunk_test_db(3).await?;
|
||||||
|
|
||||||
|
let user_id = "user-upsert";
|
||||||
|
let source_id = "source-upsert";
|
||||||
|
let chunk_key = "chunk-upsert";
|
||||||
|
|
||||||
|
create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await?;
|
||||||
|
|
||||||
|
let initial = TextChunkEmbedding::new(
|
||||||
|
chunk_key,
|
||||||
|
source_id.to_owned(),
|
||||||
|
vec![1.0_f32, 0.0, 0.0],
|
||||||
|
user_id.to_owned(),
|
||||||
|
);
|
||||||
|
db.upsert_item(initial)
|
||||||
|
.await
|
||||||
|
.with_context(|| "initial upsert".to_string())?;
|
||||||
|
|
||||||
|
let replacement = TextChunkEmbedding::new(
|
||||||
|
chunk_key,
|
||||||
|
source_id.to_owned(),
|
||||||
|
vec![0.0, 1.0, 0.0],
|
||||||
|
user_id.to_owned(),
|
||||||
|
);
|
||||||
|
db.upsert_item(replacement)
|
||||||
|
.await
|
||||||
|
.with_context(|| "upsert replacement embedding".to_string())?;
|
||||||
|
|
||||||
|
let chunk_rid = RecordId::from_table_key(TextChunk::table_name(), chunk_key);
|
||||||
|
let rows: Vec<TextChunkEmbedding> = db
|
||||||
|
.client
|
||||||
|
.query(format!(
|
||||||
|
"SELECT * FROM {} WHERE chunk_id = $chunk_id",
|
||||||
|
TextChunkEmbedding::table_name()
|
||||||
|
))
|
||||||
|
.bind(("chunk_id", chunk_rid))
|
||||||
|
.await
|
||||||
|
.with_context(|| "count embeddings".to_string())?
|
||||||
|
.take(0)
|
||||||
|
.with_context(|| "take embeddings".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(rows.len(), 1);
|
||||||
|
let row = rows.first().expect("expected one embedding row");
|
||||||
|
assert_eq!(row.id, chunk_key);
|
||||||
|
assert_eq!(row.embedding, vec![0.0, 1.0, 0.0]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_redefine_hnsw_index_updates_dimension() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
// Change the index dimension from default (1536) to a smaller test value.
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
|
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
|
||||||
.await
|
.await
|
||||||
.expect("failed to redefine index");
|
.with_context(|| "failed to redefine index".to_string())?;
|
||||||
|
|
||||||
let mut info_res = db
|
let idx_sql = get_idx_sql(&db).await?;
|
||||||
.client
|
|
||||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
|
||||||
.await
|
|
||||||
.expect("info query failed");
|
|
||||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
|
||||||
let info_json: serde_json::Value =
|
|
||||||
serde_json::to_value(info).expect("failed to convert info to json");
|
|
||||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
|
||||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
|
||||||
.as_str()
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
idx_sql.contains("DIMENSION 8"),
|
idx_sql.contains("DIMENSION 8"),
|
||||||
"expected index definition to contain new dimension, got: {idx_sql}"
|
"expected index definition to contain new dimension, got: {idx_sql}"
|
||||||
);
|
);
|
||||||
|
assert!(
|
||||||
|
idx_sql.contains("DIST COSINE"),
|
||||||
|
"expected index definition to use cosine distance, got: {idx_sql}"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_redefine_hnsw_index_is_idempotent() {
|
async fn test_redefine_hnsw_index_is_idempotent() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||||
.await
|
.await
|
||||||
.expect("first redefine failed");
|
.with_context(|| "first redefine failed".to_string())?;
|
||||||
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
|
||||||
.await
|
.await
|
||||||
.expect("second redefine failed");
|
.with_context(|| "second redefine failed".to_string())?;
|
||||||
|
|
||||||
let mut info_res = db
|
let idx_sql = get_idx_sql(&db).await?;
|
||||||
.client
|
|
||||||
.query("INFO FOR TABLE text_chunk_embedding;")
|
|
||||||
.await
|
|
||||||
.expect("info query failed");
|
|
||||||
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
|
|
||||||
let info_json: serde_json::Value =
|
|
||||||
serde_json::to_value(info).expect("failed to convert info to json");
|
|
||||||
let idx_sql = info_json["Object"]["indexes"]["Object"]
|
|
||||||
["idx_embedding_text_chunk_embedding"]["Strand"]
|
|
||||||
.as_str()
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
idx_sql.contains("DIMENSION 4"),
|
idx_sql.contains("DIMENSION 4"),
|
||||||
"expected index definition to retain dimension 4, got: {idx_sql}"
|
"expected index definition to retain dimension 4, got: {idx_sql}"
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
use surrealdb::opt::PatchOp;
|
use surrealdb::opt::PatchOp;
|
||||||
|
use surrealdb::RecordId;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||||
@@ -69,6 +73,7 @@ stored_object!(TextContent, "text_content", {
|
|||||||
});
|
});
|
||||||
|
|
||||||
impl TextContent {
|
impl TextContent {
|
||||||
|
#[must_use]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
text: String,
|
text: String,
|
||||||
context: Option<String>,
|
context: Option<String>,
|
||||||
@@ -100,7 +105,7 @@ impl TextContent {
|
|||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
|
||||||
let _res: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.update((Self::table_name(), id))
|
.update((Self::table_name(), id))
|
||||||
.patch(PatchOp::replace("/context", context))
|
.patch(PatchOp::replace("/context", context))
|
||||||
.patch(PatchOp::replace("/category", category))
|
.patch(PatchOp::replace("/category", category))
|
||||||
@@ -109,7 +114,12 @@ impl TextContent {
|
|||||||
"/updated_at",
|
"/updated_at",
|
||||||
surrealdb::Datetime::from(now),
|
surrealdb::Datetime::from(now),
|
||||||
))
|
))
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
if updated.is_none() {
|
||||||
|
return Err(AppError::NotFound(format!("text content {id} not found")));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -127,9 +137,10 @@ impl TextContent {
|
|||||||
.bind(("table_name", TextContent::table_name()))
|
.bind(("table_name", TextContent::table_name()))
|
||||||
.bind(("file_id", file_id.to_owned()))
|
.bind(("file_id", file_id.to_owned()))
|
||||||
.bind(("exclude_id", exclude_id.to_owned()))
|
.bind(("exclude_id", exclude_id.to_owned()))
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
let existing: Option<surrealdb::sql::Thing> = response.take(0)?;
|
let existing: Option<surrealdb::sql::Thing> = response.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(existing.is_some())
|
Ok(existing.is_some())
|
||||||
}
|
}
|
||||||
@@ -140,7 +151,8 @@ impl TextContent {
|
|||||||
user_id: &str,
|
user_id: &str,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
) -> Result<Vec<TextContentSearchResult>, AppError> {
|
) -> Result<Vec<TextContentSearchResult>, AppError> {
|
||||||
let sql = r#"
|
let sql = format!(
|
||||||
|
r#"
|
||||||
SELECT
|
SELECT
|
||||||
*,
|
*,
|
||||||
search::highlight('<b>', '</b>', 0) AS highlighted_text,
|
search::highlight('<b>', '</b>', 0) AS highlighted_text,
|
||||||
@@ -157,7 +169,7 @@ impl TextContent {
|
|||||||
IF search::score(4) != NONE THEN search::score(4) ELSE 0 END +
|
IF search::score(4) != NONE THEN search::score(4) ELSE 0 END +
|
||||||
IF search::score(5) != NONE THEN search::score(5) ELSE 0 END
|
IF search::score(5) != NONE THEN search::score(5) ELSE 0 END
|
||||||
) AS score
|
) AS score
|
||||||
FROM text_content
|
FROM {table}
|
||||||
WHERE
|
WHERE
|
||||||
(
|
(
|
||||||
text @0@ $terms OR
|
text @0@ $terms OR
|
||||||
@@ -170,25 +182,192 @@ impl TextContent {
|
|||||||
AND user_id = $user_id
|
AND user_id = $user_id
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit;
|
LIMIT $limit;
|
||||||
"#;
|
"#,
|
||||||
|
table = Self::table_name(),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(db
|
db.client
|
||||||
.client
|
|
||||||
.query(sql)
|
.query(sql)
|
||||||
.bind(("terms", search_terms.to_owned()))
|
.bind(("terms", search_terms.to_owned()))
|
||||||
.bind(("user_id", user_id.to_owned()))
|
.bind(("user_id", user_id.to_owned()))
|
||||||
.bind(("limit", limit))
|
.bind(("limit", limit))
|
||||||
.await?
|
.await
|
||||||
.take(0)?)
|
.map_err(AppError::from)?
|
||||||
|
.take(0)
|
||||||
|
.map_err(AppError::from)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Builds a fallback display label for a source id when no matching content row exists.
|
||||||
|
#[must_use]
|
||||||
|
pub fn fallback_source_label(source_id: &str) -> String {
|
||||||
|
format!("Text snippet: {}", source_id_suffix(source_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolves human-readable labels for the given source ids owned by `user_id`.
|
||||||
|
pub async fn resolve_source_labels(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
user_id: &str,
|
||||||
|
source_ids: impl IntoIterator<Item = impl AsRef<str>>,
|
||||||
|
) -> Result<HashMap<String, String>, AppError> {
|
||||||
|
let source_ids: HashSet<String> = source_ids
|
||||||
|
.into_iter()
|
||||||
|
.map(|id| id.as_ref().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if source_ids.is_empty() {
|
||||||
|
return Ok(HashMap::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let record_ids: Vec<RecordId> = source_ids
|
||||||
|
.iter()
|
||||||
|
.filter_map(|id| {
|
||||||
|
if id.contains(':') {
|
||||||
|
RecordId::from_str(id).ok()
|
||||||
|
} else {
|
||||||
|
Some(RecordId::from_table_key(Self::table_name(), id))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut response = db
|
||||||
|
.client
|
||||||
|
.query(
|
||||||
|
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
|
||||||
|
)
|
||||||
|
.bind(("table_name", Self::table_name()))
|
||||||
|
.bind(("user_id", user_id.to_owned()))
|
||||||
|
.bind(("record_ids", record_ids))
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
|
let contents: Vec<SourceLabelRow> = response.take(0).map_err(AppError::from)?;
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
source_id_count = source_ids.len(),
|
||||||
|
label_row_count = contents.len(),
|
||||||
|
"resolved source labels"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut labels = HashMap::new();
|
||||||
|
for content in contents {
|
||||||
|
let label = build_source_label(&content);
|
||||||
|
labels.insert(content.id.clone(), label.clone());
|
||||||
|
labels.insert(format!("{}:{}", Self::table_name(), content.id), label);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(labels)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const SOURCE_LABEL_MAX_CHARS: usize = 80;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct SourceLabelRow {
|
||||||
|
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||||
|
id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
url_info: Option<UrlInfo>,
|
||||||
|
#[serde(default)]
|
||||||
|
file_info: Option<FileInfo>,
|
||||||
|
#[serde(default)]
|
||||||
|
context: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
category: String,
|
||||||
|
#[serde(default)]
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn source_id_suffix(source_id: &str) -> String {
|
||||||
|
let start = source_id.len().saturating_sub(8);
|
||||||
|
source_id[start..].to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_with_ellipsis(value: &str, max_chars: usize) -> String {
|
||||||
|
const ELLIPSIS: &str = "…";
|
||||||
|
|
||||||
|
if max_chars == 0 {
|
||||||
|
return if value.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
ELLIPSIS.to_string()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut end_byte = value.len();
|
||||||
|
for (count, (idx, _)) in value.char_indices().enumerate() {
|
||||||
|
if count == max_chars {
|
||||||
|
end_byte = idx;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if end_byte == value.len() {
|
||||||
|
return value.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
format!("{}{}", &value[..end_byte], ELLIPSIS)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn first_non_empty_line(text: &str, max_chars: usize) -> Option<String> {
|
||||||
|
text.lines().find_map(|line| {
|
||||||
|
let trimmed = line.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(truncate_with_ellipsis(trimmed, max_chars))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_source_label(row: &SourceLabelRow) -> String {
|
||||||
|
if let Some(url_info) = row.url_info.as_ref() {
|
||||||
|
let title = url_info.title.trim();
|
||||||
|
if !title.is_empty() {
|
||||||
|
return title.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = url_info.url.trim();
|
||||||
|
if !url.is_empty() {
|
||||||
|
return url.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(file_info) = row.file_info.as_ref() {
|
||||||
|
let name = file_info.file_name.trim();
|
||||||
|
if !name.is_empty() {
|
||||||
|
return name.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(context) = row.context.as_ref() {
|
||||||
|
let trimmed = context.trim();
|
||||||
|
if !trimmed.is_empty() {
|
||||||
|
return truncate_with_ellipsis(trimmed, SOURCE_LABEL_MAX_CHARS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(text_label) = first_non_empty_line(&row.text, SOURCE_LABEL_MAX_CHARS) {
|
||||||
|
return text_label;
|
||||||
|
}
|
||||||
|
|
||||||
|
let category = row.category.trim();
|
||||||
|
if !category.is_empty() {
|
||||||
|
return truncate_with_ellipsis(category, SOURCE_LABEL_MAX_CHARS);
|
||||||
|
}
|
||||||
|
|
||||||
|
TextContent::fallback_source_label(&row.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::test_utils::setup_test_db_with_runtime_indexes;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_creation() {
|
async fn test_text_content_creation() -> anyhow::Result<()> {
|
||||||
// Test basic object creation
|
// Test basic object creation
|
||||||
let text = "Test content text".to_string();
|
let text = "Test content text".to_string();
|
||||||
let context = "Test context".to_string();
|
let context = "Test context".to_string();
|
||||||
@@ -212,10 +391,11 @@ mod tests {
|
|||||||
assert!(text_content.file_info.is_none());
|
assert!(text_content.file_info.is_none());
|
||||||
assert!(text_content.url_info.is_none());
|
assert!(text_content.url_info.is_none());
|
||||||
assert!(!text_content.id.is_empty());
|
assert!(!text_content.id.is_empty());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_with_url() {
|
async fn test_text_content_with_url() -> anyhow::Result<()> {
|
||||||
// Test creating with URL
|
// Test creating with URL
|
||||||
let text = "Content with URL".to_string();
|
let text = "Content with URL".to_string();
|
||||||
let context = "URL context".to_string();
|
let context = "URL context".to_string();
|
||||||
@@ -232,26 +412,27 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let text_content = TextContent::new(
|
let text_content = TextContent::new(
|
||||||
text.clone(),
|
text,
|
||||||
Some(context.clone()),
|
Some(context),
|
||||||
category.clone(),
|
category,
|
||||||
None,
|
None,
|
||||||
url_info.clone(),
|
url_info.clone(),
|
||||||
user_id.clone(),
|
user_id,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check URL field is set
|
// Check URL field is set
|
||||||
assert_eq!(text_content.url_info, url_info);
|
assert_eq!(text_content.url_info, url_info);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_text_content_patch() {
|
async fn test_text_content_patch() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Create initial text content
|
// Create initial text content
|
||||||
let initial_text = "Initial text".to_string();
|
let initial_text = "Initial text".to_string();
|
||||||
@@ -272,7 +453,7 @@ mod tests {
|
|||||||
let stored: Option<TextContent> = db
|
let stored: Option<TextContent> = db
|
||||||
.store_item(text_content.clone())
|
.store_item(text_content.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store text content");
|
.with_context(|| "Failed to store text content".to_string())?;
|
||||||
assert!(stored.is_some());
|
assert!(stored.is_some());
|
||||||
|
|
||||||
// New values for patch
|
// New values for patch
|
||||||
@@ -283,31 +464,42 @@ mod tests {
|
|||||||
// Apply the patch
|
// Apply the patch
|
||||||
TextContent::patch(&text_content.id, new_context, new_category, new_text, &db)
|
TextContent::patch(&text_content.id, new_context, new_category, new_text, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to patch text content");
|
.with_context(|| "Failed to patch text content".to_string())?;
|
||||||
|
|
||||||
// Retrieve the updated content
|
// Retrieve the updated content
|
||||||
let updated: Option<TextContent> = db
|
let updated: Option<TextContent> = db
|
||||||
.get_item(&text_content.id)
|
.get_item(&text_content.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get updated text content");
|
.with_context(|| "Failed to get updated text content".to_string())?;
|
||||||
assert!(updated.is_some());
|
let updated_content = updated.with_context(|| "expected updated content".to_string())?;
|
||||||
|
|
||||||
let updated_content = updated.unwrap();
|
|
||||||
|
|
||||||
// Verify the updates
|
// Verify the updates
|
||||||
assert_eq!(updated_content.context, Some(new_context.to_string()));
|
assert_eq!(updated_content.context, Some(new_context.to_string()));
|
||||||
assert_eq!(updated_content.category, new_category);
|
assert_eq!(updated_content.category, new_category);
|
||||||
assert_eq!(updated_content.text, new_text);
|
assert_eq!(updated_content.text, new_text);
|
||||||
assert!(updated_content.updated_at > text_content.updated_at);
|
assert!(updated_content.updated_at > text_content.updated_at);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_has_other_with_file_detects_shared_usage() {
|
async fn test_text_content_patch_not_found() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
|
||||||
|
let err = TextContent::patch("missing-id", "ctx", "cat", "text", &db)
|
||||||
|
.await
|
||||||
|
.expect_err("expected not found");
|
||||||
|
|
||||||
|
assert!(matches!(err, AppError::NotFound(_)));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_has_other_with_file_detects_shared_usage() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
let database = &Uuid::new_v4().to_string();
|
let database = &Uuid::new_v4().to_string();
|
||||||
let db = SurrealDbClient::memory(namespace, database)
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to start in-memory surrealdb");
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
let user_id = "user123".to_string();
|
let user_id = "user123".to_string();
|
||||||
let file_info = FileInfo {
|
let file_info = FileInfo {
|
||||||
@@ -340,24 +532,110 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(content_a.clone())
|
db.store_item(content_a.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store first content");
|
.with_context(|| "Failed to store first content".to_string())?;
|
||||||
db.store_item(content_b.clone())
|
db.store_item(content_b.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store second content");
|
.with_context(|| "Failed to store second content".to_string())?;
|
||||||
|
|
||||||
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to check for shared file usage");
|
.with_context(|| "Failed to check for shared file usage".to_string())?;
|
||||||
assert!(has_other);
|
assert!(has_other);
|
||||||
|
|
||||||
let _removed: Option<TextContent> = db
|
let _removed: Option<TextContent> = db
|
||||||
.delete_item(&content_b.id)
|
.delete_item(&content_b.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to delete second content");
|
.with_context(|| "Failed to delete second content".to_string())?;
|
||||||
|
|
||||||
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to check shared usage after delete");
|
.with_context(|| "Failed to check shared usage after delete".to_string())?;
|
||||||
assert!(!has_other_after);
|
assert!(!has_other_after);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_search_returns_empty_when_no_content() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
|
||||||
|
let results = TextContent::search(&db, "hello", "user", 5)
|
||||||
|
.await
|
||||||
|
.with_context(|| "search".to_string())?;
|
||||||
|
|
||||||
|
assert!(results.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_search_finds_matching_text_and_filters_user() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
let user_id = "search_user";
|
||||||
|
|
||||||
|
let matching = TextContent::new(
|
||||||
|
"rust programming language".to_string(),
|
||||||
|
Some("context".to_string()),
|
||||||
|
"notes".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
let other_user = TextContent::new(
|
||||||
|
"rust programming language".to_string(),
|
||||||
|
None,
|
||||||
|
"notes".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
"other_user".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
db.store_item(matching.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "store matching".to_string())?;
|
||||||
|
db.store_item(other_user)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store other user".to_string())?;
|
||||||
|
|
||||||
|
let results = TextContent::search(&db, "rust", user_id, 5)
|
||||||
|
.await
|
||||||
|
.with_context(|| "search".to_string())?;
|
||||||
|
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
let row = results.first().context("expected one result")?;
|
||||||
|
assert_eq!(row.id, matching.id);
|
||||||
|
assert_eq!(row.user_id, user_id);
|
||||||
|
assert!(row.score.is_finite());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_resolve_source_labels_uses_url_title() -> anyhow::Result<()> {
|
||||||
|
let db = setup_test_db_with_runtime_indexes().await?;
|
||||||
|
let user_id = "label_user";
|
||||||
|
|
||||||
|
let content = TextContent::new(
|
||||||
|
"body".to_string(),
|
||||||
|
None,
|
||||||
|
"notes".to_string(),
|
||||||
|
None,
|
||||||
|
Some(UrlInfo {
|
||||||
|
url: "https://example.com/doc".to_string(),
|
||||||
|
title: "Example Document".to_string(),
|
||||||
|
image_id: String::new(),
|
||||||
|
}),
|
||||||
|
user_id.to_string(),
|
||||||
|
);
|
||||||
|
db.store_item(content.clone()).await?;
|
||||||
|
|
||||||
|
let labels = TextContent::resolve_source_labels(&db, user_id, [content.id.clone()]).await?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
labels.get(&content.id),
|
||||||
|
Some(&"Example Document".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
labels.get(&format!("text_content:{}", content.id)),
|
||||||
|
Some(&"Example Document".to_string())
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+122
-111
@@ -55,6 +55,7 @@ impl FromStr for Theme {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Theme {
|
impl Theme {
|
||||||
|
#[must_use]
|
||||||
pub fn as_str(&self) -> &'static str {
|
pub fn as_str(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::Light => "light",
|
Self::Light => "light",
|
||||||
@@ -67,6 +68,7 @@ impl Theme {
|
|||||||
|
|
||||||
/// Returns the theme that should be initially applied.
|
/// Returns the theme that should be initially applied.
|
||||||
/// For "system", defaults to "light".
|
/// For "system", defaults to "light".
|
||||||
|
#[must_use]
|
||||||
pub fn initial_theme(&self) -> &'static str {
|
pub fn initial_theme(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::System => "light",
|
Self::System => "light",
|
||||||
@@ -371,7 +373,7 @@ impl User {
|
|||||||
.client
|
.client
|
||||||
.query(
|
.query(
|
||||||
"UPDATE type::thing('user', $id)
|
"UPDATE type::thing('user', $id)
|
||||||
SET api_key = test_string_nullish
|
SET api_key = NONE
|
||||||
RETURN AFTER",
|
RETURN AFTER",
|
||||||
)
|
)
|
||||||
.bind(("id", id.to_owned()))
|
.bind(("id", id.to_owned()))
|
||||||
@@ -532,7 +534,6 @@ impl User {
|
|||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
db.query("UPDATE type::thing('user', $user_id) SET timezone = $timezone")
|
db.query("UPDATE type::thing('user', $user_id) SET timezone = $timezone")
|
||||||
.bind(("table_name", Self::table_name()))
|
|
||||||
.bind(("user_id", user_id.to_string()))
|
.bind(("user_id", user_id.to_string()))
|
||||||
.bind(("timezone", timezone.to_string()))
|
.bind(("timezone", timezone.to_string()))
|
||||||
.await?;
|
.await?;
|
||||||
@@ -579,7 +580,7 @@ impl User {
|
|||||||
let entity: KnowledgeEntity = db
|
let entity: KnowledgeEntity = db
|
||||||
.get_item(id)
|
.get_item(id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| AppError::NotFound("Entity not found".into()))?;
|
.ok_or_else(|| AppError::NotFound("entity not found".into()))?;
|
||||||
|
|
||||||
if entity.user_id != user_id {
|
if entity.user_id != user_id {
|
||||||
return Err(AppError::Auth("Access denied".into()));
|
return Err(AppError::Auth("Access denied".into()));
|
||||||
@@ -596,7 +597,7 @@ impl User {
|
|||||||
let text_content: TextContent = db
|
let text_content: TextContent = db
|
||||||
.get_item(id)
|
.get_item(id)
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| AppError::NotFound("Content not found".into()))?;
|
.ok_or_else(|| AppError::NotFound("content not found".into()))?;
|
||||||
|
|
||||||
if text_content.user_id != user_id {
|
if text_content.user_id != user_id {
|
||||||
return Err(AppError::Auth("Access denied".into()));
|
return Err(AppError::Auth("Access denied".into()));
|
||||||
@@ -687,7 +688,7 @@ impl User {
|
|||||||
|
|
||||||
db.delete_item::<IngestionTask>(id)
|
db.delete_item::<IngestionTask>(id)
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::Database)?;
|
.map_err(AppError::from)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -723,30 +724,20 @@ impl User {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
// Helper function to set up a test database with SystemSettings
|
use crate::test_utils::setup_test_db;
|
||||||
async fn setup_test_db() -> SurrealDbClient {
|
|
||||||
let namespace = "test_ns";
|
|
||||||
let database = Uuid::new_v4().to_string();
|
|
||||||
let db = SurrealDbClient::memory(namespace, &database)
|
|
||||||
.await
|
|
||||||
.expect("Failed to start in-memory surrealdb");
|
|
||||||
|
|
||||||
db.apply_migrations()
|
|
||||||
.await
|
|
||||||
.expect("Failed to setup the migrations");
|
|
||||||
|
|
||||||
db
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_user_creation() {
|
async fn test_user_creation() -> anyhow::Result<()> {
|
||||||
// Setup test database
|
// Setup test database
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
// Create a user
|
// Create a user
|
||||||
let email = "test@example.com";
|
let email = "test@example.com";
|
||||||
@@ -761,7 +752,7 @@ mod tests {
|
|||||||
"system".to_string(),
|
"system".to_string(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create user");
|
.with_context(|| "Failed to create user".to_string())?;
|
||||||
|
|
||||||
// Verify user properties
|
// Verify user properties
|
||||||
assert!(!user.id.is_empty());
|
assert!(!user.id.is_empty());
|
||||||
@@ -774,18 +765,17 @@ mod tests {
|
|||||||
let retrieved: Option<User> = db
|
let retrieved: Option<User> = db
|
||||||
.get_item(&user.id)
|
.get_item(&user.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve user");
|
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||||
assert!(retrieved.is_some());
|
let retrieved = retrieved.with_context(|| "expected user to exist".to_string())?;
|
||||||
|
|
||||||
let retrieved = retrieved.unwrap();
|
|
||||||
assert_eq!(retrieved.id, user.id);
|
assert_eq!(retrieved.id, user.id);
|
||||||
assert_eq!(retrieved.email, email);
|
assert_eq!(retrieved.email, email);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_user_authentication() {
|
async fn test_user_authentication() -> anyhow::Result<()> {
|
||||||
// Setup test database
|
// Setup test database
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
// Create a user
|
// Create a user
|
||||||
let email = "auth_test@example.com";
|
let email = "auth_test@example.com";
|
||||||
@@ -799,7 +789,7 @@ mod tests {
|
|||||||
"system".to_string(),
|
"system".to_string(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create user");
|
.with_context(|| "Failed to create user".to_string())?;
|
||||||
|
|
||||||
// Test successful authentication
|
// Test successful authentication
|
||||||
let auth_result = User::authenticate(email, password, &db).await;
|
let auth_result = User::authenticate(email, password, &db).await;
|
||||||
@@ -812,11 +802,12 @@ mod tests {
|
|||||||
// Test failed authentication with non-existent user
|
// Test failed authentication with non-existent user
|
||||||
let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await;
|
let nonexistent = User::authenticate("nonexistent@example.com", password, &db).await;
|
||||||
assert!(nonexistent.is_err());
|
assert!(nonexistent.is_err());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_unfinished_ingestion_tasks_filters_correctly() {
|
async fn test_get_unfinished_ingestion_tasks_filters_correctly() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "unfinished_user";
|
let user_id = "unfinished_user";
|
||||||
let other_user_id = "other_user";
|
let other_user_id = "other_user";
|
||||||
|
|
||||||
@@ -830,14 +821,14 @@ mod tests {
|
|||||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let created_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
db.store_item(created_task.clone())
|
db.store_item(created_task.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store created task");
|
.with_context(|| "Failed to store created task".to_string())?;
|
||||||
|
|
||||||
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
processing_task.state = TaskState::Processing;
|
processing_task.state = TaskState::Processing;
|
||||||
processing_task.attempts = 1;
|
processing_task.attempts = 1;
|
||||||
db.store_item(processing_task.clone())
|
db.store_item(processing_task.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store processing task");
|
.with_context(|| "Failed to store processing task".to_string())?;
|
||||||
|
|
||||||
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
failed_retry_task.state = TaskState::Failed;
|
failed_retry_task.state = TaskState::Failed;
|
||||||
@@ -845,7 +836,7 @@ mod tests {
|
|||||||
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
|
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
|
||||||
db.store_item(failed_retry_task.clone())
|
db.store_item(failed_retry_task.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store retryable failed task");
|
.with_context(|| "Failed to store retryable failed task".to_string())?;
|
||||||
|
|
||||||
let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
failed_blocked_task.state = TaskState::Failed;
|
failed_blocked_task.state = TaskState::Failed;
|
||||||
@@ -853,13 +844,13 @@ mod tests {
|
|||||||
failed_blocked_task.error_message = Some("Too many failures".into());
|
failed_blocked_task.error_message = Some("Too many failures".into());
|
||||||
db.store_item(failed_blocked_task.clone())
|
db.store_item(failed_blocked_task.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store blocked task");
|
.with_context(|| "Failed to store blocked task".to_string())?;
|
||||||
|
|
||||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
completed_task.state = TaskState::Succeeded;
|
completed_task.state = TaskState::Succeeded;
|
||||||
db.store_item(completed_task.clone())
|
db.store_item(completed_task.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store completed task");
|
.with_context(|| "Failed to store completed task".to_string())?;
|
||||||
|
|
||||||
let other_payload = IngestionPayload::Text {
|
let other_payload = IngestionPayload::Text {
|
||||||
text: "Other".to_string(),
|
text: "Other".to_string(),
|
||||||
@@ -870,11 +861,11 @@ mod tests {
|
|||||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||||
db.store_item(other_task)
|
db.store_item(other_task)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store other user task");
|
.with_context(|| "Failed to store other user task".to_string())?;
|
||||||
|
|
||||||
let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db)
|
let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to fetch unfinished tasks");
|
.with_context(|| "Failed to fetch unfinished tasks".to_string())?;
|
||||||
|
|
||||||
let unfinished_ids: HashSet<String> =
|
let unfinished_ids: HashSet<String> =
|
||||||
unfinished.iter().map(|task| task.id.clone()).collect();
|
unfinished.iter().map(|task| task.id.clone()).collect();
|
||||||
@@ -885,11 +876,12 @@ mod tests {
|
|||||||
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
|
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
|
||||||
assert!(!unfinished_ids.contains(&completed_task.id));
|
assert!(!unfinished_ids.contains(&completed_task.id));
|
||||||
assert_eq!(unfinished_ids.len(), 3);
|
assert_eq!(unfinished_ids.len(), 3);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_all_ingestion_tasks_returns_sorted() {
|
async fn test_get_all_ingestion_tasks_returns_sorted() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "archive_user";
|
let user_id = "archive_user";
|
||||||
let other_user_id = "other_user";
|
let other_user_id = "other_user";
|
||||||
|
|
||||||
@@ -902,15 +894,19 @@ mod tests {
|
|||||||
|
|
||||||
// Oldest task
|
// Oldest task
|
||||||
let mut first = IngestionTask::new(payload.clone(), user_id.to_string());
|
let mut first = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
first.created_at = first.created_at - chrono::Duration::minutes(1);
|
first.created_at -= chrono::Duration::minutes(1);
|
||||||
first.updated_at = first.created_at;
|
first.updated_at = first.created_at;
|
||||||
first.state = TaskState::Succeeded;
|
first.state = TaskState::Succeeded;
|
||||||
db.store_item(first.clone()).await.expect("store first");
|
db.store_item(first.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "store first".to_string())?;
|
||||||
|
|
||||||
// Latest task
|
// Latest task
|
||||||
let mut second = IngestionTask::new(payload.clone(), user_id.to_string());
|
let mut second = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||||
second.state = TaskState::Processing;
|
second.state = TaskState::Processing;
|
||||||
db.store_item(second.clone()).await.expect("store second");
|
db.store_item(second.clone())
|
||||||
|
.await
|
||||||
|
.with_context(|| "store second".to_string())?;
|
||||||
|
|
||||||
let other_payload = IngestionPayload::Text {
|
let other_payload = IngestionPayload::Text {
|
||||||
text: "Other".to_string(),
|
text: "Other".to_string(),
|
||||||
@@ -919,21 +915,24 @@ mod tests {
|
|||||||
user_id: other_user_id.to_string(),
|
user_id: other_user_id.to_string(),
|
||||||
};
|
};
|
||||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||||
db.store_item(other_task).await.expect("store other");
|
db.store_item(other_task)
|
||||||
|
.await
|
||||||
|
.with_context(|| "store other".to_string())?;
|
||||||
|
|
||||||
let tasks = User::get_all_ingestion_tasks(user_id, &db)
|
let tasks = User::get_all_ingestion_tasks(user_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("fetch all tasks");
|
.with_context(|| "fetch all tasks".to_string())?;
|
||||||
|
|
||||||
assert_eq!(tasks.len(), 2);
|
assert_eq!(tasks.len(), 2);
|
||||||
assert_eq!(tasks[0].id, second.id); // newest first
|
assert_eq!(tasks.first().map(|t| &t.id), Some(&second.id)); // newest first
|
||||||
assert_eq!(tasks[1].id, first.id);
|
assert_eq!(tasks.get(1).map(|t| &t.id), Some(&first.id));
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_find_by_email() {
|
async fn test_find_by_email() -> anyhow::Result<()> {
|
||||||
// Setup test database
|
// Setup test database
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
// Create a user
|
// Create a user
|
||||||
let email = "find_test@example.com";
|
let email = "find_test@example.com";
|
||||||
@@ -947,28 +946,28 @@ mod tests {
|
|||||||
"system".to_string(),
|
"system".to_string(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create user");
|
.with_context(|| "Failed to create user".to_string())?;
|
||||||
|
|
||||||
// Test finding user by email
|
// Test finding user by email
|
||||||
let found_user = User::find_by_email(email, &db)
|
let found_user = User::find_by_email(email, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Error searching for user");
|
.with_context(|| "Error searching for user".to_string())?
|
||||||
assert!(found_user.is_some());
|
.with_context(|| "expected user to exist".to_string())?;
|
||||||
let found_user = found_user.unwrap();
|
|
||||||
assert_eq!(found_user.id, created_user.id);
|
assert_eq!(found_user.id, created_user.id);
|
||||||
assert_eq!(found_user.email, email);
|
assert_eq!(found_user.email, email);
|
||||||
|
|
||||||
// Test finding non-existent user
|
// Test finding non-existent user
|
||||||
let not_found = User::find_by_email("nonexistent@example.com", &db)
|
let not_found = User::find_by_email("nonexistent@example.com", &db)
|
||||||
.await
|
.await
|
||||||
.expect("Error searching for user");
|
.with_context(|| "Error searching for user".to_string())?;
|
||||||
assert!(not_found.is_none());
|
assert!(not_found.is_none());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_api_key_management() {
|
async fn test_api_key_management() -> anyhow::Result<()> {
|
||||||
// Setup test database
|
// Setup test database
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
// Create a user
|
// Create a user
|
||||||
let email = "apikey_test@example.com";
|
let email = "apikey_test@example.com";
|
||||||
@@ -982,7 +981,7 @@ mod tests {
|
|||||||
"system".to_string(),
|
"system".to_string(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create user");
|
.with_context(|| "Failed to create user".to_string())?;
|
||||||
|
|
||||||
// Initially, user should have no API key
|
// Initially, user should have no API key
|
||||||
assert!(user.api_key.is_none());
|
assert!(user.api_key.is_none());
|
||||||
@@ -990,7 +989,7 @@ mod tests {
|
|||||||
// Generate API key
|
// Generate API key
|
||||||
let api_key = User::set_api_key(&user.id, &db)
|
let api_key = User::set_api_key(&user.id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to set API key");
|
.with_context(|| "Failed to set API key".to_string())?;
|
||||||
assert!(!api_key.is_empty());
|
assert!(!api_key.is_empty());
|
||||||
assert!(api_key.starts_with("sk_"));
|
assert!(api_key.starts_with("sk_"));
|
||||||
|
|
||||||
@@ -998,43 +997,41 @@ mod tests {
|
|||||||
let updated_user: Option<User> = db
|
let updated_user: Option<User> = db
|
||||||
.get_item(&user.id)
|
.get_item(&user.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve user");
|
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||||
assert!(updated_user.is_some());
|
let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
|
||||||
let updated_user = updated_user.unwrap();
|
|
||||||
assert_eq!(updated_user.api_key, Some(api_key.clone()));
|
assert_eq!(updated_user.api_key, Some(api_key.clone()));
|
||||||
|
|
||||||
// Test finding user by API key
|
// Test finding user by API key
|
||||||
let found_user = User::find_by_api_key(&api_key, &db)
|
let found_user = User::find_by_api_key(&api_key, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Error searching by API key");
|
.with_context(|| "Error searching by API key".to_string())?
|
||||||
assert!(found_user.is_some());
|
.with_context(|| "expected user found by api key".to_string())?;
|
||||||
let found_user = found_user.unwrap();
|
|
||||||
assert_eq!(found_user.id, user.id);
|
assert_eq!(found_user.id, user.id);
|
||||||
|
|
||||||
// Revoke API key
|
// Revoke API key
|
||||||
User::revoke_api_key(&user.id, &db)
|
User::revoke_api_key(&user.id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to revoke API key");
|
.with_context(|| "Failed to revoke API key".to_string())?;
|
||||||
|
|
||||||
// Verify API key was revoked
|
// Verify API key was revoked
|
||||||
let revoked_user: Option<User> = db
|
let revoked_user: Option<User> = db
|
||||||
.get_item(&user.id)
|
.get_item(&user.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve user");
|
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||||
assert!(revoked_user.is_some());
|
let revoked_user = revoked_user.with_context(|| "expected revoked user".to_string())?;
|
||||||
let revoked_user = revoked_user.unwrap();
|
|
||||||
assert!(revoked_user.api_key.is_none());
|
assert!(revoked_user.api_key.is_none());
|
||||||
|
|
||||||
// Test searching by revoked API key
|
// Test searching by revoked API key
|
||||||
let not_found = User::find_by_api_key(&api_key, &db)
|
let not_found = User::find_by_api_key(&api_key, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Error searching by API key");
|
.with_context(|| "Error searching by API key".to_string())?;
|
||||||
assert!(not_found.is_none());
|
assert!(not_found.is_none());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_set_api_key_with_none_theme() {
|
async fn test_set_api_key_with_none_theme() {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await.expect("Failed to setup test db");
|
||||||
|
|
||||||
let user = User::create_new(
|
let user = User::create_new(
|
||||||
"legacy_theme@example.com".to_string(),
|
"legacy_theme@example.com".to_string(),
|
||||||
@@ -1069,9 +1066,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_password_update() {
|
async fn test_password_update() -> anyhow::Result<()> {
|
||||||
// Setup test database
|
// Setup test database
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
// Create a user
|
// Create a user
|
||||||
let email = "pwd_test@example.com";
|
let email = "pwd_test@example.com";
|
||||||
@@ -1086,7 +1083,7 @@ mod tests {
|
|||||||
"system".to_string(),
|
"system".to_string(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create user");
|
.with_context(|| "Failed to create user".to_string())?;
|
||||||
|
|
||||||
// Authenticate with old password
|
// Authenticate with old password
|
||||||
let auth_result = User::authenticate(email, old_password, &db).await;
|
let auth_result = User::authenticate(email, old_password, &db).await;
|
||||||
@@ -1095,7 +1092,7 @@ mod tests {
|
|||||||
// Update password
|
// Update password
|
||||||
User::patch_password(email, new_password, &db)
|
User::patch_password(email, new_password, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to update password");
|
.with_context(|| "Failed to update password".to_string())?;
|
||||||
|
|
||||||
// Old password should no longer work
|
// Old password should no longer work
|
||||||
let old_auth = User::authenticate(email, old_password, &db).await;
|
let old_auth = User::authenticate(email, old_password, &db).await;
|
||||||
@@ -1104,10 +1101,11 @@ mod tests {
|
|||||||
// New password should work
|
// New password should work
|
||||||
let new_auth = User::authenticate(email, new_password, &db).await;
|
let new_auth = User::authenticate(email, new_password, &db).await;
|
||||||
assert!(new_auth.is_ok());
|
assert!(new_auth.is_ok());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validate_timezone() {
|
async fn test_validate_timezone() -> anyhow::Result<()> {
|
||||||
// Valid timezones should be accepted as-is
|
// Valid timezones should be accepted as-is
|
||||||
assert_eq!(validate_timezone("America/New_York"), "America/New_York");
|
assert_eq!(validate_timezone("America/New_York"), "America/New_York");
|
||||||
assert_eq!(validate_timezone("Europe/London"), "Europe/London");
|
assert_eq!(validate_timezone("Europe/London"), "Europe/London");
|
||||||
@@ -1117,12 +1115,13 @@ mod tests {
|
|||||||
// Invalid timezones should be replaced with UTC
|
// Invalid timezones should be replaced with UTC
|
||||||
assert_eq!(validate_timezone("Invalid/Timezone"), "UTC");
|
assert_eq!(validate_timezone("Invalid/Timezone"), "UTC");
|
||||||
assert_eq!(validate_timezone("Not_Real"), "UTC");
|
assert_eq!(validate_timezone("Not_Real"), "UTC");
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_timezone_update() {
|
async fn test_timezone_update() -> anyhow::Result<()> {
|
||||||
// Setup test database
|
// Setup test database
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
|
|
||||||
// Create user with default timezone
|
// Create user with default timezone
|
||||||
let email = "timezone_test@example.com";
|
let email = "timezone_test@example.com";
|
||||||
@@ -1134,7 +1133,7 @@ mod tests {
|
|||||||
"system".to_string(),
|
"system".to_string(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create user");
|
.with_context(|| "Failed to create user".to_string())?;
|
||||||
|
|
||||||
assert_eq!(user.timezone, "UTC");
|
assert_eq!(user.timezone, "UTC");
|
||||||
|
|
||||||
@@ -1142,58 +1141,64 @@ mod tests {
|
|||||||
let new_timezone = "Europe/Paris";
|
let new_timezone = "Europe/Paris";
|
||||||
User::update_timezone(&user.id, new_timezone, &db)
|
User::update_timezone(&user.id, new_timezone, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to update timezone");
|
.with_context(|| "Failed to update timezone".to_string())?;
|
||||||
|
|
||||||
// Verify timezone was updated
|
// Verify timezone was updated
|
||||||
let updated_user: Option<User> = db
|
let updated_user: Option<User> = db
|
||||||
.get_item(&user.id)
|
.get_item(&user.id)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve user");
|
.with_context(|| "Failed to retrieve user".to_string())?;
|
||||||
assert!(updated_user.is_some());
|
let updated_user = updated_user.with_context(|| "expected updated user".to_string())?;
|
||||||
let updated_user = updated_user.unwrap();
|
|
||||||
assert_eq!(updated_user.timezone, new_timezone);
|
assert_eq!(updated_user.timezone, new_timezone);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_conversations_order() {
|
async fn test_conversations_order() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "user_order_test";
|
let user_id = "user_order_test";
|
||||||
|
|
||||||
// Create conversations with varying updated_at timestamps
|
// Create conversations with varying updated_at timestamps
|
||||||
let mut conversations = Vec::new();
|
let mut conversations = Vec::new();
|
||||||
for i in 0..5 {
|
for i in 0..5 {
|
||||||
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {}", i));
|
let mut conv = Conversation::new(user_id.to_string(), format!("Conv {i}"));
|
||||||
// Fake updated_at i minutes apart
|
// Fake updated_at i minutes apart
|
||||||
conv.created_at = chrono::Utc::now() - chrono::Duration::minutes(i);
|
conv.created_at = chrono::Utc::now() - chrono::Duration::minutes(i);
|
||||||
db.store_item(conv.clone())
|
db.store_item(conv.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store conversation");
|
.with_context(|| "Failed to store conversation".to_string())?;
|
||||||
conversations.push(conv);
|
conversations.push(conv);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve via get_user_conversations - should be ordered by updated_at DESC
|
// Retrieve via get_user_conversations - should be ordered by updated_at DESC
|
||||||
let retrieved = User::get_user_conversations(user_id, &db)
|
let retrieved = User::get_user_conversations(user_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to get conversations");
|
.with_context(|| "Failed to get conversations".to_string())?;
|
||||||
|
|
||||||
assert_eq!(retrieved.len(), conversations.len());
|
assert_eq!(retrieved.len(), conversations.len());
|
||||||
|
|
||||||
for window in retrieved.windows(2) {
|
for pair in retrieved.windows(2) {
|
||||||
// Assert each earlier conversation has updated_at >= later conversation
|
let a = pair.first().context("expected first in pair")?;
|
||||||
|
let b = pair.get(1).context("expected second in pair")?;
|
||||||
assert!(
|
assert!(
|
||||||
window[0].created_at >= window[1].created_at,
|
a.created_at >= b.created_at,
|
||||||
"Conversations not ordered descending by created_at"
|
"Conversations not ordered descending by created_at"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check first conversation title matches the most recently updated
|
// Check first conversation title matches the most recently updated
|
||||||
let most_recent = conversations.iter().max_by_key(|c| c.created_at).unwrap();
|
let most_recent = conversations
|
||||||
assert_eq!(retrieved[0].id, most_recent.id);
|
.iter()
|
||||||
|
.max_by_key(|c| c.created_at)
|
||||||
|
.context("expected most recent")?;
|
||||||
|
let r0 = retrieved.first().context("expected first result")?;
|
||||||
|
assert_eq!(r0.id, most_recent.id);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_latest_text_contents_returns_last_five() {
|
async fn test_get_latest_text_contents_returns_last_five() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let user_id = "latest_text_user";
|
let user_id = "latest_text_user";
|
||||||
|
|
||||||
let mut inserted_ids = Vec::new();
|
let mut inserted_ids = Vec::new();
|
||||||
@@ -1201,8 +1206,8 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..12 {
|
for i in 0..12 {
|
||||||
let mut item = TextContent::new(
|
let mut item = TextContent::new(
|
||||||
format!("Text {}", i),
|
format!("Text {i}"),
|
||||||
Some(format!("Context {}", i)),
|
Some(format!("Context {i}")),
|
||||||
"Category".to_string(),
|
"Category".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
@@ -1215,18 +1220,19 @@ mod tests {
|
|||||||
|
|
||||||
db.store_item(item.clone())
|
db.store_item(item.clone())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to store text content");
|
.with_context(|| "Failed to store text content".to_string())?;
|
||||||
|
|
||||||
inserted_ids.push(item.id.clone());
|
inserted_ids.push(item.id.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
let latest = User::get_latest_text_contents(user_id, &db)
|
let latest = User::get_latest_text_contents(user_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to fetch latest text contents");
|
.with_context(|| "Failed to fetch latest text contents".to_string())?;
|
||||||
|
|
||||||
assert_eq!(latest.len(), 5, "Expected exactly five items");
|
assert_eq!(latest.len(), 5, "Expected exactly five items");
|
||||||
|
|
||||||
let mut expected_ids = inserted_ids[inserted_ids.len() - 5..].to_vec();
|
let start = inserted_ids.len().saturating_sub(5);
|
||||||
|
let mut expected_ids = inserted_ids.get(start..).unwrap_or_default().to_vec();
|
||||||
expected_ids.reverse();
|
expected_ids.reverse();
|
||||||
|
|
||||||
let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect();
|
let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect();
|
||||||
@@ -1235,25 +1241,29 @@ mod tests {
|
|||||||
"Latest items did not match expectation"
|
"Latest items did not match expectation"
|
||||||
);
|
);
|
||||||
|
|
||||||
for window in latest.windows(2) {
|
for pair in latest.windows(2) {
|
||||||
|
let a = pair.first().context("expected first in pair")?;
|
||||||
|
let b = pair.get(1).context("expected second in pair")?;
|
||||||
assert!(
|
assert!(
|
||||||
window[0].created_at >= window[1].created_at,
|
a.created_at >= b.created_at,
|
||||||
"Results are not ordered by created_at descending"
|
"Results are not ordered by created_at descending"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validate_theme() {
|
async fn test_validate_theme() -> anyhow::Result<()> {
|
||||||
assert_eq!(validate_theme("light"), Theme::Light);
|
assert_eq!(validate_theme("light"), Theme::Light);
|
||||||
assert_eq!(validate_theme("dark"), Theme::Dark);
|
assert_eq!(validate_theme("dark"), Theme::Dark);
|
||||||
assert_eq!(validate_theme("system"), Theme::System);
|
assert_eq!(validate_theme("system"), Theme::System);
|
||||||
assert_eq!(validate_theme("invalid"), Theme::System);
|
assert_eq!(validate_theme("invalid"), Theme::System);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_theme_update() {
|
async fn test_theme_update() -> anyhow::Result<()> {
|
||||||
let db = setup_test_db().await;
|
let db = setup_test_db().await?;
|
||||||
let email = "theme_test@example.com";
|
let email = "theme_test@example.com";
|
||||||
let user = User::create_new(
|
let user = User::create_new(
|
||||||
email.to_string(),
|
email.to_string(),
|
||||||
@@ -1263,30 +1273,31 @@ mod tests {
|
|||||||
"system".to_string(),
|
"system".to_string(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to create user");
|
.with_context(|| "Failed to create user".to_string())?;
|
||||||
|
|
||||||
assert_eq!(user.theme, Theme::System);
|
assert_eq!(user.theme, Theme::System);
|
||||||
|
|
||||||
User::update_theme(&user.id, "dark", &db)
|
User::update_theme(&user.id, "dark", &db)
|
||||||
.await
|
.await
|
||||||
.expect("update theme");
|
.with_context(|| "update theme".to_string())?;
|
||||||
|
|
||||||
let updated = db
|
let updated = db
|
||||||
.get_item::<User>(&user.id)
|
.get_item::<User>(&user.id)
|
||||||
.await
|
.await
|
||||||
.expect("get user")
|
.with_context(|| "get user".to_string())?
|
||||||
.unwrap();
|
.with_context(|| "expected user".to_string())?;
|
||||||
assert_eq!(updated.theme, Theme::Dark);
|
assert_eq!(updated.theme, Theme::Dark);
|
||||||
|
|
||||||
// Invalid theme should default to system (but update_theme calls validate_theme)
|
// Invalid theme should default to system (but update_theme calls validate_theme)
|
||||||
User::update_theme(&user.id, "invalid", &db)
|
User::update_theme(&user.id, "invalid", &db)
|
||||||
.await
|
.await
|
||||||
.expect("update theme invalid");
|
.with_context(|| "update theme invalid".to_string())?;
|
||||||
let updated2 = db
|
let updated2 = db
|
||||||
.get_item::<User>(&user.id)
|
.get_item::<User>(&user.id)
|
||||||
.await
|
.await
|
||||||
.expect("get user")
|
.with_context(|| "get user".to_string())?
|
||||||
.unwrap();
|
.with_context(|| "expected user".to_string())?;
|
||||||
assert_eq!(updated2.theme, Theme::System);
|
assert_eq!(updated2.theme, Theme::System);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,93 @@
|
|||||||
|
//! Shared helpers for in-memory SurrealDB tests.
|
||||||
|
#![cfg(any(test, feature = "test-utils"))]
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::storage::{
|
||||||
|
db::SurrealDbClient,
|
||||||
|
indexes::{ensure_runtime, rebuild},
|
||||||
|
types::{
|
||||||
|
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
|
||||||
|
text_chunk_embedding::TextChunkEmbedding,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const TEST_NAMESPACE: &str = "test_ns";
|
||||||
|
|
||||||
|
/// Starts an in-memory database, applies migrations, and returns a client.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the database cannot be started or migrations fail.
|
||||||
|
pub async fn setup_test_db() -> Result<SurrealDbClient> {
|
||||||
|
let database = Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(TEST_NAMESPACE, &database)
|
||||||
|
.await
|
||||||
|
.context("start in-memory surrealdb")?;
|
||||||
|
|
||||||
|
db.apply_migrations().await.context("apply migrations")?;
|
||||||
|
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Updates singleton [`SystemSettings`] embedding dimensions for tests.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if settings cannot be loaded or updated.
|
||||||
|
pub async fn configure_embedding_dimension(db: &SurrealDbClient, dimension: u32) -> Result<()> {
|
||||||
|
let mut settings = SystemSettings::get_current(db).await?;
|
||||||
|
settings.embedding_dimensions = dimension;
|
||||||
|
SystemSettings::update(db, settings).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts a test database and sets the embedding dimension in system settings.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup or settings update fails.
|
||||||
|
pub async fn setup_test_db_with_embedding_dimension(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
configure_embedding_dimension(&db, dimension).await?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prepares a database for text-chunk embedding tests at the given dimension.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, settings update, or index redefinition fails.
|
||||||
|
pub async fn prepare_text_chunk_test_db(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||||
|
TextChunkEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("set text chunk index dimension to {dimension}"))?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prepares a database for knowledge-entity embedding tests at the given dimension.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, settings update, or index redefinition fails.
|
||||||
|
pub async fn prepare_knowledge_entity_test_db(dimension: u32) -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db_with_embedding_dimension(dimension).await?;
|
||||||
|
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, dimension as usize)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("set knowledge entity index dimension to {dimension}"))?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts a test database and ensures runtime FTS/HNSW indexes are ready.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if setup, index creation, or rebuild fails.
|
||||||
|
pub async fn setup_test_db_with_runtime_indexes() -> Result<SurrealDbClient> {
|
||||||
|
let db = setup_test_db().await?;
|
||||||
|
ensure_runtime(&db, 1536).await?;
|
||||||
|
rebuild(&db).await?;
|
||||||
|
Ok(db)
|
||||||
|
}
|
||||||
+85
-15
@@ -1,9 +1,18 @@
|
|||||||
use config::{Config, ConfigError, Environment, File};
|
use config::{Config, ConfigError, Environment, File};
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::{env, str::FromStr, sync::Once};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// Error returned when parsing an embedding backend name.
|
||||||
|
#[derive(Debug, Error, PartialEq, Eq)]
|
||||||
|
#[error("unknown embedding backend '{input}': expected 'openai', 'hashed', or 'fastembed'")]
|
||||||
|
pub struct ParseEmbeddingBackendError {
|
||||||
|
/// The unrecognized input string.
|
||||||
|
pub input: String,
|
||||||
|
}
|
||||||
|
|
||||||
/// Selects the embedding backend for vector generation.
|
/// Selects the embedding backend for vector generation.
|
||||||
#[derive(Clone, Deserialize, Debug, Default, PartialEq)]
|
#[derive(Clone, Copy, Deserialize, Serialize, Debug, Default, PartialEq, Eq)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum EmbeddingBackend {
|
pub enum EmbeddingBackend {
|
||||||
/// Use OpenAI-compatible API for embeddings.
|
/// Use OpenAI-compatible API for embeddings.
|
||||||
@@ -15,7 +24,33 @@ pub enum EmbeddingBackend {
|
|||||||
Hashed,
|
Hashed,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Debug, PartialEq)]
|
impl EmbeddingBackend {
|
||||||
|
#[must_use]
|
||||||
|
pub fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::OpenAI => "openai",
|
||||||
|
Self::FastEmbed => "fastembed",
|
||||||
|
Self::Hashed => "hashed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for EmbeddingBackend {
|
||||||
|
type Err = ParseEmbeddingBackendError;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
match s.to_ascii_lowercase().as_str() {
|
||||||
|
"openai" => Ok(Self::OpenAI),
|
||||||
|
"hashed" => Ok(Self::Hashed),
|
||||||
|
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
||||||
|
other => Err(ParseEmbeddingBackendError {
|
||||||
|
input: other.to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Deserialize, Debug, PartialEq)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum StorageKind {
|
pub enum StorageKind {
|
||||||
Local,
|
Local,
|
||||||
@@ -28,12 +63,12 @@ fn default_storage_kind() -> StorageKind {
|
|||||||
StorageKind::Local
|
StorageKind::Local
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_s3_region() -> Option<String> {
|
fn default_s3_region() -> String {
|
||||||
Some("us-east-1".to_string())
|
"us-east-1".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Selects the strategy used for PDF ingestion.
|
/// Selects the strategy used for PDF ingestion.
|
||||||
#[derive(Clone, Deserialize, Debug)]
|
#[derive(Clone, Copy, Deserialize, Debug)]
|
||||||
#[serde(rename_all = "kebab-case")]
|
#[serde(rename_all = "kebab-case")]
|
||||||
pub enum PdfIngestMode {
|
pub enum PdfIngestMode {
|
||||||
/// Only rely on classic text extraction (no LLM fallbacks).
|
/// Only rely on classic text extraction (no LLM fallbacks).
|
||||||
@@ -69,7 +104,7 @@ pub struct AppConfig {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub s3_endpoint: Option<String>,
|
pub s3_endpoint: Option<String>,
|
||||||
#[serde(default = "default_s3_region")]
|
#[serde(default = "default_s3_region")]
|
||||||
pub s3_region: Option<String>,
|
pub s3_region: String,
|
||||||
#[serde(default = "default_pdf_ingest_mode")]
|
#[serde(default = "default_pdf_ingest_mode")]
|
||||||
pub pdf_ingest_mode: PdfIngestMode,
|
pub pdf_ingest_mode: PdfIngestMode,
|
||||||
#[serde(default = "default_reranking_enabled")]
|
#[serde(default = "default_reranking_enabled")]
|
||||||
@@ -82,10 +117,14 @@ pub struct AppConfig {
|
|||||||
pub fastembed_show_download_progress: Option<bool>,
|
pub fastembed_show_download_progress: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub fastembed_max_length: Option<usize>,
|
pub fastembed_max_length: Option<usize>,
|
||||||
|
/// HuggingFace-style FastEmbed `model_code` (e.g. `Xenova/bge-small-en-v1.5`). Overrides
|
||||||
|
/// `system_settings.embedding_model` when `embedding_backend` is `fastembed`.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub retrieval_strategy: Option<String>,
|
pub fastembed_model: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub embedding_backend: EmbeddingBackend,
|
pub embedding_backend: EmbeddingBackend,
|
||||||
|
#[serde(default)]
|
||||||
|
pub embedding_pool_size: Option<usize>,
|
||||||
#[serde(default = "default_ingest_max_body_bytes")]
|
#[serde(default = "default_ingest_max_body_bytes")]
|
||||||
pub ingest_max_body_bytes: usize,
|
pub ingest_max_body_bytes: usize,
|
||||||
#[serde(default = "default_ingest_max_files")]
|
#[serde(default = "default_ingest_max_files")]
|
||||||
@@ -133,11 +172,17 @@ fn default_ingest_max_category_bytes() -> usize {
|
|||||||
128
|
128
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ORT_PATH_INIT: Once = Once::new();
|
||||||
|
|
||||||
|
/// Sets `ORT_DYLIB_PATH` once per process when a bundled ONNX runtime library is found.
|
||||||
pub fn ensure_ort_path() {
|
pub fn ensure_ort_path() {
|
||||||
if env::var_os("ORT_DYLIB_PATH").is_some() {
|
ORT_PATH_INIT.call_once(|| {
|
||||||
return;
|
if env::var_os("ORT_DYLIB_PATH").is_some() {
|
||||||
}
|
return;
|
||||||
if let Ok(mut exe) = env::current_exe() {
|
}
|
||||||
|
let Ok(mut exe) = env::current_exe() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
exe.pop();
|
exe.pop();
|
||||||
|
|
||||||
if cfg!(target_os = "windows") {
|
if cfg!(target_os = "windows") {
|
||||||
@@ -160,7 +205,7 @@ pub fn ensure_ort_path() {
|
|||||||
if p.exists() {
|
if p.exists() {
|
||||||
env::set_var("ORT_DYLIB_PATH", p);
|
env::set_var("ORT_DYLIB_PATH", p);
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AppConfig {
|
impl Default for AppConfig {
|
||||||
@@ -185,8 +230,9 @@ impl Default for AppConfig {
|
|||||||
fastembed_cache_dir: None,
|
fastembed_cache_dir: None,
|
||||||
fastembed_show_download_progress: None,
|
fastembed_show_download_progress: None,
|
||||||
fastembed_max_length: None,
|
fastembed_max_length: None,
|
||||||
retrieval_strategy: None,
|
fastembed_model: None,
|
||||||
embedding_backend: EmbeddingBackend::default(),
|
embedding_backend: EmbeddingBackend::default(),
|
||||||
|
embedding_pool_size: None,
|
||||||
ingest_max_body_bytes: default_ingest_max_body_bytes(),
|
ingest_max_body_bytes: default_ingest_max_body_bytes(),
|
||||||
ingest_max_files: default_ingest_max_files(),
|
ingest_max_files: default_ingest_max_files(),
|
||||||
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
||||||
@@ -208,3 +254,27 @@ pub fn get_config() -> Result<AppConfig, ConfigError> {
|
|||||||
|
|
||||||
config.try_deserialize()
|
config.try_deserialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used)]
|
||||||
|
|
||||||
|
use super::EmbeddingBackend;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embedding_backend_defaults_to_fastembed() {
|
||||||
|
assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn embedding_backend_parses_aliases() {
|
||||||
|
assert_eq!(
|
||||||
|
"openai".parse::<EmbeddingBackend>().expect("openai"),
|
||||||
|
EmbeddingBackend::OpenAI
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
"fast".parse::<EmbeddingBackend>().expect("fast"),
|
||||||
|
EmbeddingBackend::FastEmbed
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+512
-178
@@ -2,44 +2,25 @@ use std::{
|
|||||||
collections::hash_map::DefaultHasher,
|
collections::hash_map::DefaultHasher,
|
||||||
hash::{Hash, Hasher},
|
hash::{Hash, Hasher},
|
||||||
str::FromStr,
|
str::FromStr,
|
||||||
sync::Arc,
|
sync::{Arc, Mutex},
|
||||||
|
thread::available_parallelism,
|
||||||
};
|
};
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use serde::Serialize;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
|
||||||
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||||
use tracing::debug;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::AppError,
|
error::{AppError, EmbeddingError},
|
||||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||||
|
utils::config::AppConfig,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Supported embedding backends.
|
|
||||||
#[allow(clippy::module_name_repetitions)]
|
#[allow(clippy::module_name_repetitions)]
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
pub use crate::utils::config::{EmbeddingBackend, ParseEmbeddingBackendError};
|
||||||
pub enum EmbeddingBackend {
|
|
||||||
#[default]
|
|
||||||
OpenAI,
|
|
||||||
FastEmbed,
|
|
||||||
Hashed,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::str::FromStr for EmbeddingBackend {
|
|
||||||
type Err = anyhow::Error;
|
|
||||||
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
||||||
match s.to_ascii_lowercase().as_str() {
|
|
||||||
"openai" => Ok(Self::OpenAI),
|
|
||||||
"hashed" => Ok(Self::Hashed),
|
|
||||||
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
|
|
||||||
other => Err(anyhow!(
|
|
||||||
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Wrapper around the chosen embedding backend.
|
/// Wrapper around the chosen embedding backend.
|
||||||
#[allow(clippy::module_name_repetitions)]
|
#[allow(clippy::module_name_repetitions)]
|
||||||
@@ -57,7 +38,7 @@ enum EmbeddingInner {
|
|||||||
/// Client used to issue embedding requests.
|
/// Client used to issue embedding requests.
|
||||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||||
/// Model identifier for the API.
|
/// Model identifier for the API.
|
||||||
model: String,
|
model: Arc<str>,
|
||||||
/// Expected output dimensions.
|
/// Expected output dimensions.
|
||||||
dimensions: u32,
|
dimensions: u32,
|
||||||
},
|
},
|
||||||
@@ -68,8 +49,8 @@ enum EmbeddingInner {
|
|||||||
},
|
},
|
||||||
/// Uses `FastEmbed` running locally.
|
/// Uses `FastEmbed` running locally.
|
||||||
FastEmbed {
|
FastEmbed {
|
||||||
/// Shared `FastEmbed` model.
|
/// Pool of `FastEmbed` engines providing bounded-concurrency local embedding.
|
||||||
model: Arc<Mutex<TextEmbedding>>,
|
pool: Arc<FastEmbedPool>,
|
||||||
/// Model metadata used for info logging.
|
/// Model metadata used for info logging.
|
||||||
model_name: EmbeddingModel,
|
model_name: EmbeddingModel,
|
||||||
/// Output vector length.
|
/// Output vector length.
|
||||||
@@ -77,7 +58,250 @@ enum EmbeddingInner {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Batch size used when re-embedding stored data in bulk. Bounds peak memory and preserves
|
||||||
|
/// progress logging while still amortising per-call lock/dispatch overhead.
|
||||||
|
pub const RE_EMBED_BATCH_SIZE: usize = 128;
|
||||||
|
|
||||||
|
/// Default FastEmbed model (`BGESmallENV15`) when config and DB do not specify a valid code.
|
||||||
|
pub const DEFAULT_FASTEMBED_MODEL_CODE: &str = "Xenova/bge-small-en-v1.5";
|
||||||
|
|
||||||
|
/// A supported FastEmbed model for admin UI and documentation.
|
||||||
|
#[derive(Clone, Debug, Serialize)]
|
||||||
|
pub struct FastEmbedModelOption {
|
||||||
|
/// HuggingFace-style `model_code` accepted by [`EmbeddingModel::from_str`].
|
||||||
|
pub model_code: String,
|
||||||
|
/// Fixed output dimension for this model.
|
||||||
|
pub dimension: u32,
|
||||||
|
/// Short human-readable description from fastembed metadata.
|
||||||
|
pub description: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lists supported FastEmbed text embedding models (sorted by `model_code`).
|
||||||
|
#[must_use]
|
||||||
|
pub fn list_fastembed_embedding_models() -> Vec<FastEmbedModelOption> {
|
||||||
|
let mut list: Vec<FastEmbedModelOption> = TextEmbedding::list_supported_models()
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|info| {
|
||||||
|
let dimension = u32::try_from(info.dim).ok()?;
|
||||||
|
Some(FastEmbedModelOption {
|
||||||
|
model_code: info.model_code,
|
||||||
|
dimension,
|
||||||
|
description: info.description,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
list.sort_by(|left, right| left.model_code.cmp(&right.model_code));
|
||||||
|
list
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true when `code` is a supported FastEmbed `model_code` (HuggingFace-style id).
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_valid_fastembed_model_code(code: &str) -> bool {
|
||||||
|
!code.trim().is_empty() && EmbeddingModel::from_str(code.trim()).is_ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Vector dimension for a supported FastEmbed `model_code`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`EmbeddingError::UnknownModel`] when the code is not recognized.
|
||||||
|
pub fn fastembed_model_dimension(code: &str) -> Result<u32, EmbeddingError> {
|
||||||
|
let model = EmbeddingModel::from_str(code.trim())
|
||||||
|
.map_err(|_| EmbeddingError::UnknownModel(unknown_fastembed_model_message(code)))?;
|
||||||
|
let dim = EmbeddingModel::get_model_info(&model)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
EmbeddingError::Config(format!("fastembed model metadata missing for {code}"))
|
||||||
|
})?
|
||||||
|
.dim;
|
||||||
|
u32::try_from(dim).map_err(|_| {
|
||||||
|
EmbeddingError::Config(format!("fastembed model dimension {dim} exceeds u32::MAX"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolves the FastEmbed model code to load: config override, then DB, then default.
|
||||||
|
///
|
||||||
|
/// When `config.fastembed_model` is set it must be valid. When only the DB value is used and it
|
||||||
|
/// is not a FastEmbed code (e.g. legacy `text-embedding-3-small`), returns the default model.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`EmbeddingError::UnknownModel`] if `config.fastembed_model` is set but invalid.
|
||||||
|
pub fn resolve_fastembed_model_code(
|
||||||
|
config: &AppConfig,
|
||||||
|
settings_model: &str,
|
||||||
|
) -> Result<String, EmbeddingError> {
|
||||||
|
if let Some(code) = config.fastembed_model.as_deref() {
|
||||||
|
let trimmed = code.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return Err(EmbeddingError::Config(
|
||||||
|
"fastembed_model must not be empty when set".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
EmbeddingModel::from_str(trimmed)
|
||||||
|
.map_err(|_| EmbeddingError::UnknownModel(unknown_fastembed_model_message(trimmed)))?;
|
||||||
|
return Ok(trimmed.to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
|
let trimmed = settings_model.trim();
|
||||||
|
if is_valid_fastembed_model_code(trimmed) {
|
||||||
|
return Ok(trimmed.to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !trimmed.is_empty() {
|
||||||
|
warn!(
|
||||||
|
stored_model = trimmed,
|
||||||
|
default_model = DEFAULT_FASTEMBED_MODEL_CODE,
|
||||||
|
"system_settings.embedding_model is not a FastEmbed model code; using default"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(DEFAULT_FASTEMBED_MODEL_CODE.to_owned())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Persists a FastEmbed-compatible `embedding_model` and `embedding_dimensions` before startup
|
||||||
|
/// when the active backend is FastEmbed and stored settings still carry OpenAI defaults.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`AppError`] if settings cannot be loaded, resolved, or updated.
|
||||||
|
pub async fn align_fastembed_system_settings(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
config: &AppConfig,
|
||||||
|
) -> Result<SystemSettings, AppError> {
|
||||||
|
if config.embedding_backend != EmbeddingBackend::FastEmbed {
|
||||||
|
return SystemSettings::get_current(db).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut settings = SystemSettings::get_current(db).await?;
|
||||||
|
let resolved = resolve_fastembed_model_code(config, &settings.embedding_model)?;
|
||||||
|
let dimension = fastembed_model_dimension(&resolved)?;
|
||||||
|
|
||||||
|
if settings.embedding_model == resolved && settings.embedding_dimensions == dimension {
|
||||||
|
return Ok(settings);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
old_model = %settings.embedding_model,
|
||||||
|
new_model = %resolved,
|
||||||
|
old_dimensions = settings.embedding_dimensions,
|
||||||
|
new_dimensions = dimension,
|
||||||
|
"Aligning system settings with FastEmbed model"
|
||||||
|
);
|
||||||
|
settings.embedding_model = resolved;
|
||||||
|
settings.embedding_dimensions = dimension;
|
||||||
|
SystemSettings::update(db, settings).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unknown_fastembed_model_message(code: &str) -> String {
|
||||||
|
let mut codes: Vec<String> = TextEmbedding::list_supported_models()
|
||||||
|
.into_iter()
|
||||||
|
.map(|info| info.model_code)
|
||||||
|
.collect();
|
||||||
|
codes.sort();
|
||||||
|
let examples: Vec<&str> = codes.iter().take(6).map(String::as_str).collect();
|
||||||
|
format!(
|
||||||
|
"unknown FastEmbed model '{code}' (expected a HuggingFace model_code such as {}). \
|
||||||
|
Set fastembed_model in config.yaml or update system_settings; \
|
||||||
|
see docs/configuration.md ({count} models supported)",
|
||||||
|
examples.join(", "),
|
||||||
|
count = codes.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default FastEmbed pool size.
|
||||||
|
///
|
||||||
|
/// Kept small on purpose: the ONNX runtime already uses intra-op threads per inference, so
|
||||||
|
/// running many engines concurrently oversubscribes the CPU and each engine duplicates the
|
||||||
|
/// model weights in memory. Mirrors the reranker pool default.
|
||||||
|
#[must_use]
|
||||||
|
pub fn default_embedding_pool_size() -> usize {
|
||||||
|
available_parallelism()
|
||||||
|
.map_or(2, |value| value.get().min(2))
|
||||||
|
.max(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pool of `FastEmbed` engines enabling bounded-concurrency local embedding.
|
||||||
|
///
|
||||||
|
/// A single [`TextEmbedding`] embeds one batch at a time (`&mut self`), so the pool keeps
|
||||||
|
/// several instances and hands out a distinct idle engine per checkout. The semaphore bounds
|
||||||
|
/// total in-flight embeds (backpressure); the free list guarantees each active lease holds a
|
||||||
|
/// different engine — unlike a round-robin index, which can hand the same engine to two callers.
|
||||||
|
struct FastEmbedPool {
|
||||||
|
/// Idle engines; one is popped on checkout and returned on lease drop.
|
||||||
|
engines: Mutex<Vec<Arc<Mutex<TextEmbedding>>>>,
|
||||||
|
/// Sized to the engine count; gates concurrent checkouts.
|
||||||
|
semaphore: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FastEmbedPool {
|
||||||
|
fn new(engines: Vec<Arc<Mutex<TextEmbedding>>>) -> Self {
|
||||||
|
let permits = engines.len().max(1);
|
||||||
|
Self {
|
||||||
|
engines: Mutex::new(engines),
|
||||||
|
semaphore: Arc::new(Semaphore::new(permits)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Acquire a permit and borrow a distinct idle engine. The permit guarantees an engine is
|
||||||
|
/// available, so the pop always succeeds for a correctly sized pool.
|
||||||
|
async fn checkout(self: &Arc<Self>) -> Result<FastEmbedLease, EmbeddingError> {
|
||||||
|
let permit = Arc::clone(&self.semaphore)
|
||||||
|
.acquire_owned()
|
||||||
|
.await
|
||||||
|
.map_err(|_| EmbeddingError::Config("embedding pool is closed".into()))?;
|
||||||
|
let engine = self
|
||||||
|
.engines
|
||||||
|
.lock()
|
||||||
|
.map_err(EmbeddingError::mutex_poisoned)?
|
||||||
|
.pop()
|
||||||
|
.ok_or_else(|| EmbeddingError::Config("embedding pool unexpectedly empty".into()))?;
|
||||||
|
Ok(FastEmbedLease {
|
||||||
|
pool: Arc::clone(self),
|
||||||
|
engine,
|
||||||
|
_permit: permit,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Active borrow of a single `FastEmbed` engine; returns it to the pool on drop.
|
||||||
|
struct FastEmbedLease {
|
||||||
|
pool: Arc<FastEmbedPool>,
|
||||||
|
engine: Arc<Mutex<TextEmbedding>>,
|
||||||
|
/// Released after the engine is returned, unblocking the next checkout.
|
||||||
|
_permit: OwnedSemaphorePermit,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FastEmbedLease {
|
||||||
|
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
|
let engine = Arc::clone(&self.engine);
|
||||||
|
let texts = texts.to_vec();
|
||||||
|
tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
|
let mut guard = engine.lock().map_err(EmbeddingError::mutex_poisoned)?;
|
||||||
|
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(EmbeddingError::from)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for FastEmbedLease {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Ok(mut free) = self.pool.engines.lock() {
|
||||||
|
free.push(Arc::clone(&self.engine));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_fastembed(
|
||||||
|
pool: &Arc<FastEmbedPool>,
|
||||||
|
texts: &[String],
|
||||||
|
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
|
let lease = pool.checkout().await?;
|
||||||
|
lease.embed(texts).await
|
||||||
|
}
|
||||||
|
|
||||||
impl EmbeddingProvider {
|
impl EmbeddingProvider {
|
||||||
|
#[must_use]
|
||||||
pub fn backend_label(&self) -> &'static str {
|
pub fn backend_label(&self) -> &'static str {
|
||||||
match self.inner {
|
match self.inner {
|
||||||
EmbeddingInner::Hashed { .. } => "hashed",
|
EmbeddingInner::Hashed { .. } => "hashed",
|
||||||
@@ -86,6 +310,7 @@ impl EmbeddingProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn dimension(&self) -> usize {
|
pub fn dimension(&self) -> usize {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => {
|
EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => {
|
||||||
@@ -95,26 +320,28 @@ impl EmbeddingProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn model_code(&self) -> Option<String> {
|
pub fn model_code(&self) -> Option<String> {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
|
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
|
||||||
EmbeddingInner::OpenAI { model, .. } => Some(model.clone()),
|
EmbeddingInner::OpenAI { model, .. } => Some(model.as_ref().to_owned()),
|
||||||
EmbeddingInner::Hashed { .. } => None,
|
EmbeddingInner::Hashed { .. } => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
/// Generate an embedding vector for the given text.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`EmbeddingError`] if the backend API call fails, FastEmbed initialisation fails,
|
||||||
|
/// or the backend returns no embedding data.
|
||||||
|
pub async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
|
||||||
EmbeddingInner::FastEmbed { model, .. } => {
|
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||||
let mut guard = model.lock().await;
|
let text = text.to_owned();
|
||||||
let embeddings = guard
|
let embeddings = run_fastembed(pool, std::slice::from_ref(&text)).await?;
|
||||||
.embed(vec![text.to_owned()], None)
|
embeddings.into_iter().next().ok_or(EmbeddingError::NoData)
|
||||||
.context("generating fastembed vector")?;
|
|
||||||
embeddings
|
|
||||||
.into_iter()
|
|
||||||
.next()
|
|
||||||
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
|
|
||||||
}
|
}
|
||||||
EmbeddingInner::OpenAI {
|
EmbeddingInner::OpenAI {
|
||||||
client,
|
client,
|
||||||
@@ -122,7 +349,7 @@ impl EmbeddingProvider {
|
|||||||
dimensions,
|
dimensions,
|
||||||
} => {
|
} => {
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
let request = CreateEmbeddingRequestArgs::default()
|
||||||
.model(model.clone())
|
.model(model.as_ref())
|
||||||
.input([text])
|
.input([text])
|
||||||
.dimensions(*dimensions)
|
.dimensions(*dimensions)
|
||||||
.build()?;
|
.build()?;
|
||||||
@@ -132,7 +359,7 @@ impl EmbeddingProvider {
|
|||||||
let embedding = response
|
let embedding = response
|
||||||
.data
|
.data
|
||||||
.first()
|
.first()
|
||||||
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
|
.ok_or(EmbeddingError::NoData)?
|
||||||
.embedding
|
.embedding
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
@@ -141,20 +368,23 @@ impl EmbeddingProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
/// Generate embedding vectors for a batch of texts.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
|
||||||
|
/// Returns an empty `Vec` when `texts` is empty.
|
||||||
|
pub async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(|text| hashed_embedding(&text, *dimension))
|
.map(|text| hashed_embedding(text, *dimension))
|
||||||
.collect()),
|
.collect()),
|
||||||
EmbeddingInner::FastEmbed { model, .. } => {
|
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||||
if texts.is_empty() {
|
if texts.is_empty() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
let mut guard = model.lock().await;
|
run_fastembed(pool, texts).await
|
||||||
guard
|
|
||||||
.embed(texts, None)
|
|
||||||
.context("generating fastembed batch embeddings")
|
|
||||||
}
|
}
|
||||||
EmbeddingInner::OpenAI {
|
EmbeddingInner::OpenAI {
|
||||||
client,
|
client,
|
||||||
@@ -166,8 +396,8 @@ impl EmbeddingProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
let request = CreateEmbeddingRequestArgs::default()
|
||||||
.model(model.clone())
|
.model(model.as_ref())
|
||||||
.input(texts)
|
.input(texts.to_vec())
|
||||||
.dimensions(*dimensions)
|
.dimensions(*dimensions)
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
@@ -184,51 +414,84 @@ impl EmbeddingProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Currently infallible; reserved for future validation.
|
||||||
pub fn new_openai(
|
pub fn new_openai(
|
||||||
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
client: Arc<Client<async_openai::config::OpenAIConfig>>,
|
||||||
model: String,
|
model: impl AsRef<str>,
|
||||||
dimensions: u32,
|
dimensions: u32,
|
||||||
) -> Result<Self> {
|
) -> Result<Self, EmbeddingError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: EmbeddingInner::OpenAI {
|
inner: EmbeddingInner::OpenAI {
|
||||||
client,
|
client,
|
||||||
model,
|
model: Arc::from(model.as_ref()),
|
||||||
dimensions,
|
dimensions,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
|
/// Initialise a local FastEmbed provider backed by a pool of `pool_size` engines.
|
||||||
|
///
|
||||||
|
/// `pool_size` is clamped to at least 1. Larger pools allow concurrent embeds at the cost of
|
||||||
|
/// `pool_size`× model memory; see [`default_embedding_pool_size`] for guidance.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`EmbeddingError`] if the model name is unknown or FastEmbed initialisation fails.
|
||||||
|
pub async fn new_fastembed(
|
||||||
|
model_override: Option<String>,
|
||||||
|
pool_size: usize,
|
||||||
|
) -> Result<Self, EmbeddingError> {
|
||||||
|
let pool_size = pool_size.max(1);
|
||||||
let model_name = if let Some(code) = model_override {
|
let model_name = if let Some(code) = model_override {
|
||||||
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
|
EmbeddingModel::from_str(code.trim())
|
||||||
|
.map_err(|_| EmbeddingError::UnknownModel(unknown_fastembed_model_message(&code)))?
|
||||||
} else {
|
} else {
|
||||||
EmbeddingModel::default()
|
EmbeddingModel::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
|
|
||||||
let model_name_for_task = model_name.clone();
|
let model_name_for_task = model_name.clone();
|
||||||
let model_name_code = model_name.to_string();
|
let model_name_code = model_name.to_string();
|
||||||
|
|
||||||
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
|
let (engines, dimension) =
|
||||||
let model =
|
match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||||
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
|
let info =
|
||||||
let info = EmbeddingModel::get_model_info(&model_name_for_task)
|
EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||||
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
|
EmbeddingError::Config(format!(
|
||||||
Ok((model, info.dim))
|
"fastembed model metadata missing for {model_name_code}"
|
||||||
})
|
))
|
||||||
.await
|
})?;
|
||||||
.context("joining FastEmbed initialisation task")??;
|
let mut engines = Vec::with_capacity(pool_size);
|
||||||
|
for index in 0..pool_size {
|
||||||
|
let options = TextInitOptions::new(model_name_for_task.clone())
|
||||||
|
// Only the first engine reports download progress; the rest reuse the cache.
|
||||||
|
.with_show_download_progress(index == 0);
|
||||||
|
let model =
|
||||||
|
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||||
|
engines.push(Arc::new(Mutex::new(model)));
|
||||||
|
}
|
||||||
|
Ok((engines, info.dim))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(result) => result?,
|
||||||
|
Err(join_error) => return Err(EmbeddingError::from(join_error)),
|
||||||
|
};
|
||||||
|
|
||||||
Ok(EmbeddingProvider {
|
Ok(EmbeddingProvider {
|
||||||
inner: EmbeddingInner::FastEmbed {
|
inner: EmbeddingInner::FastEmbed {
|
||||||
model: Arc::new(Mutex::new(model)),
|
pool: Arc::new(FastEmbedPool::new(engines)),
|
||||||
model_name,
|
model_name,
|
||||||
dimension,
|
dimension,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_hashed(dimension: usize) -> Result<Self> {
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Currently infallible; reserved for future validation.
|
||||||
|
pub fn new_hashed(dimension: usize) -> Result<Self, EmbeddingError> {
|
||||||
Ok(EmbeddingProvider {
|
Ok(EmbeddingProvider {
|
||||||
inner: EmbeddingInner::Hashed {
|
inner: EmbeddingInner::Hashed {
|
||||||
dimension: dimension.max(1),
|
dimension: dimension.max(1),
|
||||||
@@ -236,30 +499,44 @@ impl EmbeddingProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an embedding provider based on application configuration.
|
/// Creates an embedding provider from persisted settings and bootstrap config.
|
||||||
///
|
///
|
||||||
/// Dispatches to the appropriate constructor based on `config.embedding_backend`:
|
/// OpenAI/hashed model settings come from [`SystemSettings`]. FastEmbed uses
|
||||||
/// - `OpenAI`: Requires a valid OpenAI client
|
/// [`resolve_fastembed_model_code`] (config `fastembed_model` overrides DB). The active
|
||||||
/// - `FastEmbed`: Uses local embedding model
|
/// backend is taken from `config.embedding_backend`; [`SystemSettings::sync_from_embedding_provider`]
|
||||||
/// - `Hashed`: Uses deterministic hashed embeddings (for testing)
|
/// persists the resolved backend to the database after startup.
|
||||||
pub async fn from_config(
|
///
|
||||||
config: &crate::utils::config::AppConfig,
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns [`EmbeddingError`] if the selected backend cannot be initialised.
|
||||||
|
pub async fn from_system_settings(
|
||||||
|
settings: &SystemSettings,
|
||||||
|
config: &AppConfig,
|
||||||
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self, EmbeddingError> {
|
||||||
use crate::utils::config::EmbeddingBackend;
|
let dimensions = settings.embedding_dimensions;
|
||||||
|
|
||||||
match config.embedding_backend {
|
match config.embedding_backend {
|
||||||
EmbeddingBackend::OpenAI => {
|
EmbeddingBackend::OpenAI => {
|
||||||
let client = openai_client
|
let client = openai_client.ok_or_else(|| {
|
||||||
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
|
EmbeddingError::Config(
|
||||||
// Use defaults that match SystemSettings initial values
|
"openai embedding backend requires an openai client".into(),
|
||||||
Self::new_openai(client, "text-embedding-3-small".to_string(), 1536)
|
)
|
||||||
|
})?;
|
||||||
|
Self::new_openai(client, settings.embedding_model.as_str(), dimensions)
|
||||||
}
|
}
|
||||||
EmbeddingBackend::FastEmbed => {
|
EmbeddingBackend::FastEmbed => {
|
||||||
// Use nomic-embed-text-v1.5 as the default FastEmbed model
|
let pool_size = config
|
||||||
Self::new_fastembed(Some("nomic-ai/nomic-embed-text-v1.5".to_string())).await
|
.embedding_pool_size
|
||||||
|
.unwrap_or_else(default_embedding_pool_size);
|
||||||
|
let model_code = resolve_fastembed_model_code(config, &settings.embedding_model)?;
|
||||||
|
Self::new_fastembed(Some(model_code), pool_size).await
|
||||||
|
}
|
||||||
|
EmbeddingBackend::Hashed => {
|
||||||
|
let dimension = usize::try_from(dimensions).map_err(|_| {
|
||||||
|
EmbeddingError::Config("embedding_dimensions exceeds usize::MAX".into())
|
||||||
|
})?;
|
||||||
|
Self::new_hashed(dimension)
|
||||||
}
|
}
|
||||||
EmbeddingBackend::Hashed => Self::new_hashed(384),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -306,94 +583,151 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
|||||||
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
|
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backward compatibility function
|
#[cfg(test)]
|
||||||
pub async fn generate_embedding_with_provider(
|
mod tests {
|
||||||
provider: &EmbeddingProvider,
|
#![allow(clippy::expect_used)]
|
||||||
input: &str,
|
|
||||||
) -> Result<Vec<f32>, AppError> {
|
use super::{
|
||||||
provider.embed(input).await.map_err(AppError::from)
|
align_fastembed_system_settings, fastembed_model_dimension,
|
||||||
}
|
list_fastembed_embedding_models, resolve_fastembed_model_code, EmbeddingError,
|
||||||
|
DEFAULT_FASTEMBED_MODEL_CODE,
|
||||||
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
|
};
|
||||||
///
|
use crate::storage::types::system_settings::SystemSettings;
|
||||||
/// This function takes a text input and converts it into a numerical vector representation (embedding)
|
use crate::utils::config::{AppConfig, EmbeddingBackend, ParseEmbeddingBackendError};
|
||||||
/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity
|
use serde_json::json;
|
||||||
/// comparisons, vector search, and other natural language processing tasks.
|
|
||||||
///
|
#[test]
|
||||||
/// # Arguments
|
fn embedding_backend_defaults_to_fastembed() {
|
||||||
///
|
assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
|
||||||
/// * `client`: The `OpenAI` client instance used to make API requests.
|
}
|
||||||
/// * `input`: The text string to generate embeddings for.
|
|
||||||
///
|
#[test]
|
||||||
/// # Returns
|
fn embedding_backend_as_str_matches_serde_names() {
|
||||||
///
|
assert_eq!(EmbeddingBackend::OpenAI.as_str(), "openai");
|
||||||
/// Returns a `Result` containing either:
|
assert_eq!(EmbeddingBackend::FastEmbed.as_str(), "fastembed");
|
||||||
/// * `Ok(Vec<f32>)`: A vector of 32-bit floating point numbers representing the text embedding
|
assert_eq!(EmbeddingBackend::Hashed.as_str(), "hashed");
|
||||||
/// * `Err(ProcessingError)`: An error if the embedding generation fails
|
|
||||||
///
|
assert_eq!(
|
||||||
/// # Errors
|
serde_json::to_string(&EmbeddingBackend::FastEmbed).expect("serialize"),
|
||||||
///
|
"\"fastembed\""
|
||||||
/// This function can return a `AppError` in the following cases:
|
);
|
||||||
/// * If the `OpenAI` API request fails
|
}
|
||||||
/// * If the request building fails
|
|
||||||
/// * If no embedding data is received in the response
|
#[test]
|
||||||
#[allow(clippy::module_name_repetitions)]
|
fn embedding_backend_deserializes_lowercase_values() {
|
||||||
pub async fn generate_embedding(
|
let openai: EmbeddingBackend = serde_json::from_str("\"openai\"").expect("openai");
|
||||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
let fastembed: EmbeddingBackend = serde_json::from_str("\"fastembed\"").expect("fastembed");
|
||||||
input: &str,
|
let hashed: EmbeddingBackend = serde_json::from_str("\"hashed\"").expect("hashed");
|
||||||
db: &SurrealDbClient,
|
|
||||||
) -> Result<Vec<f32>, AppError> {
|
assert_eq!(openai, EmbeddingBackend::OpenAI);
|
||||||
let model = SystemSettings::get_current(db).await?;
|
assert_eq!(fastembed, EmbeddingBackend::FastEmbed);
|
||||||
|
assert_eq!(hashed, EmbeddingBackend::Hashed);
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
}
|
||||||
.model(model.embedding_model)
|
|
||||||
.dimensions(model.embedding_dimensions)
|
#[test]
|
||||||
.input([input])
|
fn embedding_backend_from_str_accepts_aliases() {
|
||||||
.build()?;
|
assert_eq!(
|
||||||
|
"fast-embed"
|
||||||
// Send the request to OpenAI
|
.parse::<EmbeddingBackend>()
|
||||||
let response = client.embeddings().create(request).await?;
|
.expect("fast-embed"),
|
||||||
|
EmbeddingBackend::FastEmbed
|
||||||
// Extract the embedding vector
|
);
|
||||||
let embedding: Vec<f32> = response
|
assert_eq!(
|
||||||
.data
|
"FASTEMBED".parse::<EmbeddingBackend>().expect("FASTEMBED"),
|
||||||
.first()
|
EmbeddingBackend::FastEmbed
|
||||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received".into()))?
|
);
|
||||||
.embedding
|
assert!(matches!(
|
||||||
.clone();
|
"unknown-backend".parse::<EmbeddingBackend>(),
|
||||||
|
Err(ParseEmbeddingBackendError { .. })
|
||||||
Ok(embedding)
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates an embedding vector using a specific model and dimension.
|
#[test]
|
||||||
///
|
fn list_fastembed_embedding_models_includes_default() {
|
||||||
/// This is used for the re-embedding process where the model and dimensions
|
let models = list_fastembed_embedding_models();
|
||||||
/// are known ahead of time and shouldn't be repeatedly fetched from settings.
|
assert!(
|
||||||
pub async fn generate_embedding_with_params(
|
models
|
||||||
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
.iter()
|
||||||
input: &str,
|
.any(|m| m.model_code == DEFAULT_FASTEMBED_MODEL_CODE),
|
||||||
model: &str,
|
"catalog should include the default FastEmbed model"
|
||||||
dimensions: u32,
|
);
|
||||||
) -> Result<Vec<f32>, AppError> {
|
}
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
|
||||||
.model(model)
|
#[test]
|
||||||
.input([input])
|
fn resolve_fastembed_model_prefers_config_over_db() {
|
||||||
.dimensions(dimensions)
|
let config = AppConfig {
|
||||||
.build()?;
|
fastembed_model: Some("Xenova/bge-base-en-v1.5".into()),
|
||||||
|
..AppConfig::default()
|
||||||
let response = client.embeddings().create(request).await?;
|
};
|
||||||
|
let resolved =
|
||||||
let embedding = response
|
resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("config model");
|
||||||
.data
|
assert_eq!(resolved, "Xenova/bge-base-en-v1.5");
|
||||||
.first()
|
}
|
||||||
.ok_or_else(|| AppError::LLMParsing("No embedding data received from API".into()))?
|
|
||||||
.embedding
|
#[test]
|
||||||
.clone();
|
fn resolve_fastembed_model_falls_back_from_openai_default() {
|
||||||
|
let config = AppConfig::default();
|
||||||
debug!(
|
let resolved =
|
||||||
"Embedding was created with {:?} dimensions",
|
resolve_fastembed_model_code(&config, "text-embedding-3-small").expect("default model");
|
||||||
embedding.len()
|
assert_eq!(resolved, DEFAULT_FASTEMBED_MODEL_CODE);
|
||||||
);
|
}
|
||||||
|
|
||||||
Ok(embedding)
|
#[test]
|
||||||
|
fn resolve_fastembed_model_rejects_invalid_config_override() {
|
||||||
|
let config = AppConfig {
|
||||||
|
fastembed_model: Some("not-a-real-model".into()),
|
||||||
|
..AppConfig::default()
|
||||||
|
};
|
||||||
|
let err = resolve_fastembed_model_code(&config, "Xenova/bge-small-en-v1.5")
|
||||||
|
.expect_err("invalid config model");
|
||||||
|
assert!(matches!(err, EmbeddingError::UnknownModel(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fastembed_model_dimension_matches_model_metadata() {
|
||||||
|
let dim = fastembed_model_dimension(DEFAULT_FASTEMBED_MODEL_CODE).expect("dim");
|
||||||
|
assert_eq!(dim, 384);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn align_fastembed_system_settings_replaces_openai_default() -> anyhow::Result<()> {
|
||||||
|
use crate::storage::db::SurrealDbClient;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
let db = SurrealDbClient::memory("align_fe", &Uuid::new_v4().to_string()).await?;
|
||||||
|
db.apply_migrations().await?;
|
||||||
|
|
||||||
|
let config = AppConfig {
|
||||||
|
embedding_backend: EmbeddingBackend::FastEmbed,
|
||||||
|
..AppConfig::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let settings = align_fastembed_system_settings(&db, &config).await?;
|
||||||
|
assert_eq!(settings.embedding_model, DEFAULT_FASTEMBED_MODEL_CODE);
|
||||||
|
assert_eq!(settings.embedding_dimensions, 384);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn system_settings_deserializes_embedding_backend_field() {
|
||||||
|
let value = json!({
|
||||||
|
"id": "current",
|
||||||
|
"registrations_enabled": true,
|
||||||
|
"require_email_verification": false,
|
||||||
|
"query_model": "gpt-4o-mini",
|
||||||
|
"processing_model": "gpt-4o-mini",
|
||||||
|
"embedding_model": "text-embedding-3-small",
|
||||||
|
"embedding_dimensions": 1536,
|
||||||
|
"embedding_backend": "hashed",
|
||||||
|
"query_system_prompt": "query",
|
||||||
|
"ingestion_system_prompt": "ingestion",
|
||||||
|
"image_processing_model": "gpt-4o-mini",
|
||||||
|
"image_processing_prompt": "image",
|
||||||
|
"voice_processing_model": "whisper-1",
|
||||||
|
});
|
||||||
|
|
||||||
|
let settings: SystemSettings =
|
||||||
|
serde_json::from_value(value).expect("deserialize system settings");
|
||||||
|
assert_eq!(settings.embedding_backend, Some(EmbeddingBackend::Hashed));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,48 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
use super::config::AppConfig;
|
use super::config::AppConfig;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
/// Errors raised when validating ingestion payloads against configured limits.
|
||||||
|
#[derive(Error, Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum IngestValidationError {
|
pub enum IngestValidationError {
|
||||||
|
/// The payload exceeds a configured size limit (content, context, or category).
|
||||||
|
#[error("payload too large: {0}")]
|
||||||
PayloadTooLarge(String),
|
PayloadTooLarge(String),
|
||||||
|
/// The request violates a non-size constraint (e.g., too many files).
|
||||||
|
#[error("bad request: {0}")]
|
||||||
BadRequest(String),
|
BadRequest(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validates ingestion input against configured limits.
|
||||||
|
///
|
||||||
|
/// Checks file count, content size, context size, and category length.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `IngestValidationError::BadRequest` if the file count exceeds the maximum.
|
||||||
|
/// Returns `IngestValidationError::PayloadTooLarge` if content, context, or
|
||||||
|
/// category exceed their configured byte limits.
|
||||||
pub fn validate_ingest_input(
|
pub fn validate_ingest_input(
|
||||||
config: &AppConfig,
|
config: &AppConfig,
|
||||||
content: Option<&str>,
|
content: Option<&str>,
|
||||||
context: &str,
|
ctx: &str,
|
||||||
category: &str,
|
category: &str,
|
||||||
file_count: usize,
|
file_count: usize,
|
||||||
) -> Result<(), IngestValidationError> {
|
) -> Result<(), IngestValidationError> {
|
||||||
|
let content_bytes = content.map_or(0, str::len);
|
||||||
|
let text_field_bytes = content_bytes
|
||||||
|
.saturating_add(ctx.len())
|
||||||
|
.saturating_add(category.len());
|
||||||
|
if text_field_bytes > config.ingest_max_body_bytes {
|
||||||
|
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||||
|
"request text fields exceed maximum allowed body size of {} bytes",
|
||||||
|
config.ingest_max_body_bytes
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
if file_count > config.ingest_max_files {
|
if file_count > config.ingest_max_files {
|
||||||
return Err(IngestValidationError::BadRequest(format!(
|
return Err(IngestValidationError::BadRequest(format!(
|
||||||
"Too many files. Maximum allowed is {}",
|
"too many files: maximum allowed is {}",
|
||||||
config.ingest_max_files
|
config.ingest_max_files
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
@@ -23,22 +50,22 @@ pub fn validate_ingest_input(
|
|||||||
if let Some(content) = content {
|
if let Some(content) = content {
|
||||||
if content.len() > config.ingest_max_content_bytes {
|
if content.len() > config.ingest_max_content_bytes {
|
||||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||||
"Content is too large. Maximum allowed is {} bytes",
|
"content is too large: maximum allowed is {} bytes",
|
||||||
config.ingest_max_content_bytes
|
config.ingest_max_content_bytes
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if context.len() > config.ingest_max_context_bytes {
|
if ctx.len() > config.ingest_max_context_bytes {
|
||||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||||
"Context is too large. Maximum allowed is {} bytes",
|
"context is too large: maximum allowed is {} bytes",
|
||||||
config.ingest_max_context_bytes
|
config.ingest_max_context_bytes
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if category.len() > config.ingest_max_category_bytes {
|
if category.len() > config.ingest_max_category_bytes {
|
||||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||||
"Category is too large. Maximum allowed is {} bytes",
|
"category is too large: maximum allowed is {} bytes",
|
||||||
config.ingest_max_category_bytes
|
config.ingest_max_category_bytes
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
@@ -48,6 +75,7 @@ pub fn validate_ingest_input(
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
#![allow(clippy::expect_used, clippy::must_use_candidate)]
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -110,4 +138,18 @@ mod tests {
|
|||||||
|
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_ingest_input_rejects_oversized_text_fields() {
|
||||||
|
let config = AppConfig {
|
||||||
|
ingest_max_body_bytes: 10,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let result = validate_ingest_input(&config, Some("123456"), "ctx", "cat", 0);
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Err(IngestValidationError::PayloadTooLarge(_))
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod ingest_limits;
|
pub mod ingest_limits;
|
||||||
|
pub mod serde_helpers;
|
||||||
pub mod template_engine;
|
pub mod template_engine;
|
||||||
|
|||||||
@@ -0,0 +1,82 @@
|
|||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::de::{self, Visitor};
|
||||||
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
use std::fmt;
|
||||||
|
use surrealdb::sql::Thing;
|
||||||
|
|
||||||
|
struct FlexibleIdVisitor;
|
||||||
|
|
||||||
|
impl<'de> Visitor<'de> for FlexibleIdVisitor {
|
||||||
|
type Value = String;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
formatter.write_str("a string or a Thing")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: de::Error,
|
||||||
|
{
|
||||||
|
Ok(value.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: de::Error,
|
||||||
|
{
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: de::MapAccess<'de>,
|
||||||
|
{
|
||||||
|
let thing = Thing::deserialize(de::value::MapAccessDeserializer::new(map))?;
|
||||||
|
Ok(thing.id.to_raw())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize_flexible_id<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
deserializer.deserialize_any(FlexibleIdVisitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn serialize_datetime<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
Into::<surrealdb::sql::Datetime>::into(*date).serialize(serializer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let dt = surrealdb::sql::Datetime::deserialize(deserializer)?;
|
||||||
|
Ok(DateTime::<Utc>::from(dt))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn serialize_option_datetime<S>(
|
||||||
|
date: &Option<DateTime<Utc>>,
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
match date {
|
||||||
|
Some(dt) => serializer.serialize_some(&Into::<surrealdb::sql::Datetime>::into(*dt)),
|
||||||
|
None => serializer.serialize_none(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize_option_datetime<'de, D>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> Result<Option<DateTime<Utc>>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Option::<surrealdb::sql::Datetime>::deserialize(deserializer)?;
|
||||||
|
Ok(value.map(DateTime::<Utc>::from))
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ pub enum TemplateEngine {
|
|||||||
Embedded(Arc<Environment<'static>>),
|
Embedded(Arc<Environment<'static>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::module_name_repetitions)]
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! create_template_engine {
|
macro_rules! create_template_engine {
|
||||||
// Single path argument
|
// Single path argument
|
||||||
|
|||||||
+12
-4
@@ -4,7 +4,15 @@
|
|||||||
config,
|
config,
|
||||||
inputs,
|
inputs,
|
||||||
...
|
...
|
||||||
}: {
|
}:
|
||||||
|
let
|
||||||
|
ortVersion = lib.removeSuffix "\n" (builtins.readFile "${toString ./.}/ort-version");
|
||||||
|
_ortVersionCheck =
|
||||||
|
if pkgs.onnxruntime.version == ortVersion
|
||||||
|
then null
|
||||||
|
else
|
||||||
|
throw "pkgs.onnxruntime.version (${pkgs.onnxruntime.version}) must match ort-version (${ortVersion})";
|
||||||
|
in {
|
||||||
cachix.enable = false;
|
cachix.enable = false;
|
||||||
|
|
||||||
packages = [
|
packages = [
|
||||||
@@ -22,8 +30,9 @@
|
|||||||
|
|
||||||
languages.rust = {
|
languages.rust = {
|
||||||
enable = true;
|
enable = true;
|
||||||
|
channel = "stable";
|
||||||
|
version = "1.91.1";
|
||||||
components = ["rustc" "clippy" "rustfmt" "cargo" "rust-analyzer"];
|
components = ["rustc" "clippy" "rustfmt" "cargo" "rust-analyzer"];
|
||||||
channel = "nightly";
|
|
||||||
targets = ["x86_64-unknown-linux-gnu" "x86_64-pc-windows-msvc"];
|
targets = ["x86_64-unknown-linux-gnu" "x86_64-pc-windows-msvc"];
|
||||||
mold.enable = true;
|
mold.enable = true;
|
||||||
};
|
};
|
||||||
@@ -47,8 +56,7 @@
|
|||||||
};
|
};
|
||||||
|
|
||||||
processes = {
|
processes = {
|
||||||
surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --net=host --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password";
|
surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --net=host --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:v2.6.5-dev start rocksdb:/database/database.db --user root_user --pass root_password";
|
||||||
server.exec = "cargo watch -x 'run --bin main'";
|
|
||||||
tailwind.exec = "tailwindcss --cwd html-router -i app.css -o assets/style.css --watch=always";
|
tailwind.exec = "tailwindcss --cwd html-router -i app.css -o assets/style.css --watch=always";
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -4,7 +4,7 @@ members = ["cargo:."]
|
|||||||
# Config for 'dist'
|
# Config for 'dist'
|
||||||
[dist]
|
[dist]
|
||||||
# The preferred dist version to use in CI (Cargo.toml SemVer syntax)
|
# The preferred dist version to use in CI (Cargo.toml SemVer syntax)
|
||||||
cargo-dist-version = "0.30.0"
|
cargo-dist-version = "0.30.3"
|
||||||
# CI backends to support
|
# CI backends to support
|
||||||
ci = "github"
|
ci = "github"
|
||||||
# Extra static files to include in each App (path relative to this Cargo.toml's dir)
|
# Extra static files to include in each App (path relative to this Cargo.toml's dir)
|
||||||
|
|||||||
+1
-1
@@ -22,7 +22,7 @@ services:
|
|||||||
command: ["sh", "-c", "echo 'Waiting for SurrealDB to start...' && sleep 10 && echo 'Starting application...' && /usr/local/bin/main"]
|
command: ["sh", "-c", "echo 'Waiting for SurrealDB to start...' && sleep 10 && echo 'Starting application...' && /usr/local/bin/main"]
|
||||||
|
|
||||||
surrealdb:
|
surrealdb:
|
||||||
image: surrealdb/surrealdb:latest
|
image: surrealdb/surrealdb:v2.6.5
|
||||||
container_name: minne_surrealdb
|
container_name: minne_surrealdb
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
|
|||||||
| `RUST_LOG` | Logging level | `info` |
|
| `RUST_LOG` | Logging level | `info` |
|
||||||
| `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` |
|
| `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` |
|
||||||
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
|
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
|
||||||
| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - |
|
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`, `hashed`) | `fastembed` |
|
||||||
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
|
| `FASTEMBED_MODEL` | FastEmbed HuggingFace `model_code` (overrides DB when set) | `Xenova/bge-small-en-v1.5` |
|
||||||
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
|
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
|
||||||
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
|
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
|
||||||
| `FASTEMBED_MAX_LENGTH` | Max sequence length for FastEmbed models | - |
|
| `FASTEMBED_MAX_LENGTH` | Max sequence length for FastEmbed models | - |
|
||||||
@@ -77,6 +77,8 @@ storage: "local"
|
|||||||
# s3_region: "us-east-1"
|
# s3_region: "us-east-1"
|
||||||
pdf_ingest_mode: "llm-first"
|
pdf_ingest_mode: "llm-first"
|
||||||
embedding_backend: "fastembed"
|
embedding_backend: "fastembed"
|
||||||
|
# HuggingFace model_code (see fastembed docs); dimensions are fixed per model
|
||||||
|
fastembed_model: "Xenova/bge-small-en-v1.5"
|
||||||
|
|
||||||
# Optional reranking
|
# Optional reranking
|
||||||
reranking_enabled: true
|
reranking_enabled: true
|
||||||
|
|||||||
+6
-3
@@ -27,13 +27,16 @@ The D3-based graph visualization shows entities as nodes and relationships as ed
|
|||||||
|
|
||||||
## Hybrid Retrieval
|
## Hybrid Retrieval
|
||||||
|
|
||||||
Minne combines multiple retrieval strategies:
|
Minne uses hybrid retrieval over the knowledge base:
|
||||||
|
|
||||||
- **Vector similarity** — Semantic matching via embeddings
|
- **Vector similarity** — Semantic matching via embeddings
|
||||||
- **Full-text search** — Keyword matching with BM25
|
- **Full-text search** — Keyword matching with BM25
|
||||||
- **Graph traversal** — Following relationships between entities
|
|
||||||
|
|
||||||
Results are merged using Reciprocal Rank Fusion (RRF) for optimal relevance.
|
For **content search** (chat, global search, ingestion linking), retrieval is chunk-first: vector and FTS run over `text_chunk` rows, merged with Reciprocal Rank Fusion (RRF). When entities are needed, they are derived from the top retrieved chunks grouped by `source_id`.
|
||||||
|
|
||||||
|
For **relationship suggestions** when creating an entity, retrieval is entity-first: vector and FTS run directly over `knowledge_entity` name/description and embedding indexes, then merged with the same RRF approach.
|
||||||
|
|
||||||
|
Optional **reranking** can rescore fused chunk lists with a cross-encoder model; see below.
|
||||||
|
|
||||||
## Reranking (Optional)
|
## Reranking (Optional)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ name = "evaluations"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
async-openai = { workspace = true }
|
async-openai = { workspace = true }
|
||||||
|
|||||||
+19
-27
@@ -5,7 +5,6 @@ use std::{
|
|||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use clap::{Args, Parser, ValueEnum};
|
use clap::{Args, Parser, ValueEnum};
|
||||||
use retrieval_pipeline::RetrievalStrategy;
|
|
||||||
|
|
||||||
use crate::datasets::DatasetKind;
|
use crate::datasets::DatasetKind;
|
||||||
|
|
||||||
@@ -55,15 +54,11 @@ pub struct RetrievalSettings {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_fts_take: Option<usize>,
|
pub chunk_fts_take: Option<usize>,
|
||||||
|
|
||||||
/// Override average characters per token used for budgeting
|
|
||||||
#[arg(long)]
|
|
||||||
pub chunk_avg_chars_per_token: Option<usize>,
|
|
||||||
|
|
||||||
/// Override maximum chunks attached per entity
|
/// Override maximum chunks attached per entity
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub max_chunks_per_entity: Option<usize>,
|
pub max_chunks_per_entity: Option<usize>,
|
||||||
|
|
||||||
/// Enable the FastEmbed reranking stage
|
/// Enable the `FastEmbed` reranking stage
|
||||||
#[arg(long = "rerank", action = clap::ArgAction::SetTrue, default_value_t = false)]
|
#[arg(long = "rerank", action = clap::ArgAction::SetTrue, default_value_t = false)]
|
||||||
pub rerank: bool,
|
pub rerank: bool,
|
||||||
|
|
||||||
@@ -71,41 +66,37 @@ pub struct RetrievalSettings {
|
|||||||
#[arg(long, default_value_t = 4)]
|
#[arg(long, default_value_t = 4)]
|
||||||
pub rerank_pool_size: usize,
|
pub rerank_pool_size: usize,
|
||||||
|
|
||||||
/// Keep top-N entities after reranking
|
/// Keep top-N chunks after reranking
|
||||||
#[arg(long, default_value_t = 10)]
|
#[arg(long, default_value_t = 10)]
|
||||||
pub rerank_keep_top: usize,
|
pub rerank_keep_top: usize,
|
||||||
|
|
||||||
/// Cap the number of chunks returned by retrieval (revised strategy)
|
/// Cap the number of chunks returned by retrieval
|
||||||
#[arg(long, default_value_t = 5)]
|
#[arg(long, default_value_t = 5)]
|
||||||
pub chunk_result_cap: usize,
|
pub chunk_result_cap: usize,
|
||||||
|
|
||||||
/// Reciprocal rank fusion k value for revised chunk merging
|
/// Reciprocal rank fusion k value for chunk merging
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_k: Option<f32>,
|
pub chunk_rrf_k: Option<f32>,
|
||||||
|
|
||||||
/// Weight for vector ranks in revised RRF
|
/// Weight for vector ranks in RRF
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_vector_weight: Option<f32>,
|
pub chunk_rrf_vector_weight: Option<f32>,
|
||||||
|
|
||||||
/// Weight for chunk FTS ranks in revised RRF
|
/// Weight for chunk FTS ranks in RRF
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_fts_weight: Option<f32>,
|
pub chunk_rrf_fts_weight: Option<f32>,
|
||||||
|
|
||||||
/// Include vector ranks in revised RRF (default: true)
|
/// Include vector ranks in RRF (default: true)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_use_vector: Option<bool>,
|
pub chunk_rrf_use_vector: Option<bool>,
|
||||||
|
|
||||||
/// Include chunk FTS ranks in revised RRF (default: true)
|
/// Include chunk FTS ranks in RRF (default: true)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub chunk_rrf_use_fts: Option<bool>,
|
pub chunk_rrf_use_fts: Option<bool>,
|
||||||
|
|
||||||
/// Require verified chunks (disable with --llm-mode)
|
/// Require verified chunks (disable with --llm-mode)
|
||||||
#[arg(skip = true)]
|
#[arg(skip = true)]
|
||||||
pub require_verified_chunks: bool,
|
pub require_verified_chunks: bool,
|
||||||
|
|
||||||
/// Select the retrieval pipeline strategy
|
|
||||||
#[arg(long, default_value_t = RetrievalStrategy::Default)]
|
|
||||||
pub strategy: RetrievalStrategy,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RetrievalSettings {
|
impl Default for RetrievalSettings {
|
||||||
@@ -113,7 +104,6 @@ impl Default for RetrievalSettings {
|
|||||||
Self {
|
Self {
|
||||||
chunk_vector_take: None,
|
chunk_vector_take: None,
|
||||||
chunk_fts_take: None,
|
chunk_fts_take: None,
|
||||||
chunk_avg_chars_per_token: None,
|
|
||||||
max_chunks_per_entity: None,
|
max_chunks_per_entity: None,
|
||||||
rerank: false,
|
rerank: false,
|
||||||
rerank_pool_size: 4,
|
rerank_pool_size: 4,
|
||||||
@@ -125,7 +115,6 @@ impl Default for RetrievalSettings {
|
|||||||
chunk_rrf_use_vector: None,
|
chunk_rrf_use_vector: None,
|
||||||
chunk_rrf_use_fts: None,
|
chunk_rrf_use_fts: None,
|
||||||
require_verified_chunks: true,
|
require_verified_chunks: true,
|
||||||
strategy: RetrievalStrategy::Default,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,23 +160,23 @@ pub struct IngestConfig {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Args)]
|
#[derive(Debug, Clone, Args)]
|
||||||
pub struct DatabaseArgs {
|
pub struct DatabaseArgs {
|
||||||
/// SurrealDB server endpoint
|
/// `SurrealDB` server endpoint
|
||||||
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
|
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
|
||||||
pub db_endpoint: String,
|
pub db_endpoint: String,
|
||||||
|
|
||||||
/// SurrealDB root username
|
/// `SurrealDB` root username
|
||||||
#[arg(long, default_value = "root_user", env = "EVAL_DB_USERNAME")]
|
#[arg(long, default_value = "root_user", env = "EVAL_DB_USERNAME")]
|
||||||
pub db_username: String,
|
pub db_username: String,
|
||||||
|
|
||||||
/// SurrealDB root password
|
/// `SurrealDB` root password
|
||||||
#[arg(long, default_value = "root_password", env = "EVAL_DB_PASSWORD")]
|
#[arg(long, default_value = "root_password", env = "EVAL_DB_PASSWORD")]
|
||||||
pub db_password: String,
|
pub db_password: String,
|
||||||
|
|
||||||
/// Override the namespace used on the SurrealDB server
|
/// Override the namespace used on the `SurrealDB` server
|
||||||
#[arg(long, env = "EVAL_DB_NAMESPACE")]
|
#[arg(long, env = "EVAL_DB_NAMESPACE")]
|
||||||
pub db_namespace: Option<String>,
|
pub db_namespace: Option<String>,
|
||||||
|
|
||||||
/// Override the database used on the SurrealDB server
|
/// Override the database used on the `SurrealDB` server
|
||||||
#[arg(long, env = "EVAL_DB_DATABASE")]
|
#[arg(long, env = "EVAL_DB_DATABASE")]
|
||||||
pub db_database: Option<String>,
|
pub db_database: Option<String>,
|
||||||
|
|
||||||
@@ -198,6 +187,7 @@ pub struct DatabaseArgs {
|
|||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
/// Convert the selected dataset and exit
|
/// Convert the selected dataset and exit
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@@ -258,7 +248,7 @@ pub struct Config {
|
|||||||
#[arg(long, default_value_t = EmbeddingBackend::FastEmbed)]
|
#[arg(long, default_value_t = EmbeddingBackend::FastEmbed)]
|
||||||
pub embedding_backend: EmbeddingBackend,
|
pub embedding_backend: EmbeddingBackend,
|
||||||
|
|
||||||
/// FastEmbed model code
|
/// `FastEmbed` model code
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub embedding_model: Option<String>,
|
pub embedding_model: Option<String>,
|
||||||
|
|
||||||
@@ -277,7 +267,7 @@ pub struct Config {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub slice: Option<String>,
|
pub slice: Option<String>,
|
||||||
|
|
||||||
/// Ignore cached corpus state and rebuild the slice's SurrealDB corpus
|
/// Ignore cached corpus state and rebuild the slice's `SurrealDB` corpus
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub reseed_slice: bool,
|
pub reseed_slice: bool,
|
||||||
|
|
||||||
@@ -313,7 +303,7 @@ pub struct Config {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub inspect_manifest: Option<PathBuf>,
|
pub inspect_manifest: Option<PathBuf>,
|
||||||
|
|
||||||
/// Override the SurrealDB system settings query model
|
/// Override the `SurrealDB` system settings query model
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub query_model: Option<String>,
|
pub query_model: Option<String>,
|
||||||
|
|
||||||
@@ -344,10 +334,12 @@ pub struct Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
#[allow(clippy::unused_self)]
|
||||||
pub fn context_token_limit(&self) -> Option<usize> {
|
pub fn context_token_limit(&self) -> Option<usize> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
pub fn finalize(&mut self) -> Result<()> {
|
pub fn finalize(&mut self) -> Result<()> {
|
||||||
// Handle dataset paths
|
// Handle dataset paths
|
||||||
if let Some(raw) = &self.raw {
|
if let Some(raw) = &self.raw {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
path::{Path, PathBuf},
|
path::Path,
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
@@ -19,7 +19,7 @@ struct EmbeddingCacheData {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct EmbeddingCache {
|
pub struct EmbeddingCache {
|
||||||
path: Arc<PathBuf>,
|
path: Arc<Path>,
|
||||||
data: Arc<Mutex<EmbeddingCacheData>>,
|
data: Arc<Mutex<EmbeddingCacheData>>,
|
||||||
dirty: Arc<AtomicBool>,
|
dirty: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
@@ -39,7 +39,7 @@ impl EmbeddingCache {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
path: Arc::new(path),
|
path: Arc::from(path.as_path()),
|
||||||
data: Arc::new(Mutex::new(data)),
|
data: Arc::new(Mutex::new(data)),
|
||||||
dirty: Arc::new(AtomicBool::new(false)),
|
dirty: Arc::new(AtomicBool::new(false)),
|
||||||
})
|
})
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user