Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2c9bb848d | ||
|
|
04ee225732 | ||
|
|
13b7ad6f3a | ||
|
|
112a6965a4 | ||
|
|
911e830be5 | ||
|
|
3196e65172 | ||
|
|
380c900c86 | ||
|
|
a99e5ada8b | ||
|
|
b0deabaf3f | ||
|
|
a8f0d9fa88 | ||
|
|
56a1dfddb8 | ||
|
|
863b921fb4 | ||
|
|
f13791cfcf | ||
|
|
75c200b2ba | ||
|
|
1b7c24747a | ||
|
|
241ad9a089 | ||
|
|
72578296db | ||
|
|
a0e9387c76 | ||
|
|
798b1468b6 | ||
|
|
3b805778b4 | ||
|
|
07b3e1a0e8 | ||
|
|
83d39afad4 | ||
|
|
21e4ab1f42 | ||
|
|
3c97d8ead5 | ||
|
|
ab68bccb80 | ||
|
|
99b88c3063 | ||
|
|
44e5d8a2fc | ||
|
|
7332347f1a | ||
|
|
199186e5a3 | ||
|
|
64728468cd | ||
|
|
c3a7e8dc59 | ||
|
|
35ff4e1464 | ||
|
|
2964f1a5a5 | ||
|
|
cb7f625b81 | ||
|
|
dc40cf7663 | ||
|
|
aa0b1462a1 | ||
|
|
41fc7bb99c | ||
|
|
61d8d7abe7 | ||
|
|
b7344644dc | ||
|
|
3742598a6d | ||
|
|
c6a6080e1c | ||
|
|
1159712724 | ||
|
|
e5e1414f54 | ||
|
|
fcc49b1954 | ||
|
|
022f4d8575 | ||
|
|
945a2b7f37 | ||
|
|
ff4ea55cd5 | ||
|
|
c4c76efe92 | ||
|
|
c0fcad5952 | ||
|
|
b0ed69330d | ||
|
|
5cb15dab45 | ||
|
|
7403195df5 | ||
|
|
9faef31387 | ||
|
|
110f7b8a8f | ||
|
|
f343005af8 | ||
|
|
e1d98b0c35 | ||
|
|
c12d00edaa | ||
|
|
903585bfef | ||
|
|
f592eb7200 | ||
|
|
c2839f8db3 | ||
|
|
9a7c57cb19 | ||
|
|
fe5143cd7f | ||
|
|
3f774302c7 | ||
|
|
6ea51095e8 | ||
|
|
62d909bb7e | ||
|
|
69954cf78e | ||
|
|
153efd1a98 | ||
|
|
fdf29bb735 | ||
|
|
e150b476c3 | ||
|
|
2e076c8236 | ||
|
|
a0632c9768 | ||
|
|
33300d3193 |
49
.github/build-setup.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
- name: Prepare lib dir
|
||||
run: mkdir -p lib
|
||||
|
||||
# Linux
|
||||
- name: Fetch ONNX Runtime (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
set -euo pipefail
|
||||
ARCH="$(uname -m)"
|
||||
case "$ARCH" in
|
||||
x86_64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-x64-${ORT_VER}.tgz" ;;
|
||||
aarch64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-aarch64-${ORT_VER}.tgz" ;;
|
||||
*) echo "Unsupported arch $ARCH"; exit 1 ;;
|
||||
esac
|
||||
curl -fsSL -o ort.tgz "$URL"
|
||||
tar -xzf ort.tgz
|
||||
cp -v onnxruntime-*/lib/libonnxruntime.so* lib/
|
||||
|
||||
# macOS
|
||||
- name: Fetch ONNX Runtime (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
set -euo pipefail
|
||||
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
||||
tar -xzf ort.tgz
|
||||
# copy the main dylib; rename to stable name if needed
|
||||
cp -v onnxruntime-*/lib/libonnxruntime*.dylib lib/
|
||||
# optional: ensure a stable name
|
||||
if [ ! -f lib/libonnxruntime.dylib ]; then
|
||||
cp -v lib/libonnxruntime*.dylib lib/libonnxruntime.dylib
|
||||
fi
|
||||
|
||||
# Windows
|
||||
- name: Fetch ONNX Runtime (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
||||
Invoke-WebRequest $url -OutFile ort.zip
|
||||
Expand-Archive ort.zip -DestinationPath ort
|
||||
$dll = Get-ChildItem -Recurse -Path ort -Filter onnxruntime.dll | Select-Object -First 1
|
||||
Copy-Item $dll.FullName lib\onnxruntime.dll
|
||||
|
||||
242
.github/workflows/release.yml
vendored
@@ -1,44 +1,8 @@
|
||||
# This file was autogenerated by dist: https://opensource.axo.dev/cargo-dist/
|
||||
#
|
||||
# Copyright 2022-2024, axodotdev
|
||||
# SPDX-License-Identifier: MIT or Apache-2.0
|
||||
#
|
||||
# CI that:
|
||||
#
|
||||
# * checks for a Git Tag that looks like a release
|
||||
# * builds artifacts with dist (archives, installers, hashes)
|
||||
# * uploads those artifacts to temporary workflow zip
|
||||
# * on success, uploads the artifacts to a GitHub Release
|
||||
#
|
||||
# Note that the GitHub Release will be created with a generated
|
||||
# title/body based on your changelogs.
|
||||
|
||||
name: Release
|
||||
permissions:
|
||||
"contents": "write"
|
||||
"packages": "write"
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
# This task will run whenever you push a git tag that looks like a version
|
||||
# like "1.0.0", "v0.1.0-prerelease.1", "my-app/0.1.0", "releases/v1.0.0", etc.
|
||||
# Various formats will be parsed into a VERSION and an optional PACKAGE_NAME, where
|
||||
# PACKAGE_NAME must be the name of a Cargo package in your workspace, and VERSION
|
||||
# must be a Cargo-style SemVer Version (must have at least major.minor.patch).
|
||||
#
|
||||
# If PACKAGE_NAME is specified, then the announcement will be for that
|
||||
# package (erroring out if it doesn't have the given version or isn't dist-able).
|
||||
#
|
||||
# If PACKAGE_NAME isn't specified, then the announcement will be for all
|
||||
# (dist-able) packages in the workspace with that version (this mode is
|
||||
# intended for workspaces with only one dist-able package, or with all dist-able
|
||||
# packages versioned/released in lockstep).
|
||||
#
|
||||
# If you push multiple tags at once, separate instances of this workflow will
|
||||
# spin up, creating an independent announcement for each one. However, GitHub
|
||||
# will hard limit this to 3 tags per commit, as it will assume more tags is a
|
||||
# mistake.
|
||||
#
|
||||
# If there's a prerelease-style suffix to the version, then the release(s)
|
||||
# will be marked as a prerelease.
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
@@ -46,9 +10,8 @@ on:
|
||||
- '**[0-9]+.[0-9]+.[0-9]+*'
|
||||
|
||||
jobs:
|
||||
# Run 'dist plan' (or host) to determine what tasks we need to do
|
||||
plan:
|
||||
runs-on: "ubuntu-22.04"
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
val: ${{ steps.plan.outputs.manifest }}
|
||||
tag: ${{ !github.event.pull_request && github.ref_name || '' }}
|
||||
@@ -60,52 +23,36 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install dist
|
||||
# we specify bash to get pipefail; it guards against the `curl` command
|
||||
# failing. otherwise `sh` won't catch that `curl` returned non-0
|
||||
shell: bash
|
||||
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.28.0/cargo-dist-installer.sh | sh"
|
||||
run: "curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.30.0/cargo-dist-installer.sh | sh"
|
||||
|
||||
- name: Cache dist
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/dist
|
||||
# sure would be cool if github gave us proper conditionals...
|
||||
# so here's a doubly-nested ternary-via-truthiness to try to provide the best possible
|
||||
# functionality based on whether this is a pull_request, and whether it's from a fork.
|
||||
# (PRs run on the *source* but secrets are usually on the *target* -- that's *good*
|
||||
# but also really annoying to build CI around when it needs secrets to work right.)
|
||||
|
||||
- id: plan
|
||||
run: |
|
||||
dist ${{ (!github.event.pull_request && format('host --steps=create --tag={0}', github.ref_name)) || 'plan' }} --output-format=json > plan-dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
cat plan-dist-manifest.json
|
||||
echo "manifest=$(jq -c "." plan-dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
- name: "Upload dist-manifest.json"
|
||||
echo "manifest=$(jq -c . plan-dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Upload dist-manifest.json
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-plan-dist-manifest
|
||||
path: plan-dist-manifest.json
|
||||
|
||||
# Build and packages all the platform-specific things
|
||||
build-local-artifacts:
|
||||
name: build-local-artifacts (${{ join(matrix.targets, ', ') }})
|
||||
# Let the initial task tell us to not run (currently very blunt)
|
||||
needs:
|
||||
- plan
|
||||
needs: [plan]
|
||||
if: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix.include != null && (needs.plan.outputs.publishing == 'true' || fromJson(needs.plan.outputs.val).ci.github.pr_run_mode == 'upload') }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
# Target platforms/runners are computed by dist in create-release.
|
||||
# Each member of the matrix has the following arguments:
|
||||
#
|
||||
# - runner: the github runner
|
||||
# - dist-args: cli flags to pass to dist
|
||||
# - install-dist: expression to run to install dist on the runner
|
||||
#
|
||||
# Typically there will be:
|
||||
# - 1 "global" task that builds universal installers
|
||||
# - N "local" tasks that build each platform's binaries and platform-specific installers
|
||||
matrix: ${{ fromJson(needs.plan.outputs.val).ci.github.artifacts_matrix }}
|
||||
runs-on: ${{ matrix.runner }}
|
||||
container: ${{ matrix.container && matrix.container.image || null }}
|
||||
@@ -114,11 +61,12 @@ jobs:
|
||||
BUILD_MANIFEST_NAME: target/distrib/${{ join(matrix.targets, '-') }}-dist-manifest.json
|
||||
steps:
|
||||
- name: enable windows longpaths
|
||||
run: |
|
||||
git config --global core.longpaths true
|
||||
run: git config --global core.longpaths true
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install Rust non-interactively if not already installed
|
||||
if: ${{ matrix.container }}
|
||||
run: |
|
||||
@@ -126,37 +74,103 @@ jobs:
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install dist
|
||||
run: ${{ matrix.install_dist.run }}
|
||||
# Get the dist-manifest
|
||||
|
||||
- name: Fetch local artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
# ===== BEGIN: Injected ORT staging for cargo-dist bundling =====
|
||||
- run: echo "=== BUILD-SETUP START ==="
|
||||
|
||||
# Unix shells
|
||||
- name: Prepare lib dir (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p lib
|
||||
rm -f lib/*
|
||||
|
||||
# Windows PowerShell
|
||||
- name: Prepare lib dir (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
New-Item -ItemType Directory -Force -Path lib | Out-Null
|
||||
# remove contents if any
|
||||
Get-ChildItem -Path lib -Force | Remove-Item -Force -Recurse -ErrorAction SilentlyContinue
|
||||
|
||||
- name: Fetch ONNX Runtime (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
set -euo pipefail
|
||||
ARCH="$(uname -m)"
|
||||
case "$ARCH" in
|
||||
x86_64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-x64-${ORT_VER}.tgz" ;;
|
||||
aarch64) URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-linux-aarch64-${ORT_VER}.tgz" ;;
|
||||
*) echo "Unsupported arch $ARCH"; exit 1 ;;
|
||||
esac
|
||||
curl -fsSL -o ort.tgz "$URL"
|
||||
tar -xzf ort.tgz
|
||||
cp -v onnxruntime-*/lib/libonnxruntime.so* lib/
|
||||
# normalize to stable name if needed
|
||||
[ -f lib/libonnxruntime.so ] || cp -v lib/libonnxruntime.so.* lib/libonnxruntime.so
|
||||
|
||||
- name: Fetch ONNX Runtime (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
set -euo pipefail
|
||||
curl -fsSL -o ort.tgz "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VER}/onnxruntime-osx-universal2-${ORT_VER}.tgz"
|
||||
tar -xzf ort.tgz
|
||||
cp -v onnxruntime-*/lib/libonnxruntime*.dylib lib/
|
||||
[ -f lib/libonnxruntime.dylib ] || cp -v lib/libonnxruntime*.dylib lib/libonnxruntime.dylib
|
||||
|
||||
- name: Fetch ONNX Runtime (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
env:
|
||||
ORT_VER: 1.22.0
|
||||
run: |
|
||||
$url = "https://github.com/microsoft/onnxruntime/releases/download/v$env:ORT_VER/onnxruntime-win-x64-$env:ORT_VER.zip"
|
||||
Invoke-WebRequest $url -OutFile ort.zip
|
||||
Expand-Archive ort.zip -DestinationPath ort
|
||||
$dll = Get-ChildItem -Recurse -Path ort -Filter onnxruntime.dll | Select-Object -First 1
|
||||
Copy-Item $dll.FullName lib\onnxruntime.dll
|
||||
|
||||
- run: |
|
||||
echo "=== BUILD-SETUP END ==="
|
||||
echo "lib/ contents:"
|
||||
ls -l lib || dir lib
|
||||
# ===== END: Injected ORT staging =====
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
${{ matrix.packages_install }}
|
||||
|
||||
- name: Build artifacts
|
||||
run: |
|
||||
# Actually do builds and make zips and whatnot
|
||||
dist build ${{ needs.plan.outputs.tag-flag }} --print=linkage --output-format=json ${{ matrix.dist_args }} > dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
|
||||
- id: cargo-dist
|
||||
name: Post-build
|
||||
# We force bash here just because github makes it really hard to get values up
|
||||
# to "real" actions without writing to env-vars, and writing to env-vars has
|
||||
# inconsistent syntax between shell and powershell.
|
||||
shell: bash
|
||||
run: |
|
||||
# Parse out what we just built and upload it to scratch storage
|
||||
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
||||
dist print-upload-files-from-manifest --manifest dist-manifest.json >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
|
||||
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
||||
- name: "Upload artifacts"
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-build-local-${{ join(matrix.targets, '_') }}
|
||||
@@ -167,16 +181,16 @@ jobs:
|
||||
build_and_push_docker_image:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
needs: [plan]
|
||||
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
||||
needs: [plan]
|
||||
if: ${{ needs.plan.outputs.publishing == 'true' }}
|
||||
permissions:
|
||||
contents: read # Permission to checkout the repository
|
||||
packages: write # Permission to push Docker image to GHCR
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive # Matches your other checkout steps
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -185,33 +199,28 @@ jobs:
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }} # User triggering the workflow
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract Docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/${{ github.repository }}
|
||||
# This action automatically uses the Git tag as the Docker image tag.
|
||||
# For example, a Git tag 'v1.2.3' will result in Docker tag 'ghcr.io/owner/repo:v1.2.3'.
|
||||
images: ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha # Enable Docker layer caching from GitHub Actions cache
|
||||
cache-to: type=gha,mode=max # Enable Docker layer caching to GitHub Actions cache
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# Build and package all the platform-agnostic(ish) things
|
||||
build-global-artifacts:
|
||||
needs:
|
||||
- plan
|
||||
- build-local-artifacts
|
||||
runs-on: "ubuntu-22.04"
|
||||
needs: [plan, build-local-artifacts]
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
BUILD_MANIFEST_NAME: target/distrib/global-dist-manifest.json
|
||||
@@ -219,92 +228,90 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install cached dist
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/
|
||||
- run: chmod +x ~/.cargo/bin/dist
|
||||
# Get all the local artifacts for the global tasks to use (for e.g. checksums)
|
||||
|
||||
- name: Fetch local artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
- id: cargo-dist
|
||||
shell: bash
|
||||
run: |
|
||||
dist build ${{ needs.plan.outputs.tag-flag }} --output-format=json "--artifacts=global" > dist-manifest.json
|
||||
echo "dist ran successfully"
|
||||
|
||||
# Parse out what we just built and upload it to scratch storage
|
||||
echo "paths<<EOF" >> "$GITHUB_OUTPUT"
|
||||
jq --raw-output ".upload_files[]" dist-manifest.json >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
|
||||
cp dist-manifest.json "$BUILD_MANIFEST_NAME"
|
||||
- name: "Upload artifacts"
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: artifacts-build-global
|
||||
path: |
|
||||
${{ steps.cargo-dist.outputs.paths }}
|
||||
${{ env.BUILD_MANIFEST_NAME }}
|
||||
# Determines if we should publish/announce
|
||||
|
||||
host:
|
||||
needs:
|
||||
- plan
|
||||
- build-local-artifacts
|
||||
- build-global-artifacts
|
||||
# Only run if we're "publishing", and only if local and global didn't fail (skipped is fine)
|
||||
needs: [plan, build-local-artifacts, build-global-artifacts]
|
||||
if: ${{ always() && needs.plan.outputs.publishing == 'true' && (needs.build-global-artifacts.result == 'skipped' || needs.build-global-artifacts.result == 'success') && (needs.build-local-artifacts.result == 'skipped' || needs.build-local-artifacts.result == 'success') }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
runs-on: "ubuntu-22.04"
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
val: ${{ steps.host.outputs.manifest }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install cached dist
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: cargo-dist-cache
|
||||
path: ~/.cargo/bin/
|
||||
- run: chmod +x ~/.cargo/bin/dist
|
||||
# Fetch artifacts from scratch-storage
|
||||
|
||||
- name: Fetch artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: target/distrib/
|
||||
merge-multiple: true
|
||||
|
||||
- id: host
|
||||
shell: bash
|
||||
run: |
|
||||
dist host ${{ needs.plan.outputs.tag-flag }} --steps=upload --steps=release --output-format=json > dist-manifest.json
|
||||
echo "artifacts uploaded and released successfully"
|
||||
cat dist-manifest.json
|
||||
echo "manifest=$(jq -c "." dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
- name: "Upload dist-manifest.json"
|
||||
echo "manifest=$(jq -c . dist-manifest.json)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Upload dist-manifest.json
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
# Overwrite the previous copy
|
||||
name: artifacts-dist-manifest
|
||||
path: dist-manifest.json
|
||||
# Create a GitHub Release while uploading all files to it
|
||||
- name: "Download GitHub Artifacts"
|
||||
|
||||
- name: Download GitHub Artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: artifacts-*
|
||||
path: artifacts
|
||||
merge-multiple: true
|
||||
|
||||
- name: Cleanup
|
||||
run: |
|
||||
# Remove the granular manifests
|
||||
rm -f artifacts/*-dist-manifest.json
|
||||
run: rm -f artifacts/*-dist-manifest.json
|
||||
|
||||
- name: Create GitHub Release
|
||||
env:
|
||||
PRERELEASE_FLAG: "${{ fromJson(steps.host.outputs.manifest).announcement_is_prerelease && '--prerelease' || '' }}"
|
||||
@@ -312,20 +319,13 @@ jobs:
|
||||
ANNOUNCEMENT_BODY: "${{ fromJson(steps.host.outputs.manifest).announcement_github_body }}"
|
||||
RELEASE_COMMIT: "${{ github.sha }}"
|
||||
run: |
|
||||
# Write and read notes from a file to avoid quoting breaking things
|
||||
echo "$ANNOUNCEMENT_BODY" > $RUNNER_TEMP/notes.txt
|
||||
|
||||
gh release create "${{ needs.plan.outputs.tag }}" --target "$RELEASE_COMMIT" $PRERELEASE_FLAG --title "$ANNOUNCEMENT_TITLE" --notes-file "$RUNNER_TEMP/notes.txt" artifacts/*
|
||||
|
||||
announce:
|
||||
needs:
|
||||
- plan
|
||||
- host
|
||||
# use "always() && ..." to allow us to wait for all publish jobs while
|
||||
# still allowing individual publish jobs to skip themselves (for prereleases).
|
||||
# "host" however must run to completion, no skipping allowed!
|
||||
needs: [plan, host]
|
||||
if: ${{ always() && needs.host.result == 'success' }}
|
||||
runs-on: "ubuntu-22.04"
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
|
||||
64
CHANGELOG.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Changelog
|
||||
## Unreleased
|
||||
|
||||
## Version 0.2.7 (2025-12-04)
|
||||
- Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features.
|
||||
- Fix: timezone aware info in scratchpad
|
||||
|
||||
## Version 0.2.6 (2025-10-29)
|
||||
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
|
||||
- Fix: default name for relationships harmonized across application
|
||||
|
||||
## Version 0.2.5 (2025-10-24)
|
||||
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships
|
||||
- Scratchpad feature, with the feature to convert scratchpads to content.
|
||||
- Added knowledge entity search results to the global search
|
||||
- Backend fixes for improved performance when ingesting and retrieval
|
||||
|
||||
## Version 0.2.4 (2025-10-15)
|
||||
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
|
||||
- Ingestion task archive
|
||||
|
||||
## Version 0.2.3 (2025-10-12)
|
||||
- Fix changing vector dimensions on a fresh database (#3)
|
||||
|
||||
## Version 0.2.2 (2025-10-07)
|
||||
- Support for ingestion of PDF files
|
||||
- Improved ingestion speed
|
||||
- Fix deletion of items work as expected
|
||||
- Fix enabling GPT-5 use via OpenAI API
|
||||
|
||||
## Version 0.2.1 (2025-09-24)
|
||||
- Fixed API JSON responses so iOS Shortcuts integrations keep working.
|
||||
|
||||
## Version 0.2.0 (2025-09-23)
|
||||
- Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph.
|
||||
- Added pagination for entities and content plus new observability metrics on the dashboard.
|
||||
- Enabled audio ingestion and merged the new storage backend.
|
||||
- Improved performance, request filtering, and journalctl/systemd compatibility.
|
||||
|
||||
## Version 0.1.4 (2025-07-01)
|
||||
- Added image ingestion with configurable system settings and updated Docker Compose docs.
|
||||
- Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses.
|
||||
|
||||
## Version 0.1.3 (2025-06-08)
|
||||
- Added support for AI providers beyond OpenAI.
|
||||
- Made the HTTP port configurable for deployments.
|
||||
- Smoothed graph mapper failures, long content tiles, and refreshed project documentation.
|
||||
|
||||
## Version 0.1.2 (2025-05-26)
|
||||
- Introduced full-text search across indexed knowledge.
|
||||
- Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling.
|
||||
- Fixed search result links and SurrealDB vector formatting glitches.
|
||||
|
||||
## Version 0.1.1 (2025-05-13)
|
||||
- Added streaming feedback to ingestion tasks for clearer progress updates.
|
||||
- Made the data storage path configurable.
|
||||
- Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes.
|
||||
|
||||
## Version 0.1.0 (2025-05-06)
|
||||
- Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage.
|
||||
- Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts.
|
||||
- Introduced an admin console with analytics, registration and timezone controls, and job monitoring.
|
||||
- Shipped a Tailwind/daisyUI web UI with responsive layouts, modals, content viewers, and editing flows.
|
||||
- Provided readability-based content ingestion, API/HTML ingress routes, and Docker/Docker Compose tooling.
|
||||
1481
Cargo.lock
generated
57
Cargo.toml
@@ -12,7 +12,7 @@ resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1.0.94"
|
||||
async-openai = "0.24.1"
|
||||
async-openai = "0.29.3"
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1.88"
|
||||
axum-htmx = "0.7.0"
|
||||
@@ -34,7 +34,6 @@ minijinja-autoreload = "2.5.0"
|
||||
minijinja-contrib = { version = "2.6.0", features = ["datetime", "timezone"] }
|
||||
minijinja-embed = { version = "2.8.0" }
|
||||
minijinja = { version = "2.5.0", features = ["loader", "multi_template"] }
|
||||
plotly = "0.12.1"
|
||||
reqwest = {version = "0.12.12", features = ["charset", "json"]}
|
||||
serde_json = "1.0.128"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
@@ -46,7 +45,7 @@ text-splitter = "0.18.1"
|
||||
thiserror = "1.0.63"
|
||||
tokio-util = { version = "0.7.15", features = ["io"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tower-http = { version = "0.6.2", features = ["fs"] }
|
||||
tower-http = { version = "0.6.2", features = ["fs", "compression-full"] }
|
||||
tower-serve-static = "0.1.1"
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
@@ -54,7 +53,59 @@ url = { version = "2.5.2", features = ["serde"] }
|
||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||
tokio-retry = "0.3.0"
|
||||
base64 = "0.22.1"
|
||||
object_store = { version = "0.11.2" }
|
||||
bytes = "1.7.1"
|
||||
state-machines = "0.2.0"
|
||||
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }
|
||||
|
||||
[profile.dist]
|
||||
inherits = "release"
|
||||
lto = "thin"
|
||||
|
||||
[workspace.lints.clippy]
|
||||
# Performance-focused lints
|
||||
perf = { level = "warn", priority = -1 }
|
||||
vec_init_then_push = "warn"
|
||||
large_stack_frames = "warn"
|
||||
redundant_allocation = "warn"
|
||||
single_char_pattern = "warn"
|
||||
string_extend_chars = "warn"
|
||||
format_in_format_args = "warn"
|
||||
slow_vector_initialization = "warn"
|
||||
inefficient_to_string = "warn"
|
||||
implicit_clone = "warn"
|
||||
redundant_clone = "warn"
|
||||
|
||||
# Security-focused lints
|
||||
integer_arithmetic = "warn"
|
||||
indexing_slicing = "warn"
|
||||
unwrap_used = "warn"
|
||||
expect_used = "warn"
|
||||
panic = "warn"
|
||||
unimplemented = "warn"
|
||||
todo = "warn"
|
||||
|
||||
# Async/Network lints
|
||||
async_yields_async = "warn"
|
||||
await_holding_invalid_state = "warn"
|
||||
rc_buffer = "warn"
|
||||
|
||||
# Maintainability-focused lints
|
||||
cargo = { level = "warn", priority = -1 }
|
||||
pedantic = { level = "warn", priority = -1 }
|
||||
clone_on_ref_ptr = "warn"
|
||||
float_cmp = "warn"
|
||||
manual_string_new = "warn"
|
||||
uninlined_format_args = "warn"
|
||||
unused_self = "warn"
|
||||
must_use_candidate = "allow"
|
||||
missing_errors_doc = "allow"
|
||||
missing_panics_doc = "warn"
|
||||
module_name_repetitions = "warn"
|
||||
wildcard_dependencies = "warn"
|
||||
missing_docs_in_private_items = "warn"
|
||||
|
||||
# Allow noisy lints that don't add value for this project
|
||||
manual_must_use = "allow"
|
||||
needless_raw_string_hashes = "allow"
|
||||
multiple_bound_locations = "allow"
|
||||
|
||||
62
Dockerfile
@@ -1,7 +1,10 @@
|
||||
# === Builder Stage ===
|
||||
FROM clux/muslrust:1.86.0-stable as builder
|
||||
|
||||
# === Builder ===
|
||||
FROM rust:1.86-bookworm AS builder
|
||||
WORKDIR /usr/src/minne
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cache deps
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
RUN mkdir -p api-router common composite-retrieval html-router ingestion-pipeline json-stream-parser main worker
|
||||
COPY api-router/Cargo.toml ./api-router/
|
||||
@@ -11,43 +14,38 @@ COPY html-router/Cargo.toml ./html-router/
|
||||
COPY ingestion-pipeline/Cargo.toml ./ingestion-pipeline/
|
||||
COPY json-stream-parser/Cargo.toml ./json-stream-parser/
|
||||
COPY main/Cargo.toml ./main/
|
||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker || true
|
||||
|
||||
# Build with the MUSL target
|
||||
RUN cargo build --release --target x86_64-unknown-linux-musl --bin main --features ingestion-pipeline/docker || true
|
||||
|
||||
# Copy the rest of the source code
|
||||
# Build
|
||||
COPY . .
|
||||
RUN cargo build --release --bin main --features ingestion-pipeline/docker
|
||||
|
||||
# Build the final application binary with the MUSL target
|
||||
RUN cargo build --release --target x86_64-unknown-linux-musl --bin main --features ingestion-pipeline/docker
|
||||
# === Runtime ===
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# === Runtime Stage ===
|
||||
FROM alpine:latest
|
||||
# Chromium + runtime deps + OpenMP for ORT
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
chromium libnss3 libasound2 libgbm1 libxshmfence1 \
|
||||
ca-certificates fonts-dejavu fonts-noto-color-emoji \
|
||||
libgomp1 libstdc++6 curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN apk update && apk add --no-cache \
|
||||
chromium \
|
||||
nss \
|
||||
freetype \
|
||||
harfbuzz \
|
||||
ca-certificates \
|
||||
ttf-freefont \
|
||||
font-noto-emoji \
|
||||
&& \
|
||||
rm -rf /var/cache/apk/*
|
||||
# ONNX Runtime (CPU). Change if you bump ort.
|
||||
ARG ORT_VERSION=1.22.0
|
||||
RUN mkdir -p /opt/onnxruntime && \
|
||||
curl -fsSL -o /tmp/ort.tgz \
|
||||
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
||||
tar -xzf /tmp/ort.tgz -C /opt/onnxruntime --strip-components=1 && rm /tmp/ort.tgz
|
||||
|
||||
ENV CHROME_BIN=/usr/bin/chromium-browser \
|
||||
CHROME_PATH=/usr/lib/chromium/ \
|
||||
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt
|
||||
ENV CHROME_BIN=/usr/bin/chromium \
|
||||
SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt \
|
||||
ORT_DYLIB_PATH=/opt/onnxruntime/lib/libonnxruntime.so
|
||||
|
||||
# Create a non-root user to run the application
|
||||
RUN adduser -D -h /home/appuser appuser
|
||||
WORKDIR /home/appuser
|
||||
# Non-root
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
WORKDIR /home/appuser
|
||||
|
||||
# Copy the compiled binary from the builder stage (note the target path)
|
||||
COPY --from=builder /usr/src/minne/target/x86_64-unknown-linux-musl/release/main /usr/local/bin/main
|
||||
|
||||
COPY --from=builder /usr/src/minne/target/release/main /usr/local/bin/main
|
||||
EXPOSE 3000
|
||||
# EXPOSE 8000-9000
|
||||
|
||||
CMD ["main"]
|
||||
|
||||
401
README.md
@@ -6,200 +6,148 @@
|
||||
[](https://www.gnu.org/licenses/agpl-3.0)
|
||||
[](https://github.com/perstarkse/minne/releases/latest)
|
||||
|
||||

|
||||

|
||||
|
||||
## Demo deployment
|
||||
|
||||
To test *Minne* out, enter [this](https://minne-demo.stark.pub) read-only demo deployment to view and test functionality out.
|
||||
To test _Minne_ out, enter [this](https://minne-demo.stark.pub) read-only demo deployment to view and test functionality out.
|
||||
|
||||
## Noteworthy Features
|
||||
|
||||
- **Search & Chat Interface** - Find content or knowledge instantly with full-text search, or use the chat mode and conversational AI to find and reason about content
|
||||
- **Manual and AI-assisted connections** - Build entities and relationships manually with full control, let AI create entities and relationships automatically, or blend both approaches with AI suggestions for manual approval
|
||||
- **Hybrid Retrieval System** - Search combining vector similarity, full-text search, and graph traversal for highly relevant results
|
||||
- **Scratchpad Feature** - Quickly capture thoughts and convert them to permanent content when ready
|
||||
- **Visual Graph Explorer** - Interactive D3-based navigation of your knowledge entities and connections
|
||||
- **Multi-Format Support** - Ingest text, URLs, PDFs, audio files, and images into your knowledge base
|
||||
- **Performance Focus** - Built with Rust and server-side rendering for speed and efficiency
|
||||
- **Self-Hosted & Privacy-Focused** - Full control over your data, and compatible with any OpenAI-compatible API that supports structured outputs
|
||||
|
||||
## The "Why" Behind Minne
|
||||
|
||||
For a while I've been fascinated by Zettelkasten-style PKM systems. While tools like Logseq and Obsidian are excellent, I found the manual linking process to be a hindrance for me. I also wanted a centralized storage and easy access across devices.
|
||||
For a while I've been fascinated by personal knowledge management systems. I wanted something that made it incredibly easy to capture content - snippets of text, URLs, and other media - while automatically discovering connections between ideas. But I also wanted to maintain control over my knowledge structure.
|
||||
|
||||
While developing Minne, I discovered [KaraKeep](https://karakeep.com/) (formerly Hoarder), which is an excellent application in a similar space – you probably want to check it out! However, if you're interested in a PKM that builds an automatic network between related concepts using AI, offers search and the **possibility to chat with your knowledge resource**, and provides a blend of manual and AI-driven organization, then Minne might be worth testing.
|
||||
Traditional tools like Logseq and Obsidian are excellent, but the manual linking process often became a hindrance. Meanwhile, fully automated systems sometimes miss important context or create relationships I wouldn't have chosen myself.
|
||||
|
||||
## Core Philosophy & Features
|
||||
So I built Minne to offer the best of both worlds: effortless content capture with AI-assisted relationship discovery, but with the flexibility to manually curate, edit, or override any connections. You can let AI handle the heavy lifting of extracting entities and finding relationships, take full control yourself, or use a hybrid approach where AI suggests connections that you can approve or modify.
|
||||
|
||||
Minne is designed to make it incredibly easy to save snippets of text, URLs, and other content (limited, pending demand). Simply send content along with a category tag. Minne then ingests this, leveraging AI to create relevant nodes and relationships within its graph database, alongside your manual categorization. This graph backend allows for discoverable connections between your pieces of knowledge.
|
||||
While developing Minne, I discovered [KaraKeep](https://github.com/karakeep-app/karakeep) (formerly Hoarder), which is an excellent application in a similar space – you probably want to check it out! However, if you're interested in a PKM that offers both intelligent automation and manual curation, with the ability to chat with your knowledge base, then Minne might be worth testing.
|
||||
|
||||
You can converse with your knowledge base through an LLM-powered chat interface (via OpenAI compatible API, like Ollama or others). For those who like to see the bigger picture, Minne also includes an **experimental feature to visually explore your knowledge graph.**
|
||||
## Table of Contents
|
||||
|
||||
You may switch and choose between models used, and have the possiblity to change the prompts to your liking. There is since release **0.1.3** the option to change embeddings length, making it easy to test another embedding model.
|
||||
- [Quick Start](#quick-start)
|
||||
- [Features in Detail](#features-in-detail)
|
||||
- [Configuration](#configuration)
|
||||
- [Tech Stack](#tech-stack)
|
||||
- [Application Architecture](#application-architecture)
|
||||
- [AI Configuration](#ai-configuration--model-selection)
|
||||
- [Roadmap](#roadmap)
|
||||
- [Development](#development)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
|
||||
The application is built for speed and efficiency using Rust with a Server-Side Rendered (SSR) frontend (HTMX and minimal JavaScript). It's fully responsive, offering a complete mobile interface for reading, editing, and managing your content, including the graph database itself. **PWA (Progressive Web App) support** means you can "install" Minne to your device for a native-like experience. For quick capture on the go on iOS, a [**Shortcut**](https://www.icloud.com/shortcuts/9aa960600ec14329837ba4169f57a166) makes sending content to your Minne instance a breeze.
|
||||
## Quick Start
|
||||
|
||||
Minne is open source (AGPL), self-hostable, and can be deployed flexibly: via Nix, Docker Compose, pre-built binaries, or by building from source. It can run as a single `main` binary or as separate `server` and `worker` processes for optimized resource allocation.
|
||||
The fastest way to get Minne running is with Docker Compose:
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/perstarkse/minne.git
|
||||
cd minne
|
||||
|
||||
# Start Minne and its database
|
||||
docker compose up -d
|
||||
|
||||
# Access at http://localhost:3000
|
||||
```
|
||||
|
||||
**Required Setup:**
|
||||
- Replace `your_openai_api_key_here` in `docker-compose.yml` with your actual API key
|
||||
- Configure `OPENAI_BASE_URL` if using a custom AI provider (like Ollama)
|
||||
|
||||
For detailed installation options, see [Configuration](#configuration).
|
||||
|
||||
## Features in Detail
|
||||
|
||||
### Search vs. Chat mode
|
||||
|
||||
**Search** - Use when you know roughly what you're looking for. Full-text search finds items quickly by matching your query terms.
|
||||
|
||||
**Chat Mode** - Use when you want to explore concepts, find connections, or reason about your knowledge. The AI analyzes your query and finds relevant context across your entire knowledge base.
|
||||
|
||||
### Content Processing
|
||||
|
||||
Minne automatically processes content you save:
|
||||
1. **Web scraping** extracts readable text from URLs
|
||||
2. **Text analysis** identifies key concepts and relationships
|
||||
3. **Graph creation** builds connections between related content
|
||||
4. **Embedding generation** enables semantic search capabilities
|
||||
|
||||
### Visual Knowledge Graph
|
||||
|
||||
Explore your knowledge as an interactive network with flexible curation options:
|
||||
|
||||
**Manual Curation** - Create knowledge entities and relationships yourself with full control over your graph structure
|
||||
|
||||
**AI Automation** - Let AI automatically extract entities and discover relationships from your content
|
||||
|
||||
**Hybrid Approach** - Get AI-suggested relationships and entities that you can manually review, edit, or approve
|
||||
|
||||
The graph visualization shows:
|
||||
- Knowledge entities as nodes (manually created or AI-extracted)
|
||||
- Relationships as connections (manually defined, AI-discovered, or suggested)
|
||||
- Interactive navigation for discovery and editing
|
||||
|
||||
### Optional FastEmbed Reranking
|
||||
|
||||
Minne ships with an opt-in reranking stage powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs). When enabled, the hybrid retrieval results are rescored with a lightweight cross-encoder before being returned to chat or ingestion flows. In practice this often means more relevant results, boosting answer quality and downstream enrichment.
|
||||
|
||||
⚠️ **Resource notes**
|
||||
- Enabling reranking downloads and caches ~1.1 GB of model data on first startup (cached under `<data_dir>/fastembed/reranker` by default).
|
||||
- Initialization takes longer while warming the cache, and each query consumes extra CPU. The default pool size (2) is tuned for a singe user setup, but could work with a pool size on 1 as well.
|
||||
- The feature is disabled by default. Set `reranking_enabled: true` (or `RERANKING_ENABLED=true`) if you’re comfortable with the additional footprint.
|
||||
|
||||
Example configuration:
|
||||
|
||||
```yaml
|
||||
reranking_enabled: true
|
||||
reranking_pool_size: 2
|
||||
fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults to .fastembed_cache
|
||||
```
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Backend:** Rust. Server-Side Rendering (SSR). Axum. Minijinja for templating.
|
||||
- **Frontend:** HTML. HTMX and plain JavaScript for interactivity.
|
||||
- **Database:** SurrealDB
|
||||
- **AI Integration:** OpenAI API compatible endpoint (for chat and content processing), with support for structured outputs.
|
||||
- **Web Content Processing:** Relies on a Chromium instance for robust webpage fetching/rendering.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **For Docker/Nix:** Docker or Nix installed. These methods handle SurrealDB and Chromium dependencies.
|
||||
- **For Binaries/Source:**
|
||||
- A running SurrealDB instance.
|
||||
- Chromium (or a compatible Chrome browser) installed and accessible in your `PATH`.
|
||||
- Git (if cloning and building from source).
|
||||
- Rust toolchain (if building from source).
|
||||
|
||||
## Getting Started
|
||||
|
||||
You have several options to get Minne up and running:
|
||||
|
||||
### 1. Nix (Recommended for ease of dependency management)
|
||||
|
||||
If you have Nix installed, you can run Minne directly:
|
||||
|
||||
```bash
|
||||
nix run 'github:perstarkse/minne#main'
|
||||
```
|
||||
|
||||
This command will fetch Minne and its dependencies (including Chromium) and run the `main` (combined server/worker) application.
|
||||
|
||||
### 2. Docker Compose (Recommended for containerized environments)
|
||||
|
||||
This is a great way to manage Minne and its SurrealDB dependency together.
|
||||
|
||||
1. Clone the repository (or just save the `docker-compose.yml` below).
|
||||
|
||||
1. Create a `docker-compose.yml` file:
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
minne:
|
||||
image: ghcr.io/perstarkse/minne:latest # Pulls the latest pre-built image
|
||||
# Or, to build from local source:
|
||||
# build: .
|
||||
container_name: minne_app
|
||||
ports:
|
||||
- "3000:3000" # Exposes Minne on port 3000
|
||||
environment:
|
||||
# These are examples, ensure they match your SurrealDB setup below
|
||||
# and your actual OpenAI key.
|
||||
SURREALDB_ADDRESS: "ws://surrealdb:8000"
|
||||
SURREALDB_USERNAME: "root_user" # Default from SurrealDB service below
|
||||
SURREALDB_PASSWORD: "root_password" # Default from SurrealDB service below
|
||||
SURREALDB_DATABASE: "minne_db"
|
||||
SURREALDB_NAMESPACE: "minne_ns"
|
||||
OPENAI_API_KEY: "your_openai_api_key_here" # IMPORTANT: Replace with your actual key
|
||||
#OPENAI_BASE_URL: "your_ollama_address" # Uncomment this and change it to override the default openai base url
|
||||
HTTP_PORT: 3000
|
||||
DATA_DIR: "/data" # Data directory inside the container
|
||||
RUST_LOG: "minne=info,tower_http=info" # Example logging level
|
||||
volumes:
|
||||
- ./minne_data:/data # Persists Minne's data (e.g., scraped content) on the host
|
||||
depends_on:
|
||||
- surrealdb
|
||||
networks:
|
||||
- minne-net
|
||||
# Waits for SurrealDB to be ready before starting Minne
|
||||
command: >
|
||||
sh -c "
|
||||
echo 'Waiting for SurrealDB to start...' &&
|
||||
# Adjust sleep time if SurrealDB takes longer to initialize in your environment
|
||||
until nc -z surrealdb 8000; do echo 'Waiting for SurrealDB...'; sleep 2; done &&
|
||||
echo 'SurrealDB is up, starting Minne application...' &&
|
||||
/usr/local/bin/main
|
||||
"
|
||||
# For separate server/worker:
|
||||
# command: /usr/local/bin/server # or /usr/local/bin/worker
|
||||
|
||||
surrealdb:
|
||||
image: surrealdb/surrealdb:latest
|
||||
container_name: minne_surrealdb
|
||||
ports:
|
||||
# Exposes SurrealDB on port 8000 (primarily for direct access/debugging if needed,
|
||||
# not strictly required for Minne if only accessed internally by the minne service)
|
||||
- "127.0.0.1:8000:8000" # Bind to localhost only for SurrealDB by default
|
||||
volumes:
|
||||
# Persists SurrealDB data on the host in a 'surreal_database' folder
|
||||
- ./surreal_database:/database
|
||||
command: >
|
||||
start
|
||||
--log info # Consider 'debug' for troubleshooting
|
||||
--user root_user
|
||||
--pass root_password
|
||||
file:/database/minne_v1.db # Using file-based storage for simplicity
|
||||
networks:
|
||||
- minne-net
|
||||
|
||||
volumes:
|
||||
minne_data: {} # Defines a named volume for Minne data (can be managed by Docker)
|
||||
surreal_database: {} # Defines a named volume for SurrealDB data
|
||||
|
||||
networks:
|
||||
minne-net:
|
||||
driver: bridge
|
||||
```
|
||||
|
||||
1. Run:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
Minne will be accessible at `http://localhost:3000`.
|
||||
|
||||
### 3. Pre-built Binaries (GitHub Releases)
|
||||
|
||||
Binaries for Windows, macOS, and Linux (combined `main` version) are available on the [GitHub Releases page](https://github.com/perstarkse/minne/releases/latest).
|
||||
|
||||
1. Download the appropriate binary for your system.
|
||||
1. **You will need to provide and run SurrealDB and have Chromium installed and accessible in your PATH separately.**
|
||||
1. Set the required [Configuration](#configuration) environment variables or use a `config.yaml`.
|
||||
1. Run the executable.
|
||||
|
||||
### 4. Build from Source
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/perstarkse/minne.git
|
||||
cd minne
|
||||
```
|
||||
1. **You will need to provide and run SurrealDB and have Chromium installed and accessible in your PATH separately.**
|
||||
1. Set the required [Configuration](#configuration) environment variables or use a `config.yaml`.
|
||||
1. Build and run:
|
||||
- For the combined `main` binary:
|
||||
```bash
|
||||
cargo run --release --bin main
|
||||
```
|
||||
- For the `server` binary:
|
||||
```bash
|
||||
cargo run --release --bin server
|
||||
```
|
||||
- For the `worker` binary (if you want to run it separately):
|
||||
```bash
|
||||
cargo run --release --bin worker
|
||||
```
|
||||
The compiled binaries will be in `target/release/`.
|
||||
- **Backend:** Rust with Axum framework and Server-Side Rendering (SSR)
|
||||
- **Frontend:** HTML with HTMX and minimal JavaScript for interactivity
|
||||
- **Database:** SurrealDB (graph, document, and vector search)
|
||||
- **AI Integration:** OpenAI-compatible API with structured outputs
|
||||
- **Web Processing:** Headless Chrome for robust webpage content extraction
|
||||
|
||||
## Configuration
|
||||
|
||||
Minne can be configured using environment variables or a `config.yaml` file placed in the working directory where you run the application. Environment variables take precedence over `config.yaml`.
|
||||
Minne can be configured using environment variables or a `config.yaml` file. Environment variables take precedence over `config.yaml`.
|
||||
|
||||
**Required Configuration:**
|
||||
### Required Configuration
|
||||
|
||||
- `SURREALDB_ADDRESS`: WebSocket address of your SurrealDB instance (e.g., `ws://127.0.0.1:8000` or `ws://surrealdb:8000` for Docker).
|
||||
- `SURREALDB_USERNAME`: Username for SurrealDB (e.g., `root_user`).
|
||||
- `SURREALDB_PASSWORD`: Password for SurrealDB (e.g., `root_password`).
|
||||
- `SURREALDB_DATABASE`: Database name in SurrealDB (e.g., `minne_db`).
|
||||
- `SURREALDB_NAMESPACE`: Namespace in SurrealDB (e.g., `minne_ns`).
|
||||
- `OPENAI_API_KEY`: Your API key for OpenAI (e.g., `sk-YourActualOpenAIKeyGoesHere`).
|
||||
- `HTTP_PORT`: Port for the Minne server to listen on (Default: `3000`).
|
||||
- `SURREALDB_ADDRESS`: WebSocket address of your SurrealDB instance (e.g., `ws://127.0.0.1:8000`)
|
||||
- `SURREALDB_USERNAME`: Username for SurrealDB (e.g., `root_user`)
|
||||
- `SURREALDB_PASSWORD`: Password for SurrealDB (e.g., `root_password`)
|
||||
- `SURREALDB_DATABASE`: Database name in SurrealDB (e.g., `minne_db`)
|
||||
- `SURREALDB_NAMESPACE`: Namespace in SurrealDB (e.g., `minne_ns`)
|
||||
- `OPENAI_API_KEY`: Your API key for OpenAI compatible endpoint
|
||||
- `HTTP_PORT`: Port for the Minne server (Default: `3000`)
|
||||
|
||||
**Optional Configuration:**
|
||||
### Optional Configuration
|
||||
|
||||
- `RUST_LOG`: Controls logging level (e.g., `minne=info,tower_http=debug`).
|
||||
- `DATA_DIR`: Directory to store local data like fetched webpage content (e.g., `./data`).
|
||||
- `OPENAI_BASE_URL`: Base URL to a OpenAI API provider, such as Ollama.
|
||||
- `RUST_LOG`: Controls logging level (e.g., `minne=info,tower_http=debug`)
|
||||
- `DATA_DIR`: Directory to store local data (e.g., `./data`)
|
||||
- `OPENAI_BASE_URL`: Base URL for custom AI providers (like Ollama)
|
||||
- `RERANKING_ENABLED` / `reranking_enabled`: Set to `true` to enable the FastEmbed reranking stage (default `false`)
|
||||
- `RERANKING_POOL_SIZE` / `reranking_pool_size`: Maximum concurrent reranker workers (defaults to `2`)
|
||||
- `FASTEMBED_CACHE_DIR` / `fastembed_cache_dir`: Directory for cached FastEmbed models (defaults to `<data_dir>/fastembed/reranker`)
|
||||
- `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` / `fastembed_show_download_progress`: Show model download progress when warming the cache (default `true`)
|
||||
|
||||
**Example `config.yaml`:**
|
||||
### Example config.yaml
|
||||
|
||||
```yaml
|
||||
surrealdb_address: "ws://127.0.0.1:8000"
|
||||
@@ -209,62 +157,109 @@ surrealdb_database: "minne_db"
|
||||
surrealdb_namespace: "minne_ns"
|
||||
openai_api_key: "sk-YourActualOpenAIKeyGoesHere"
|
||||
data_dir: "./minne_app_data"
|
||||
http_port: 3000
|
||||
# rust_log: "info"
|
||||
# http_port: 3000
|
||||
```
|
||||
|
||||
## Application Architecture (Binaries)
|
||||
## Installation Options
|
||||
|
||||
Minne offers flexibility in deployment:
|
||||
### 1. Docker Compose (Recommended)
|
||||
|
||||
- **`main`**: A combined binary running both server (API, web UI) and worker (background tasks) in one process. Ideal for simpler setups.
|
||||
- **`server`**: Runs only the server component.
|
||||
- **`worker`**: Runs only the worker component, suitable for deployment on a machine with more resources for intensive tasks.
|
||||
```bash
|
||||
# Clone and run
|
||||
git clone https://github.com/perstarkse/minne.git
|
||||
cd minne
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
This modularity allows scaling and resource optimization. The `main` binary or the Docker Compose setup (using `main`) is sufficient for most users.
|
||||
The included `docker-compose.yml` handles SurrealDB and Chromium dependencies automatically.
|
||||
|
||||
### 2. Nix
|
||||
|
||||
```bash
|
||||
nix run 'github:perstarkse/minne#main'
|
||||
```
|
||||
|
||||
This fetches Minne and all dependencies, including Chromium.
|
||||
|
||||
### 3. Pre-built Binaries
|
||||
|
||||
Download binaries for Windows, macOS, and Linux from the [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
|
||||
|
||||
**Requirements:** You'll need to provide SurrealDB and Chromium separately.
|
||||
|
||||
### 4. Build from Source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/perstarkse/minne.git
|
||||
cd minne
|
||||
cargo run --release --bin main
|
||||
```
|
||||
|
||||
**Requirements:** SurrealDB and Chromium must be installed and accessible in your PATH.
|
||||
|
||||
## Application Architecture
|
||||
|
||||
Minne offers flexible deployment options:
|
||||
|
||||
- **`main`**: Combined server and worker in one process (recommended for most users)
|
||||
- **`server`**: Web interface and API only
|
||||
- **`worker`**: Background processing only (for resource optimization)
|
||||
|
||||
## Usage
|
||||
|
||||
Once Minne is running:
|
||||
Once Minne is running at `http://localhost:3000`:
|
||||
|
||||
1. Access the web interface at `http://localhost:3000` (or your configured port).
|
||||
1. On iOS, consider setting up the [Minne iOS Shortcut](https://www.icloud.com/shortcuts/9aa960600ec14329837ba4169f57a166) for effortless content sending. **Add the shortcut, replace the [insert_url] and the [insert_api_key] snippets**.
|
||||
1. Start adding notes, URLs and explore your growing knowledge graph.
|
||||
1. Engage with the chat interface to query your saved content.
|
||||
1. Try the experimental visual graph explorer to see connections.
|
||||
1. **Web Interface**: Full-featured experience for desktop and mobile
|
||||
2. **iOS Shortcut**: Use the [Minne iOS Shortcut](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) for quick content capture
|
||||
3. **Content Types**: Save notes, URLs, audio files, and more
|
||||
4. **Knowledge Graph**: Explore automatic connections between your content
|
||||
5. **Chat Interface**: Query your knowledge base conversationally
|
||||
|
||||
## AI Configuration & Model Selection
|
||||
|
||||
Minne relies on an OpenAI-compatible API for processing content, generating graph relationships, and powering the chat feature.
|
||||
### Setting Up AI Providers
|
||||
|
||||
**Environment Variables / `config.yaml` keys:**
|
||||
Minne uses OpenAI-compatible APIs. Configure via environment variables or `config.yaml`:
|
||||
|
||||
- `OPENAI_API_KEY` (required): Your API key for the chosen AI provider.
|
||||
- `OPENAI_BASE_URL` (optional): Use this to override the default OpenAI API URL (`https://api.openai.com/v1`). This is essential for using local models via services like Ollama, or other API providers.
|
||||
- **Example for Ollama:** `http://<your-ollama-ip>:11434/v1`
|
||||
- `OPENAI_API_KEY` (required): Your API key
|
||||
- `OPENAI_BASE_URL` (optional): Custom provider URL (e.g., Ollama: `http://localhost:11434/v1`)
|
||||
|
||||
### Changing Models
|
||||
### Model Selection
|
||||
|
||||
Once you have configured the `OPENAI_BASE_URL` to point to your desired provider, you can select the specific models Minne should use.
|
||||
|
||||
1. Navigate to the `/admin` page in your Minne instance.
|
||||
1. The page will list the models available from your configured endpoint. You can select different models for processing content and for chat.
|
||||
1. **Important:** For content processing, Minne relies on structured outputs (function calling). The model and provider you select for this task **must** support this feature.
|
||||
1. **Embedding Dimensions:** If you change the embedding model, you **must** update the "Embedding Dimensions" setting in the admin panel to match the output dimensions of your new model (e.g., `text-embedding-3-small` uses 1536, `nomic-embed-text` uses 768). Mismatched dimensions will cause errors. Some newer models will accept a dimension argument, and for these setting the dimensions to whatever should work.
|
||||
1. Access the `/admin` page in your Minne instance
|
||||
2. Select models for content processing and chat from your configured provider
|
||||
3. **Content Processing Requirements**: The model must support structured outputs
|
||||
4. **Embedding Dimensions**: Update this setting when changing embedding models (e.g., 1536 for `text-embedding-3-small`, 768 for `nomic-embed-text`)
|
||||
|
||||
## Roadmap
|
||||
|
||||
I've developed Minne primarily for my own use, but having been in the selfhosted space for a long time, and using the efforts by others, I thought I'd share with the community. Feature requests are welcome.
|
||||
The roadmap as of now is:
|
||||
Current development focus:
|
||||
|
||||
- Handle uploaded images wisely.
|
||||
- An updated explorer of the graph database.
|
||||
- A TUI frontend which opens your system default editor for improved writing and document management.
|
||||
- TUI frontend with system editor integration
|
||||
- Enhanced reranking for improved retrieval recall
|
||||
- Additional content type support
|
||||
|
||||
Feature requests and contributions are welcome!
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
# Run tests
|
||||
cargo test
|
||||
|
||||
# Development build
|
||||
cargo build
|
||||
|
||||
# Comprehensive linting
|
||||
cargo clippy --workspace --all-targets --all-features
|
||||
```
|
||||
|
||||
The codebase includes extensive unit tests. Integration tests and additional contributions are welcome.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Whether it's bug reports, feature suggestions, documentation improvements, or code contributions, please feel free to open an issue or submit a pull request.
|
||||
I've developed Minne primarily for my own use, but having been in the selfhosted space for a long time, and using the efforts by others, I thought I'd share with the community. Feature requests are welcome.
|
||||
|
||||
## License
|
||||
|
||||
Minne is licensed under the **GNU Affero General Public License v3.0 (AGPL-3.0)**. See the [LICENSE](LICENSE) file for details. This means if you run a modified version of Minne as a network service, you must also offer the source code of that modified version to its users.
|
||||
Minne is licensed under the **GNU Affero General Public License v3.0 (AGPL-3.0)**. See the [LICENSE](LICENSE) file for details.
|
||||
|
||||
@@ -4,6 +4,9 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{storage::db::SurrealDbClient, utils::config::AppConfig};
|
||||
use common::{
|
||||
storage::{db::SurrealDbClient, store::StorageManager},
|
||||
utils::config::AppConfig,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiState {
|
||||
pub db: Arc<SurrealDbClient>,
|
||||
pub config: AppConfig,
|
||||
pub storage: StorageManager,
|
||||
}
|
||||
|
||||
impl ApiState {
|
||||
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
pub async fn new(
|
||||
config: &AppConfig,
|
||||
storage: StorageManager,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let surreal_db_client = Arc::new(
|
||||
SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
@@ -23,9 +30,10 @@ impl ApiState {
|
||||
|
||||
surreal_db_client.apply_migrations().await?;
|
||||
|
||||
let app_state = ApiState {
|
||||
let app_state = Self {
|
||||
db: surreal_db_client.clone(),
|
||||
config: config.clone(),
|
||||
storage,
|
||||
};
|
||||
|
||||
Ok(app_state)
|
||||
|
||||
@@ -27,40 +27,40 @@ impl From<AppError> for ApiError {
|
||||
match err {
|
||||
AppError::Database(_) | AppError::OpenAI(_) => {
|
||||
tracing::error!("Internal error: {:?}", err);
|
||||
ApiError::InternalError("Internal server error".to_string())
|
||||
Self::InternalError("Internal server error".to_string())
|
||||
}
|
||||
AppError::NotFound(msg) => ApiError::NotFound(msg),
|
||||
AppError::Validation(msg) => ApiError::ValidationError(msg),
|
||||
AppError::Auth(msg) => ApiError::Unauthorized(msg),
|
||||
_ => ApiError::InternalError("Internal server error".to_string()),
|
||||
AppError::NotFound(msg) => Self::NotFound(msg),
|
||||
AppError::Validation(msg) => Self::ValidationError(msg),
|
||||
AppError::Auth(msg) => Self::Unauthorized(msg),
|
||||
_ => Self::InternalError("Internal server error".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_response) = match self {
|
||||
ApiError::InternalError(message) => (
|
||||
Self::InternalError(message) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::ValidationError(message) => (
|
||||
Self::ValidationError(message) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::NotFound(message) => (
|
||||
Self::NotFound(message) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
ApiError::Unauthorized(message) => (
|
||||
Self::Unauthorized(message) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
|
||||
@@ -6,7 +6,7 @@ use axum::{
|
||||
Router,
|
||||
};
|
||||
use middleware_api_auth::api_auth;
|
||||
use routes::{categories::get_categories, ingress::ingest_data};
|
||||
use routes::{categories::get_categories, ingress::ingest_data, liveness::live, readiness::ready};
|
||||
|
||||
pub mod api_state;
|
||||
pub mod error;
|
||||
@@ -19,9 +19,17 @@ where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
ApiState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
// Public, unauthenticated endpoints (for k8s/systemd probes)
|
||||
let public = Router::new()
|
||||
.route("/ready", get(ready))
|
||||
.route("/live", get(live));
|
||||
|
||||
// Protected API endpoints (require auth)
|
||||
let protected = Router::new()
|
||||
.route("/ingress", post(ingest_data))
|
||||
.route("/categories", get(get_categories))
|
||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth))
|
||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth));
|
||||
|
||||
public.merge(protected)
|
||||
}
|
||||
|
||||
@@ -13,14 +13,12 @@ pub async fn api_auth(
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ApiError> {
|
||||
let api_key = extract_api_key(&request).ok_or(ApiError::Unauthorized(
|
||||
"You have to be authenticated".to_string(),
|
||||
))?;
|
||||
let api_key = extract_api_key(&request)
|
||||
.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
|
||||
let user = User::find_by_api_key(&api_key, &state.db).await?;
|
||||
let user = user.ok_or(ApiError::Unauthorized(
|
||||
"You have to be authenticated".to_string(),
|
||||
))?;
|
||||
let user =
|
||||
user.ok_or_else(|| ApiError::Unauthorized("You have to be authenticated".to_string()))?;
|
||||
|
||||
request.extensions_mut().insert(user);
|
||||
|
||||
@@ -37,7 +35,7 @@ fn extract_api_key(request: &Request) -> Option<String> {
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(|s| s.trim()))
|
||||
.and_then(|auth| auth.strip_prefix("Bearer ").map(str::trim))
|
||||
})
|
||||
.map(String::from)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension};
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
use common::{
|
||||
error::AppError,
|
||||
@@ -8,6 +8,7 @@ use common::{
|
||||
},
|
||||
};
|
||||
use futures::{future::try_join_all, TryFutureExt};
|
||||
use serde_json::json;
|
||||
use tempfile::NamedTempFile;
|
||||
use tracing::info;
|
||||
|
||||
@@ -31,7 +32,8 @@ pub async fn ingest_data(
|
||||
info!("Received input: {:?}", input);
|
||||
|
||||
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
||||
FileInfo::new(file, &state.db, &user.id, &state.config).map_err(AppError::from)
|
||||
FileInfo::new_with_storage(file, &state.db, &user.id, &state.storage)
|
||||
.map_err(AppError::from)
|
||||
}))
|
||||
.await?;
|
||||
|
||||
@@ -45,12 +47,10 @@ pub async fn ingest_data(
|
||||
|
||||
let futures: Vec<_> = payloads
|
||||
.into_iter()
|
||||
.map(|object| {
|
||||
IngestionTask::create_and_add_to_db(object.clone(), user.id.clone(), &state.db)
|
||||
})
|
||||
.map(|object| IngestionTask::create_and_add_to_db(object, user.id.clone(), &state.db))
|
||||
.collect();
|
||||
|
||||
try_join_all(futures).await.map_err(AppError::from)?;
|
||||
try_join_all(futures).await?;
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
Ok((StatusCode::OK, Json(json!({ "status": "success" }))))
|
||||
}
|
||||
|
||||
7
api-router/src/routes/liveness.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
use axum::{http::StatusCode, response::IntoResponse, Json};
|
||||
use serde_json::json;
|
||||
|
||||
/// Liveness probe: always returns 200 to indicate the process is running.
|
||||
pub async fn live() -> impl IntoResponse {
|
||||
(StatusCode::OK, Json(json!({"status": "ok"})))
|
||||
}
|
||||
@@ -1,2 +1,4 @@
|
||||
pub mod categories;
|
||||
pub mod ingress;
|
||||
pub mod liveness;
|
||||
pub mod readiness;
|
||||
|
||||
25
api-router/src/routes/readiness.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::api_state::ApiState;
|
||||
|
||||
/// Readiness probe: returns 200 if core dependencies are ready, else 503.
|
||||
pub async fn ready(State(state): State<ApiState>) -> impl IntoResponse {
|
||||
match state.db.client.query("RETURN true").await {
|
||||
Ok(_) => (
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"status": "ok",
|
||||
"checks": { "db": "ok" }
|
||||
})),
|
||||
),
|
||||
Err(e) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"status": "error",
|
||||
"checks": { "db": "fail" },
|
||||
"reason": e.to_string()
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,9 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# Workspace dependencies
|
||||
tokio = { workspace = true }
|
||||
@@ -39,6 +42,9 @@ url = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
surrealdb-migrations = { workspace = true }
|
||||
tokio-retry = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
state-machines = { workspace = true }
|
||||
|
||||
|
||||
[features]
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;
|
||||
|
||||
UPDATE system_settings:current SET
|
||||
voice_processing_model = "whisper-1"
|
||||
WHERE voice_processing_model == NONE;
|
||||
115
common/migrations/20250921_120004_fix_datetime_fields.surql
Normal file
@@ -0,0 +1,115 @@
|
||||
-- Align timestamp fields with SurrealDB native datetime type.
|
||||
|
||||
-- User timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON user FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON user FLEXIBLE;
|
||||
|
||||
UPDATE user SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE user SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON user TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON user TYPE datetime;
|
||||
|
||||
-- Text content timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON text_content FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON text_content FLEXIBLE;
|
||||
|
||||
UPDATE text_content SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE text_content SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON text_content TYPE datetime;
|
||||
|
||||
REBUILD INDEX text_content_created_at_idx ON text_content;
|
||||
|
||||
-- Text chunk timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON text_chunk FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON text_chunk FLEXIBLE;
|
||||
|
||||
UPDATE text_chunk SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE text_chunk SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON text_chunk TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON text_chunk TYPE datetime;
|
||||
|
||||
-- Knowledge entity timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON knowledge_entity FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON knowledge_entity FLEXIBLE;
|
||||
|
||||
UPDATE knowledge_entity SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE knowledge_entity SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON knowledge_entity TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON knowledge_entity TYPE datetime;
|
||||
|
||||
REBUILD INDEX knowledge_entity_created_at_idx ON knowledge_entity;
|
||||
|
||||
-- Conversation timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON conversation FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON conversation FLEXIBLE;
|
||||
|
||||
UPDATE conversation SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE conversation SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON conversation TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON conversation TYPE datetime;
|
||||
|
||||
REBUILD INDEX conversation_created_at_idx ON conversation;
|
||||
|
||||
-- Message timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON message FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON message FLEXIBLE;
|
||||
|
||||
UPDATE message SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE message SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON message TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON message TYPE datetime;
|
||||
|
||||
REBUILD INDEX message_updated_at_idx ON message;
|
||||
|
||||
-- Ingestion task timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON ingestion_task FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON ingestion_task FLEXIBLE;
|
||||
|
||||
UPDATE ingestion_task SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE ingestion_task SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON ingestion_task TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON ingestion_task TYPE datetime;
|
||||
|
||||
REBUILD INDEX idx_ingestion_task_created ON ingestion_task;
|
||||
|
||||
-- File timestamps
|
||||
DEFINE FIELD OVERWRITE created_at ON file FLEXIBLE;
|
||||
DEFINE FIELD OVERWRITE updated_at ON file FLEXIBLE;
|
||||
|
||||
UPDATE file SET created_at = type::datetime(created_at)
|
||||
WHERE type::is::string(created_at) AND created_at != "";
|
||||
|
||||
UPDATE file SET updated_at = type::datetime(updated_at)
|
||||
WHERE type::is::string(updated_at) AND updated_at != "";
|
||||
|
||||
DEFINE FIELD OVERWRITE created_at ON file TYPE datetime;
|
||||
DEFINE FIELD OVERWRITE updated_at ON file TYPE datetime;
|
||||
@@ -0,0 +1,17 @@
|
||||
-- Add FTS indexes for searching name and description on entities
|
||||
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity
|
||||
FIELDS name
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity
|
||||
FIELDS description
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk
|
||||
FIELDS chunk
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
173
common/migrations/20251012_205900_state_machine_migration.surql
Normal file
@@ -0,0 +1,173 @@
|
||||
-- State machine migration for ingestion_task records
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS state ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS attempts ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS max_attempts ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS scheduled_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS locked_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS lease_duration_secs ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS worker_id ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS error_code ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS error_message ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS last_error_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS priority ON TABLE ingestion_task TYPE option<number>;
|
||||
|
||||
REMOVE FIELD status ON TABLE ingestion_task;
|
||||
DEFINE FIELD status ON TABLE ingestion_task TYPE option<object>;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS idx_ingestion_task_state_sched ON TABLE ingestion_task FIELDS state, scheduled_at;
|
||||
|
||||
LET $needs_migration = (SELECT count() AS count FROM type::table('ingestion_task') WHERE state = NONE)[0].count;
|
||||
|
||||
IF $needs_migration > 0 THEN {
|
||||
-- Created -> Pending
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Pending",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF created_at != NONE THEN created_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Created";
|
||||
|
||||
-- InProgress -> Processing
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Processing",
|
||||
attempts = IF status.attempts != NONE THEN status.attempts ELSE 1 END,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
|
||||
locked_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "InProgress";
|
||||
|
||||
-- Completed -> Succeeded
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Succeeded",
|
||||
attempts = 1,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Completed";
|
||||
|
||||
-- Error -> DeadLetter (terminal failure)
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "DeadLetter",
|
||||
attempts = 3,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = status.message,
|
||||
last_error_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Error";
|
||||
|
||||
-- Cancelled -> Cancelled
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Cancelled",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Cancelled";
|
||||
|
||||
-- Fallback for any remaining records missing state
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Pending",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE;
|
||||
} END;
|
||||
|
||||
-- Ensure defaults for newly added fields
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET max_attempts = 3
|
||||
WHERE max_attempts = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET lease_duration_secs = 300
|
||||
WHERE lease_duration_secs = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET attempts = 0
|
||||
WHERE attempts = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET priority = 0
|
||||
WHERE priority = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END
|
||||
WHERE scheduled_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET locked_at = NONE
|
||||
WHERE locked_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET worker_id = NONE
|
||||
WHERE worker_id != NONE AND worker_id = "";
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET error_code = NONE
|
||||
WHERE error_code = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET error_message = NONE
|
||||
WHERE error_message = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET last_error_at = NONE
|
||||
WHERE last_error_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET status = NONE
|
||||
WHERE status != NONE;
|
||||
24
common/migrations/20251022_120302_add_scratchpad_table.surql
Normal file
@@ -0,0 +1,24 @@
|
||||
-- Add scratchpad table and schema
|
||||
|
||||
-- Define scratchpad table and schema
|
||||
DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;
|
||||
|
||||
-- Standard fields from stored_object! macro
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;
|
||||
|
||||
-- Custom fields from the Scratchpad struct
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;
|
||||
DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;
|
||||
DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;
|
||||
|
||||
-- Indexes based on query patterns
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -160,6 +160,7 @@\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -18,8 +18,8 @@\n DEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE datetime;\n\n # Custom fields from the Conversation struct\n DEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;\n@@ -34,8 +34,8 @@\n DEFINE TABLE IF NOT EXISTS file SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON file TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON file TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE datetime;\n\n # Custom fields from the FileInfo struct\n DEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;\n@@ -54,8 +54,8 @@\n DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime;\n\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n@@ -71,8 +71,8 @@\n DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE datetime;\n\n # Custom fields from the KnowledgeEntity struct\n DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;\n@@ -102,8 +102,8 @@\n DEFINE TABLE IF NOT EXISTS message SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON message TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON message TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE datetime;\n\n # Custom fields from the Message struct\n DEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;\n@@ -167,8 +167,8 @@\n DEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;\n\n # Custom fields from the TextChunk struct\n DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;\n@@ -191,8 +191,8 @@\n DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;\n\n # Custom fields from the TextContent struct\n DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;\n@@ -215,8 +215,8 @@\n DEFINE TABLE IF NOT EXISTS user SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON user TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON user TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE datetime;\n\n # Custom fields from the User struct\n DEFINE FIELD IF NOT EXISTS email ON user TYPE string;\n","events":null}
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -137,6 +137,30 @@\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n\n+# Defines the schema for the 'scratchpad' table.\n+\n+DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;\n+\n+# Standard fields from stored_object! macro\n+DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;\n+\n+# Custom fields from the Scratchpad struct\n+DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;\n+DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;\n+DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;\n+DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;\n+\n+# Indexes based on query patterns\n+DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;\n+DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;\n+DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;\n+\n DEFINE TABLE OVERWRITE script_migration SCHEMAFULL\n PERMISSIONS\n FOR select FULL\n","events":null}
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE datetime;
|
||||
|
||||
# Custom fields from the Conversation struct
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS file SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON file TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON file TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE datetime;
|
||||
|
||||
# Custom fields from the FileInfo struct
|
||||
DEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime;
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;
|
||||
DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE datetime;
|
||||
|
||||
# Custom fields from the KnowledgeEntity struct
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS message SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON message TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON message TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE datetime;
|
||||
|
||||
# Custom fields from the Message struct
|
||||
DEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;
|
||||
|
||||
23
common/schemas/scratchpad.surql
Normal file
@@ -0,0 +1,23 @@
|
||||
# Defines the schema for the 'scratchpad' table.
|
||||
|
||||
DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;
|
||||
|
||||
# Standard fields from stored_object! macro
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;
|
||||
|
||||
# Custom fields from the Scratchpad struct
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;
|
||||
DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;
|
||||
DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;
|
||||
|
||||
# Indexes based on query patterns
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;
|
||||
DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;
|
||||
@@ -13,3 +13,4 @@ DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;
|
||||
DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;
|
||||
|
||||
# Custom fields from the TextChunk struct
|
||||
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;
|
||||
|
||||
# Custom fields from the TextContent struct
|
||||
DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
DEFINE TABLE IF NOT EXISTS user SCHEMALESS;
|
||||
|
||||
# Standard fields
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON user TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS created_at ON user TYPE datetime;
|
||||
DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE datetime;
|
||||
|
||||
# Custom fields from the User struct
|
||||
DEFINE FIELD IF NOT EXISTS email ON user TYPE string;
|
||||
|
||||
@@ -80,15 +80,18 @@ impl SurrealDbClient {
|
||||
/// Operation to rebuild indexes
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||
debug!("Rebuilding indexes");
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
|
||||
.await?;
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity")
|
||||
.await?;
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content")
|
||||
.await?;
|
||||
let rebuild_sql = r#"
|
||||
BEGIN TRANSACTION;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
|
||||
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
|
||||
COMMIT TRANSACTION;
|
||||
"#;
|
||||
|
||||
self.client.query(rebuild_sql).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
pub mod db;
|
||||
pub mod store;
|
||||
pub mod types;
|
||||
|
||||
837
common/src/storage/store.rs
Normal file
@@ -0,0 +1,837 @@
|
||||
use std::io::ErrorKind;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result as AnyResult};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use object_store::local::LocalFileSystem;
|
||||
use object_store::memory::InMemory;
|
||||
use object_store::{path::Path as ObjPath, ObjectStore};
|
||||
|
||||
use crate::utils::config::{AppConfig, StorageKind};
|
||||
|
||||
pub type DynStore = Arc<dyn ObjectStore>;
|
||||
|
||||
/// Storage manager with persistent state and proper lifecycle management.
|
||||
#[derive(Clone)]
|
||||
pub struct StorageManager {
|
||||
store: DynStore,
|
||||
backend_kind: StorageKind,
|
||||
local_base: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl StorageManager {
|
||||
/// Create a new StorageManager with the specified configuration.
|
||||
///
|
||||
/// This method validates the configuration and creates the appropriate
|
||||
/// storage backend with proper initialization.
|
||||
pub async fn new(cfg: &AppConfig) -> object_store::Result<Self> {
|
||||
let backend_kind = cfg.storage.clone();
|
||||
let (store, local_base) = create_storage_backend(cfg).await?;
|
||||
|
||||
Ok(Self {
|
||||
store,
|
||||
backend_kind,
|
||||
local_base,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a StorageManager with a custom storage backend.
|
||||
///
|
||||
/// This method is useful for testing scenarios where you want to inject
|
||||
/// a specific storage backend.
|
||||
pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self {
|
||||
Self {
|
||||
store,
|
||||
backend_kind,
|
||||
local_base: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the storage backend kind.
|
||||
pub fn backend_kind(&self) -> &StorageKind {
|
||||
&self.backend_kind
|
||||
}
|
||||
|
||||
/// Access the resolved local base directory when using the local backend.
|
||||
pub fn local_base_path(&self) -> Option<&Path> {
|
||||
self.local_base.as_deref()
|
||||
}
|
||||
|
||||
/// Resolve an object location to a filesystem path when using the local backend.
|
||||
///
|
||||
/// Returns `None` when the backend is not local or when the provided location includes
|
||||
/// unsupported components (absolute paths or parent traversals).
|
||||
pub fn resolve_local_path(&self, location: &str) -> Option<PathBuf> {
|
||||
let base = self.local_base_path()?;
|
||||
let relative = Path::new(location);
|
||||
if relative.is_absolute()
|
||||
|| relative
|
||||
.components()
|
||||
.any(|component| matches!(component, Component::ParentDir | Component::Prefix(_)))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(base.join(relative))
|
||||
}
|
||||
|
||||
/// Store bytes at the specified location.
|
||||
///
|
||||
/// This operation persists data using the underlying storage backend.
|
||||
/// For memory backends, data persists for the lifetime of the StorageManager.
|
||||
pub async fn put(&self, location: &str, data: Bytes) -> object_store::Result<()> {
|
||||
let path = ObjPath::from(location);
|
||||
let payload = object_store::PutPayload::from_bytes(data);
|
||||
self.store.put(&path, payload).await.map(|_| ())
|
||||
}
|
||||
|
||||
/// Retrieve bytes from the specified location.
|
||||
///
|
||||
/// Returns the full contents buffered in memory.
|
||||
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
|
||||
let path = ObjPath::from(location);
|
||||
let result = self.store.get(&path).await?;
|
||||
result.bytes().await
|
||||
}
|
||||
|
||||
/// Get a streaming handle for large objects.
|
||||
///
|
||||
/// Returns a fallible stream of Bytes chunks suitable for large file processing.
|
||||
pub async fn get_stream(
|
||||
&self,
|
||||
location: &str,
|
||||
) -> object_store::Result<BoxStream<'static, object_store::Result<Bytes>>> {
|
||||
let path = ObjPath::from(location);
|
||||
let result = self.store.get(&path).await?;
|
||||
Ok(result.into_stream())
|
||||
}
|
||||
|
||||
/// Delete all objects below the specified prefix.
|
||||
///
|
||||
/// For local filesystem backends, this also attempts to clean up empty directories.
|
||||
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
|
||||
let prefix_path = ObjPath::from(prefix);
|
||||
let locations = self
|
||||
.store
|
||||
.list(Some(&prefix_path))
|
||||
.map_ok(|m| m.location)
|
||||
.boxed();
|
||||
self.store
|
||||
.delete_stream(locations)
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
|
||||
// Cleanup filesystem directories only for local backend
|
||||
if matches!(self.backend_kind, StorageKind::Local) {
|
||||
self.cleanup_filesystem_directories(prefix).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all objects below the specified prefix.
|
||||
pub async fn list(
|
||||
&self,
|
||||
prefix: Option<&str>,
|
||||
) -> object_store::Result<Vec<object_store::ObjectMeta>> {
|
||||
let prefix_path = prefix.map(ObjPath::from);
|
||||
self.store.list(prefix_path.as_ref()).try_collect().await
|
||||
}
|
||||
|
||||
/// Check if an object exists at the specified location.
|
||||
pub async fn exists(&self, location: &str) -> object_store::Result<bool> {
|
||||
let path = ObjPath::from(location);
|
||||
self.store
|
||||
.head(&path)
|
||||
.await
|
||||
.map(|_| true)
|
||||
.or_else(|e| match e {
|
||||
object_store::Error::NotFound { .. } => Ok(false),
|
||||
_ => Err(e),
|
||||
})
|
||||
}
|
||||
|
||||
/// Cleanup filesystem directories for local backend.
|
||||
///
|
||||
/// This is a best-effort cleanup and ignores errors.
|
||||
async fn cleanup_filesystem_directories(&self, prefix: &str) -> object_store::Result<()> {
|
||||
if !matches!(self.backend_kind, StorageKind::Local) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(base) = &self.local_base else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let relative = Path::new(prefix);
|
||||
if relative.is_absolute()
|
||||
|| relative
|
||||
.components()
|
||||
.any(|component| matches!(component, Component::ParentDir | Component::Prefix(_)))
|
||||
{
|
||||
tracing::warn!(
|
||||
prefix = %prefix,
|
||||
"Skipping directory cleanup for unsupported prefix components"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut current = base.join(relative);
|
||||
|
||||
while current.starts_with(base) && current.as_path() != base.as_path() {
|
||||
match tokio::fs::remove_dir(¤t).await {
|
||||
Ok(_) => {}
|
||||
Err(err) => match err.kind() {
|
||||
ErrorKind::NotFound => {}
|
||||
ErrorKind::DirectoryNotEmpty => break,
|
||||
_ => tracing::debug!(
|
||||
error = %err,
|
||||
path = %current.display(),
|
||||
"Failed to remove directory during cleanup"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
if let Some(parent) = current.parent() {
|
||||
current = parent.to_path_buf();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a storage backend based on configuration.
|
||||
///
|
||||
/// This factory function handles the creation and initialization of different
|
||||
/// storage backends with proper error handling and validation.
|
||||
async fn create_storage_backend(
|
||||
cfg: &AppConfig,
|
||||
) -> object_store::Result<(DynStore, Option<PathBuf>)> {
|
||||
match cfg.storage {
|
||||
StorageKind::Local => {
|
||||
let base = resolve_base_dir(cfg);
|
||||
if !base.exists() {
|
||||
tokio::fs::create_dir_all(&base).await.map_err(|e| {
|
||||
object_store::Error::Generic {
|
||||
store: "LocalFileSystem",
|
||||
source: e.into(),
|
||||
}
|
||||
})?;
|
||||
}
|
||||
let store = LocalFileSystem::new_with_prefix(base.clone())?;
|
||||
Ok((Arc::new(store), Some(base)))
|
||||
}
|
||||
StorageKind::Memory => {
|
||||
let store = InMemory::new();
|
||||
Ok((Arc::new(store), None))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Testing utilities for storage operations.
|
||||
///
|
||||
/// This module provides specialized utilities for testing scenarios with
|
||||
/// automatic memory backend setup and proper test isolation.
|
||||
#[cfg(test)]
|
||||
pub mod testing {
|
||||
use super::*;
|
||||
use crate::utils::config::{AppConfig, PdfIngestMode};
|
||||
use uuid;
|
||||
|
||||
/// Create a test configuration with memory storage.
|
||||
///
|
||||
/// This provides a ready-to-use configuration for testing scenarios
|
||||
/// that don't require filesystem persistence.
|
||||
pub fn test_config_memory() -> AppConfig {
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
surrealdb_address: "test".into(),
|
||||
surrealdb_username: "test".into(),
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: "/tmp/unused".into(), // Ignored for memory storage
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Memory,
|
||||
pdf_ingest_mode: PdfIngestMode::LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a test configuration with local storage.
|
||||
///
|
||||
/// This provides a ready-to-use configuration for testing scenarios
|
||||
/// that require actual filesystem operations.
|
||||
pub fn test_config_local() -> AppConfig {
|
||||
let base = format!("/tmp/minne_test_storage_{}", uuid::Uuid::new_v4());
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
surrealdb_address: "test".into(),
|
||||
surrealdb_username: "test".into(),
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: base.into(),
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: PdfIngestMode::LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// A specialized StorageManager for testing scenarios.
|
||||
///
|
||||
/// This provides automatic setup for memory storage with proper isolation
|
||||
/// and cleanup capabilities for test environments.
|
||||
#[derive(Clone)]
|
||||
pub struct TestStorageManager {
|
||||
storage: StorageManager,
|
||||
_temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
|
||||
}
|
||||
|
||||
impl TestStorageManager {
|
||||
/// Create a new TestStorageManager with memory backend.
|
||||
///
|
||||
/// This is the preferred method for unit tests as it provides
|
||||
/// fast execution and complete isolation.
|
||||
pub async fn new_memory() -> object_store::Result<Self> {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg).await?;
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new TestStorageManager with local filesystem backend.
|
||||
///
|
||||
/// This method creates a temporary directory that will be automatically
|
||||
/// cleaned up when the TestStorageManager is dropped.
|
||||
pub async fn new_local() -> object_store::Result<Self> {
|
||||
let cfg = test_config_local();
|
||||
let storage = StorageManager::new(&cfg).await?;
|
||||
let resolved = storage
|
||||
.local_base_path()
|
||||
.map(|path| (cfg.data_dir.clone(), path.to_path_buf()));
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: resolved,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a TestStorageManager with custom configuration.
|
||||
pub async fn with_config(cfg: &AppConfig) -> object_store::Result<Self> {
|
||||
let storage = StorageManager::new(cfg).await?;
|
||||
let temp_dir = if matches!(cfg.storage, StorageKind::Local) {
|
||||
storage
|
||||
.local_base_path()
|
||||
.map(|path| (cfg.data_dir.clone(), path.to_path_buf()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: temp_dir,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying StorageManager.
|
||||
pub fn storage(&self) -> &StorageManager {
|
||||
&self.storage
|
||||
}
|
||||
|
||||
/// Clone the underlying StorageManager.
|
||||
pub fn clone_storage(&self) -> StorageManager {
|
||||
self.storage.clone()
|
||||
}
|
||||
|
||||
/// Store test data at the specified location.
|
||||
pub async fn put(&self, location: &str, data: &[u8]) -> object_store::Result<()> {
|
||||
self.storage.put(location, Bytes::from(data.to_vec())).await
|
||||
}
|
||||
|
||||
/// Retrieve test data from the specified location.
|
||||
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
|
||||
self.storage.get(location).await
|
||||
}
|
||||
|
||||
/// Delete test data below the specified prefix.
|
||||
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
|
||||
self.storage.delete_prefix(prefix).await
|
||||
}
|
||||
|
||||
/// Check if test data exists at the specified location.
|
||||
pub async fn exists(&self, location: &str) -> object_store::Result<bool> {
|
||||
self.storage.exists(location).await
|
||||
}
|
||||
|
||||
/// List all test objects below the specified prefix.
|
||||
pub async fn list(
|
||||
&self,
|
||||
prefix: Option<&str>,
|
||||
) -> object_store::Result<Vec<object_store::ObjectMeta>> {
|
||||
self.storage.list(prefix).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestStorageManager {
|
||||
fn drop(&mut self) {
|
||||
// Clean up temporary directories for local storage
|
||||
if let Some((_, path)) = &self._temp_dir {
|
||||
if path.exists() {
|
||||
let _ = std::fs::remove_dir_all(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience macro for creating memory storage tests.
|
||||
///
|
||||
/// This macro simplifies the creation of test storage with memory backend.
|
||||
#[macro_export]
|
||||
macro_rules! test_storage_memory {
|
||||
() => {{
|
||||
async move {
|
||||
$crate::storage::store::testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("Failed to create test memory storage")
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/// Convenience macro for creating local storage tests.
|
||||
///
|
||||
/// This macro simplifies the creation of test storage with local filesystem backend.
|
||||
#[macro_export]
|
||||
macro_rules! test_storage_local {
|
||||
() => {{
|
||||
async move {
|
||||
$crate::storage::store::testing::TestStorageManager::new_local()
|
||||
.await
|
||||
.expect("Failed to create test local storage")
|
||||
}
|
||||
}};
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the absolute base directory used for local storage from config.
|
||||
///
|
||||
/// If `data_dir` is relative, it is resolved against the current working directory.
|
||||
pub fn resolve_base_dir(cfg: &AppConfig) -> PathBuf {
|
||||
if cfg.data_dir.starts_with('/') {
|
||||
PathBuf::from(&cfg.data_dir)
|
||||
} else {
|
||||
std::env::current_dir()
|
||||
.unwrap_or_else(|_| PathBuf::from("."))
|
||||
.join(&cfg.data_dir)
|
||||
}
|
||||
}
|
||||
|
||||
/// Split an absolute filesystem path into `(parent_dir, file_name)`.
|
||||
pub fn split_abs_path(path: &str) -> AnyResult<(PathBuf, String)> {
|
||||
let pb = PathBuf::from(path);
|
||||
let parent = pb
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow!("Path has no parent: {path}"))?
|
||||
.to_path_buf();
|
||||
let file = pb
|
||||
.file_name()
|
||||
.ok_or_else(|| anyhow!("Path has no file name: {path}"))?
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
Ok((parent, file))
|
||||
}
|
||||
|
||||
/// Split a logical object location `"a/b/c"` into `("a/b", "c")`.
|
||||
pub fn split_object_path(path: &str) -> AnyResult<(String, String)> {
|
||||
if let Some((p, f)) = path.rsplit_once('/') {
|
||||
return Ok((p.to_string(), f.to_string()));
|
||||
}
|
||||
Err(anyhow!("Object path has no separator: {path}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
|
||||
use bytes::Bytes;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_config(root: &str) -> AppConfig {
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
surrealdb_address: "test".into(),
|
||||
surrealdb_username: "test".into(),
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: root.into(),
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Local,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn test_config_memory() -> AppConfig {
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
surrealdb_address: "test".into(),
|
||||
surrealdb_username: "test".into(),
|
||||
surrealdb_password: "test".into(),
|
||||
surrealdb_namespace: "test".into(),
|
||||
surrealdb_database: "test".into(),
|
||||
data_dir: "/tmp/unused".into(), // Ignored for memory storage
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::Memory,
|
||||
pdf_ingest_mode: LlmFirst,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_memory_basic_operations() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
assert!(storage.local_base_path().is_none());
|
||||
|
||||
let location = "test/data/file.txt";
|
||||
let data = b"test data for storage manager";
|
||||
|
||||
// Test put and get
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(location).await.expect("exists check"));
|
||||
|
||||
// Test delete
|
||||
storage.delete_prefix("test/data/").await.expect("delete");
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists check after delete"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_local_basic_operations() {
|
||||
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
|
||||
let cfg = test_config(&base);
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
let resolved_base = storage
|
||||
.local_base_path()
|
||||
.expect("resolved base dir")
|
||||
.to_path_buf();
|
||||
assert_eq!(resolved_base, PathBuf::from(&base));
|
||||
|
||||
let location = "test/data/file.txt";
|
||||
let data = b"test data for local storage";
|
||||
|
||||
// Test put and get
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
let object_dir = resolved_base.join("test/data");
|
||||
tokio::fs::metadata(&object_dir)
|
||||
.await
|
||||
.expect("object directory exists after write");
|
||||
|
||||
// Test exists
|
||||
assert!(storage.exists(location).await.expect("exists check"));
|
||||
|
||||
// Test delete
|
||||
storage.delete_prefix("test/data/").await.expect("delete");
|
||||
assert!(!storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists check after delete"));
|
||||
assert!(
|
||||
tokio::fs::metadata(&object_dir).await.is_err(),
|
||||
"object directory should be removed"
|
||||
);
|
||||
tokio::fs::metadata(&resolved_base)
|
||||
.await
|
||||
.expect("base directory remains intact");
|
||||
|
||||
// Clean up
|
||||
let _ = tokio::fs::remove_dir_all(&base).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_memory_persistence() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
|
||||
let location = "persistence/test.txt";
|
||||
let data1 = b"first data";
|
||||
let data2 = b"second data";
|
||||
|
||||
// Put first data
|
||||
storage
|
||||
.put(location, Bytes::from(data1.to_vec()))
|
||||
.await
|
||||
.expect("put first");
|
||||
|
||||
// Retrieve and verify first data
|
||||
let retrieved1 = storage.get(location).await.expect("get first");
|
||||
assert_eq!(retrieved1.as_ref(), data1);
|
||||
|
||||
// Overwrite with second data
|
||||
storage
|
||||
.put(location, Bytes::from(data2.to_vec()))
|
||||
.await
|
||||
.expect("put second");
|
||||
|
||||
// Retrieve and verify second data
|
||||
let retrieved2 = storage.get(location).await.expect("get second");
|
||||
assert_eq!(retrieved2.as_ref(), data2);
|
||||
|
||||
// Data persists across multiple operations using the same StorageManager
|
||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_list_operations() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
|
||||
// Create multiple files
|
||||
let files = vec![
|
||||
("dir1/file1.txt", b"content1"),
|
||||
("dir1/file2.txt", b"content2"),
|
||||
("dir2/file3.txt", b"content3"),
|
||||
];
|
||||
|
||||
for (location, data) in &files {
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
}
|
||||
|
||||
// Test listing without prefix
|
||||
let all_files = storage.list(None).await.expect("list all");
|
||||
assert_eq!(all_files.len(), 3);
|
||||
|
||||
// Test listing with prefix
|
||||
let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1");
|
||||
assert_eq!(dir1_files.len(), 2);
|
||||
assert!(dir1_files
|
||||
.iter()
|
||||
.any(|meta| meta.location.as_ref().contains("file1.txt")));
|
||||
assert!(dir1_files
|
||||
.iter()
|
||||
.any(|meta| meta.location.as_ref().contains("file2.txt")));
|
||||
|
||||
// Test listing non-existent prefix
|
||||
let empty_files = storage
|
||||
.list(Some("nonexistent/"))
|
||||
.await
|
||||
.expect("list nonexistent");
|
||||
assert_eq!(empty_files.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_stream_operations() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
|
||||
let location = "stream/test.bin";
|
||||
let content = vec![42u8; 1024 * 64]; // 64KB of data
|
||||
|
||||
// Put large data
|
||||
storage
|
||||
.put(location, Bytes::from(content.clone()))
|
||||
.await
|
||||
.expect("put large data");
|
||||
|
||||
// Get as stream
|
||||
let mut stream = storage.get_stream(location).await.expect("get stream");
|
||||
let mut collected = Vec::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.expect("stream chunk");
|
||||
collected.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
assert_eq!(collected, content);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_with_custom_backend() {
|
||||
use object_store::memory::InMemory;
|
||||
|
||||
// Create custom memory backend
|
||||
let custom_store = InMemory::new();
|
||||
let storage = StorageManager::with_backend(Arc::new(custom_store), StorageKind::Memory);
|
||||
|
||||
let location = "custom/test.txt";
|
||||
let data = b"custom backend test";
|
||||
|
||||
// Test operations with custom backend
|
||||
storage
|
||||
.put(location, Bytes::from(data.to_vec()))
|
||||
.await
|
||||
.expect("put");
|
||||
let retrieved = storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
assert!(storage.exists(location).await.expect("exists"));
|
||||
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_error_handling() {
|
||||
let cfg = test_config_memory();
|
||||
let storage = StorageManager::new(&cfg)
|
||||
.await
|
||||
.expect("create storage manager");
|
||||
|
||||
// Test getting non-existent file
|
||||
let result = storage.get("nonexistent.txt").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test checking existence of non-existent file
|
||||
let exists = storage
|
||||
.exists("nonexistent.txt")
|
||||
.await
|
||||
.expect("exists check");
|
||||
assert!(!exists);
|
||||
|
||||
// Test listing with invalid location (should not panic)
|
||||
let _result = storage.get("").await;
|
||||
// This may or may not error depending on the backend implementation
|
||||
// The important thing is that it doesn't panic
|
||||
}
|
||||
|
||||
// TestStorageManager tests
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_memory() {
|
||||
let test_storage = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage");
|
||||
|
||||
let location = "test/storage/file.txt";
|
||||
let data = b"test data with TestStorageManager";
|
||||
|
||||
// Test put and get
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage.exists(location).await.expect("exists"));
|
||||
|
||||
// Test list
|
||||
let files = test_storage
|
||||
.list(Some("test/storage/"))
|
||||
.await
|
||||
.expect("list");
|
||||
assert_eq!(files.len(), 1);
|
||||
|
||||
// Test delete
|
||||
test_storage
|
||||
.delete_prefix("test/storage/")
|
||||
.await
|
||||
.expect("delete");
|
||||
assert!(!test_storage
|
||||
.exists(location)
|
||||
.await
|
||||
.expect("exists after delete"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_local() {
|
||||
let test_storage = testing::TestStorageManager::new_local()
|
||||
.await
|
||||
.expect("create test storage");
|
||||
|
||||
let location = "test/local/file.txt";
|
||||
let data = b"test data with local TestStorageManager";
|
||||
|
||||
// Test put and get
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Test existence check
|
||||
assert!(test_storage.exists(location).await.expect("exists"));
|
||||
|
||||
// The storage should be automatically cleaned up when test_storage is dropped
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_isolation() {
|
||||
let storage1 = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage 1");
|
||||
let storage2 = testing::TestStorageManager::new_memory()
|
||||
.await
|
||||
.expect("create test storage 2");
|
||||
|
||||
let location = "isolation/test.txt";
|
||||
let data1 = b"storage 1 data";
|
||||
let data2 = b"storage 2 data";
|
||||
|
||||
// Put different data in each storage
|
||||
storage1.put(location, data1).await.expect("put storage 1");
|
||||
storage2.put(location, data2).await.expect("put storage 2");
|
||||
|
||||
// Verify isolation
|
||||
let retrieved1 = storage1.get(location).await.expect("get storage 1");
|
||||
let retrieved2 = storage2.get(location).await.expect("get storage 2");
|
||||
|
||||
assert_eq!(retrieved1.as_ref(), data1);
|
||||
assert_eq!(retrieved2.as_ref(), data2);
|
||||
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_storage_manager_config() {
|
||||
let cfg = testing::test_config_memory();
|
||||
let test_storage = testing::TestStorageManager::with_config(&cfg)
|
||||
.await
|
||||
.expect("create test storage with config");
|
||||
|
||||
let location = "config/test.txt";
|
||||
let data = b"test data with custom config";
|
||||
|
||||
test_storage.put(location, data).await.expect("put");
|
||||
let retrieved = test_storage.get(location).await.expect("get");
|
||||
assert_eq!(retrieved.as_ref(), data);
|
||||
|
||||
// Verify it's using memory backend
|
||||
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
|
||||
}
|
||||
}
|
||||
@@ -67,7 +67,10 @@ impl Conversation {
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||
.patch(PatchOp::replace("/updated_at", Utc::now()))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(Utc::now()),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,116 +1,529 @@
|
||||
use futures::Stream;
|
||||
use surrealdb::{opt::PatchOp, Notification};
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::Duration as ChronoDuration;
|
||||
use state_machines::state_machine;
|
||||
use surrealdb::sql::Datetime as SurrealDatetime;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use super::ingestion_payload::IngestionPayload;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "name")]
|
||||
pub enum IngestionTaskStatus {
|
||||
Created,
|
||||
InProgress {
|
||||
attempts: u32,
|
||||
last_attempt: DateTime<Utc>,
|
||||
},
|
||||
Completed,
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
pub const MAX_ATTEMPTS: u32 = 3;
|
||||
pub const DEFAULT_LEASE_SECS: i64 = 300;
|
||||
pub const DEFAULT_PRIORITY: i32 = 0;
|
||||
|
||||
#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
||||
pub enum TaskState {
|
||||
#[serde(rename = "Pending")]
|
||||
#[default]
|
||||
Pending,
|
||||
#[serde(rename = "Reserved")]
|
||||
Reserved,
|
||||
#[serde(rename = "Processing")]
|
||||
Processing,
|
||||
#[serde(rename = "Succeeded")]
|
||||
Succeeded,
|
||||
#[serde(rename = "Failed")]
|
||||
Failed,
|
||||
#[serde(rename = "Cancelled")]
|
||||
Cancelled,
|
||||
#[serde(rename = "DeadLetter")]
|
||||
DeadLetter,
|
||||
}
|
||||
|
||||
impl TaskState {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
TaskState::Pending => "Pending",
|
||||
TaskState::Reserved => "Reserved",
|
||||
TaskState::Processing => "Processing",
|
||||
TaskState::Succeeded => "Succeeded",
|
||||
TaskState::Failed => "Failed",
|
||||
TaskState::Cancelled => "Cancelled",
|
||||
TaskState::DeadLetter => "DeadLetter",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
TaskState::Succeeded | TaskState::Cancelled | TaskState::DeadLetter
|
||||
)
|
||||
}
|
||||
|
||||
pub fn display_label(&self) -> &'static str {
|
||||
match self {
|
||||
TaskState::Pending => "Pending",
|
||||
TaskState::Reserved => "Reserved",
|
||||
TaskState::Processing => "Processing",
|
||||
TaskState::Succeeded => "Completed",
|
||||
TaskState::Failed => "Retrying",
|
||||
TaskState::Cancelled => "Cancelled",
|
||||
TaskState::DeadLetter => "Dead Letter",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Default)]
|
||||
pub struct TaskErrorInfo {
|
||||
pub code: Option<String>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum TaskTransition {
|
||||
StartProcessing,
|
||||
Succeed,
|
||||
Fail,
|
||||
Cancel,
|
||||
DeadLetter,
|
||||
Release,
|
||||
}
|
||||
|
||||
impl TaskTransition {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
TaskTransition::StartProcessing => "start_processing",
|
||||
TaskTransition::Succeed => "succeed",
|
||||
TaskTransition::Fail => "fail",
|
||||
TaskTransition::Cancel => "cancel",
|
||||
TaskTransition::DeadLetter => "deadletter",
|
||||
TaskTransition::Release => "release",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod lifecycle {
|
||||
use super::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: TaskLifecycleMachine,
|
||||
initial: Pending,
|
||||
states: [Pending, Reserved, Processing, Succeeded, Failed, Cancelled, DeadLetter],
|
||||
events {
|
||||
reserve {
|
||||
transition: { from: Pending, to: Reserved }
|
||||
transition: { from: Failed, to: Reserved }
|
||||
}
|
||||
start_processing {
|
||||
transition: { from: Reserved, to: Processing }
|
||||
}
|
||||
succeed {
|
||||
transition: { from: Processing, to: Succeeded }
|
||||
}
|
||||
fail {
|
||||
transition: { from: Processing, to: Failed }
|
||||
}
|
||||
cancel {
|
||||
transition: { from: Pending, to: Cancelled }
|
||||
transition: { from: Reserved, to: Cancelled }
|
||||
transition: { from: Processing, to: Cancelled }
|
||||
}
|
||||
deadletter {
|
||||
transition: { from: Failed, to: DeadLetter }
|
||||
}
|
||||
release {
|
||||
transition: { from: Reserved, to: Pending }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn pending() -> TaskLifecycleMachine<(), Pending> {
|
||||
TaskLifecycleMachine::new(())
|
||||
}
|
||||
|
||||
pub(super) fn reserved() -> TaskLifecycleMachine<(), Reserved> {
|
||||
pending()
|
||||
.reserve()
|
||||
.expect("reserve transition from Pending should exist")
|
||||
}
|
||||
|
||||
pub(super) fn processing() -> TaskLifecycleMachine<(), Processing> {
|
||||
reserved()
|
||||
.start_processing()
|
||||
.expect("start_processing transition from Reserved should exist")
|
||||
}
|
||||
|
||||
pub(super) fn failed() -> TaskLifecycleMachine<(), Failed> {
|
||||
processing()
|
||||
.fail()
|
||||
.expect("fail transition from Processing should exist")
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_transition(state: &TaskState, event: TaskTransition) -> AppError {
|
||||
AppError::Validation(format!(
|
||||
"Invalid task transition: {} -> {}",
|
||||
state.as_str(),
|
||||
event.as_str()
|
||||
))
|
||||
}
|
||||
|
||||
stored_object!(IngestionTask, "ingestion_task", {
|
||||
content: IngestionPayload,
|
||||
status: IngestionTaskStatus,
|
||||
user_id: String
|
||||
state: TaskState,
|
||||
user_id: String,
|
||||
attempts: u32,
|
||||
max_attempts: u32,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime")]
|
||||
scheduled_at: chrono::DateTime<chrono::Utc>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
locked_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
lease_duration_secs: i64,
|
||||
worker_id: Option<String>,
|
||||
error_code: Option<String>,
|
||||
error_message: Option<String>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
last_error_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
priority: i32
|
||||
});
|
||||
|
||||
pub const MAX_ATTEMPTS: u32 = 3;
|
||||
|
||||
impl IngestionTask {
|
||||
pub async fn new(content: IngestionPayload, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
pub fn new(content: IngestionPayload, user_id: String) -> Self {
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
content,
|
||||
status: IngestionTaskStatus::Created,
|
||||
state: TaskState::Pending,
|
||||
user_id,
|
||||
attempts: 0,
|
||||
max_attempts: MAX_ATTEMPTS,
|
||||
scheduled_at: now,
|
||||
locked_at: None,
|
||||
lease_duration_secs: DEFAULT_LEASE_SECS,
|
||||
worker_id: None,
|
||||
error_code: None,
|
||||
error_message: None,
|
||||
last_error_at: None,
|
||||
priority: DEFAULT_PRIORITY,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new job and stores it in the database
|
||||
pub fn can_retry(&self) -> bool {
|
||||
self.attempts < self.max_attempts
|
||||
}
|
||||
|
||||
pub fn lease_duration(&self) -> Duration {
|
||||
Duration::from_secs(self.lease_duration_secs.max(0) as u64)
|
||||
}
|
||||
|
||||
pub async fn create_and_add_to_db(
|
||||
content: IngestionPayload,
|
||||
user_id: String,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
let task = Self::new(content, user_id).await;
|
||||
|
||||
let task = Self::new(content, user_id);
|
||||
db.store_item(task.clone()).await?;
|
||||
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
// Update job status
|
||||
pub async fn update_status(
|
||||
id: &str,
|
||||
status: IngestionTaskStatus,
|
||||
pub async fn claim_next_ready(
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let _job: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/status", status))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::sql::Datetime::default(),
|
||||
worker_id: &str,
|
||||
now: chrono::DateTime<chrono::Utc>,
|
||||
lease_duration: Duration,
|
||||
) -> Result<Option<IngestionTask>, AppError> {
|
||||
debug_assert!(lifecycle::pending().reserve().is_ok());
|
||||
debug_assert!(lifecycle::failed().reserve().is_ok());
|
||||
|
||||
const CLAIM_QUERY: &str = r#"
|
||||
UPDATE (
|
||||
SELECT * FROM type::table($table)
|
||||
WHERE state IN $candidate_states
|
||||
AND scheduled_at <= $now
|
||||
AND (
|
||||
attempts < max_attempts
|
||||
OR state IN $sticky_states
|
||||
)
|
||||
AND (
|
||||
locked_at = NONE
|
||||
OR time::unix($now) - time::unix(locked_at) >= lease_duration_secs
|
||||
)
|
||||
ORDER BY priority DESC, scheduled_at ASC, created_at ASC
|
||||
LIMIT 1
|
||||
)
|
||||
SET state = $reserved_state,
|
||||
attempts = if state IN $increment_states THEN
|
||||
if attempts + 1 > max_attempts THEN max_attempts ELSE attempts + 1 END
|
||||
ELSE
|
||||
attempts
|
||||
END,
|
||||
locked_at = $now,
|
||||
worker_id = $worker_id,
|
||||
lease_duration_secs = $lease_secs,
|
||||
updated_at = $now
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(CLAIM_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind((
|
||||
"candidate_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Failed.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind((
|
||||
"sticky_states",
|
||||
vec![TaskState::Reserved.as_str(), TaskState::Processing.as_str()],
|
||||
))
|
||||
.bind((
|
||||
"increment_states",
|
||||
vec![TaskState::Pending.as_str(), TaskState::Failed.as_str()],
|
||||
))
|
||||
.bind(("reserved_state", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", worker_id.to_string()))
|
||||
.bind(("lease_secs", lease_duration.as_secs() as i64))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
let task: Option<IngestionTask> = result.take(0)?;
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
/// Listen for new jobs
|
||||
pub async fn listen_for_tasks(
|
||||
pub async fn mark_processing(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const START_PROCESSING_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $processing,
|
||||
updated_at = $now,
|
||||
locked_at = $now
|
||||
WHERE state = $reserved AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(START_PROCESSING_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("reserved", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::StartProcessing))
|
||||
}
|
||||
|
||||
pub async fn mark_succeeded(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const COMPLETE_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $succeeded,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $now,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE
|
||||
WHERE state = $processing AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(COMPLETE_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("succeeded", TaskState::Succeeded.as_str()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Succeed))
|
||||
}
|
||||
|
||||
pub async fn mark_failed(
|
||||
&self,
|
||||
error: TaskErrorInfo,
|
||||
retry_delay: Duration,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<impl Stream<Item = Result<Notification<Self>, surrealdb::Error>>, surrealdb::Error>
|
||||
{
|
||||
db.listen::<Self>().await
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
let now = chrono::Utc::now();
|
||||
let retry_at = now
|
||||
+ ChronoDuration::from_std(retry_delay).unwrap_or_else(|_| ChronoDuration::seconds(30));
|
||||
|
||||
const FAIL_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $failed,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $retry_at,
|
||||
error_code = $error_code,
|
||||
error_message = $error_message,
|
||||
last_error_at = $now
|
||||
WHERE state = $processing AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(FAIL_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("failed", TaskState::Failed.as_str()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("retry_at", SurrealDatetime::from(retry_at)))
|
||||
.bind(("error_code", error.code.clone()))
|
||||
.bind(("error_message", error.message.clone()))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Fail))
|
||||
}
|
||||
|
||||
/// Get all unfinished tasks, ie newly created and in progress up two times
|
||||
pub async fn get_unfinished_tasks(db: &SurrealDbClient) -> Result<Vec<Self>, AppError> {
|
||||
let jobs: Vec<Self> = db
|
||||
pub async fn mark_dead_letter(
|
||||
&self,
|
||||
error: TaskErrorInfo,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
const DEAD_LETTER_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $dead,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $now,
|
||||
error_code = $error_code,
|
||||
error_message = $error_message,
|
||||
last_error_at = $now
|
||||
WHERE state = $failed
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(DEAD_LETTER_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("dead", TaskState::DeadLetter.as_str()))
|
||||
.bind(("failed", TaskState::Failed.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("error_code", error.code.clone()))
|
||||
.bind(("error_message", error.message.clone()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::DeadLetter))
|
||||
}
|
||||
|
||||
pub async fn mark_cancelled(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const CANCEL_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $cancelled,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE
|
||||
WHERE state IN $allow_states
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(CANCEL_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("cancelled", TaskState::Cancelled.as_str()))
|
||||
.bind((
|
||||
"allow_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Cancel))
|
||||
}
|
||||
|
||||
pub async fn release(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const RELEASE_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $pending,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE
|
||||
WHERE state = $reserved
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(RELEASE_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("pending", TaskState::Pending.as_str()))
|
||||
.bind(("reserved", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Release))
|
||||
}
|
||||
|
||||
pub async fn get_unfinished_tasks(
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<IngestionTask>, AppError> {
|
||||
let tasks: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE
|
||||
status.name = 'Created'
|
||||
OR (
|
||||
status.name = 'InProgress'
|
||||
AND status.attempts < $max_attempts
|
||||
)
|
||||
ORDER BY created_at ASC",
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE state IN $active_states
|
||||
ORDER BY scheduled_at ASC, created_at ASC",
|
||||
)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("max_attempts", MAX_ATTEMPTS))
|
||||
.bind((
|
||||
"active_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
TaskState::Failed.as_str(),
|
||||
],
|
||||
))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(jobs)
|
||||
Ok(tasks)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
|
||||
// Helper function to create a test ingestion payload
|
||||
fn create_test_payload(user_id: &str) -> IngestionPayload {
|
||||
fn create_payload(user_id: &str) -> IngestionPayload {
|
||||
IngestionPayload::Text {
|
||||
text: "Test content".to_string(),
|
||||
context: "Test context".to_string(),
|
||||
@@ -119,182 +532,197 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
async fn memory_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("in-memory surrealdb")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_ingestion_task() {
|
||||
async fn test_new_task_defaults() {
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
|
||||
// Verify task properties
|
||||
assert_eq!(task.user_id, user_id);
|
||||
assert_eq!(task.content, payload);
|
||||
assert!(matches!(task.status, IngestionTaskStatus::Created));
|
||||
assert!(!task.id.is_empty());
|
||||
assert_eq!(task.state, TaskState::Pending);
|
||||
assert_eq!(task.attempts, 0);
|
||||
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
|
||||
assert!(task.locked_at.is_none());
|
||||
assert!(task.worker_id.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_add_to_db() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_create_and_store_task() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
// Create and store task
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
let created =
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
.await
|
||||
.expect("store");
|
||||
|
||||
let stored: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&created.id)
|
||||
.await
|
||||
.expect("Failed to create and add task to db");
|
||||
.expect("fetch");
|
||||
|
||||
// Query to verify task was stored
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE user_id = '{}'",
|
||||
IngestionTask::table_name(),
|
||||
user_id
|
||||
);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
let tasks: Vec<IngestionTask> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify task is in the database
|
||||
assert!(!tasks.is_empty(), "Task should exist in the database");
|
||||
let stored_task = &tasks[0];
|
||||
assert_eq!(stored_task.user_id, user_id);
|
||||
assert!(matches!(stored_task.status, IngestionTaskStatus::Created));
|
||||
let stored = stored.expect("task exists");
|
||||
assert_eq!(stored.id, created.id);
|
||||
assert_eq!(stored.state, TaskState::Pending);
|
||||
assert_eq!(stored.attempts, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_status() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_claim_and_transition() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
// Create task manually
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
let task_id = task.id.clone();
|
||||
let worker_id = "worker-1";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim");
|
||||
|
||||
// Store task
|
||||
db.store_item(task).await.expect("Failed to store task");
|
||||
let claimed = claimed.expect("task claimed");
|
||||
assert_eq!(claimed.state, TaskState::Reserved);
|
||||
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
|
||||
|
||||
// Update status to InProgress
|
||||
let now = Utc::now();
|
||||
let new_status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: now,
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
assert_eq!(processing.state, TaskState::Processing);
|
||||
|
||||
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded");
|
||||
assert_eq!(succeeded.state, TaskState::Succeeded);
|
||||
assert!(succeeded.worker_id.is_none());
|
||||
assert!(succeeded.locked_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fail_and_dead_letter() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let worker_id = "worker-dead";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim")
|
||||
.expect("claimed");
|
||||
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
|
||||
let error_info = TaskErrorInfo {
|
||||
code: Some("pipeline_error".into()),
|
||||
message: "failed".into(),
|
||||
};
|
||||
|
||||
IngestionTask::update_status(&task_id, new_status.clone(), &db)
|
||||
let failed = processing
|
||||
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
|
||||
.await
|
||||
.expect("Failed to update status");
|
||||
.expect("failed update");
|
||||
assert_eq!(failed.state, TaskState::Failed);
|
||||
assert_eq!(failed.error_message.as_deref(), Some("failed"));
|
||||
assert!(failed.worker_id.is_none());
|
||||
assert!(failed.locked_at.is_none());
|
||||
assert!(failed.scheduled_at > now);
|
||||
|
||||
// Verify status updated
|
||||
let updated_task: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&task_id)
|
||||
let dead = failed
|
||||
.mark_dead_letter(error_info.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to get updated task");
|
||||
.expect("dead letter");
|
||||
assert_eq!(dead.state, TaskState::DeadLetter);
|
||||
assert_eq!(dead.error_message.as_deref(), Some("failed"));
|
||||
}
|
||||
|
||||
assert!(updated_task.is_some());
|
||||
let updated_task = updated_task.unwrap();
|
||||
#[tokio::test]
|
||||
async fn test_mark_processing_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
match updated_task.status {
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
assert_eq!(attempts, 1);
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let err = task
|
||||
.mark_processing(&db)
|
||||
.await
|
||||
.expect_err("processing should fail without reservation");
|
||||
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> start_processing"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected InProgress status"),
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_tasks() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_mark_failed_requires_processing() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
// Create tasks with different statuses
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let mut in_progress_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
in_progress_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut max_attempts_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
max_attempts_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: MAX_ATTEMPTS,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
completed_task.status = IngestionTaskStatus::Completed;
|
||||
|
||||
let mut error_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
error_task.status = IngestionTaskStatus::Error {
|
||||
message: "Test error".to_string(),
|
||||
};
|
||||
|
||||
// Store all tasks
|
||||
db.store_item(created_task)
|
||||
let err = task
|
||||
.mark_failed(
|
||||
TaskErrorInfo {
|
||||
code: None,
|
||||
message: "boom".into(),
|
||||
},
|
||||
Duration::from_secs(30),
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
db.store_item(in_progress_task)
|
||||
.await
|
||||
.expect("Failed to store in-progress task");
|
||||
db.store_item(max_attempts_task)
|
||||
.await
|
||||
.expect("Failed to store max-attempts task");
|
||||
db.store_item(completed_task)
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
db.store_item(error_task)
|
||||
.await
|
||||
.expect("Failed to store error task");
|
||||
.expect_err("failing should require processing state");
|
||||
|
||||
// Get unfinished tasks
|
||||
let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db)
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> fail"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_release_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let err = task
|
||||
.release(&db)
|
||||
.await
|
||||
.expect("Failed to get unfinished tasks");
|
||||
.expect_err("release should require reserved state");
|
||||
|
||||
// Verify only Created and InProgress with attempts < MAX_ATTEMPTS are returned
|
||||
assert_eq!(unfinished_tasks.len(), 2);
|
||||
|
||||
let statuses: Vec<_> = unfinished_tasks
|
||||
.iter()
|
||||
.map(|task| match &task.status {
|
||||
IngestionTaskStatus::Created => "Created",
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
if *attempts < MAX_ATTEMPTS {
|
||||
"InProgress<MAX"
|
||||
} else {
|
||||
"InProgress>=MAX"
|
||||
}
|
||||
}
|
||||
IngestionTaskStatus::Completed => "Completed",
|
||||
IngestionTaskStatus::Error { .. } => "Error",
|
||||
IngestionTaskStatus::Cancelled => "Cancelled",
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert!(statuses.contains(&"Created"));
|
||||
assert!(statuses.contains(&"InProgress<MAX"));
|
||||
assert!(!statuses.contains(&"InProgress>=MAX"));
|
||||
assert!(!statuses.contains(&"Completed"));
|
||||
assert!(!statuses.contains(&"Error"));
|
||||
assert!(!statuses.contains(&"Cancelled"));
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> release"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,38 @@ impl From<String> for KnowledgeEntityType {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct KnowledgeEntitySearchResult {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
#[serde(
|
||||
serialize_with = "serialize_datetime",
|
||||
deserialize_with = "deserialize_datetime",
|
||||
default
|
||||
)]
|
||||
pub created_at: DateTime<Utc>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_datetime",
|
||||
deserialize_with = "deserialize_datetime",
|
||||
default
|
||||
)]
|
||||
pub updated_at: DateTime<Utc>,
|
||||
|
||||
pub source_id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub entity_type: KnowledgeEntityType,
|
||||
#[serde(default)]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub user_id: String,
|
||||
|
||||
pub score: f32,
|
||||
#[serde(default)]
|
||||
pub highlighted_name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub highlighted_description: Option<String>,
|
||||
}
|
||||
|
||||
stored_object!(KnowledgeEntity, "knowledge_entity", {
|
||||
source_id: String,
|
||||
name: String,
|
||||
@@ -75,6 +107,50 @@ impl KnowledgeEntity {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn search(
|
||||
db: &SurrealDbClient,
|
||||
search_terms: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
||||
let sql = r#"
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
source_id,
|
||||
name,
|
||||
description,
|
||||
entity_type,
|
||||
metadata,
|
||||
user_id,
|
||||
search::highlight('<b>', '</b>', 0) AS highlighted_name,
|
||||
search::highlight('<b>', '</b>', 1) AS highlighted_description,
|
||||
(
|
||||
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END +
|
||||
IF search::score(1) != NONE THEN search::score(1) ELSE 0 END
|
||||
) AS score
|
||||
FROM knowledge_entity
|
||||
WHERE
|
||||
(
|
||||
name @0@ $terms OR
|
||||
description @1@ $terms
|
||||
)
|
||||
AND user_id = $user_id
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit;
|
||||
"#;
|
||||
|
||||
Ok(db
|
||||
.client
|
||||
.query(sql)
|
||||
.bind(("terms", search_terms.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", limit))
|
||||
.await?
|
||||
.take(0)?)
|
||||
}
|
||||
|
||||
pub async fn delete_by_source_id(
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
@@ -103,6 +179,8 @@ impl KnowledgeEntity {
|
||||
);
|
||||
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query(
|
||||
@@ -117,7 +195,7 @@ impl KnowledgeEntity {
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", id.to_string()))
|
||||
.bind(("name", name.to_string()))
|
||||
.bind(("updated_at", Utc::now()))
|
||||
.bind(("updated_at", surrealdb::Datetime::from(now)))
|
||||
.bind(("entity_type", entity_type.to_owned()))
|
||||
.bind(("embedding", embedding))
|
||||
.bind(("description", description.to_string()))
|
||||
@@ -148,7 +226,18 @@ impl KnowledgeEntity {
|
||||
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
|
||||
let total_entities = all_entities.len();
|
||||
if total_entities == 0 {
|
||||
info!("No knowledge entities to update. Skipping.");
|
||||
info!("No knowledge entities to update. Just updating the idx");
|
||||
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
transaction_query
|
||||
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.query(transaction_query).await?;
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} entities to process.", total_entities);
|
||||
|
||||
@@ -75,13 +75,36 @@ impl KnowledgeRelationship {
|
||||
|
||||
pub async fn delete_relationship_by_id(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!("DELETE relates_to:`{}`", id);
|
||||
let mut authorized_result = db_client
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE id = relates_to:`{}` AND metadata.user_id = '{}'",
|
||||
id, user_id
|
||||
))
|
||||
.await?;
|
||||
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
|
||||
|
||||
db_client.query(query).await?;
|
||||
if authorized.is_empty() {
|
||||
let mut exists_result = db_client
|
||||
.query(format!("SELECT * FROM relates_to:`{}`", id))
|
||||
.await?;
|
||||
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
|
||||
|
||||
Ok(())
|
||||
if existing.is_some() {
|
||||
Err(AppError::Auth(
|
||||
"Not authorized to delete relationship".into(),
|
||||
))
|
||||
} else {
|
||||
Err(AppError::NotFound(format!("Relationship {} not found", id)))
|
||||
}
|
||||
} else {
|
||||
db_client
|
||||
.query(format!("DELETE relates_to:`{}`", id))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,7 +184,7 @@ mod tests {
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id,
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
@@ -209,7 +232,7 @@ mod tests {
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
user_id,
|
||||
user_id.clone(),
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
@@ -220,20 +243,107 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
// Ensure relationship exists before deletion attempt
|
||||
let mut existing_before_delete = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
||||
user_id, source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let before_results: Vec<KnowledgeRelationship> =
|
||||
existing_before_delete.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
!before_results.is_empty(),
|
||||
"Relationship should exist before deletion"
|
||||
);
|
||||
|
||||
// Delete the relationship by ID
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &db)
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationship by ID");
|
||||
|
||||
// Query to verify the relationship was deleted
|
||||
let query = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship.id);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
let mut result = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
|
||||
user_id, source_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify the relationship no longer exists
|
||||
assert!(results.is_empty(), "Relationship should be deleted");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id_unauthorized() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
|
||||
let owner_user_id = "owner-user".to_string();
|
||||
let source_id = "source123".to_string();
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
owner_user_id.clone(),
|
||||
source_id,
|
||||
"references".to_string(),
|
||||
);
|
||||
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
let mut before_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
||||
owner_user_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
|
||||
assert!(
|
||||
!before_results.is_empty(),
|
||||
"Relationship should exist before unauthorized delete attempt"
|
||||
);
|
||||
|
||||
let result = KnowledgeRelationship::delete_relationship_by_id(
|
||||
&relationship.id,
|
||||
"different-user",
|
||||
&db,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected authorization error when deleting someone else's relationship"),
|
||||
}
|
||||
|
||||
let mut after_attempt = db
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
|
||||
owner_user_id
|
||||
))
|
||||
.await
|
||||
.expect("Query failed");
|
||||
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
|
||||
|
||||
assert!(
|
||||
!results.is_empty(),
|
||||
"Relationship should still exist after unauthorized delete attempt"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
|
||||
@@ -7,6 +7,7 @@ pub mod ingestion_task;
|
||||
pub mod knowledge_entity;
|
||||
pub mod knowledge_relationship;
|
||||
pub mod message;
|
||||
pub mod scratchpad;
|
||||
pub mod system_prompts;
|
||||
pub mod system_settings;
|
||||
pub mod text_chunk;
|
||||
@@ -83,6 +84,32 @@ macro_rules! stored_object {
|
||||
Ok(DateTime::<Utc>::from(dt))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn serialize_option_datetime<S>(
|
||||
date: &Option<DateTime<Utc>>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match date {
|
||||
Some(dt) => serializer
|
||||
.serialize_some(&Into::<surrealdb::sql::Datetime>::into(*dt)),
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn deserialize_option_datetime<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<Option<DateTime<Utc>>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let value = Option::<surrealdb::sql::Datetime>::deserialize(deserializer)?;
|
||||
Ok(value.map(DateTime::<Utc>::from))
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct $name {
|
||||
@@ -92,7 +119,7 @@ macro_rules! stored_object {
|
||||
pub created_at: DateTime<Utc>,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
|
||||
pub updated_at: DateTime<Utc>,
|
||||
$(pub $field: $ty),*
|
||||
$( $(#[$attr])* pub $field: $ty),*
|
||||
}
|
||||
|
||||
impl StoredObject for $name {
|
||||
|
||||
502
common/src/storage/types/scratchpad.rs
Normal file
@@ -0,0 +1,502 @@
|
||||
use chrono::Utc as ChronoUtc;
|
||||
use surrealdb::opt::PatchOp;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
stored_object!(Scratchpad, "scratchpad", {
|
||||
user_id: String,
|
||||
title: String,
|
||||
content: String,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with="deserialize_datetime")]
|
||||
last_saved_at: DateTime<Utc>,
|
||||
is_dirty: bool,
|
||||
#[serde(default)]
|
||||
is_archived: bool,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
archived_at: Option<DateTime<Utc>>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
ingested_at: Option<DateTime<Utc>>
|
||||
});
|
||||
|
||||
impl Scratchpad {
|
||||
pub fn new(user_id: String, title: String) -> Self {
|
||||
let now = ChronoUtc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id,
|
||||
title,
|
||||
content: String::new(),
|
||||
last_saved_at: now,
|
||||
is_dirty: false,
|
||||
is_archived: false,
|
||||
archived_at: None,
|
||||
ingested_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_by_user(user_id: &str, db: &SurrealDbClient) -> Result<Vec<Self>, AppError> {
|
||||
let scratchpads: Vec<Scratchpad> = db.client
|
||||
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id AND (is_archived = false OR is_archived IS NONE) ORDER BY updated_at DESC")
|
||||
.bind(("table_name", Self::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(scratchpads)
|
||||
}
|
||||
|
||||
pub async fn get_archived_by_user(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<Self>, AppError> {
|
||||
let scratchpads: Vec<Scratchpad> = db.client
|
||||
.query("SELECT * FROM type::table($table_name) WHERE user_id = $user_id AND is_archived = true ORDER BY archived_at DESC, updated_at DESC")
|
||||
.bind(("table_name", Self::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(scratchpads)
|
||||
}
|
||||
|
||||
pub async fn get_by_id(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Self, AppError> {
|
||||
let scratchpad: Option<Scratchpad> = db.get_item(id).await?;
|
||||
|
||||
let scratchpad =
|
||||
scratchpad.ok_or_else(|| AppError::NotFound("Scratchpad not found".to_string()))?;
|
||||
|
||||
if scratchpad.user_id != user_id {
|
||||
return Err(AppError::Auth(
|
||||
"You don't have access to this scratchpad".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(scratchpad)
|
||||
}
|
||||
|
||||
pub async fn update_content(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
new_content: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Self, AppError> {
|
||||
// First verify ownership
|
||||
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
if scratchpad.is_archived {
|
||||
return Ok(scratchpad);
|
||||
}
|
||||
|
||||
let now = ChronoUtc::now();
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/content", new_content.to_string()))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(now),
|
||||
))
|
||||
.patch(PatchOp::replace(
|
||||
"/last_saved_at",
|
||||
surrealdb::Datetime::from(now),
|
||||
))
|
||||
.patch(PatchOp::replace("/is_dirty", false))
|
||||
.await?;
|
||||
|
||||
// Return the updated scratchpad
|
||||
Self::get_by_id(id, user_id, db).await
|
||||
}
|
||||
|
||||
pub async fn update_title(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
new_title: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
// First verify ownership
|
||||
let _scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(ChronoUtc::now()),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete(id: &str, user_id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
// First verify ownership
|
||||
let _scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
let _: Option<Self> = db.client.delete((Self::table_name(), id)).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn archive(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
mark_ingested: bool,
|
||||
) -> Result<Self, AppError> {
|
||||
// Verify ownership
|
||||
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
if scratchpad.is_archived {
|
||||
if mark_ingested && scratchpad.ingested_at.is_none() {
|
||||
// Ensure ingested_at is set if required
|
||||
let surreal_now = surrealdb::Datetime::from(ChronoUtc::now());
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/ingested_at", surreal_now))
|
||||
.await?;
|
||||
return Self::get_by_id(id, user_id, db).await;
|
||||
}
|
||||
return Ok(scratchpad);
|
||||
}
|
||||
|
||||
let now = ChronoUtc::now();
|
||||
let surreal_now = surrealdb::Datetime::from(now);
|
||||
let mut update = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/is_archived", true))
|
||||
.patch(PatchOp::replace("/archived_at", surreal_now.clone()))
|
||||
.patch(PatchOp::replace("/updated_at", surreal_now.clone()));
|
||||
|
||||
update = if mark_ingested {
|
||||
update.patch(PatchOp::replace("/ingested_at", surreal_now))
|
||||
} else {
|
||||
update.patch(PatchOp::remove("/ingested_at"))
|
||||
};
|
||||
|
||||
let _updated: Option<Self> = update.await?;
|
||||
|
||||
Self::get_by_id(id, user_id, db).await
|
||||
}
|
||||
|
||||
pub async fn restore(id: &str, user_id: &str, db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
// Verify ownership
|
||||
let scratchpad = Self::get_by_id(id, user_id, db).await?;
|
||||
|
||||
if !scratchpad.is_archived {
|
||||
return Ok(scratchpad);
|
||||
}
|
||||
|
||||
let now = ChronoUtc::now();
|
||||
let surreal_now = surrealdb::Datetime::from(now);
|
||||
let _updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/is_archived", false))
|
||||
.patch(PatchOp::remove("/archived_at"))
|
||||
.patch(PatchOp::remove("/ingested_at"))
|
||||
.patch(PatchOp::replace("/updated_at", surreal_now))
|
||||
.await?;
|
||||
|
||||
Self::get_by_id(id, user_id, db).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_scratchpad() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
// Create a new scratchpad
|
||||
let user_id = "test_user";
|
||||
let title = "Test Scratchpad";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), title.to_string());
|
||||
|
||||
// Verify scratchpad properties
|
||||
assert_eq!(scratchpad.user_id, user_id);
|
||||
assert_eq!(scratchpad.title, title);
|
||||
assert_eq!(scratchpad.content, "");
|
||||
assert!(!scratchpad.is_dirty);
|
||||
assert!(!scratchpad.is_archived);
|
||||
assert!(scratchpad.archived_at.is_none());
|
||||
assert!(scratchpad.ingested_at.is_none());
|
||||
assert!(!scratchpad.id.is_empty());
|
||||
|
||||
// Store the scratchpad
|
||||
let result = db.store_item(scratchpad.clone()).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it can be retrieved
|
||||
let retrieved: Option<Scratchpad> = db
|
||||
.get_item(&scratchpad.id)
|
||||
.await
|
||||
.expect("Failed to retrieve scratchpad");
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.id, scratchpad.id);
|
||||
assert_eq!(retrieved.user_id, user_id);
|
||||
assert_eq!(retrieved.title, title);
|
||||
assert!(!retrieved.is_archived);
|
||||
assert!(retrieved.archived_at.is_none());
|
||||
assert!(retrieved.ingested_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_user() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user";
|
||||
|
||||
// Create multiple scratchpads
|
||||
let scratchpad1 = Scratchpad::new(user_id.to_string(), "First".to_string());
|
||||
let scratchpad2 = Scratchpad::new(user_id.to_string(), "Second".to_string());
|
||||
let scratchpad3 = Scratchpad::new("other_user".to_string(), "Other".to_string());
|
||||
|
||||
// Store them
|
||||
let scratchpad1_id = scratchpad1.id.clone();
|
||||
let scratchpad2_id = scratchpad2.id.clone();
|
||||
db.store_item(scratchpad1).await.unwrap();
|
||||
db.store_item(scratchpad2).await.unwrap();
|
||||
db.store_item(scratchpad3).await.unwrap();
|
||||
|
||||
// Archive one of the user's scratchpads
|
||||
Scratchpad::archive(&scratchpad2_id, user_id, &db, false)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get scratchpads for user_id
|
||||
let user_scratchpads = Scratchpad::get_by_user(user_id, &db).await.unwrap();
|
||||
assert_eq!(user_scratchpads.len(), 1);
|
||||
assert_eq!(user_scratchpads[0].id, scratchpad1_id);
|
||||
|
||||
// Verify they belong to the user
|
||||
for scratchpad in &user_scratchpads {
|
||||
assert_eq!(scratchpad.user_id, user_id);
|
||||
}
|
||||
|
||||
let archived = Scratchpad::get_archived_by_user(user_id, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(archived.len(), 1);
|
||||
assert_eq!(archived[0].id, scratchpad2_id);
|
||||
assert!(archived[0].is_archived);
|
||||
assert!(archived[0].ingested_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_archive_and_restore() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, true)
|
||||
.await
|
||||
.expect("Failed to archive");
|
||||
assert!(archived.is_archived);
|
||||
assert!(archived.archived_at.is_some());
|
||||
assert!(archived.ingested_at.is_some());
|
||||
|
||||
let restored = Scratchpad::restore(&scratchpad_id, user_id, &db)
|
||||
.await
|
||||
.expect("Failed to restore");
|
||||
assert!(!restored.is_archived);
|
||||
assert!(restored.archived_at.is_none());
|
||||
assert!(restored.ingested_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let new_content = "Updated content";
|
||||
let updated = Scratchpad::update_content(&scratchpad_id, user_id, new_content, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(updated.content, new_content);
|
||||
assert!(!updated.is_dirty);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_content_unauthorized() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let result = Scratchpad::update_content(&scratchpad_id, other_user, "Hacked", &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_scratchpad() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user";
|
||||
let scratchpad = Scratchpad::new(user_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
// Delete should succeed
|
||||
let result = Scratchpad::delete(&scratchpad_id, user_id, &db).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify it's gone
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
||||
assert!(retrieved.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_unauthorized() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let owner_id = "owner";
|
||||
let other_user = "other_user";
|
||||
let scratchpad = Scratchpad::new(owner_id.to_string(), "Test".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let result = Scratchpad::delete(&scratchpad_id, other_user, &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::Auth(_)) => {}
|
||||
_ => panic!("Expected Auth error"),
|
||||
}
|
||||
|
||||
// Verify it still exists
|
||||
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timezone_aware_scratchpad_conversion() {
|
||||
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to create test database");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let user_id = "test_user_123";
|
||||
let scratchpad =
|
||||
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
|
||||
let scratchpad_id = scratchpad.id.clone();
|
||||
|
||||
db.store_item(scratchpad).await.unwrap();
|
||||
|
||||
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Test that datetime fields are preserved and can be used for timezone formatting
|
||||
assert!(retrieved.created_at.timestamp() > 0);
|
||||
assert!(retrieved.updated_at.timestamp() > 0);
|
||||
assert!(retrieved.last_saved_at.timestamp() > 0);
|
||||
|
||||
// Test that optional datetime fields work correctly
|
||||
assert!(retrieved.archived_at.is_none());
|
||||
assert!(retrieved.ingested_at.is_none());
|
||||
|
||||
// Archive the scratchpad to test optional datetime handling
|
||||
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(archived.archived_at.is_some());
|
||||
assert!(archived.archived_at.unwrap().timestamp() > 0);
|
||||
assert!(archived.ingested_at.is_none());
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ pub struct SystemSettings {
|
||||
pub ingestion_system_prompt: String,
|
||||
pub image_processing_model: String,
|
||||
pub image_processing_prompt: String,
|
||||
pub voice_processing_model: String,
|
||||
}
|
||||
|
||||
impl StoredObject for SystemSettings {
|
||||
@@ -52,11 +53,60 @@ impl SystemSettings {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||
use async_openai::Client;
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn get_hnsw_index_dimension(
|
||||
db: &SurrealDbClient,
|
||||
table_name: &str,
|
||||
index_name: &str,
|
||||
) -> u32 {
|
||||
let query = format!("INFO FOR TABLE {table_name};");
|
||||
let mut response = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Failed to fetch table info");
|
||||
|
||||
let info: Option<serde_json::Value> = response
|
||||
.take(0)
|
||||
.expect("Failed to extract table info response");
|
||||
|
||||
let info = info.expect("Table info result missing");
|
||||
|
||||
let indexes = info
|
||||
.get("indexes")
|
||||
.or_else(|| {
|
||||
info.get("tables")
|
||||
.and_then(|tables| tables.get(table_name))
|
||||
.and_then(|table| table.get("indexes"))
|
||||
})
|
||||
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info:#?}"));
|
||||
|
||||
let definition = indexes
|
||||
.get(index_name)
|
||||
.and_then(|definition| definition.as_str())
|
||||
.unwrap_or_else(|| panic!("Index definition not found in table info: {info:#?}"));
|
||||
|
||||
let dimension_part = definition
|
||||
.split("DIMENSION")
|
||||
.nth(1)
|
||||
.expect("Index definition missing DIMENSION clause");
|
||||
|
||||
let dimension_token = dimension_part
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.expect("Dimension value missing in definition")
|
||||
.trim_end_matches(';');
|
||||
|
||||
dimension_token
|
||||
.parse::<u32>()
|
||||
.expect("Dimension value is not a valid number")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() {
|
||||
// Setup in-memory database for testing
|
||||
@@ -254,4 +304,74 @@ mod tests {
|
||||
|
||||
assert!(migration_result.is_ok(), "Migrations should not fail");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length() {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
|
||||
let mut current_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to load current settings");
|
||||
|
||||
let initial_chunk_dimension =
|
||||
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
|
||||
|
||||
assert_eq!(
|
||||
initial_chunk_dimension, current_settings.embedding_dimensions,
|
||||
"embedding size should match initial system settings"
|
||||
);
|
||||
|
||||
let new_dimension = 768;
|
||||
let new_model = "new-test-embedding-model".to_string();
|
||||
|
||||
current_settings.embedding_dimensions = new_dimension;
|
||||
current_settings.embedding_model = new_model.clone();
|
||||
|
||||
let updated_settings = SystemSettings::update(&db, current_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
|
||||
assert_eq!(
|
||||
updated_settings.embedding_dimensions, new_dimension,
|
||||
"Settings should reflect the new embedding dimension"
|
||||
);
|
||||
|
||||
let openai_client = Client::new();
|
||||
|
||||
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("TextChunk re-embedding should succeed on fresh DB");
|
||||
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
|
||||
|
||||
let text_chunk_dimension =
|
||||
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
|
||||
let knowledge_dimension =
|
||||
get_hnsw_index_dimension(&db, "knowledge_entity", "idx_embedding_entities").await;
|
||||
|
||||
assert_eq!(
|
||||
text_chunk_dimension, new_dimension,
|
||||
"text_chunk index dimension should update"
|
||||
);
|
||||
assert_eq!(
|
||||
knowledge_dimension, new_dimension,
|
||||
"knowledge_entity index dimension should update"
|
||||
);
|
||||
|
||||
let persisted_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to reload updated settings");
|
||||
assert_eq!(
|
||||
persisted_settings.embedding_dimensions, new_dimension,
|
||||
"Settings should persist new embedding dimension"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +68,17 @@ impl TextChunk {
|
||||
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
|
||||
let total_chunks = all_chunks.len();
|
||||
if total_chunks == 0 {
|
||||
info!("No text chunks to update. Skipping.");
|
||||
info!("No text chunks to update. Just updating the idx");
|
||||
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions));
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} chunks to process.", total_chunks);
|
||||
|
||||
@@ -101,12 +101,35 @@ impl TextContent {
|
||||
.patch(PatchOp::replace("/context", context))
|
||||
.patch(PatchOp::replace("/category", category))
|
||||
.patch(PatchOp::replace("/text", text))
|
||||
.patch(PatchOp::replace("/updated_at", now))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(now),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn has_other_with_file(
|
||||
file_id: &str,
|
||||
exclude_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<bool, AppError> {
|
||||
let mut response = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT VALUE id FROM type::table($table_name) WHERE file_info.id = $file_id AND id != type::thing($table_name, $exclude_id) LIMIT 1",
|
||||
)
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.bind(("file_id", file_id.to_owned()))
|
||||
.bind(("exclude_id", exclude_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let existing: Option<surrealdb::sql::Thing> = response.take(0)?;
|
||||
|
||||
Ok(existing.is_some())
|
||||
}
|
||||
|
||||
pub async fn search(
|
||||
db: &SurrealDbClient,
|
||||
search_terms: &str,
|
||||
@@ -273,4 +296,64 @@ mod tests {
|
||||
assert_eq!(updated_content.text, new_text);
|
||||
assert!(updated_content.updated_at > text_content.updated_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_has_other_with_file_detects_shared_usage() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
let file_info = FileInfo {
|
||||
id: "file-1".to_string(),
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
sha256: "sha-test".to_string(),
|
||||
path: "user123/file-1/test.txt".to_string(),
|
||||
file_name: "test.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
user_id: user_id.clone(),
|
||||
};
|
||||
|
||||
let content_a = TextContent::new(
|
||||
"First".to_string(),
|
||||
Some("ctx-a".to_string()),
|
||||
"category".to_string(),
|
||||
Some(file_info.clone()),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
let content_b = TextContent::new(
|
||||
"Second".to_string(),
|
||||
Some("ctx-b".to_string()),
|
||||
"category".to_string(),
|
||||
Some(file_info.clone()),
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
db.store_item(content_a.clone())
|
||||
.await
|
||||
.expect("Failed to store first content");
|
||||
db.store_item(content_b.clone())
|
||||
.await
|
||||
.expect("Failed to store second content");
|
||||
|
||||
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||
.await
|
||||
.expect("Failed to check for shared file usage");
|
||||
assert!(has_other);
|
||||
|
||||
let _removed: Option<TextContent> = db
|
||||
.delete_item(&content_b.id)
|
||||
.await
|
||||
.expect("Failed to delete second content");
|
||||
|
||||
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
|
||||
.await
|
||||
.expect("Failed to check shared usage after delete");
|
||||
assert!(!has_other_after);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
use async_trait::async_trait;
|
||||
use axum_session_auth::Authentication;
|
||||
use chrono_tz::Tz;
|
||||
use surrealdb::{engine::any::Any, Surreal};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::text_chunk::TextChunk;
|
||||
use super::{
|
||||
conversation::Conversation, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, system_settings::SystemSettings,
|
||||
conversation::Conversation,
|
||||
ingestion_task::{IngestionTask, TaskState},
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
system_settings::SystemSettings,
|
||||
text_content::TextContent,
|
||||
};
|
||||
use chrono::Duration;
|
||||
use futures::try_join;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CategoryResponse {
|
||||
@@ -49,9 +56,6 @@ impl Authentication<User, String, Surreal<Any>> for User {
|
||||
}
|
||||
|
||||
fn validate_timezone(input: &str) -> String {
|
||||
use chrono_tz::Tz;
|
||||
|
||||
// Check if it's a valid IANA timezone identifier
|
||||
match input.parse::<Tz>() {
|
||||
Ok(_) => input.to_owned(),
|
||||
Err(_) => {
|
||||
@@ -61,7 +65,93 @@ fn validate_timezone(input: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct DashboardStats {
|
||||
pub total_documents: i64,
|
||||
pub new_documents_week: i64,
|
||||
pub total_entities: i64,
|
||||
pub new_entities_week: i64,
|
||||
pub total_conversations: i64,
|
||||
pub new_conversations_week: i64,
|
||||
pub total_text_chunks: i64,
|
||||
pub new_text_chunks_week: i64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CountResult {
|
||||
count: i64,
|
||||
}
|
||||
|
||||
impl User {
|
||||
async fn count_total<T: crate::storage::types::StoredObject>(
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<i64, AppError> {
|
||||
let result: Option<CountResult> = db
|
||||
.client
|
||||
.query("SELECT count() as count FROM type::table($table) WHERE user_id = $user_id GROUP ALL")
|
||||
.bind(("table", T::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
Ok(result.map(|r| r.count).unwrap_or(0))
|
||||
}
|
||||
|
||||
async fn count_since<T: crate::storage::types::StoredObject>(
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
since: chrono::DateTime<chrono::Utc>,
|
||||
) -> Result<i64, AppError> {
|
||||
let result: Option<CountResult> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT count() as count FROM type::table($table) WHERE user_id = $user_id AND created_at >= $since GROUP ALL",
|
||||
)
|
||||
.bind(("table", T::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.bind(("since", surrealdb::Datetime::from(since)))
|
||||
.await?
|
||||
.take(0)?;
|
||||
Ok(result.map(|r| r.count).unwrap_or(0))
|
||||
}
|
||||
|
||||
pub async fn get_dashboard_stats(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<DashboardStats, AppError> {
|
||||
let since = chrono::Utc::now() - Duration::days(7);
|
||||
|
||||
let (
|
||||
total_documents,
|
||||
new_documents_week,
|
||||
total_entities,
|
||||
new_entities_week,
|
||||
total_conversations,
|
||||
new_conversations_week,
|
||||
total_text_chunks,
|
||||
new_text_chunks_week,
|
||||
) = try_join!(
|
||||
Self::count_total::<TextContent>(db, user_id),
|
||||
Self::count_since::<TextContent>(db, user_id, since),
|
||||
Self::count_total::<KnowledgeEntity>(db, user_id),
|
||||
Self::count_since::<KnowledgeEntity>(db, user_id, since),
|
||||
Self::count_total::<Conversation>(db, user_id),
|
||||
Self::count_since::<Conversation>(db, user_id, since),
|
||||
Self::count_total::<TextChunk>(db, user_id),
|
||||
Self::count_since::<TextChunk>(db, user_id, since)
|
||||
)?;
|
||||
|
||||
Ok(DashboardStats {
|
||||
total_documents,
|
||||
new_documents_week,
|
||||
total_entities,
|
||||
new_entities_week,
|
||||
total_conversations,
|
||||
new_conversations_week,
|
||||
total_text_chunks,
|
||||
new_text_chunks_week,
|
||||
})
|
||||
}
|
||||
pub async fn create_new(
|
||||
email: String,
|
||||
password: String,
|
||||
@@ -78,7 +168,7 @@ impl User {
|
||||
let now = Utc::now();
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"LET $count = (SELECT count() FROM type::table($table))[0].count;
|
||||
@@ -95,8 +185,8 @@ impl User {
|
||||
.bind(("id", id))
|
||||
.bind(("email", email))
|
||||
.bind(("password", password))
|
||||
.bind(("created_at", now))
|
||||
.bind(("updated_at", now))
|
||||
.bind(("created_at", surrealdb::Datetime::from(now)))
|
||||
.bind(("updated_at", surrealdb::Datetime::from(now)))
|
||||
.bind(("timezone", validated_tz))
|
||||
.await?
|
||||
.take(1)?;
|
||||
@@ -127,7 +217,7 @@ impl User {
|
||||
password: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Self, AppError> {
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM user
|
||||
@@ -145,7 +235,7 @@ impl User {
|
||||
email: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM user WHERE email = $email LIMIT 1")
|
||||
.bind(("email", email.to_string()))
|
||||
@@ -159,7 +249,7 @@ impl User {
|
||||
api_key: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Option<Self>, AppError> {
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query("SELECT * FROM user WHERE api_key = $api_key LIMIT 1")
|
||||
.bind(("api_key", api_key.to_string()))
|
||||
@@ -174,7 +264,7 @@ impl User {
|
||||
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
|
||||
|
||||
// Update the user record with the new API key
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('user', $id)
|
||||
@@ -195,7 +285,7 @@ impl User {
|
||||
}
|
||||
|
||||
pub async fn revoke_api_key(id: &str, db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
let user: Option<User> = db
|
||||
let user: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"UPDATE type::thing('user', $id)
|
||||
@@ -266,7 +356,10 @@ impl User {
|
||||
// Extract the entity types from the response
|
||||
let entity_types: Vec<String> = response
|
||||
.into_iter()
|
||||
.map(|item| format!("{:?}", item.entity_type))
|
||||
.map(|item| {
|
||||
let normalized = KnowledgeEntityType::from(item.entity_type);
|
||||
format!("{:?}", normalized)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(entity_types)
|
||||
@@ -356,7 +449,7 @@ impl User {
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
db.query("UPDATE type::thing('user', $user_id) SET timezone = $timezone")
|
||||
.bind(("table_name", User::table_name()))
|
||||
.bind(("table_name", Self::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.bind(("timezone", timezone.to_string()))
|
||||
.await?;
|
||||
@@ -442,19 +535,43 @@ impl User {
|
||||
let jobs: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE user_id = $user_id
|
||||
AND (
|
||||
status = 'Created'
|
||||
OR (
|
||||
status.InProgress != NONE
|
||||
AND status.InProgress.attempts < $max_attempts
|
||||
)
|
||||
)
|
||||
ORDER BY created_at DESC",
|
||||
WHERE user_id = $user_id
|
||||
AND (
|
||||
state IN $active_states
|
||||
OR (state = $failed_state AND attempts < max_attempts)
|
||||
)
|
||||
ORDER BY scheduled_at ASC, created_at DESC",
|
||||
)
|
||||
.bind(("table", IngestionTask::table_name()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind((
|
||||
"active_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind(("failed_state", TaskState::Failed.as_str()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(jobs)
|
||||
}
|
||||
|
||||
/// Gets all ingestion tasks for the specified user ordered by newest first
|
||||
pub async fn get_all_ingestion_tasks(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<IngestionTask>, AppError> {
|
||||
let jobs: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE user_id = $user_id
|
||||
ORDER BY created_at DESC",
|
||||
)
|
||||
.bind(("table", IngestionTask::table_name()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("max_attempts", 3))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
@@ -511,6 +628,9 @@ impl User {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||
use std::collections::HashSet;
|
||||
|
||||
// Helper function to set up a test database with SystemSettings
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
@@ -596,6 +716,122 @@ mod tests {
|
||||
assert!(nonexistent.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_ingestion_tasks_filters_correctly() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "unfinished_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
let payload = IngestionPayload::Text {
|
||||
text: "Test".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
};
|
||||
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
db.store_item(created_task.clone())
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
|
||||
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
processing_task.state = TaskState::Processing;
|
||||
processing_task.attempts = 1;
|
||||
db.store_item(processing_task.clone())
|
||||
.await
|
||||
.expect("Failed to store processing task");
|
||||
|
||||
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
failed_retry_task.state = TaskState::Failed;
|
||||
failed_retry_task.attempts = 1;
|
||||
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
|
||||
db.store_item(failed_retry_task.clone())
|
||||
.await
|
||||
.expect("Failed to store retryable failed task");
|
||||
|
||||
let mut failed_blocked_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
failed_blocked_task.state = TaskState::Failed;
|
||||
failed_blocked_task.attempts = MAX_ATTEMPTS;
|
||||
failed_blocked_task.error_message = Some("Too many failures".into());
|
||||
db.store_item(failed_blocked_task.clone())
|
||||
.await
|
||||
.expect("Failed to store blocked task");
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
completed_task.state = TaskState::Succeeded;
|
||||
db.store_item(completed_task.clone())
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: other_user_id.to_string(),
|
||||
};
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||
db.store_item(other_task)
|
||||
.await
|
||||
.expect("Failed to store other user task");
|
||||
|
||||
let unfinished = User::get_unfinished_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to fetch unfinished tasks");
|
||||
|
||||
let unfinished_ids: HashSet<String> =
|
||||
unfinished.iter().map(|task| task.id.clone()).collect();
|
||||
|
||||
assert!(unfinished_ids.contains(&created_task.id));
|
||||
assert!(unfinished_ids.contains(&processing_task.id));
|
||||
assert!(unfinished_ids.contains(&failed_retry_task.id));
|
||||
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
|
||||
assert!(!unfinished_ids.contains(&completed_task.id));
|
||||
assert_eq!(unfinished_ids.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_all_ingestion_tasks_returns_sorted() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "archive_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
let payload = IngestionPayload::Text {
|
||||
text: "One".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
};
|
||||
|
||||
// Oldest task
|
||||
let mut first = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
first.created_at = first.created_at - chrono::Duration::minutes(1);
|
||||
first.updated_at = first.created_at;
|
||||
first.state = TaskState::Succeeded;
|
||||
db.store_item(first.clone()).await.expect("store first");
|
||||
|
||||
// Latest task
|
||||
let mut second = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
second.state = TaskState::Processing;
|
||||
db.store_item(second.clone()).await.expect("store second");
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: other_user_id.to_string(),
|
||||
};
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string());
|
||||
db.store_item(other_task).await.expect("store other");
|
||||
|
||||
let tasks = User::get_all_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("fetch all tasks");
|
||||
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0].id, second.id); // newest first
|
||||
assert_eq!(tasks[1].id, first.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_by_email() {
|
||||
// Setup test database
|
||||
@@ -816,4 +1052,56 @@ mod tests {
|
||||
let most_recent = conversations.iter().max_by_key(|c| c.created_at).unwrap();
|
||||
assert_eq!(retrieved[0].id, most_recent.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_latest_text_contents_returns_last_five() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "latest_text_user";
|
||||
|
||||
let mut inserted_ids = Vec::new();
|
||||
let base_time = chrono::Utc::now() - chrono::Duration::minutes(60);
|
||||
|
||||
for i in 0..12 {
|
||||
let mut item = TextContent::new(
|
||||
format!("Text {}", i),
|
||||
Some(format!("Context {}", i)),
|
||||
"Category".to_string(),
|
||||
None,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
|
||||
let timestamp = base_time + chrono::Duration::minutes(i);
|
||||
item.created_at = timestamp;
|
||||
item.updated_at = timestamp;
|
||||
|
||||
db.store_item(item.clone())
|
||||
.await
|
||||
.expect("Failed to store text content");
|
||||
|
||||
inserted_ids.push(item.id.clone());
|
||||
}
|
||||
|
||||
let latest = User::get_latest_text_contents(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to fetch latest text contents");
|
||||
|
||||
assert_eq!(latest.len(), 5, "Expected exactly five items");
|
||||
|
||||
let mut expected_ids = inserted_ids[inserted_ids.len() - 5..].to_vec();
|
||||
expected_ids.reverse();
|
||||
|
||||
let returned_ids: Vec<String> = latest.iter().map(|item| item.id.clone()).collect();
|
||||
assert_eq!(
|
||||
returned_ids, expected_ids,
|
||||
"Latest items did not match expectation"
|
||||
);
|
||||
|
||||
for window in latest.windows(2) {
|
||||
assert!(
|
||||
window[0].created_at >= window[1].created_at,
|
||||
"Results are not ordered by created_at descending"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,31 @@
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
|
||||
#[derive(Clone, Deserialize, Debug, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum StorageKind {
|
||||
Local,
|
||||
Memory,
|
||||
}
|
||||
|
||||
fn default_storage_kind() -> StorageKind {
|
||||
StorageKind::Local
|
||||
}
|
||||
|
||||
/// Selects the strategy used for PDF ingestion.
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum PdfIngestMode {
|
||||
/// Only rely on classic text extraction (no LLM fallbacks).
|
||||
Classic,
|
||||
/// Prefer fast text extraction, but fall back to the LLM rendering path when needed.
|
||||
LlmFirst,
|
||||
}
|
||||
|
||||
fn default_pdf_ingest_mode() -> PdfIngestMode {
|
||||
PdfIngestMode::LlmFirst
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
pub struct AppConfig {
|
||||
@@ -14,6 +40,20 @@ pub struct AppConfig {
|
||||
pub http_port: u16,
|
||||
#[serde(default = "default_base_url")]
|
||||
pub openai_base_url: String,
|
||||
#[serde(default = "default_storage_kind")]
|
||||
pub storage: StorageKind,
|
||||
#[serde(default = "default_pdf_ingest_mode")]
|
||||
pub pdf_ingest_mode: PdfIngestMode,
|
||||
#[serde(default = "default_reranking_enabled")]
|
||||
pub reranking_enabled: bool,
|
||||
#[serde(default)]
|
||||
pub reranking_pool_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub fastembed_cache_dir: Option<String>,
|
||||
#[serde(default)]
|
||||
pub fastembed_show_download_progress: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub fastembed_max_length: Option<usize>,
|
||||
}
|
||||
|
||||
fn default_data_dir() -> String {
|
||||
@@ -24,11 +64,70 @@ fn default_base_url() -> String {
|
||||
"https://api.openai.com/v1".to_string()
|
||||
}
|
||||
|
||||
fn default_reranking_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn ensure_ort_path() {
|
||||
if env::var_os("ORT_DYLIB_PATH").is_some() {
|
||||
return;
|
||||
}
|
||||
if let Ok(mut exe) = env::current_exe() {
|
||||
exe.pop();
|
||||
|
||||
if cfg!(target_os = "windows") {
|
||||
for p in [
|
||||
exe.join("onnxruntime.dll"),
|
||||
exe.join("lib").join("onnxruntime.dll"),
|
||||
] {
|
||||
if p.exists() {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
let name = if cfg!(target_os = "macos") {
|
||||
"libonnxruntime.dylib"
|
||||
} else {
|
||||
"libonnxruntime.so"
|
||||
};
|
||||
let p = exe.join("lib").join(name);
|
||||
if p.exists() {
|
||||
env::set_var("ORT_DYLIB_PATH", p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
openai_api_key: String::new(),
|
||||
surrealdb_address: String::new(),
|
||||
surrealdb_username: String::new(),
|
||||
surrealdb_password: String::new(),
|
||||
surrealdb_namespace: String::new(),
|
||||
surrealdb_database: String::new(),
|
||||
data_dir: default_data_dir(),
|
||||
http_port: 0,
|
||||
openai_base_url: default_base_url(),
|
||||
storage: default_storage_kind(),
|
||||
pdf_ingest_mode: default_pdf_ingest_mode(),
|
||||
reranking_enabled: default_reranking_enabled(),
|
||||
reranking_pool_size: None,
|
||||
fastembed_cache_dir: None,
|
||||
fastembed_show_download_progress: None,
|
||||
fastembed_max_length: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_config() -> Result<AppConfig, ConfigError> {
|
||||
ensure_ort_path();
|
||||
|
||||
let config = Config::builder()
|
||||
.add_source(File::with_name("config").required(false))
|
||||
.add_source(Environment::default())
|
||||
.build()?;
|
||||
|
||||
Ok(config.try_deserialize()?)
|
||||
config.try_deserialize()
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ pub async fn generate_embedding_with_params(
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model)
|
||||
.input([input])
|
||||
.dimensions(dimensions as u32)
|
||||
.dimensions(dimensions)
|
||||
.build()?;
|
||||
|
||||
let response = client.embeddings().create(request).await?;
|
||||
|
||||
@@ -59,13 +59,13 @@ impl TemplateEngine {
|
||||
match self {
|
||||
// Only compile this arm for debug builds
|
||||
#[cfg(debug_assertions)]
|
||||
TemplateEngine::AutoReload(reloader) => {
|
||||
Self::AutoReload(reloader) => {
|
||||
let env = reloader.acquire_env()?;
|
||||
env.get_template(name)?.render(ctx)
|
||||
}
|
||||
// Only compile this arm for release builds
|
||||
#[cfg(not(debug_assertions))]
|
||||
TemplateEngine::Embedded(env) => env.get_template(name)?.render(ctx),
|
||||
Self::Embedded(env) => env.get_template(name)?.render(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,19 +78,17 @@ impl TemplateEngine {
|
||||
match self {
|
||||
// Only compile this arm for debug builds
|
||||
#[cfg(debug_assertions)]
|
||||
TemplateEngine::AutoReload(reloader) => {
|
||||
let env = reloader.acquire_env()?;
|
||||
let template = env.get_template(template_name)?;
|
||||
let mut state = template.eval_to_state(context)?;
|
||||
state.render_block(block_name)
|
||||
}
|
||||
Self::AutoReload(reloader) => reloader
|
||||
.acquire_env()?
|
||||
.get_template(template_name)?
|
||||
.eval_to_state(context)?
|
||||
.render_block(block_name),
|
||||
// Only compile this arm for release builds
|
||||
#[cfg(not(debug_assertions))]
|
||||
TemplateEngine::Embedded(env) => {
|
||||
let template = env.get_template(template_name)?;
|
||||
let mut state = template.eval_to_state(context)?;
|
||||
state.render_block(block_name)
|
||||
}
|
||||
Self::Embedded(env) => env
|
||||
.get_template(template_name)?
|
||||
.eval_to_state(context)?
|
||||
.render_block(block_name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
@@ -16,5 +19,7 @@ surrealdb = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
state-machines = { workspace = true }
|
||||
|
||||
@@ -8,19 +8,13 @@ use async_openai::{
|
||||
};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
storage::types::{
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::retrieve_entities;
|
||||
use serde_json::Value;
|
||||
|
||||
use super::answer_retrieval_helper::get_query_response_schema;
|
||||
|
||||
@@ -37,80 +31,23 @@ pub struct LLMResponseFormat {
|
||||
pub references: Vec<Reference>,
|
||||
}
|
||||
|
||||
/// Orchestrates query processing and returns an answer with references
|
||||
///
|
||||
/// Takes a query and uses the provided clients to generate an answer with supporting references.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `surreal_db_client` - Client for SurrealDB interactions
|
||||
/// * `openai_client` - Client for OpenAI API calls
|
||||
/// * `query` - The user's query string
|
||||
/// * `user_id` - The user's id
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a tuple of the answer and its references, or an API error
|
||||
#[derive(Debug)]
|
||||
pub struct Answer {
|
||||
pub content: String,
|
||||
pub references: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn get_answer_with_references(
|
||||
surreal_db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Answer, AppError> {
|
||||
let entities = retrieve_entities(surreal_db_client, openai_client, query, user_id).await?;
|
||||
let settings = SystemSettings::get_current(surreal_db_client).await?;
|
||||
|
||||
let entities_json = format_entities_json(&entities);
|
||||
let user_message = create_user_message(&entities_json, query);
|
||||
|
||||
let request = create_chat_request(user_message, &settings)?;
|
||||
let response = openai_client.chat().create(request).await?;
|
||||
|
||||
let llm_response = process_llm_response(response).await?;
|
||||
|
||||
Ok(Answer {
|
||||
content: llm_response.answer,
|
||||
references: llm_response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
|
||||
json!(entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entity.id,
|
||||
"name": entity.name,
|
||||
"description": entity.description
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r#"
|
||||
r"
|
||||
Context Information:
|
||||
==================
|
||||
{}
|
||||
{entities_json}
|
||||
|
||||
User Question:
|
||||
==================
|
||||
{}
|
||||
"#,
|
||||
entities_json, query
|
||||
{query}
|
||||
"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -120,7 +57,7 @@ pub fn create_user_message_with_history(
|
||||
query: &str,
|
||||
) -> String {
|
||||
format!(
|
||||
r#"
|
||||
r"
|
||||
Chat history:
|
||||
==================
|
||||
{}
|
||||
@@ -132,7 +69,7 @@ pub fn create_user_message_with_history(
|
||||
User Question:
|
||||
==================
|
||||
{}
|
||||
"#,
|
||||
",
|
||||
format_history(history),
|
||||
entities_json,
|
||||
query
|
||||
@@ -154,8 +91,6 @@ pub fn create_chat_request(
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.model(&settings.query_model)
|
||||
.temperature(0.2)
|
||||
.max_tokens(3048u32)
|
||||
.messages([
|
||||
ChatCompletionRequestSystemMessage::from(settings.query_system_prompt.clone()).into(),
|
||||
ChatCompletionRequestUserMessage::from(user_message).into(),
|
||||
@@ -176,7 +111,7 @@ pub async fn process_llm_response(
|
||||
))
|
||||
.and_then(|content| {
|
||||
serde_json::from_str::<LLMResponseFormat>(content).map_err(|e| {
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {}", e))
|
||||
AppError::LLMParsing(format!("Failed to parse LLM response into analysis: {e}"))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
265
composite-retrieval/src/fts.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::StoredObject},
|
||||
};
|
||||
|
||||
use crate::scoring::Scored;
|
||||
use common::storage::types::file_info::deserialize_flexible_id;
|
||||
use surrealdb::sql::Thing;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FtsScoreRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
fts_score: Option<f32>,
|
||||
}
|
||||
|
||||
/// Executes a full-text search query against SurrealDB and returns scored results.
|
||||
///
|
||||
/// The function expects FTS indexes to exist for the provided table. Currently supports
|
||||
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
|
||||
pub async fn find_items_by_fts<T>(
|
||||
take: usize,
|
||||
query: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
let (filter_clause, score_clause) = match table {
|
||||
"knowledge_entity" => (
|
||||
"(name @0@ $terms OR description @1@ $terms)",
|
||||
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
|
||||
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
|
||||
),
|
||||
"text_chunk" => (
|
||||
"(chunk @0@ $terms)",
|
||||
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
|
||||
),
|
||||
_ => {
|
||||
return Err(AppError::Validation(format!(
|
||||
"FTS not configured for table '{table}'"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
"SELECT id, {score_clause} AS fts_score \
|
||||
FROM {table} \
|
||||
WHERE {filter_clause} \
|
||||
AND user_id = $user_id \
|
||||
ORDER BY fts_score DESC \
|
||||
LIMIT $limit",
|
||||
table = table,
|
||||
filter_clause = filter_clause,
|
||||
score_clause = score_clause
|
||||
);
|
||||
|
||||
debug!(
|
||||
table = table,
|
||||
limit = take,
|
||||
"Executing FTS query with filter clause: {}",
|
||||
filter_clause
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(sql)
|
||||
.bind(("terms", query.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
|
||||
|
||||
if score_rows.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
|
||||
let thing_ids: Vec<Thing> = ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((table, id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut items_response = db_client
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", table.to_owned()))
|
||||
.bind(("things", thing_ids.clone()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let items: Vec<T> = items_response.take(0)?;
|
||||
|
||||
let mut item_map: HashMap<String, T> = items
|
||||
.into_iter()
|
||||
.map(|item| (item.get_id().to_owned(), item))
|
||||
.collect();
|
||||
|
||||
let mut results = Vec::with_capacity(score_rows.len());
|
||||
for row in score_rows {
|
||||
if let Some(item) = item_map.remove(&row.id) {
|
||||
let score = row.fts_score.unwrap_or_default();
|
||||
results.push(Scored::new(item).with_fts_score(score));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
text_chunk::TextChunk,
|
||||
StoredObject,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn dummy_embedding() -> Vec<f32> {
|
||||
vec![0.0; 1536]
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_single_field_score_for_name() {
|
||||
let namespace = "fts_test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_a".into(),
|
||||
"Rustacean handbook".into(),
|
||||
"completely unrelated description".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to insert entity");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results = find_items_by_fts::<KnowledgeEntity>(
|
||||
5,
|
||||
"rustacean",
|
||||
&db,
|
||||
KnowledgeEntity::table_name(),
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when only the name matched"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_single_field_score_for_description() {
|
||||
let namespace = "fts_test_ns_desc";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts_desc";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_b".into(),
|
||||
"neutral name".into(),
|
||||
"Detailed notes about async runtimes".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to insert entity");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results = find_items_by_fts::<KnowledgeEntity>(
|
||||
5,
|
||||
"async",
|
||||
&db,
|
||||
KnowledgeEntity::table_name(),
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when only the description matched"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_scores_for_text_chunks() {
|
||||
let namespace = "fts_test_ns_chunks";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts_chunk";
|
||||
let chunk = TextChunk::new(
|
||||
"source_chunk".into(),
|
||||
"GraphQL documentation reference".into(),
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.expect("failed to insert chunk");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results =
|
||||
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when chunk field matched"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,14 @@
|
||||
use surrealdb::Error;
|
||||
use tracing::debug;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
|
||||
use surrealdb::{sql::Thing, Error};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
StoredObject,
|
||||
},
|
||||
};
|
||||
|
||||
/// Retrieves database entries that match a specific source identifier.
|
||||
///
|
||||
@@ -13,7 +20,7 @@ use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEnt
|
||||
///
|
||||
/// * `source_id` - The identifier to search for in the database
|
||||
/// * `table_name` - The name of the table to search in
|
||||
/// * `db_client` - The SurrealDB client instance for database operations
|
||||
/// * `db_client` - The `SurrealDB` client instance for database operations
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
@@ -31,18 +38,21 @@ use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEnt
|
||||
/// * The database query fails to execute
|
||||
/// * The results cannot be deserialized into type `T`
|
||||
pub async fn find_entities_by_source_ids<T>(
|
||||
source_id: Vec<String>,
|
||||
table_name: String,
|
||||
source_ids: Vec<String>,
|
||||
table_name: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
|
||||
let query =
|
||||
"SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id";
|
||||
|
||||
db.query(query)
|
||||
.bind(("table", table_name))
|
||||
.bind(("source_ids", source_id))
|
||||
.bind(("table", table_name.to_owned()))
|
||||
.bind(("source_ids", source_ids))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?
|
||||
.take(0)
|
||||
}
|
||||
@@ -50,16 +60,92 @@ where
|
||||
/// Find entities by their relationship to the id
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: String,
|
||||
entity_id: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let query = format!(
|
||||
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
|
||||
entity_id
|
||||
);
|
||||
let mut relationships_response = db
|
||||
.query(
|
||||
"
|
||||
SELECT * FROM relates_to
|
||||
WHERE metadata.user_id = $user_id
|
||||
AND (in = type::thing('knowledge_entity', $entity_id)
|
||||
OR out = type::thing('knowledge_entity', $entity_id))
|
||||
",
|
||||
)
|
||||
.bind(("entity_id", entity_id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
debug!("{}", query);
|
||||
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
|
||||
if relationships.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
db.query(query).await?.take(0)
|
||||
let mut neighbor_ids: Vec<String> = Vec::new();
|
||||
let mut seen: HashSet<String> = HashSet::new();
|
||||
for rel in relationships {
|
||||
if rel.in_ == entity_id {
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
} else if rel.out == entity_id {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_);
|
||||
}
|
||||
} else {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_.clone());
|
||||
}
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
neighbor_ids.retain(|id| id != entity_id);
|
||||
|
||||
if neighbor_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
if limit > 0 && neighbor_ids.len() > limit {
|
||||
neighbor_ids.truncate(limit);
|
||||
}
|
||||
|
||||
let thing_ids: Vec<Thing> = neighbor_ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut neighbors_response = db
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", KnowledgeEntity::table_name().to_owned()))
|
||||
.bind(("things", thing_ids))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
|
||||
if neighbors.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
|
||||
.into_iter()
|
||||
.map(|entity| (entity.id.clone(), entity))
|
||||
.collect();
|
||||
|
||||
let mut ordered = Vec::new();
|
||||
for id in neighbor_ids {
|
||||
if let Some(entity) = neighbor_map.remove(&id) {
|
||||
ordered.push(entity);
|
||||
}
|
||||
if limit > 0 && ordered.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ordered)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -149,7 +235,7 @@ mod tests {
|
||||
// Test finding entities by multiple source_ids
|
||||
let source_ids = vec![source_id1.clone(), source_id2.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name().to_string(), &db)
|
||||
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db)
|
||||
.await
|
||||
.expect("Failed to find entities by source_ids");
|
||||
|
||||
@@ -180,7 +266,8 @@ mod tests {
|
||||
let single_source_id = vec![source_id1.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
single_source_id,
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
KnowledgeEntity::table_name(),
|
||||
&user_id,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
@@ -205,7 +292,8 @@ mod tests {
|
||||
let non_existent_source_id = vec!["non_existent_source".to_string()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
non_existent_source_id,
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
KnowledgeEntity::table_name(),
|
||||
&user_id,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
@@ -330,11 +418,15 @@ mod tests {
|
||||
.expect("Failed to store relationship 2");
|
||||
|
||||
// Test finding entities related to the central entity
|
||||
let related_entities = find_entities_by_relationship_by_id(&db, central_entity.id.clone())
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
let related_entities =
|
||||
find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX)
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
|
||||
// Check that we found relationships
|
||||
assert!(related_entities.len() > 0, "Should find related entities");
|
||||
assert!(
|
||||
related_entities.len() >= 2,
|
||||
"Should find related entities in both directions"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
pub mod answer_retrieval;
|
||||
pub mod answer_retrieval_helper;
|
||||
pub mod fts;
|
||||
pub mod graph;
|
||||
pub mod pipeline;
|
||||
pub mod reranking;
|
||||
pub mod scoring;
|
||||
pub mod vector;
|
||||
|
||||
use common::{
|
||||
@@ -10,81 +14,254 @@ use common::{
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
},
|
||||
};
|
||||
use futures::future::{try_join, try_join_all};
|
||||
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
|
||||
use std::collections::HashMap;
|
||||
use vector::find_items_by_vector_similarity;
|
||||
use reranking::RerankerLease;
|
||||
use tracing::instrument;
|
||||
|
||||
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
|
||||
/// to find the most relevant entities for a given query.
|
||||
///
|
||||
/// # Strategy
|
||||
/// The function employs a three-pronged approach to knowledge retrieval:
|
||||
/// 1. Direct vector similarity search on knowledge entities
|
||||
/// 2. Text chunk similarity search with source entity lookup
|
||||
/// 3. Graph relationship traversal from related entities
|
||||
///
|
||||
/// This combined approach ensures both semantic similarity matches and structurally
|
||||
/// related content are included in the results.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - SurrealDB client for database operations
|
||||
/// * `openai_client` - OpenAI client for vector embeddings generation
|
||||
/// * `query` - The search query string to find relevant knowledge entities
|
||||
/// * 'user_id' - The user id of the current user
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
|
||||
/// knowledge entities, or an error if the retrieval process fails
|
||||
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedChunk {
|
||||
pub chunk: TextChunk,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
// Final entity representation returned to callers, enriched with ranked chunks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedEntity {
|
||||
pub entity: KnowledgeEntity,
|
||||
pub score: f32,
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
||||
find_items_by_vector_similarity(
|
||||
10,
|
||||
query,
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
openai_client,
|
||||
user_id,
|
||||
),
|
||||
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
pipeline::run_pipeline(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text,
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
reranker,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let source_ids = closest_chunks
|
||||
.iter()
|
||||
.map(|chunk: &TextChunk| chunk.source_id.clone())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
|
||||
|
||||
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
|
||||
.collect();
|
||||
|
||||
let items_from_relationships = try_join_all(items_from_relationships_futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<KnowledgeEntity>>();
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
|
||||
.into_iter()
|
||||
.chain(items_from_text_chunk_similarity.into_iter())
|
||||
.chain(items_from_relationships.into_iter())
|
||||
.fold(HashMap::new(), |mut map, entity| {
|
||||
map.insert(entity.id.clone(), entity);
|
||||
map
|
||||
})
|
||||
.into_values()
|
||||
.collect();
|
||||
|
||||
Ok(entities)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_openai::Client;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
};
|
||||
use pipeline::RetrievalConfig;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
vec![0.9, 0.1, 0.0]
|
||||
}
|
||||
|
||||
fn entity_embedding_high() -> Vec<f32> {
|
||||
vec![0.8, 0.2, 0.0]
|
||||
}
|
||||
|
||||
fn entity_embedding_low() -> Vec<f32> {
|
||||
vec![0.1, 0.9, 0.0]
|
||||
}
|
||||
|
||||
fn chunk_embedding_primary() -> Vec<f32> {
|
||||
vec![0.85, 0.15, 0.0]
|
||||
}
|
||||
|
||||
fn chunk_embedding_secondary() -> Vec<f32> {
|
||||
vec![0.2, 0.8, 0.0]
|
||||
}
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db.query(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
|
||||
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
|
||||
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.expect("Failed to configure indices");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retrieve_entities_with_embedding_basic_flow() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "test_user";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_1".into(),
|
||||
"Rust async guide".into(),
|
||||
"Detailed notes about async runtimes".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk = TextChunk::new(
|
||||
entity.source_id.clone(),
|
||||
"Tokio uses cooperative scheduling for fairness.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("Failed to store entity");
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.expect("Failed to store chunk");
|
||||
|
||||
let openai_client = Client::new();
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
assert!(
|
||||
!results.is_empty(),
|
||||
"Expected at least one retrieval result"
|
||||
);
|
||||
let top = &results[0];
|
||||
assert!(
|
||||
top.entity.name.contains("Rust"),
|
||||
"Expected Rust entity to be ranked first"
|
||||
);
|
||||
assert!(
|
||||
!top.chunks.is_empty(),
|
||||
"Expected Rust entity to include supporting chunks"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_graph_relationship_enriches_results() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "graph_user";
|
||||
|
||||
let primary = KnowledgeEntity::new(
|
||||
"primary_source".into(),
|
||||
"Async Rust patterns".into(),
|
||||
"Explores async runtimes and scheduling strategies.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let neighbor = KnowledgeEntity::new(
|
||||
"neighbor_source".into(),
|
||||
"Tokio Scheduler Deep Dive".into(),
|
||||
"Details on Tokio's cooperative scheduler.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_low(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(primary.clone())
|
||||
.await
|
||||
.expect("Failed to store primary entity");
|
||||
db.store_item(neighbor.clone())
|
||||
.await
|
||||
.expect("Failed to store neighbor entity");
|
||||
|
||||
let primary_chunk = TextChunk::new(
|
||||
primary.source_id.clone(),
|
||||
"Rust async tasks use Tokio's cooperative scheduler.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let neighbor_chunk = TextChunk::new(
|
||||
neighbor.source_id.clone(),
|
||||
"Tokio's scheduler manages task fairness across executors.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(primary_chunk)
|
||||
.await
|
||||
.expect("Failed to store primary chunk");
|
||||
db.store_item(neighbor_chunk)
|
||||
.await
|
||||
.expect("Failed to store neighbor chunk");
|
||||
|
||||
let openai_client = Client::new();
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
primary.id.clone(),
|
||||
neighbor.id.clone(),
|
||||
user_id.into(),
|
||||
"relationship_source".into(),
|
||||
"references".into(),
|
||||
);
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
RetrievalConfig::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
let mut neighbor_entry = None;
|
||||
for entity in &results {
|
||||
if entity.entity.id == neighbor.id {
|
||||
neighbor_entry = Some(entity.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let neighbor_entry =
|
||||
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
|
||||
|
||||
assert!(
|
||||
neighbor_entry.score > 0.2,
|
||||
"Graph-enriched entity should have a meaningful fused score"
|
||||
);
|
||||
assert!(
|
||||
neighbor_entry
|
||||
.chunks
|
||||
.iter()
|
||||
.all(|chunk| chunk.chunk.source_id == neighbor.source_id),
|
||||
"Neighbor entity should surface its own supporting chunks"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
67
composite-retrieval/src/pipeline/config.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Tunable parameters that govern each retrieval stage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalTuning {
|
||||
pub entity_vector_take: usize,
|
||||
pub chunk_vector_take: usize,
|
||||
pub entity_fts_take: usize,
|
||||
pub chunk_fts_take: usize,
|
||||
pub score_threshold: f32,
|
||||
pub fallback_min_results: usize,
|
||||
pub token_budget_estimate: usize,
|
||||
pub avg_chars_per_token: usize,
|
||||
pub max_chunks_per_entity: usize,
|
||||
pub graph_traversal_seed_limit: usize,
|
||||
pub graph_neighbor_limit: usize,
|
||||
pub graph_score_decay: f32,
|
||||
pub graph_seed_min_score: f32,
|
||||
pub graph_vector_inheritance: f32,
|
||||
pub rerank_blend_weight: f32,
|
||||
pub rerank_scores_only: bool,
|
||||
pub rerank_keep_top: usize,
|
||||
}
|
||||
|
||||
impl Default for RetrievalTuning {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
entity_vector_take: 15,
|
||||
chunk_vector_take: 20,
|
||||
entity_fts_take: 10,
|
||||
chunk_fts_take: 20,
|
||||
score_threshold: 0.35,
|
||||
fallback_min_results: 10,
|
||||
token_budget_estimate: 2800,
|
||||
avg_chars_per_token: 4,
|
||||
max_chunks_per_entity: 4,
|
||||
graph_traversal_seed_limit: 5,
|
||||
graph_neighbor_limit: 6,
|
||||
graph_score_decay: 0.75,
|
||||
graph_seed_min_score: 0.4,
|
||||
graph_vector_inheritance: 0.6,
|
||||
rerank_blend_weight: 0.65,
|
||||
rerank_scores_only: false,
|
||||
rerank_keep_top: 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper containing tuning plus future flags for per-request overrides.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalConfig {
|
||||
pub tuning: RetrievalTuning,
|
||||
}
|
||||
|
||||
impl RetrievalConfig {
|
||||
pub fn new(tuning: RetrievalTuning) -> Self {
|
||||
Self { tuning }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tuning: RetrievalTuning::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
106
composite-retrieval/src/pipeline/mod.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
mod config;
|
||||
mod stages;
|
||||
mod state;
|
||||
|
||||
pub use config::{RetrievalConfig, RetrievalTuning};
|
||||
|
||||
use crate::{reranking::RerankerLease, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use tracing::info;
|
||||
|
||||
/// Drives the retrieval pipeline from embedding through final assembly.
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let input_chars = input_text.chars().count();
|
||||
let input_preview: String = input_text.chars().take(120).collect();
|
||||
let input_preview_clean = input_preview.replace('\n', " ");
|
||||
let preview_len = input_preview_clean.chars().count();
|
||||
info!(
|
||||
%user_id,
|
||||
input_chars,
|
||||
preview_truncated = input_chars > preview_len,
|
||||
preview = %input_preview_clean,
|
||||
"Starting ingestion retrieval pipeline"
|
||||
);
|
||||
let mut ctx = stages::PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let machine = stages::rerank(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let machine = state::ready();
|
||||
let mut ctx = stages::PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
query_embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
let machine = stages::embed(machine, &mut ctx).await?;
|
||||
let machine = stages::collect_candidates(machine, &mut ctx).await?;
|
||||
let machine = stages::expand_graph(machine, &mut ctx).await?;
|
||||
let machine = stages::attach_chunks(machine, &mut ctx).await?;
|
||||
let machine = stages::rerank(machine, &mut ctx).await?;
|
||||
let results = stages::assemble(machine, &mut ctx)?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
|
||||
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
|
||||
serde_json::json!(entities
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
serde_json::json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entry.entity.id,
|
||||
"name": entry.entity.name,
|
||||
"description": entry.entity.description,
|
||||
"score": round_score(entry.score),
|
||||
"chunks": entry.chunks.iter().map(|chunk| {
|
||||
serde_json::json!({
|
||||
"score": round_score(chunk.score),
|
||||
"content": chunk.chunk.chunk
|
||||
})
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
(f64::from(value) * 1000.0).round() / 1000.0
|
||||
}
|
||||
769
composite-retrieval/src/pipeline/stages/mod.rs
Normal file
@@ -0,0 +1,769 @@
|
||||
use async_openai::Client;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use fastembed::RerankResult;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use state_machines::core::GuardError;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use tracing::{debug, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
fts::find_items_by_fts,
|
||||
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
|
||||
reranking::RerankerLease,
|
||||
scoring::{
|
||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||
FusionWeights, Scored,
|
||||
},
|
||||
vector::find_items_by_vector_similarity_with_embedding,
|
||||
RetrievedChunk, RetrievedEntity,
|
||||
};
|
||||
|
||||
use super::{
|
||||
config::RetrievalConfig,
|
||||
state::{
|
||||
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
|
||||
Reranked,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct PipelineContext<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub input_text: String,
|
||||
pub user_id: String,
|
||||
pub config: RetrievalConfig,
|
||||
pub query_embedding: Option<Vec<f32>>,
|
||||
pub entity_candidates: HashMap<String, Scored<KnowledgeEntity>>,
|
||||
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
|
||||
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
}
|
||||
|
||||
impl<'a> PipelineContext<'a> {
|
||||
pub fn new(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
query_embedding: None,
|
||||
entity_candidates: HashMap::new(),
|
||||
chunk_candidates: HashMap::new(),
|
||||
filtered_entities: Vec::new(),
|
||||
chunk_values: Vec::new(),
|
||||
reranker,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn with_embedding(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
reranker: Option<RerankerLease>,
|
||||
) -> Self {
|
||||
let mut ctx = Self::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
reranker,
|
||||
);
|
||||
ctx.query_embedding = Some(query_embedding);
|
||||
ctx
|
||||
}
|
||||
|
||||
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
|
||||
self.query_embedding.as_ref().ok_or_else(|| {
|
||||
AppError::InternalError(
|
||||
"query embedding missing before candidate collection".to_string(),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn embed(
|
||||
machine: HybridRetrievalMachine<(), Ready>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), Embedded>, AppError> {
|
||||
let embedding_cached = ctx.query_embedding.is_some();
|
||||
if embedding_cached {
|
||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||
} else {
|
||||
debug!("Generating query embedding for hybrid retrieval");
|
||||
let embedding =
|
||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?;
|
||||
ctx.query_embedding = Some(embedding);
|
||||
}
|
||||
|
||||
machine
|
||||
.embed()
|
||||
.map_err(|(_, guard)| map_guard_error("embed", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn collect_candidates(
|
||||
machine: HybridRetrievalMachine<(), Embedded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
|
||||
debug!("Collecting initial candidates via vector and FTS search");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
tuning.entity_vector_take,
|
||||
embedding.clone(),
|
||||
ctx.db_client,
|
||||
"knowledge_entity",
|
||||
&ctx.user_id,
|
||||
),
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
"text_chunk",
|
||||
&ctx.user_id,
|
||||
),
|
||||
find_items_by_fts(
|
||||
tuning.entity_fts_take,
|
||||
&ctx.input_text,
|
||||
ctx.db_client,
|
||||
"knowledge_entity",
|
||||
&ctx.user_id,
|
||||
),
|
||||
find_items_by_fts(
|
||||
tuning.chunk_fts_take,
|
||||
&ctx.input_text,
|
||||
ctx.db_client,
|
||||
"text_chunk",
|
||||
&ctx.user_id
|
||||
),
|
||||
)?;
|
||||
|
||||
debug!(
|
||||
vector_entities = vector_entities.len(),
|
||||
vector_chunks = vector_chunks.len(),
|
||||
fts_entities = fts_entities.len(),
|
||||
fts_chunks = fts_chunks.len(),
|
||||
"Hybrid retrieval initial candidate counts"
|
||||
);
|
||||
|
||||
normalize_fts_scores(&mut fts_entities);
|
||||
normalize_fts_scores(&mut fts_chunks);
|
||||
|
||||
merge_scored_by_id(&mut ctx.entity_candidates, vector_entities);
|
||||
merge_scored_by_id(&mut ctx.entity_candidates, fts_entities);
|
||||
merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks);
|
||||
merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks);
|
||||
|
||||
apply_fusion(&mut ctx.entity_candidates, weights);
|
||||
apply_fusion(&mut ctx.chunk_candidates, weights);
|
||||
|
||||
machine
|
||||
.collect_candidates()
|
||||
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn expand_graph(
|
||||
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
|
||||
debug!("Expanding candidates using graph relationships");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
if ctx.entity_candidates.is_empty() {
|
||||
return machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||
}
|
||||
|
||||
let graph_seeds = seeds_from_candidates(
|
||||
&ctx.entity_candidates,
|
||||
tuning.graph_seed_min_score,
|
||||
tuning.graph_traversal_seed_limit,
|
||||
);
|
||||
|
||||
if graph_seeds.is_empty() {
|
||||
return machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
|
||||
}
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in graph_seeds {
|
||||
let db = ctx.db_client;
|
||||
let user = ctx.user_id.clone();
|
||||
futures.push(async move {
|
||||
let neighbors = find_entities_by_relationship_by_id(
|
||||
db,
|
||||
&seed.id,
|
||||
&user,
|
||||
tuning.graph_neighbor_limit,
|
||||
)
|
||||
.await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph || entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.expand_graph()
|
||||
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn attach_chunks(
|
||||
machine: HybridRetrievalMachine<(), GraphExpanded>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
|
||||
debug!("Attaching chunks to surviving entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
|
||||
|
||||
backfill_entities_from_chunks(
|
||||
&mut ctx.entity_candidates,
|
||||
&chunk_by_source,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
weights,
|
||||
)
|
||||
.await?;
|
||||
|
||||
boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights);
|
||||
|
||||
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
|
||||
ctx.entity_candidates.values().cloned().collect();
|
||||
sort_by_fused_desc(&mut entity_results);
|
||||
|
||||
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
|
||||
.iter()
|
||||
.filter(|candidate| candidate.fused >= tuning.score_threshold)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if filtered_entities.len() < tuning.fallback_min_results {
|
||||
filtered_entities = entity_results
|
||||
.into_iter()
|
||||
.take(tuning.fallback_min_results)
|
||||
.collect();
|
||||
}
|
||||
|
||||
ctx.filtered_entities = filtered_entities;
|
||||
|
||||
let mut chunk_results: Vec<Scored<TextChunk>> =
|
||||
ctx.chunk_candidates.values().cloned().collect();
|
||||
sort_by_fused_desc(&mut chunk_results);
|
||||
|
||||
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||
for chunk in chunk_results {
|
||||
chunk_by_id.insert(chunk.item.id.clone(), chunk);
|
||||
}
|
||||
|
||||
enrich_chunks_from_entities(
|
||||
&mut chunk_by_id,
|
||||
&ctx.filtered_entities,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
weights,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
|
||||
sort_by_fused_desc(&mut chunk_values);
|
||||
|
||||
ctx.chunk_values = chunk_values;
|
||||
|
||||
machine
|
||||
.attach_chunks()
|
||||
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub async fn rerank(
|
||||
machine: HybridRetrievalMachine<(), ChunksAttached>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
|
||||
let mut applied = false;
|
||||
|
||||
if let Some(reranker) = ctx.reranker.as_ref() {
|
||||
if ctx.filtered_entities.len() > 1 {
|
||||
let documents = build_rerank_documents(ctx, ctx.config.tuning.max_chunks_per_entity);
|
||||
|
||||
if documents.len() > 1 {
|
||||
match reranker.rerank(&ctx.input_text, documents).await {
|
||||
Ok(results) if !results.is_empty() => {
|
||||
apply_rerank_results(ctx, results);
|
||||
applied = true;
|
||||
}
|
||||
Ok(_) => {
|
||||
debug!("Reranker returned no results; retaining original ordering");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
"Reranking failed; continuing with original ordering"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
document_count = documents.len(),
|
||||
"Skipping reranking stage; insufficient document context"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
debug!("Skipping reranking stage; less than two entities available");
|
||||
}
|
||||
} else {
|
||||
debug!("No reranker lease provided; skipping reranking stage");
|
||||
}
|
||||
|
||||
if applied {
|
||||
debug!("Applied reranking adjustments to candidate ordering");
|
||||
}
|
||||
|
||||
machine
|
||||
.rerank()
|
||||
.map_err(|(_, guard)| map_guard_error("rerank", guard))
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub fn assemble(
|
||||
machine: HybridRetrievalMachine<(), Reranked>,
|
||||
ctx: &mut PipelineContext<'_>,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
debug!("Assembling final retrieved entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
|
||||
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
for chunk in ctx.chunk_values.drain(..) {
|
||||
chunk_by_source
|
||||
.entry(chunk.item.source_id.clone())
|
||||
.or_default()
|
||||
.push(chunk);
|
||||
}
|
||||
|
||||
for chunk_list in chunk_by_source.values_mut() {
|
||||
sort_by_fused_desc(chunk_list);
|
||||
}
|
||||
|
||||
let mut token_budget_remaining = tuning.token_budget_estimate;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for entity in &ctx.filtered_entities {
|
||||
let mut selected_chunks = Vec::new();
|
||||
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
|
||||
let mut per_entity_count = 0;
|
||||
candidates.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
for candidate in candidates.iter() {
|
||||
if per_entity_count >= tuning.max_chunks_per_entity {
|
||||
break;
|
||||
}
|
||||
let estimated_tokens =
|
||||
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
per_entity_count += 1;
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
score: candidate.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results.push(RetrievedEntity {
|
||||
entity: entity.item.clone(),
|
||||
score: entity.fused,
|
||||
chunks: selected_chunks,
|
||||
});
|
||||
|
||||
if token_budget_remaining == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
machine
|
||||
.assemble()
|
||||
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
|
||||
AppError::InternalError(format!(
|
||||
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
|
||||
err.guard, err.event, err.kind
|
||||
))
|
||||
}
|
||||
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
||||
let raw_scores: Vec<f32> = results
|
||||
.iter()
|
||||
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
|
||||
.collect();
|
||||
|
||||
let normalized = min_max_normalize(&raw_scores);
|
||||
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
|
||||
candidate.scores.fts = Some(normalized_score);
|
||||
candidate.update_fused(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
||||
where
|
||||
T: StoredObject,
|
||||
{
|
||||
for candidate in candidates.values_mut() {
|
||||
let fused = fuse_scores(&candidate.scores, weights);
|
||||
candidate.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
fn group_chunks_by_source(
|
||||
chunks: &HashMap<String, Scored<TextChunk>>,
|
||||
) -> HashMap<String, Vec<Scored<TextChunk>>> {
|
||||
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
|
||||
for chunk in chunks.values() {
|
||||
by_source
|
||||
.entry(chunk.item.source_id.clone())
|
||||
.or_default()
|
||||
.push(chunk.clone());
|
||||
}
|
||||
by_source
|
||||
}
|
||||
|
||||
async fn backfill_entities_from_chunks(
|
||||
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
||||
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
) -> Result<(), AppError> {
|
||||
let mut missing_sources = Vec::new();
|
||||
|
||||
for source_id in chunk_by_source.keys() {
|
||||
if !entity_candidates
|
||||
.values()
|
||||
.any(|entity| entity.item.source_id == *source_id)
|
||||
{
|
||||
missing_sources.push(source_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if missing_sources.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
missing_sources.clone(),
|
||||
"knowledge_entity",
|
||||
user_id,
|
||||
db_client,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
if related_entities.is_empty() {
|
||||
warn!("expected related entities for missing chunk sources, but none were found");
|
||||
}
|
||||
|
||||
for entity in related_entities {
|
||||
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
|
||||
let best_chunk_score = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.fused)
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
|
||||
let fused = fuse_scores(&scored.scores, weights);
|
||||
scored.update_fused(fused);
|
||||
entity_candidates.insert(entity.id.clone(), scored);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn boost_entities_with_chunks(
|
||||
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
||||
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
|
||||
weights: FusionWeights,
|
||||
) {
|
||||
for entity in entity_candidates.values_mut() {
|
||||
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
|
||||
let best_chunk_score = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.fused)
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
if best_chunk_score > 0.0 {
|
||||
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
|
||||
entity.scores.vector = Some(boosted);
|
||||
let fused = fuse_scores(&entity.scores, weights);
|
||||
entity.update_fused(fused);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn enrich_chunks_from_entities(
|
||||
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
|
||||
entities: &[Scored<KnowledgeEntity>],
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
) -> Result<(), AppError> {
|
||||
let mut source_ids: HashSet<String> = HashSet::new();
|
||||
for entity in entities {
|
||||
source_ids.insert(entity.item.source_id.clone());
|
||||
}
|
||||
|
||||
if source_ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let chunks = find_entities_by_source_ids::<TextChunk>(
|
||||
source_ids.into_iter().collect(),
|
||||
"text_chunk",
|
||||
user_id,
|
||||
db_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
|
||||
for entity in entities {
|
||||
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
|
||||
}
|
||||
|
||||
for chunk in chunks {
|
||||
let entry = chunk_candidates
|
||||
.entry(chunk.id.clone())
|
||||
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
|
||||
|
||||
let entity_score = entity_score_lookup
|
||||
.get(&chunk.source_id)
|
||||
.copied()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
entry.item = chunk;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
|
||||
if ctx.filtered_entities.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut chunk_by_source: HashMap<&str, Vec<&Scored<TextChunk>>> = HashMap::new();
|
||||
for chunk in &ctx.chunk_values {
|
||||
chunk_by_source
|
||||
.entry(chunk.item.source_id.as_str())
|
||||
.or_default()
|
||||
.push(chunk);
|
||||
}
|
||||
|
||||
ctx.filtered_entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
let mut doc = format!(
|
||||
"Name: {}\nType: {:?}\nDescription: {}\n",
|
||||
entity.item.name, entity.item.entity_type, entity.item.description
|
||||
);
|
||||
|
||||
if let Some(chunks) = chunk_by_source.get(entity.item.source_id.as_str()) {
|
||||
let mut chunk_refs = chunks.clone();
|
||||
chunk_refs.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let mut header_added = false;
|
||||
for chunk in chunk_refs.into_iter().take(max_chunks_per_entity.max(1)) {
|
||||
let snippet = chunk.item.chunk.trim();
|
||||
if snippet.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !header_added {
|
||||
doc.push_str("Chunks:\n");
|
||||
header_added = true;
|
||||
}
|
||||
doc.push_str("- ");
|
||||
doc.push_str(snippet);
|
||||
doc.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
doc
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult>) {
|
||||
if results.is_empty() || ctx.filtered_entities.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut remaining: Vec<Option<Scored<KnowledgeEntity>>> =
|
||||
std::mem::take(&mut ctx.filtered_entities)
|
||||
.into_iter()
|
||||
.map(Some)
|
||||
.collect();
|
||||
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
let normalized_scores = min_max_normalize(&raw_scores);
|
||||
|
||||
let use_only = ctx.config.tuning.rerank_scores_only;
|
||||
let blend = if use_only {
|
||||
1.0
|
||||
} else {
|
||||
clamp_unit(ctx.config.tuning.rerank_blend_weight)
|
||||
};
|
||||
let mut reranked: Vec<Scored<KnowledgeEntity>> = Vec::with_capacity(remaining.len());
|
||||
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||
if let Some(slot) = remaining.get_mut(result.index) {
|
||||
if let Some(mut candidate) = slot.take() {
|
||||
let original = candidate.fused;
|
||||
let blended = if use_only {
|
||||
clamp_unit(normalized)
|
||||
} else {
|
||||
clamp_unit(original * (1.0 - blend) + normalized * blend)
|
||||
};
|
||||
candidate.update_fused(blended);
|
||||
reranked.push(candidate);
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
result_index = result.index,
|
||||
"Reranker returned out-of-range index; skipping"
|
||||
);
|
||||
}
|
||||
if reranked.len() == remaining.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for slot in remaining.into_iter() {
|
||||
if let Some(candidate) = slot {
|
||||
reranked.push(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
ctx.filtered_entities = reranked;
|
||||
let keep_top = ctx.config.tuning.rerank_keep_top;
|
||||
if keep_top > 0 && ctx.filtered_entities.len() > keep_top {
|
||||
ctx.filtered_entities.truncate(keep_top);
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / avg_chars_per_token).max(1)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct GraphSeed {
|
||||
id: String,
|
||||
fused: f32,
|
||||
}
|
||||
|
||||
fn seeds_from_candidates(
|
||||
entity_candidates: &HashMap<String, Scored<KnowledgeEntity>>,
|
||||
min_score: f32,
|
||||
limit: usize,
|
||||
) -> Vec<GraphSeed> {
|
||||
let mut seeds: Vec<GraphSeed> = entity_candidates
|
||||
.values()
|
||||
.filter(|entity| entity.fused >= min_score)
|
||||
.map(|entity| GraphSeed {
|
||||
id: entity.item.id.clone(),
|
||||
fused: entity.fused,
|
||||
})
|
||||
.collect();
|
||||
|
||||
seeds.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
if seeds.len() > limit {
|
||||
seeds.truncate(limit);
|
||||
}
|
||||
|
||||
seeds
|
||||
}
|
||||
27
composite-retrieval/src/pipeline/state.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use state_machines::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: HybridRetrievalMachine,
|
||||
state: HybridRetrievalState,
|
||||
initial: Ready,
|
||||
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Reranked, Completed, Failed],
|
||||
events {
|
||||
embed { transition: { from: Ready, to: Embedded } }
|
||||
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
|
||||
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
|
||||
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
|
||||
rerank { transition: { from: ChunksAttached, to: Reranked } }
|
||||
assemble { transition: { from: Reranked, to: Completed } }
|
||||
abort {
|
||||
transition: { from: Ready, to: Failed }
|
||||
transition: { from: CandidatesLoaded, to: Failed }
|
||||
transition: { from: GraphExpanded, to: Failed }
|
||||
transition: { from: ChunksAttached, to: Failed }
|
||||
transition: { from: Reranked, to: Failed }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ready() -> HybridRetrievalMachine<(), Ready> {
|
||||
HybridRetrievalMachine::new(())
|
||||
}
|
||||
170
composite-retrieval/src/reranking/mod.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use std::{
|
||||
env, fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
thread::available_parallelism,
|
||||
};
|
||||
|
||||
use common::{error::AppError, utils::config::AppConfig};
|
||||
use fastembed::{RerankInitOptions, RerankResult, TextRerank};
|
||||
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
|
||||
use tracing::debug;
|
||||
|
||||
static NEXT_ENGINE: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
fn pick_engine_index(pool_len: usize) -> usize {
|
||||
let n = NEXT_ENGINE.fetch_add(1, Ordering::Relaxed);
|
||||
n % pool_len
|
||||
}
|
||||
|
||||
pub struct RerankerPool {
|
||||
engines: Vec<Arc<Mutex<TextRerank>>>,
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl RerankerPool {
|
||||
/// Build the pool at startup.
|
||||
/// `pool_size` controls max parallel reranks.
|
||||
pub fn new(pool_size: usize) -> Result<Arc<Self>, AppError> {
|
||||
Self::new_with_options(pool_size, RerankInitOptions::default())
|
||||
}
|
||||
|
||||
fn new_with_options(
|
||||
pool_size: usize,
|
||||
init_options: RerankInitOptions,
|
||||
) -> Result<Arc<Self>, AppError> {
|
||||
if pool_size == 0 {
|
||||
return Err(AppError::Validation(
|
||||
"RERANKING_POOL_SIZE must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
fs::create_dir_all(&init_options.cache_dir)?;
|
||||
|
||||
let mut engines = Vec::with_capacity(pool_size);
|
||||
for x in 0..pool_size {
|
||||
debug!("Creating reranking engine: {x}");
|
||||
let model = TextRerank::try_new(init_options.clone())
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
engines.push(Arc::new(Mutex::new(model)));
|
||||
}
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
engines,
|
||||
semaphore: Arc::new(Semaphore::new(pool_size)),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Initialize a pool using application configuration.
|
||||
pub fn maybe_from_config(config: &AppConfig) -> Result<Option<Arc<Self>>, AppError> {
|
||||
if !config.reranking_enabled {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let pool_size = config.reranking_pool_size.unwrap_or_else(default_pool_size);
|
||||
|
||||
let init_options = build_rerank_init_options(config)?;
|
||||
Self::new_with_options(pool_size, init_options).map(Some)
|
||||
}
|
||||
|
||||
/// Check out capacity + pick an engine.
|
||||
/// This returns a lease that can perform rerank().
|
||||
pub async fn checkout(self: &Arc<Self>) -> RerankerLease {
|
||||
// Acquire a permit. This enforces backpressure.
|
||||
let permit = self
|
||||
.semaphore
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("semaphore closed");
|
||||
|
||||
// Pick an engine.
|
||||
// This is naive: just pick based on a simple modulo counter.
|
||||
// We use an atomic counter to avoid always choosing index 0.
|
||||
let idx = pick_engine_index(self.engines.len());
|
||||
let engine = self.engines[idx].clone();
|
||||
|
||||
RerankerLease {
|
||||
_permit: permit,
|
||||
engine,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_pool_size() -> usize {
|
||||
available_parallelism()
|
||||
.map(|value| value.get().min(2))
|
||||
.unwrap_or(2)
|
||||
.max(1)
|
||||
}
|
||||
|
||||
fn is_truthy(value: &str) -> bool {
|
||||
matches!(
|
||||
value.trim().to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "on"
|
||||
)
|
||||
}
|
||||
|
||||
fn build_rerank_init_options(config: &AppConfig) -> Result<RerankInitOptions, AppError> {
|
||||
let mut options = RerankInitOptions::default();
|
||||
|
||||
let cache_dir = config
|
||||
.fastembed_cache_dir
|
||||
.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| env::var("RERANKING_CACHE_DIR").ok().map(PathBuf::from))
|
||||
.or_else(|| env::var("FASTEMBED_CACHE_DIR").ok().map(PathBuf::from))
|
||||
.unwrap_or_else(|| {
|
||||
Path::new(&config.data_dir)
|
||||
.join("fastembed")
|
||||
.join("reranker")
|
||||
});
|
||||
fs::create_dir_all(&cache_dir)?;
|
||||
options.cache_dir = cache_dir;
|
||||
|
||||
let show_progress = config
|
||||
.fastembed_show_download_progress
|
||||
.or_else(|| env_bool("RERANKING_SHOW_DOWNLOAD_PROGRESS"))
|
||||
.or_else(|| env_bool("FASTEMBED_SHOW_DOWNLOAD_PROGRESS"))
|
||||
.unwrap_or(true);
|
||||
options.show_download_progress = show_progress;
|
||||
|
||||
if let Some(max_length) = config.fastembed_max_length.or_else(|| {
|
||||
env::var("RERANKING_MAX_LENGTH")
|
||||
.ok()
|
||||
.and_then(|value| value.parse().ok())
|
||||
}) {
|
||||
options.max_length = max_length;
|
||||
}
|
||||
|
||||
Ok(options)
|
||||
}
|
||||
|
||||
fn env_bool(key: &str) -> Option<bool> {
|
||||
env::var(key).ok().map(|value| is_truthy(&value))
|
||||
}
|
||||
|
||||
/// Active lease on a single TextRerank instance.
|
||||
pub struct RerankerLease {
|
||||
// When this drops the semaphore permit is released.
|
||||
_permit: OwnedSemaphorePermit,
|
||||
engine: Arc<Mutex<TextRerank>>,
|
||||
}
|
||||
|
||||
impl RerankerLease {
|
||||
pub async fn rerank(
|
||||
&self,
|
||||
query: &str,
|
||||
documents: Vec<String>,
|
||||
) -> Result<Vec<RerankResult>, AppError> {
|
||||
// Lock this specific engine so we get &mut TextRerank
|
||||
let mut guard = self.engine.lock().await;
|
||||
|
||||
guard
|
||||
.rerank(query.to_owned(), documents, false, None)
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))
|
||||
}
|
||||
}
|
||||
183
composite-retrieval/src/scoring.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
|
||||
/// Holds optional subscores gathered from different retrieval signals.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Scores {
|
||||
pub fts: Option<f32>,
|
||||
pub vector: Option<f32>,
|
||||
pub graph: Option<f32>,
|
||||
}
|
||||
|
||||
/// Generic wrapper combining an item with its accumulated retrieval scores.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Scored<T> {
|
||||
pub item: T,
|
||||
pub scores: Scores,
|
||||
pub fused: f32,
|
||||
}
|
||||
|
||||
impl<T> Scored<T> {
|
||||
pub fn new(item: T) -> Self {
|
||||
Self {
|
||||
item,
|
||||
scores: Scores::default(),
|
||||
fused: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn with_vector_score(mut self, score: f32) -> Self {
|
||||
self.scores.vector = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn with_fts_score(mut self, score: f32) -> Self {
|
||||
self.scores.fts = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn with_graph_score(mut self, score: f32) -> Self {
|
||||
self.scores.graph = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn update_fused(&mut self, fused: f32) {
|
||||
self.fused = fused;
|
||||
}
|
||||
}
|
||||
|
||||
/// Weights used for linear score fusion.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FusionWeights {
|
||||
pub vector: f32,
|
||||
pub fts: f32,
|
||||
pub graph: f32,
|
||||
pub multi_bonus: f32,
|
||||
}
|
||||
|
||||
impl Default for FusionWeights {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vector: 0.5,
|
||||
fts: 0.3,
|
||||
graph: 0.2,
|
||||
multi_bonus: 0.02,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn clamp_unit(value: f32) -> f32 {
|
||||
value.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
pub fn distance_to_similarity(distance: f32) -> f32 {
|
||||
if !distance.is_finite() {
|
||||
return 0.0;
|
||||
}
|
||||
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
|
||||
}
|
||||
|
||||
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||
if scores.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut min = f32::MAX;
|
||||
let mut max = f32::MIN;
|
||||
|
||||
for s in scores {
|
||||
if !s.is_finite() {
|
||||
continue;
|
||||
}
|
||||
if *s < min {
|
||||
min = *s;
|
||||
}
|
||||
if *s > max {
|
||||
max = *s;
|
||||
}
|
||||
}
|
||||
|
||||
if !min.is_finite() || !max.is_finite() {
|
||||
return scores.iter().map(|_| 0.0).collect();
|
||||
}
|
||||
|
||||
if (max - min).abs() < f32::EPSILON {
|
||||
return vec![1.0; scores.len()];
|
||||
}
|
||||
|
||||
scores
|
||||
.iter()
|
||||
.map(|score| {
|
||||
if score.is_finite() {
|
||||
clamp_unit((score - min) / (max - min))
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
let vector = scores.vector.unwrap_or(0.0);
|
||||
let fts = scores.fts.unwrap_or(0.0);
|
||||
let graph = scores.graph.unwrap_or(0.0);
|
||||
|
||||
let mut fused = graph.mul_add(
|
||||
weights.graph,
|
||||
vector.mul_add(weights.vector, fts * weights.fts),
|
||||
);
|
||||
|
||||
let signals_present = scores
|
||||
.vector
|
||||
.iter()
|
||||
.chain(scores.fts.iter())
|
||||
.chain(scores.graph.iter())
|
||||
.count();
|
||||
if signals_present >= 2 {
|
||||
fused += weights.multi_bonus;
|
||||
}
|
||||
|
||||
clamp_unit(fused)
|
||||
}
|
||||
|
||||
pub fn merge_scored_by_id<T>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>>,
|
||||
incoming: Vec<Scored<T>>,
|
||||
) where
|
||||
T: StoredObject + Clone,
|
||||
{
|
||||
for scored in incoming {
|
||||
let id = scored.item.get_id().to_owned();
|
||||
target
|
||||
.entry(id)
|
||||
.and_modify(|existing| {
|
||||
if let Some(score) = scored.scores.vector {
|
||||
existing.scores.vector = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.fts {
|
||||
existing.scores.fts = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.graph {
|
||||
existing.scores.graph = Some(score);
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| Scored {
|
||||
item: scored.item.clone(),
|
||||
scores: scored.scores,
|
||||
fused: scored.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
||||
where
|
||||
T: StoredObject,
|
||||
{
|
||||
items.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||
});
|
||||
}
|
||||
@@ -1,4 +1,15 @@
|
||||
use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::generate_embedding};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use common::storage::types::file_info::deserialize_flexible_id;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::StoredObject},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use surrealdb::sql::Thing;
|
||||
|
||||
use crate::scoring::{clamp_unit, distance_to_similarity, Scored};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
@@ -22,24 +33,125 @@ use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::ge
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: u8,
|
||||
take: usize,
|
||||
input_text: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<T>, AppError>
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE user_id = '{}' AND embedding <|{},40|> {:?} ORDER BY distance", table, user_id, take, input_embedding);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
Ok(closest_entities)
|
||||
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DistanceRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
distance: Option<f32>,
|
||||
}
|
||||
|
||||
pub async fn find_items_by_vector_similarity_with_embedding<T>(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
let closest_query = format!(
|
||||
"SELECT id, vector::distance::knn() AS distance \
|
||||
FROM {table} \
|
||||
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
|
||||
LIMIT $limit",
|
||||
table = table,
|
||||
take = take,
|
||||
embedding = embedding_literal
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(closest_query)
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let distance_rows: Vec<DistanceRow> = response.take(0)?;
|
||||
|
||||
if distance_rows.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
|
||||
let thing_ids: Vec<Thing> = ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((table, id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut items_response = db_client
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", table.to_owned()))
|
||||
.bind(("things", thing_ids.clone()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let items: Vec<T> = items_response.take(0)?;
|
||||
|
||||
let mut item_map: HashMap<String, T> = items
|
||||
.into_iter()
|
||||
.map(|item| (item.get_id().to_owned(), item))
|
||||
.collect();
|
||||
|
||||
let mut min_distance = f32::MAX;
|
||||
let mut max_distance = f32::MIN;
|
||||
|
||||
for row in &distance_rows {
|
||||
if let Some(distance) = row.distance {
|
||||
if distance.is_finite() {
|
||||
if distance < min_distance {
|
||||
min_distance = distance;
|
||||
}
|
||||
if distance > max_distance {
|
||||
max_distance = distance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let normalize = min_distance.is_finite()
|
||||
&& max_distance.is_finite()
|
||||
&& (max_distance - min_distance).abs() > f32::EPSILON;
|
||||
|
||||
let mut scored = Vec::with_capacity(distance_rows.len());
|
||||
for row in distance_rows {
|
||||
if let Some(item) = item_map.remove(&row.id) {
|
||||
let similarity = row
|
||||
.distance
|
||||
.map(|distance| {
|
||||
if normalize {
|
||||
let span = max_distance - min_distance;
|
||||
if span.abs() < f32::EPSILON {
|
||||
1.0
|
||||
} else {
|
||||
let normalized = 1.0 - ((distance - min_distance) / span);
|
||||
clamp_unit(normalized)
|
||||
}
|
||||
} else {
|
||||
distance_to_similarity(distance)
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
scored.push(Scored::new(item).with_vector_score(similarity));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
88
devenv.lock
@@ -3,10 +3,10 @@
|
||||
"devenv": {
|
||||
"locked": {
|
||||
"dir": "src/modules",
|
||||
"lastModified": 1746681099,
|
||||
"lastModified": 1761839147,
|
||||
"owner": "cachix",
|
||||
"repo": "devenv",
|
||||
"rev": "a7f2ea275621391209fd702f5ddced32dd56a4e2",
|
||||
"rev": "bb7849648b68035f6b910120252c22b28195cf54",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -16,13 +16,31 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": "nixpkgs",
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761893049,
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1733328505,
|
||||
"lastModified": 1761588595,
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
||||
"rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -40,10 +58,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1746537231,
|
||||
"lastModified": 1760663237,
|
||||
"owner": "cachix",
|
||||
"repo": "git-hooks.nix",
|
||||
"rev": "fa466640195d38ec97cf0493d6d6882bc4d14969",
|
||||
"rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -74,10 +92,25 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1746576598,
|
||||
"lastModified": 1761672384,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "b3582c75c7f21ce0b429898980eddbbf05c68e55",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nixos",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1761880412,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "a7fc11be66bdfb5cdde611ee5ce381c183da8386",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -90,11 +123,48 @@
|
||||
"root": {
|
||||
"inputs": {
|
||||
"devenv": "devenv",
|
||||
"fenix": "fenix",
|
||||
"git-hooks": "git-hooks",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs": "nixpkgs_2",
|
||||
"pre-commit-hooks": [
|
||||
"git-hooks"
|
||||
],
|
||||
"rust-overlay": "rust-overlay"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761849405,
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "rust-lang",
|
||||
"ref": "nightly",
|
||||
"repo": "rust-analyzer",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"rust-overlay": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761878277,
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "6604534e44090c917db714faa58d47861657690c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
10
devenv.nix
@@ -11,14 +11,24 @@
|
||||
pkgs.openssl
|
||||
pkgs.nodejs
|
||||
pkgs.vscode-langservers-extracted
|
||||
pkgs.cargo-dist
|
||||
pkgs.cargo-xwin
|
||||
pkgs.clang
|
||||
pkgs.onnxruntime
|
||||
];
|
||||
|
||||
languages.rust = {
|
||||
enable = true;
|
||||
components = ["rustc" "clippy" "rustfmt" "cargo" "rust-analyzer"];
|
||||
channel = "nightly";
|
||||
targets = ["x86_64-unknown-linux-gnu" "x86_64-pc-windows-msvc"];
|
||||
mold.enable = true;
|
||||
};
|
||||
|
||||
env = {
|
||||
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
|
||||
};
|
||||
|
||||
processes = {
|
||||
surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --net=host --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password";
|
||||
};
|
||||
|
||||
18
devenv.yaml
@@ -1,15 +1,11 @@
|
||||
# yaml-language-server: $schema=https://devenv.sh/devenv.schema.json
|
||||
inputs:
|
||||
fenix:
|
||||
url: github:nix-community/fenix
|
||||
nixpkgs:
|
||||
url: github:nixos/nixpkgs/nixpkgs-unstable
|
||||
|
||||
# If you're using non-OSS software, you can set allowUnfree to true.
|
||||
rust-overlay:
|
||||
url: github:oxalica/rust-overlay
|
||||
inputs:
|
||||
nixpkgs:
|
||||
follows: nixpkgs
|
||||
allowUnfree: true
|
||||
|
||||
# If you're willing to use a package that's vulnerable
|
||||
# permittedInsecurePackages:
|
||||
# - "openssl-1.1.1w"
|
||||
|
||||
# If you have more than one devenv you can merge them
|
||||
#imports:
|
||||
# - ./backend
|
||||
|
||||
@@ -4,9 +4,11 @@ members = ["cargo:."]
|
||||
# Config for 'dist'
|
||||
[dist]
|
||||
# The preferred dist version to use in CI (Cargo.toml SemVer syntax)
|
||||
cargo-dist-version = "0.28.0"
|
||||
cargo-dist-version = "0.30.0"
|
||||
# CI backends to support
|
||||
ci = "github"
|
||||
# Extra static files to include in each App (path relative to this Cargo.toml's dir)
|
||||
include = ["lib"]
|
||||
# The installers to generate for each app
|
||||
installers = []
|
||||
# Target platforms to build apps for (Rust target-triple syntax)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
minne:
|
||||
build: .
|
||||
@@ -12,10 +10,11 @@ services:
|
||||
SURREALDB_PASSWORD: "root_password"
|
||||
SURREALDB_DATABASE: "test"
|
||||
SURREALDB_NAMESPACE: "test"
|
||||
OPENAI_API_KEY: "sk-key"
|
||||
OPENAI_API_KEY: "sk-add-your-key"
|
||||
DATA_DIR: "./data"
|
||||
HTTP_PORT: 3000
|
||||
RUST_LOG: "info"
|
||||
RERANKING_ENABLED: false ## Change to true to enable reranking
|
||||
depends_on:
|
||||
- surrealdb
|
||||
networks:
|
||||
@@ -31,7 +30,7 @@ services:
|
||||
- ./database:/database # Mounts a 'database' folder from your project directory
|
||||
command: >
|
||||
start
|
||||
--log debug
|
||||
--log info
|
||||
--user root_user
|
||||
--pass root_password
|
||||
rocksdb:./database/database.db
|
||||
|
||||
22
flake.lock
generated
@@ -1,5 +1,20 @@
|
||||
{
|
||||
"nodes": {
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1760924934,
|
||||
"narHash": "sha256-tuuqY5aU7cUkR71sO2TraVKK2boYrdW3gCSXUkF4i44=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "c6b4d5308293d0d04fcfeee92705017537cad02f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
@@ -20,11 +35,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1746232882,
|
||||
"narHash": "sha256-MHmBH2rS8KkRRdoU/feC/dKbdlMkcNkB5mwkuipVHeQ=",
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "7a2622e2c0dbad5c4493cb268aba12896e28b008",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -36,6 +51,7 @@
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"crane": "crane",
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
|
||||
133
flake.nix
@@ -4,77 +4,84 @@
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
crane.url = "github:ipetkov/crane";
|
||||
};
|
||||
|
||||
outputs = {
|
||||
self,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
crane,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system: let
|
||||
pkgs = nixpkgs.legacyPackages.${system};
|
||||
|
||||
# --- Minne Package Definition ---
|
||||
minne-pkg = pkgs.rustPlatform.buildRustPackage {
|
||||
pname = "minne";
|
||||
version = "0.1.0";
|
||||
|
||||
src = self;
|
||||
|
||||
cargoLock = {
|
||||
lockFile = ./Cargo.lock;
|
||||
};
|
||||
|
||||
# Skip tests due to testing fs operations
|
||||
doCheck = false;
|
||||
|
||||
nativeBuildInputs = [
|
||||
pkgs.pkg-config
|
||||
pkgs.rustfmt
|
||||
pkgs.makeWrapper # For the postInstall hook
|
||||
];
|
||||
buildInputs = [
|
||||
pkgs.openssl
|
||||
pkgs.chromium # Runtime dependency for the browser
|
||||
];
|
||||
|
||||
# Wrap the actual executables to provide CHROME at runtime
|
||||
postInstall = let
|
||||
chromium_executable = "${pkgs.chromium}/bin/chromium";
|
||||
in ''
|
||||
wrapProgram $out/bin/main \
|
||||
--set CHROME "${chromium_executable}"
|
||||
wrapProgram $out/bin/worker \
|
||||
--set CHROME "${chromium_executable}"
|
||||
'';
|
||||
|
||||
meta = with pkgs.lib; {
|
||||
description = "Minne Application";
|
||||
license = licenses.mit;
|
||||
};
|
||||
};
|
||||
in {
|
||||
packages = {
|
||||
minne = minne-pkg;
|
||||
default = self.packages.${system}.minne;
|
||||
flake-utils.lib.eachDefaultSystem (system: let
|
||||
pkgs = nixpkgs.legacyPackages.${system};
|
||||
lib = pkgs.lib;
|
||||
craneLib = crane.mkLib pkgs;
|
||||
libExt =
|
||||
if pkgs.stdenv.isDarwin
|
||||
then "dylib"
|
||||
else "so";
|
||||
minne-pkg = craneLib.buildPackage {
|
||||
src = lib.cleanSourceWith {
|
||||
src = ./.;
|
||||
filter = let
|
||||
extraPaths = [
|
||||
(toString ./Cargo.lock)
|
||||
(toString ./common/migrations)
|
||||
(toString ./common/schemas)
|
||||
(toString ./html-router/templates)
|
||||
(toString ./html-router/assets)
|
||||
];
|
||||
in
|
||||
path: type: let
|
||||
p = toString path;
|
||||
in
|
||||
craneLib.filterCargoSources path type
|
||||
|| lib.any (x: lib.hasPrefix x p) extraPaths;
|
||||
};
|
||||
|
||||
apps = {
|
||||
main = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "main";
|
||||
};
|
||||
worker = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "worker";
|
||||
};
|
||||
server = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "server";
|
||||
};
|
||||
default = self.apps.${system}.main;
|
||||
pname = "minne";
|
||||
version = "0.2.6";
|
||||
doCheck = false;
|
||||
|
||||
nativeBuildInputs = [pkgs.pkg-config pkgs.rustfmt pkgs.makeWrapper];
|
||||
buildInputs = [pkgs.openssl pkgs.chromium pkgs.onnxruntime];
|
||||
|
||||
postInstall = ''
|
||||
wrapProgram $out/bin/main \
|
||||
--set CHROME ${pkgs.chromium}/bin/chromium \
|
||||
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
||||
for b in worker server; do
|
||||
if [ -x "$out/bin/$b" ]; then
|
||||
wrapProgram $out/bin/$b \
|
||||
--set CHROME ${pkgs.chromium}/bin/chromium \
|
||||
--set ORT_DYLIB_PATH ${pkgs.onnxruntime}/lib/libonnxruntime.${libExt}
|
||||
fi
|
||||
done
|
||||
'';
|
||||
};
|
||||
in {
|
||||
packages = {
|
||||
minne-pkg = minne-pkg;
|
||||
default = minne-pkg;
|
||||
};
|
||||
apps = {
|
||||
main = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "main";
|
||||
};
|
||||
}
|
||||
);
|
||||
worker = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "worker";
|
||||
};
|
||||
server = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "server";
|
||||
};
|
||||
default = flake-utils.lib.mkApp {
|
||||
drv = minne-pkg;
|
||||
name = "main";
|
||||
};
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
@@ -26,11 +29,13 @@ minijinja-embed = { workspace = true }
|
||||
minijinja-contrib = {workspace = true }
|
||||
axum-htmx = { workspace = true }
|
||||
async-stream = { workspace = true }
|
||||
plotly = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
chrono-tz = { workspace = true }
|
||||
tower-serve-static = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
url = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
common = { path = "../common" }
|
||||
composite-retrieval = { path = "../composite-retrieval" }
|
||||
|
||||
@@ -1,29 +1,108 @@
|
||||
@import 'tailwindcss' source(none);
|
||||
@import 'tailwindcss';
|
||||
|
||||
@source './templates/**/*.html';
|
||||
|
||||
@plugin "daisyui" {
|
||||
exclude: rootscrollbargutter;
|
||||
logs: false;
|
||||
themes: false;
|
||||
include: [ "properties",
|
||||
"scrollbar",
|
||||
"rootscrolllock",
|
||||
"rootcolor",
|
||||
"svg",
|
||||
"button",
|
||||
"menu",
|
||||
"navbar",
|
||||
"drawer",
|
||||
"modal",
|
||||
"chat",
|
||||
"card",
|
||||
"loading",
|
||||
"validator",
|
||||
"fileinput",
|
||||
"alert",
|
||||
"swap"
|
||||
];
|
||||
}
|
||||
|
||||
@plugin "@tailwindcss/typography";
|
||||
|
||||
@config './tailwind.config.js';
|
||||
|
||||
/*
|
||||
The default border color has changed to `currentColor` in Tailwind CSS v4,
|
||||
so we've added these compatibility styles to make sure everything still
|
||||
looks the same as it did with Tailwind CSS v3.
|
||||
|
||||
If we ever want to remove these styles, we need to add an explicit border
|
||||
color utility to any element that depends on these defaults.
|
||||
*/
|
||||
|
||||
@view-transition {
|
||||
navigation: auto;
|
||||
}
|
||||
|
||||
@layer base {
|
||||
:root {
|
||||
--nb-shadow: 4px 4px 0 0 #000;
|
||||
--nb-shadow-hover: 6px 6px 0 0 #000;
|
||||
}
|
||||
|
||||
[data-theme="light"] {
|
||||
color-scheme: light;
|
||||
--color-base-100: oklch(98.42% 0.012 96.42);
|
||||
--color-base-200: oklch(94.52% 0.0122 96.43);
|
||||
--color-base-300: oklch(90.96% 0.0125 91.53);
|
||||
--color-base-content: oklch(17.76% 0 89.88);
|
||||
--color-primary: oklch(20.77% 0.0398 265.75);
|
||||
--color-primary-content: oklch(100% 0 89.88);
|
||||
--color-secondary: oklch(54.61% 0.2152 262.88);
|
||||
--color-secondary-content: oklch(100% 0 89.88);
|
||||
--color-accent: oklch(72% 0.19 80);
|
||||
--color-accent-content: oklch(21% 0.035 80);
|
||||
--color-neutral: oklch(17.76% 0 89.88);
|
||||
--color-neutral-content: oklch(96.99% 0.0013 106.42);
|
||||
--color-info: oklch(60.89% 0.1109 221.72);
|
||||
--color-info-content: oklch(96.99% 0.0013 106.42);
|
||||
--color-success: oklch(62.71% 0.1699 149.21);
|
||||
--color-success-content: oklch(96.99% 0.0013 106.42);
|
||||
--color-warning: oklch(79.52% 0.1617 86.05);
|
||||
--color-warning-content: oklch(17.76% 0 89.88);
|
||||
--color-error: oklch(57.71% 0.2152 27.33);
|
||||
--color-error-content: oklch(96.99% 0.0013 106.42);
|
||||
--radius-selector: 0rem;
|
||||
--radius-field: 0rem;
|
||||
--radius-box: 0rem;
|
||||
--size-selector: 0.25rem;
|
||||
--size-field: 0.25rem;
|
||||
--border: 2px;
|
||||
}
|
||||
|
||||
[data-theme="dark"] {
|
||||
color-scheme: dark;
|
||||
--color-base-100: oklch(22% 0.015 255);
|
||||
--color-base-200: oklch(18% 0.014 253);
|
||||
--color-base-300: oklch(14% 0.012 251);
|
||||
--color-base-content: oklch(97.2% 0.02 255);
|
||||
--color-primary: oklch(58% 0.233 277.12);
|
||||
--color-primary-content: oklch(96% 0.018 272.31);
|
||||
--color-secondary: oklch(65% 0.241 354.31);
|
||||
--color-secondary-content: oklch(94% 0.028 342.26);
|
||||
--color-accent: oklch(78% 0.22 80);
|
||||
--color-accent-content: oklch(20% 0.035 80);
|
||||
--color-neutral: oklch(26% 0.02 255);
|
||||
--color-neutral-content: oklch(97% 0.03 255);
|
||||
--color-info: oklch(74% 0.16 232.66);
|
||||
--color-info-content: oklch(29% 0.066 243.16);
|
||||
--color-success: oklch(76% 0.177 163.22);
|
||||
--color-success-content: oklch(37% 0.077 168.94);
|
||||
--color-warning: oklch(82% 0.189 84.43);
|
||||
--color-warning-content: oklch(41% 0.112 45.9);
|
||||
--color-error: oklch(71% 0.194 13.43);
|
||||
--color-error-content: oklch(27% 0.105 12.09);
|
||||
--radius-selector: 0rem;
|
||||
--radius-field: 0rem;
|
||||
--radius-box: 0rem;
|
||||
--size-selector: 0.25rem;
|
||||
--size-field: 0.25rem;
|
||||
--border: 2px;
|
||||
}
|
||||
|
||||
body {
|
||||
@apply font-satoshi;
|
||||
background-color: var(--color-base-100);
|
||||
color: var(--color-base-content);
|
||||
font-family: 'Satoshi', sans-serif;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
@apply selection:bg-yellow-300/40 selection:text-neutral;
|
||||
}
|
||||
|
||||
html {
|
||||
@@ -37,6 +116,581 @@
|
||||
::file-selector-button {
|
||||
border-color: var(--color-gray-200, currentColor);
|
||||
}
|
||||
|
||||
.container {
|
||||
padding-inline: 10px;
|
||||
}
|
||||
|
||||
@media (min-width: 640px) {
|
||||
.container {
|
||||
padding-inline: 2rem;
|
||||
}
|
||||
}
|
||||
|
||||
@media (min-width: 1024px) {
|
||||
.container {
|
||||
padding-inline: 4rem;
|
||||
}
|
||||
}
|
||||
|
||||
@media (min-width: 1280px) {
|
||||
.container {
|
||||
padding-inline: 5rem;
|
||||
}
|
||||
}
|
||||
|
||||
@media (min-width: 1536px) {
|
||||
.container {
|
||||
padding-inline: 6rem;
|
||||
}
|
||||
}
|
||||
|
||||
.custom-scrollbar {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: rgba(0, 0, 0, 0.2) transparent;
|
||||
}
|
||||
|
||||
.custom-scrollbar::-webkit-scrollbar {
|
||||
width: 4px;
|
||||
}
|
||||
|
||||
.custom-scrollbar::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.custom-scrollbar::-webkit-scrollbar-thumb {
|
||||
background-color: rgba(0, 0, 0, 0.2);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.hide-scrollbar {
|
||||
-ms-overflow-style: none;
|
||||
scrollbar-width: none;
|
||||
}
|
||||
|
||||
.hide-scrollbar::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
form.htmx-request {
|
||||
opacity: 0.5;
|
||||
}
|
||||
}
|
||||
|
||||
/* Neobrutalist helpers influenced by Tufte principles */
|
||||
@layer components {
|
||||
|
||||
/* Offset, hard-edge shadow; minimal ink with strong contrast */
|
||||
.nb-shadow {
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms;
|
||||
}
|
||||
|
||||
.nb-shadow-hover {
|
||||
transform: translate(-1px, -1px);
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
.nb-card {
|
||||
@apply bg-base-100 border-2 border-neutral p-4;
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms;
|
||||
}
|
||||
|
||||
.nb-card:hover {
|
||||
transform: translate(-1px, -1px);
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
.nb-panel {
|
||||
@apply border-2 border-neutral;
|
||||
background-color: var(--nb-panel-bg, var(--color-base-200));
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms;
|
||||
}
|
||||
|
||||
.nb-panel:hover {
|
||||
transform: translate(-1px, -1px);
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
.nb-panel-canvas {
|
||||
--nb-panel-bg: var(--color-base-100);
|
||||
}
|
||||
|
||||
.nb-canvas {
|
||||
background-color: var(--color-base-100);
|
||||
}
|
||||
|
||||
.nb-btn {
|
||||
@apply btn rounded-none border-2 border-neutral text-base-content;
|
||||
--btn-color: var(--color-base-100);
|
||||
--btn-fg: var(--color-base-content);
|
||||
--btn-noise: none;
|
||||
background-image: none;
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms;
|
||||
}
|
||||
|
||||
.nb-btn:hover {
|
||||
transform: translate(-1px, -1px);
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
.nb-link {
|
||||
@apply underline underline-offset-2 decoration-neutral hover:decoration-4;
|
||||
}
|
||||
|
||||
.nb-stat {
|
||||
@apply bg-base-100 border-2 border-neutral p-5 flex flex-col gap-1;
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms;
|
||||
}
|
||||
|
||||
/* Hairline rules and quiet gridlines for Tufte feel */
|
||||
.u-hairline {
|
||||
@apply border-t border-neutral/20;
|
||||
}
|
||||
|
||||
.prose-tufte {
|
||||
@apply prose prose-neutral;
|
||||
max-width: min(90ch, 100%);
|
||||
line-height: 1.7;
|
||||
}
|
||||
|
||||
.prose-tufte-compact {
|
||||
@apply prose prose-neutral;
|
||||
max-width: min(90ch, 100%);
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .prose-tufte,
|
||||
[data-theme="dark"] .prose-tufte-compact {
|
||||
color: var(--color-base-content);
|
||||
--tw-prose-body: var(--color-base-content);
|
||||
--tw-prose-headings: var(--color-base-content);
|
||||
--tw-prose-lead: rgba(255, 255, 255, 0.78);
|
||||
--tw-prose-links: var(--color-accent);
|
||||
--tw-prose-bold: var(--color-base-content);
|
||||
--tw-prose-counters: rgba(255, 255, 255, 0.7);
|
||||
--tw-prose-bullets: rgba(255, 255, 255, 0.35);
|
||||
--tw-prose-hr: rgba(255, 255, 255, 0.2);
|
||||
--tw-prose-quotes: var(--color-base-content);
|
||||
--tw-prose-quote-borders: rgba(255, 255, 255, 0.25);
|
||||
--tw-prose-captions: rgba(255, 255, 255, 0.65);
|
||||
--tw-prose-code: var(--color-base-content);
|
||||
--tw-prose-pre-code: inherit;
|
||||
--tw-prose-pre-bg: rgba(255, 255, 255, 0.07);
|
||||
--tw-prose-th-borders: rgba(255, 255, 255, 0.25);
|
||||
--tw-prose-td-borders: rgba(255, 255, 255, 0.2);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .prose-tufte a,
|
||||
[data-theme="dark"] .prose-tufte-compact a {
|
||||
color: var(--color-accent);
|
||||
}
|
||||
|
||||
/* Encourage a consistent card look app-wide */
|
||||
.card {
|
||||
@apply border-2 border-neutral rounded-none;
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms;
|
||||
}
|
||||
|
||||
.card:hover {
|
||||
transform: translate(-1px, -1px);
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
/* Input styling with good dark/light contrast */
|
||||
.nb-input {
|
||||
@apply rounded-none border-2 border-neutral bg-base-100 text-base-content placeholder:text-base-content/60 px-3 py-[0.5rem];
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms, border-color 150ms;
|
||||
}
|
||||
|
||||
.nb-input:hover {
|
||||
transform: translate(-1px, -1px);
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
.nb-input:focus {
|
||||
outline: none;
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
/* Select styling parallels inputs */
|
||||
.nb-select {
|
||||
@apply rounded-none border-2 border-neutral bg-base-100 text-base-content px-3 py-[0.5rem];
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms, border-color 150ms;
|
||||
}
|
||||
|
||||
.nb-select:hover {
|
||||
transform: translate(-1px, -1px);
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
.nb-select:focus {
|
||||
outline: none;
|
||||
box-shadow: var(--nb-shadow-hover);
|
||||
}
|
||||
|
||||
/* Compact variants */
|
||||
.nb-input-sm {
|
||||
@apply text-sm px-2 py-[0.25rem];
|
||||
}
|
||||
|
||||
.nb-select-sm {
|
||||
@apply text-sm px-2 py-[0.25rem];
|
||||
}
|
||||
|
||||
.nb-cta {
|
||||
--btn-color: var(--color-accent);
|
||||
--btn-fg: var(--color-accent-content);
|
||||
--btn-noise: none;
|
||||
background-image: none;
|
||||
background-color: var(--color-accent);
|
||||
color: var(--color-accent-content);
|
||||
}
|
||||
|
||||
.nb-cta:hover {
|
||||
background-color: var(--color-accent);
|
||||
color: var(--color-accent-content);
|
||||
filter: saturate(1.1) brightness(1.05);
|
||||
}
|
||||
|
||||
/* Badges */
|
||||
.nb-badge {
|
||||
@apply inline-flex items-center uppercase tracking-wide text-[10px] px-2 py-0.5 bg-base-100 border-2 border-neutral rounded-none;
|
||||
box-shadow: 3px 3px 0 0 #000;
|
||||
}
|
||||
|
||||
.nb-masonry {
|
||||
column-count: 1;
|
||||
column-gap: 1rem;
|
||||
}
|
||||
|
||||
.nb-masonry>* {
|
||||
break-inside: avoid;
|
||||
display: block;
|
||||
}
|
||||
|
||||
@media (min-width: 768px) {
|
||||
.nb-masonry {
|
||||
column-count: 2;
|
||||
}
|
||||
}
|
||||
|
||||
@media (min-width: 1536px) {
|
||||
.nb-masonry {
|
||||
column-count: 3;
|
||||
}
|
||||
}
|
||||
|
||||
/* Chat bubbles neobrutalist */
|
||||
.chat .chat-bubble {
|
||||
@apply rounded-none border-2 border-neutral bg-base-100 text-neutral;
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms;
|
||||
}
|
||||
|
||||
/* Remove DaisyUI tail so our rectangle keeps clean borders/shadows */
|
||||
.chat .chat-bubble::before,
|
||||
.chat .chat-bubble::after {
|
||||
display: none !important;
|
||||
content: none !important;
|
||||
}
|
||||
|
||||
.chat.chat-start .chat-bubble {
|
||||
@apply bg-secondary text-secondary-content;
|
||||
}
|
||||
|
||||
.chat.chat-end .chat-bubble {
|
||||
@apply bg-base-100 text-neutral;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.nb-table {
|
||||
@apply w-full;
|
||||
border-collapse: separate;
|
||||
border-spacing: 0;
|
||||
}
|
||||
|
||||
.nb-table thead th {
|
||||
@apply uppercase tracking-wide text-xs border-b-2 border-neutral;
|
||||
}
|
||||
|
||||
.nb-table th,
|
||||
.nb-table td {
|
||||
@apply p-3;
|
||||
}
|
||||
|
||||
.nb-table tbody tr+tr td {
|
||||
@apply border-t border-neutral/30;
|
||||
}
|
||||
|
||||
.nb-table tbody tr:hover {
|
||||
@apply bg-base-200/40;
|
||||
}
|
||||
|
||||
.nb-table tbody tr:hover td:first-child {
|
||||
box-shadow: inset 3px 0 0 0 #000;
|
||||
}
|
||||
|
||||
.kg-overlay {
|
||||
@apply absolute top-4 left-4 right-4 z-10 flex flex-col items-stretch gap-2;
|
||||
max-width: min(420px, calc(100% - 2rem));
|
||||
}
|
||||
|
||||
.kg-control-row {
|
||||
@apply flex flex-wrap items-center gap-2;
|
||||
}
|
||||
|
||||
.kg-control-row-primary {
|
||||
@apply justify-start;
|
||||
}
|
||||
|
||||
.kg-control-row-secondary {
|
||||
@apply justify-center;
|
||||
}
|
||||
|
||||
.kg-search-input {
|
||||
@apply pl-2;
|
||||
height: 2rem;
|
||||
width: 100%;
|
||||
max-width: 320px;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.kg-control-row-primary .kg-search-input {
|
||||
flex: 1 1 auto;
|
||||
}
|
||||
|
||||
.kg-search-btn {
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
|
||||
.kg-toggle {
|
||||
@apply transition-colors;
|
||||
}
|
||||
|
||||
.kg-toggle-active {
|
||||
--btn-color: var(--color-accent);
|
||||
--btn-fg: var(--color-accent-content);
|
||||
--btn-noise: none;
|
||||
background-image: none;
|
||||
background-color: var(--color-accent);
|
||||
color: var(--color-accent-content);
|
||||
}
|
||||
|
||||
.kg-toggle-active:hover {
|
||||
background-color: var(--color-accent);
|
||||
color: var(--color-accent-content);
|
||||
filter: saturate(1.1) brightness(1.05);
|
||||
}
|
||||
|
||||
@media (min-width: 768px) {
|
||||
.kg-overlay {
|
||||
right: auto;
|
||||
max-width: none;
|
||||
width: auto;
|
||||
}
|
||||
}
|
||||
|
||||
.kg-legend {
|
||||
@apply absolute bottom-2 left-2 z-10 flex flex-wrap gap-4;
|
||||
}
|
||||
|
||||
.kg-legend-card {
|
||||
@apply p-2;
|
||||
}
|
||||
|
||||
.kg-legend-heading {
|
||||
@apply mb-1 text-xs opacity-70;
|
||||
}
|
||||
|
||||
.kg-legend-row {
|
||||
@apply flex items-center gap-2 text-xs;
|
||||
}
|
||||
|
||||
/* Checkboxes */
|
||||
.nb-checkbox {
|
||||
@apply appearance-none inline-block align-middle rounded-none border-2 border-neutral bg-base-100;
|
||||
width: 1rem;
|
||||
height: 1rem;
|
||||
box-shadow: var(--nb-shadow);
|
||||
transition: transform 150ms, box-shadow 150ms, border-color 150ms, background-color 150ms;
|
||||
background-repeat: no-repeat;
|
||||
background-position: center;
|
||||
background-size: 80% 80%;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.nb-checkbox:hover {
|
||||
transform: translate(-1px, -1px);
|
||||
box-shadow: 5px 5px 0 0 #000;
|
||||
}
|
||||
|
||||
.nb-checkbox:focus-visible {
|
||||
outline: 2px solid #000;
|
||||
outline-offset: 2px;
|
||||
}
|
||||
|
||||
.nb-checkbox:active {
|
||||
transform: translate(0, 0);
|
||||
box-shadow: 3px 3px 0 0 #000;
|
||||
}
|
||||
|
||||
/* Tick mark in light mode (black) */
|
||||
.nb-checkbox:checked {
|
||||
background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' fill='none' stroke='%23000' stroke-width='3' stroke-linecap='round' stroke-linejoin='round'><polyline points='20 6 9 17 4 12'/></svg>");
|
||||
}
|
||||
|
||||
/* Tick mark in dark mode (white) */
|
||||
[data-theme="dark"] .nb-checkbox:checked {
|
||||
background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' fill='none' stroke='%23fff' stroke-width='3' stroke-linecap='round' stroke-linejoin='round'><polyline points='20 6 9 17 4 12'/></svg>");
|
||||
}
|
||||
|
||||
/* Compact size */
|
||||
.nb-checkbox-sm {
|
||||
width: 0.875rem;
|
||||
height: 0.875rem;
|
||||
}
|
||||
|
||||
/* Placeholder style for smaller, quieter helper text */
|
||||
.nb-input::placeholder {
|
||||
font-size: 0.75rem;
|
||||
letter-spacing: 0.02em;
|
||||
opacity: 0.75;
|
||||
}
|
||||
|
||||
.markdown-content {
|
||||
line-height: 1.5;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
|
||||
.markdown-content p {
|
||||
margin-bottom: 0.75em;
|
||||
}
|
||||
|
||||
.markdown-content p:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.markdown-content ul,
|
||||
.markdown-content ol {
|
||||
margin-top: 0.5em;
|
||||
margin-bottom: 0.75em;
|
||||
padding-left: 2em;
|
||||
}
|
||||
|
||||
.markdown-content li {
|
||||
margin-bottom: 0.25em;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
background-color: var(--color-base-200);
|
||||
color: var(--color-base-content);
|
||||
padding: 0.75em 1em;
|
||||
border-radius: 4px;
|
||||
border: 1px solid rgba(0, 0, 0, 0.08);
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.markdown-content pre code {
|
||||
background-color: transparent;
|
||||
color: inherit;
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
display: block;
|
||||
line-height: inherit;
|
||||
}
|
||||
|
||||
.markdown-content :not(pre) > code {
|
||||
background-color: rgba(0, 0, 0, 0.05);
|
||||
color: var(--color-base-content);
|
||||
padding: 0.15em 0.4em;
|
||||
border-radius: 3px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.markdown-content table {
|
||||
border-collapse: collapse;
|
||||
margin: 0.75em 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.markdown-content th,
|
||||
.markdown-content td {
|
||||
border: 1px solid rgba(0, 0, 0, 0.15);
|
||||
padding: 6px 12px;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .markdown-content th,
|
||||
[data-theme="dark"] .markdown-content td {
|
||||
border-color: rgba(255, 255, 255, 0.25);
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
border-left: 4px solid rgba(0, 0, 0, 0.15);
|
||||
padding-left: 10px;
|
||||
margin: 0.5em 0 0.5em 0.5em;
|
||||
color: rgba(0, 0, 0, 0.6);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .markdown-content blockquote {
|
||||
border-color: rgba(255, 255, 255, 0.3);
|
||||
color: rgba(255, 255, 255, 0.8);
|
||||
}
|
||||
|
||||
.markdown-content hr {
|
||||
border: none;
|
||||
border-top: 1px solid rgba(0, 0, 0, 0.15);
|
||||
margin: 0.75em 0;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .markdown-content hr {
|
||||
border-top-color: rgba(255, 255, 255, 0.2);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .markdown-content pre {
|
||||
background-color: var(--color-base-200);
|
||||
border-color: rgba(255, 255, 255, 0.12);
|
||||
color: var(--color-base-content);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .markdown-content :not(pre) > code {
|
||||
background-color: rgba(255, 255, 255, 0.12);
|
||||
color: var(--color-base-content);
|
||||
}
|
||||
|
||||
.brand-mark {
|
||||
letter-spacing: 0.02em;
|
||||
}
|
||||
|
||||
.reference-tooltip {
|
||||
@apply bg-base-100 text-base-content border-2 border-neutral p-3 text-sm w-72 max-w-xs;
|
||||
position: fixed;
|
||||
z-index: 9999;
|
||||
box-shadow: var(--nb-shadow);
|
||||
}
|
||||
}
|
||||
|
||||
/* Theme-aware placeholder contrast tweaks */
|
||||
@layer base {
|
||||
|
||||
/* Light theme keeps default neutral tone via utilities */
|
||||
[data-theme="dark"] .nb-input::placeholder,
|
||||
[data-theme="dark"] .input::placeholder,
|
||||
[data-theme="dark"] .textarea::placeholder,
|
||||
[data-theme="dark"] textarea::placeholder,
|
||||
[data-theme="dark"] input::placeholder {
|
||||
color: rgba(255, 255, 255, 0.78) !important;
|
||||
opacity: 0.85;
|
||||
}
|
||||
}
|
||||
|
||||
/* satoshi.css */
|
||||
@@ -58,4 +712,28 @@
|
||||
font-weight: 300 900;
|
||||
font-style: italic;
|
||||
font-display: swap;
|
||||
}
|
||||
}
|
||||
|
||||
/* Minimal override: prevent DaisyUI .menu hover bg on our nb buttons */
|
||||
@layer utilities {
|
||||
|
||||
/* Let plain nb-btns remain transparent on hover within menus */
|
||||
.menu li>.nb-btn:hover {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
/* Keep CTA background on hover within menus */
|
||||
.menu li>.nb-cta:hover {
|
||||
background-color: var(--color-accent);
|
||||
color: var(--color-accent-content);
|
||||
}
|
||||
|
||||
.toast-alert {
|
||||
@apply mt-2 flex flex-col text-left gap-1;
|
||||
box-shadow: var(--nb-shadow);
|
||||
}
|
||||
|
||||
.toast-alert-title {
|
||||
@apply text-lg font-bold;
|
||||
}
|
||||
}
|
||||
|
||||
2
html-router/assets/d3.min.js
vendored
Normal file
|
Before Width: | Height: | Size: 47 KiB After Width: | Height: | Size: 28 KiB |
|
Before Width: | Height: | Size: 252 KiB After Width: | Height: | Size: 140 KiB |
|
Before Width: | Height: | Size: 42 KiB After Width: | Height: | Size: 25 KiB |
|
Before Width: | Height: | Size: 790 B After Width: | Height: | Size: 963 B |
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 1.9 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
429
html-router/assets/knowledge-graph.js
Normal file
@@ -0,0 +1,429 @@
|
||||
// Knowledge graph renderer: interactive 2D force graph with
|
||||
// zoom/pan, search, neighbor highlighting, curved links with arrows,
|
||||
// responsive resize, and type/relationship legends.
|
||||
(function () {
|
||||
const D3_SRC = '/assets/d3.min.js';
|
||||
|
||||
let d3Loading = null;
|
||||
|
||||
function ensureD3() {
|
||||
if (window.d3) return Promise.resolve();
|
||||
if (d3Loading) return d3Loading;
|
||||
d3Loading = new Promise((resolve, reject) => {
|
||||
const s = document.createElement('script');
|
||||
s.src = D3_SRC;
|
||||
s.async = true;
|
||||
s.onload = () => resolve();
|
||||
s.onerror = () => reject(new Error('Failed to load D3'));
|
||||
document.head.appendChild(s);
|
||||
});
|
||||
return d3Loading;
|
||||
}
|
||||
|
||||
// Simple palettes (kept deterministic across renders)
|
||||
const PALETTE_A = ['#60A5FA', '#34D399', '#F59E0B', '#A78BFA', '#F472B6', '#F87171', '#22D3EE', '#84CC16', '#FB7185'];
|
||||
const PALETTE_B = ['#94A3B8', '#A3A3A3', '#9CA3AF', '#C084FC', '#FDA4AF', '#FCA5A5', '#67E8F9', '#A3E635', '#FDBA74'];
|
||||
|
||||
function buildMap(values) {
|
||||
const unique = Array.from(new Set(values.filter(Boolean)));
|
||||
const map = new Map();
|
||||
unique.forEach((v, i) => map.set(v, PALETTE_A[i % PALETTE_A.length]));
|
||||
return map;
|
||||
}
|
||||
|
||||
function linkColorMap(values) {
|
||||
const unique = Array.from(new Set(values.filter(Boolean)));
|
||||
const map = new Map();
|
||||
unique.forEach((v, i) => map.set(v, PALETTE_B[i % PALETTE_B.length]));
|
||||
return map;
|
||||
}
|
||||
|
||||
function radiusForDegree(deg) {
|
||||
const d = Math.max(0, +deg || 0);
|
||||
const r = 6 + Math.sqrt(d) * 3; // gentle growth
|
||||
return Math.max(6, Math.min(r, 24));
|
||||
}
|
||||
|
||||
function curvedPath(d) {
|
||||
const sx = d.source.x, sy = d.source.y, tx = d.target.x, ty = d.target.y;
|
||||
const dx = tx - sx, dy = ty - sy;
|
||||
const dr = Math.hypot(dx, dy) * 0.7; // curve radius
|
||||
const mx = (sx + tx) / 2;
|
||||
const my = (sy + ty) / 2;
|
||||
// Offset normal to create a consistent arc
|
||||
const nx = -dy / (Math.hypot(dx, dy) || 1);
|
||||
const ny = dx / (Math.hypot(dx, dy) || 1);
|
||||
const cx = mx + nx * 20;
|
||||
const cy = my + ny * 20;
|
||||
return `M ${sx},${sy} Q ${cx},${cy} ${tx},${ty}`;
|
||||
}
|
||||
|
||||
function buildAdjacency(nodes, links) {
|
||||
const idToNode = new Map(nodes.map(n => [n.id, n]));
|
||||
const neighbors = new Map();
|
||||
nodes.forEach(n => neighbors.set(n.id, new Set()));
|
||||
links.forEach(l => {
|
||||
const s = typeof l.source === 'object' ? l.source.id : l.source;
|
||||
const t = typeof l.target === 'object' ? l.target.id : l.target;
|
||||
if (neighbors.has(s)) neighbors.get(s).add(t);
|
||||
if (neighbors.has(t)) neighbors.get(t).add(s);
|
||||
});
|
||||
return { idToNode, neighbors };
|
||||
}
|
||||
|
||||
function attachOverlay(container, { onSearch, onToggleNames, onToggleEdgeLabels, onCenter }) {
|
||||
const overlay = document.createElement('div');
|
||||
overlay.className = 'kg-overlay';
|
||||
|
||||
const primaryRow = document.createElement('div');
|
||||
primaryRow.className = 'kg-control-row kg-control-row-primary';
|
||||
|
||||
const secondaryRow = document.createElement('div');
|
||||
secondaryRow.className = 'kg-control-row kg-control-row-secondary';
|
||||
|
||||
// search box
|
||||
const input = document.createElement('input');
|
||||
input.type = 'text';
|
||||
input.placeholder = 'Search nodes…';
|
||||
input.className = 'nb-input kg-search-input';
|
||||
input.addEventListener('keydown', (e) => {
|
||||
if (e.key === 'Enter') onSearch && onSearch(input.value.trim());
|
||||
});
|
||||
|
||||
const searchBtn = document.createElement('button');
|
||||
searchBtn.className = 'nb-btn btn-xs nb-cta kg-search-btn';
|
||||
searchBtn.textContent = 'Go';
|
||||
searchBtn.addEventListener('click', () => onSearch && onSearch(input.value.trim()));
|
||||
|
||||
const namesToggle = document.createElement('button');
|
||||
namesToggle.className = 'nb-btn btn-xs kg-toggle';
|
||||
namesToggle.type = 'button';
|
||||
namesToggle.textContent = 'Names';
|
||||
namesToggle.addEventListener('click', () => onToggleNames && onToggleNames());
|
||||
|
||||
const labelToggle = document.createElement('button');
|
||||
labelToggle.className = 'nb-btn btn-xs kg-toggle';
|
||||
labelToggle.type = 'button';
|
||||
labelToggle.textContent = 'Labels';
|
||||
labelToggle.addEventListener('click', () => onToggleEdgeLabels && onToggleEdgeLabels());
|
||||
|
||||
const centerBtn = document.createElement('button');
|
||||
centerBtn.className = 'nb-btn btn-xs';
|
||||
centerBtn.textContent = 'Center';
|
||||
centerBtn.addEventListener('click', () => onCenter && onCenter());
|
||||
|
||||
primaryRow.appendChild(input);
|
||||
primaryRow.appendChild(searchBtn);
|
||||
|
||||
secondaryRow.appendChild(namesToggle);
|
||||
secondaryRow.appendChild(labelToggle);
|
||||
secondaryRow.appendChild(centerBtn);
|
||||
|
||||
overlay.appendChild(primaryRow);
|
||||
overlay.appendChild(secondaryRow);
|
||||
|
||||
container.style.position = 'relative';
|
||||
container.appendChild(overlay);
|
||||
|
||||
return { input, overlay, namesToggle, labelToggle };
|
||||
}
|
||||
|
||||
function attachLegends(container, typeColor, relColor) {
|
||||
const wrap = document.createElement('div');
|
||||
wrap.className = 'kg-legend';
|
||||
|
||||
function section(title, items) {
|
||||
const sec = document.createElement('div');
|
||||
sec.className = 'nb-card kg-legend-card';
|
||||
const h = document.createElement('div'); h.className = 'kg-legend-heading'; h.textContent = title; sec.appendChild(h);
|
||||
items.forEach(([label, color]) => {
|
||||
const row = document.createElement('div'); row.className = 'kg-legend-row';
|
||||
const sw = document.createElement('span'); sw.style.background = color; sw.style.width = '12px'; sw.style.height = '12px'; sw.style.border = '2px solid #000';
|
||||
const t = document.createElement('span'); t.textContent = label || '—';
|
||||
row.appendChild(sw); row.appendChild(t); sec.appendChild(row);
|
||||
});
|
||||
return sec;
|
||||
}
|
||||
|
||||
const typeItems = Array.from(typeColor.entries());
|
||||
if (typeItems.length) wrap.appendChild(section('Entity Type', typeItems));
|
||||
const relItems = Array.from(relColor.entries());
|
||||
if (relItems.length) wrap.appendChild(section('Relationship', relItems));
|
||||
|
||||
container.appendChild(wrap);
|
||||
return wrap;
|
||||
}
|
||||
|
||||
async function renderKnowledgeGraph(root) {
|
||||
const container = (root || document).querySelector('#knowledge-graph');
|
||||
if (!container) return;
|
||||
|
||||
await ensureD3().catch(() => {
|
||||
const err = document.createElement('div');
|
||||
err.className = 'alert alert-error';
|
||||
err.textContent = 'Unable to load graph library (D3).';
|
||||
container.appendChild(err);
|
||||
});
|
||||
if (!window.d3) return;
|
||||
|
||||
// Clear previous render
|
||||
container.innerHTML = '';
|
||||
|
||||
const width = container.clientWidth || 800;
|
||||
const height = container.clientHeight || 600;
|
||||
|
||||
const et = container.dataset.entityType || '';
|
||||
const cc = container.dataset.contentCategory || '';
|
||||
const qs = new URLSearchParams();
|
||||
if (et) qs.set('entity_type', et);
|
||||
if (cc) qs.set('content_category', cc);
|
||||
|
||||
const url = '/knowledge/graph.json' + (qs.toString() ? ('?' + qs.toString()) : '');
|
||||
let data;
|
||||
try {
|
||||
const res = await fetch(url, { headers: { 'Accept': 'application/json' } });
|
||||
if (!res.ok) throw new Error('Failed to load graph data');
|
||||
data = await res.json();
|
||||
} catch (_e) {
|
||||
const err = document.createElement('div');
|
||||
err.className = 'alert alert-error';
|
||||
err.textContent = 'Unable to load graph data.';
|
||||
container.appendChild(err);
|
||||
return;
|
||||
}
|
||||
|
||||
// Color maps
|
||||
const typeColor = buildMap(data.nodes.map(n => n.entity_type));
|
||||
const relColor = linkColorMap(data.links.map(l => l.relationship_type));
|
||||
const { neighbors } = buildAdjacency(data.nodes, data.links);
|
||||
|
||||
// Build overlay controls
|
||||
let namesVisible = true;
|
||||
let edgeLabelsVisible = true;
|
||||
|
||||
const togglePressedState = (button, state) => {
|
||||
if (!button) return;
|
||||
button.setAttribute('aria-pressed', state ? 'true' : 'false');
|
||||
button.classList.toggle('kg-toggle-active', !!state);
|
||||
};
|
||||
|
||||
const { input, namesToggle, labelToggle } = attachOverlay(container, {
|
||||
onSearch: (q) => focusSearch(q),
|
||||
onToggleNames: () => {
|
||||
namesVisible = !namesVisible;
|
||||
label.style('display', namesVisible ? null : 'none');
|
||||
togglePressedState(namesToggle, namesVisible);
|
||||
},
|
||||
onToggleEdgeLabels: () => {
|
||||
edgeLabelsVisible = !edgeLabelsVisible;
|
||||
linkLabel.style('display', edgeLabelsVisible ? null : 'none');
|
||||
togglePressedState(labelToggle, edgeLabelsVisible);
|
||||
},
|
||||
onCenter: () => zoomTo(1, [width / 2, height / 2])
|
||||
});
|
||||
|
||||
togglePressedState(namesToggle, namesVisible);
|
||||
togglePressedState(labelToggle, edgeLabelsVisible);
|
||||
|
||||
// SVG + zoom
|
||||
const svg = d3.select(container)
|
||||
.append('svg')
|
||||
.attr('width', '100%')
|
||||
.attr('height', height)
|
||||
.attr('viewBox', [0, 0, width, height])
|
||||
.attr('style', 'cursor: grab; touch-action: none; background: transparent;')
|
||||
.call(d3.zoom().scaleExtent([0.25, 5]).on('zoom', (event) => {
|
||||
g.attr('transform', event.transform);
|
||||
}));
|
||||
|
||||
const g = svg.append('g');
|
||||
|
||||
// Defs for arrows
|
||||
const defs = svg.append('defs');
|
||||
const markerFor = (key, color) => {
|
||||
const id = `arrow-${key.replace(/[^a-z0-9_-]/gi, '_')}`;
|
||||
if (!document.getElementById(id)) {
|
||||
defs.append('marker')
|
||||
.attr('id', id)
|
||||
.attr('viewBox', '0 -5 10 10')
|
||||
.attr('refX', 16)
|
||||
.attr('refY', 0)
|
||||
.attr('markerWidth', 6)
|
||||
.attr('markerHeight', 6)
|
||||
.attr('orient', 'auto')
|
||||
.append('path')
|
||||
.attr('d', 'M0,-5L10,0L0,5')
|
||||
.attr('fill', color);
|
||||
}
|
||||
return `url(#${id})`;
|
||||
};
|
||||
|
||||
// Forces
|
||||
const linkForce = d3.forceLink(data.links)
|
||||
.id(d => d.id)
|
||||
.distance(l => 70)
|
||||
.strength(0.5);
|
||||
|
||||
const simulation = d3.forceSimulation(data.nodes)
|
||||
.force('link', linkForce)
|
||||
.force('charge', d3.forceManyBody().strength(-220))
|
||||
.force('center', d3.forceCenter(width / 2, height / 2))
|
||||
.force('collision', d3.forceCollide().radius(d => radiusForDegree(d.degree) + 6))
|
||||
.force('y', d3.forceY(height / 2).strength(0.02))
|
||||
.force('x', d3.forceX(width / 2).strength(0.02));
|
||||
|
||||
// Links as paths so we can curve + arrow
|
||||
const link = g.append('g')
|
||||
.attr('fill', 'none')
|
||||
.attr('stroke-opacity', 0.7)
|
||||
.selectAll('path')
|
||||
.data(data.links)
|
||||
.join('path')
|
||||
.attr('stroke', d => relColor.get(d.relationship_type) || '#CBD5E1')
|
||||
.attr('stroke-width', 1.5)
|
||||
.attr('marker-end', d => markerFor(d.relationship_type || 'rel', relColor.get(d.relationship_type) || '#CBD5E1'));
|
||||
|
||||
// Optional edge labels (midpoint)
|
||||
const linkLabel = g.append('g')
|
||||
.selectAll('text')
|
||||
.data(data.links)
|
||||
.join('text')
|
||||
.attr('font-size', 9)
|
||||
.attr('fill', '#475569')
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('opacity', 0.7)
|
||||
.text(d => d.relationship_type || '');
|
||||
|
||||
// Nodes
|
||||
const node = g.append('g')
|
||||
.attr('stroke', '#fff')
|
||||
.attr('stroke-width', 1.5)
|
||||
.selectAll('circle')
|
||||
.data(data.nodes)
|
||||
.join('circle')
|
||||
.attr('r', d => radiusForDegree(d.degree))
|
||||
.attr('fill', d => typeColor.get(d.entity_type) || '#94A3B8')
|
||||
.attr('cursor', 'pointer')
|
||||
.on('mouseenter', function (_evt, d) { setHighlight(d); })
|
||||
.on('mouseleave', function () { clearHighlight(); })
|
||||
.on('click', function (_evt, d) {
|
||||
// pin/unpin on click
|
||||
if (d.fx == null) { d.fx = d.x; d.fy = d.y; this.setAttribute('data-pinned', 'true'); }
|
||||
else { d.fx = null; d.fy = null; this.removeAttribute('data-pinned'); }
|
||||
})
|
||||
.call(d3.drag()
|
||||
.on('start', (event, d) => {
|
||||
if (!event.active) simulation.alphaTarget(0.3).restart();
|
||||
d.fx = d.x; d.fy = d.y;
|
||||
})
|
||||
.on('drag', (event, d) => { d.fx = event.x; d.fy = event.y; })
|
||||
.on('end', (event, d) => { if (!event.active) simulation.alphaTarget(0); }));
|
||||
|
||||
node.append('title').text(d => `${d.name} • ${d.entity_type} • deg ${d.degree}`);
|
||||
|
||||
// Labels
|
||||
const label = g.append('g')
|
||||
.selectAll('text')
|
||||
.data(data.nodes)
|
||||
.join('text')
|
||||
.text(d => d.name)
|
||||
.attr('font-size', 11)
|
||||
.attr('fill', '#111827')
|
||||
.attr('stroke', 'white')
|
||||
.attr('paint-order', 'stroke')
|
||||
.attr('stroke-width', 3)
|
||||
.attr('dx', d => radiusForDegree(d.degree) + 6)
|
||||
.attr('dy', 4);
|
||||
|
||||
// Legends
|
||||
attachLegends(container, typeColor, relColor);
|
||||
|
||||
// Highlight logic
|
||||
function setHighlight(n) {
|
||||
const ns = neighbors.get(n.id) || new Set();
|
||||
node.attr('opacity', d => (d.id === n.id || ns.has(d.id)) ? 1 : 0.15);
|
||||
label.attr('opacity', d => (d.id === n.id || ns.has(d.id)) ? 1 : 0.15);
|
||||
link
|
||||
.attr('stroke-opacity', d => {
|
||||
const s = (typeof d.source === 'object') ? d.source.id : d.source;
|
||||
const t = (typeof d.target === 'object') ? d.target.id : d.target;
|
||||
return (s === n.id || t === n.id || (ns.has(s) && ns.has(t))) ? 0.9 : 0.05;
|
||||
})
|
||||
.attr('marker-end', d => {
|
||||
const c = relColor.get(d.relationship_type) || '#CBD5E1';
|
||||
return markerFor(d.relationship_type || 'rel', c);
|
||||
});
|
||||
linkLabel.attr('opacity', d => {
|
||||
const s = (typeof d.source === 'object') ? d.source.id : d.source;
|
||||
const t = (typeof d.target === 'object') ? d.target.id : d.target;
|
||||
return (s === n.id || t === n.id) ? 0.9 : 0.05;
|
||||
});
|
||||
}
|
||||
function clearHighlight() {
|
||||
node.attr('opacity', 1);
|
||||
label.attr('opacity', 1);
|
||||
link.attr('stroke-opacity', 0.7);
|
||||
linkLabel.attr('opacity', 0.7);
|
||||
}
|
||||
|
||||
// Search + center helpers
|
||||
function centerOnNode(n) {
|
||||
const k = 1.5; // zoom factor
|
||||
const x = n.x, y = n.y;
|
||||
const transform = d3.zoomIdentity.translate(width / 2 - k * x, height / 2 - k * y).scale(k);
|
||||
svg.transition().duration(350).call(zoom.transform, transform);
|
||||
}
|
||||
function focusSearch(query) {
|
||||
if (!query) return;
|
||||
const q = query.toLowerCase();
|
||||
const found = data.nodes.find(n => (n.name || '').toLowerCase().includes(q));
|
||||
if (found) { setHighlight(found); centerOnNode(found); }
|
||||
}
|
||||
|
||||
// Expose zoom instance
|
||||
const zoom = d3.zoom().scaleExtent([0.25, 5]).on('zoom', (event) => g.attr('transform', event.transform));
|
||||
svg.call(zoom);
|
||||
function zoomTo(k, center) {
|
||||
const transform = d3.zoomIdentity.translate(width / 2 - k * center[0], height / 2 - k * center[1]).scale(k);
|
||||
svg.transition().duration(250).call(zoom.transform, transform);
|
||||
}
|
||||
|
||||
// Tick update
|
||||
simulation.on('tick', () => {
|
||||
link.attr('d', curvedPath);
|
||||
node.attr('cx', d => d.x).attr('cy', d => d.y);
|
||||
label.attr('x', d => d.x).attr('y', d => d.y);
|
||||
linkLabel.attr('x', d => (d.source.x + d.target.x) / 2).attr('y', d => (d.source.y + d.target.y) / 2);
|
||||
});
|
||||
|
||||
// Resize handling
|
||||
const ro = new ResizeObserver(() => {
|
||||
const w = container.clientWidth || width;
|
||||
const h = container.clientHeight || height;
|
||||
svg.attr('viewBox', [0, 0, w, h]).attr('height', h);
|
||||
simulation.force('center', d3.forceCenter(w / 2, h / 2));
|
||||
simulation.alpha(0.3).restart();
|
||||
});
|
||||
ro.observe(container);
|
||||
}
|
||||
|
||||
function tryRender(root) {
|
||||
const container = (root || document).querySelector('#knowledge-graph');
|
||||
if (container) renderKnowledgeGraph(root);
|
||||
}
|
||||
|
||||
// Expose for debugging/manual re-render
|
||||
window.renderKnowledgeGraph = () => renderKnowledgeGraph(document);
|
||||
|
||||
// Full page load
|
||||
document.addEventListener('DOMContentLoaded', () => tryRender(document));
|
||||
|
||||
// HTMX partial swaps
|
||||
document.body.addEventListener('knowledge-graph-refresh', () => {
|
||||
tryRender(document);
|
||||
});
|
||||
|
||||
document.body.addEventListener('htmx:afterSettle', (evt) => {
|
||||
tryRender(evt && evt.target ? evt.target : document);
|
||||
});
|
||||
})();
|
||||
@@ -6,33 +6,31 @@
|
||||
return;
|
||||
}
|
||||
const alert = document.createElement('div');
|
||||
// Base classes for the alert
|
||||
alert.className = `alert alert-${type} mt-2 shadow-md flex flex-col text-start`;
|
||||
alert.className = `alert toast-alert alert-${type}`;
|
||||
alert.style.opacity = '1';
|
||||
alert.style.transition = 'opacity 0.5s ease-out';
|
||||
|
||||
// Build inner HTML based on whether title is provided
|
||||
let innerHTML = '';
|
||||
if (title) {
|
||||
innerHTML += `<div class="font-bold text-lg">${title}</div>`; // Title element
|
||||
innerHTML += `<div>${description}</div>`; // Description element
|
||||
} else {
|
||||
// Structure without title
|
||||
innerHTML += `<span>${description}</span>`;
|
||||
const titleEl = document.createElement('div');
|
||||
titleEl.className = 'toast-alert-title';
|
||||
titleEl.textContent = title;
|
||||
alert.appendChild(titleEl);
|
||||
}
|
||||
|
||||
alert.innerHTML = innerHTML;
|
||||
const bodyEl = document.createElement(title ? 'div' : 'span');
|
||||
bodyEl.textContent = description;
|
||||
alert.appendChild(bodyEl);
|
||||
|
||||
container.appendChild(alert);
|
||||
|
||||
// Auto-remove after a delay
|
||||
setTimeout(() => {
|
||||
// Optional: Add fade-out effect
|
||||
alert.style.opacity = '0';
|
||||
alert.style.transition = 'opacity 0.5s ease-out';
|
||||
setTimeout(() => alert.remove(), 500); // Remove after fade
|
||||
}, 3000); // Start fade-out after 3 seconds
|
||||
setTimeout(() => alert.remove(), 500);
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
document.body.addEventListener('toast', function (event) {
|
||||
console.log(event);
|
||||
// Extract data from the event detail, matching the Rust payload
|
||||
const detail = event.detail;
|
||||
if (detail && detail.description) {
|
||||
@@ -54,4 +52,3 @@
|
||||
if (container) container.innerHTML = '';
|
||||
});
|
||||
})
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{
|
||||
"name": "html-router",
|
||||
"version": "1.0.0",
|
||||
"main": "tailwind.config.js",
|
||||
"scripts": {
|
||||
"tailwind": "npx @tailwindcss/cli -i app.css -o assets/style.css -w -m"
|
||||
},
|
||||
@@ -14,4 +13,4 @@
|
||||
"daisyui": "^5.0.12",
|
||||
"tailwindcss": "^4.1.2"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use common::storage::db::SurrealDbClient;
|
||||
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
||||
use composite_retrieval::reranking::RerankerPool;
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
@@ -13,14 +14,18 @@ pub struct HtmlState {
|
||||
pub templates: Arc<TemplateEngine>,
|
||||
pub session_store: Arc<SessionStoreType>,
|
||||
pub config: AppConfig,
|
||||
pub storage: StorageManager,
|
||||
pub reranker_pool: Option<Arc<RerankerPool>>,
|
||||
}
|
||||
|
||||
impl HtmlState {
|
||||
pub fn new_with_resources(
|
||||
pub async fn new_with_resources(
|
||||
db: Arc<SurrealDbClient>,
|
||||
openai_client: Arc<OpenAIClientType>,
|
||||
session_store: Arc<SessionStoreType>,
|
||||
storage: StorageManager,
|
||||
config: AppConfig,
|
||||
reranker_pool: Option<Arc<RerankerPool>>,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let template_engine = create_template_engine!("templates");
|
||||
debug!("Template engine created for html_router.");
|
||||
@@ -31,6 +36,8 @@ impl HtmlState {
|
||||
session_store,
|
||||
templates: Arc::new(template_engine),
|
||||
config,
|
||||
storage,
|
||||
reranker_pool,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ pub mod html_state;
|
||||
pub mod middlewares;
|
||||
pub mod router_factory;
|
||||
pub mod routes;
|
||||
pub mod utils;
|
||||
|
||||
use axum::{extract::FromRef, Router};
|
||||
use axum_session::{Session, SessionStore};
|
||||
@@ -35,5 +36,7 @@ where
|
||||
.add_protected_routes(routes::content::router())
|
||||
.add_protected_routes(routes::knowledge::router())
|
||||
.add_protected_routes(routes::ingestion::router())
|
||||
.add_protected_routes(routes::scratchpad::router())
|
||||
.with_compression()
|
||||
.build()
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::Method,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
@@ -19,7 +20,8 @@ where
|
||||
S: ProvidesDb + Clone + Send + Sync + 'static,
|
||||
{
|
||||
let path = request.uri().path();
|
||||
if !path.starts_with("/assets") && !path.contains('.') {
|
||||
// Only count visits/page loads for GET requests to non-asset, non-static paths
|
||||
if request.method() == Method::GET && !path.starts_with("/assets") && !path.contains('.') {
|
||||
if !session.get::<bool>("counted_visitor").unwrap_or(false) {
|
||||
let _ = Analytics::increment_visitors(state.db()).await;
|
||||
session.set("counted_visitor", true);
|
||||
|
||||
7
html-router/src/middlewares/compression.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
use tower_http::compression::CompressionLayer;
|
||||
|
||||
/// Provides a default compression layer that negotiates encoding based on the
|
||||
/// `Accept-Encoding` header of the incoming request.
|
||||
pub fn compression_layer() -> CompressionLayer {
|
||||
CompressionLayer::new()
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod analytics_middleware;
|
||||
pub mod auth_middleware;
|
||||
pub mod compression;
|
||||
pub mod response_middleware;
|
||||
|
||||
@@ -188,7 +188,7 @@ where
|
||||
if is_htmx {
|
||||
(StatusCode::OK, [(axum_htmx::HX_REDIRECT, path)], "").into_response()
|
||||
} else {
|
||||
Redirect::to(&path).into_response()
|
||||
Redirect::to(path).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -205,26 +205,26 @@ pub enum HtmlError {
|
||||
|
||||
impl From<AppError> for HtmlError {
|
||||
fn from(err: AppError) -> Self {
|
||||
HtmlError::AppError(err)
|
||||
Self::AppError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<surrealdb::Error> for HtmlError {
|
||||
fn from(err: surrealdb::Error) -> Self {
|
||||
HtmlError::AppError(AppError::from(err))
|
||||
Self::AppError(AppError::from(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<minijinja::Error> for HtmlError {
|
||||
fn from(err: minijinja::Error) -> Self {
|
||||
HtmlError::TemplateError(err.to_string())
|
||||
Self::TemplateError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for HtmlError {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
HtmlError::AppError(err) => match err {
|
||||
Self::AppError(err) => match err {
|
||||
AppError::NotFound(_) => TemplateResponse::not_found().into_response(),
|
||||
AppError::Auth(_) => TemplateResponse::unauthorized().into_response(),
|
||||
AppError::Validation(msg) => TemplateResponse::bad_request(&msg).into_response(),
|
||||
@@ -233,7 +233,7 @@ impl IntoResponse for HtmlError {
|
||||
TemplateResponse::server_error().into_response()
|
||||
}
|
||||
},
|
||||
HtmlError::TemplateError(err) => {
|
||||
Self::TemplateError(err) => {
|
||||
error!("Template error: {}", err);
|
||||
TemplateResponse::server_error().into_response()
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
analytics_middleware::analytics_middleware, auth_middleware::require_auth,
|
||||
response_middleware::with_template_response,
|
||||
compression::compression_layer, response_middleware::with_template_response,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -48,6 +48,7 @@ pub struct RouterFactory<S> {
|
||||
nested_protected_routes: Vec<(String, Router<S>)>,
|
||||
custom_middleware: MiddleWareVecType<S>,
|
||||
public_assets_config: Option<AssetsConfig>,
|
||||
compression_enabled: bool,
|
||||
}
|
||||
|
||||
struct AssetsConfig {
|
||||
@@ -69,6 +70,7 @@ where
|
||||
nested_protected_routes: Vec::new(),
|
||||
custom_middleware: Vec::new(),
|
||||
public_assets_config: None,
|
||||
compression_enabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,6 +117,12 @@ where
|
||||
self
|
||||
}
|
||||
|
||||
/// Enables response compression when building the router.
|
||||
pub const fn with_compression(mut self) -> Self {
|
||||
self.compression_enabled = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Router<S> {
|
||||
// Start with an empty router
|
||||
let mut public_router = Router::new();
|
||||
@@ -169,21 +177,26 @@ where
|
||||
}
|
||||
|
||||
// Apply common middleware
|
||||
router = router.layer(from_fn_with_state(
|
||||
self.app_state.clone(),
|
||||
analytics_middleware::<HtmlState>,
|
||||
));
|
||||
router = router.layer(map_response_with_state(
|
||||
self.app_state.clone(),
|
||||
with_template_response::<HtmlState>,
|
||||
));
|
||||
router = router.layer(
|
||||
AuthSessionLayer::<User, String, SessionSurrealPool<Any>, Surreal<Any>>::new(Some(
|
||||
self.app_state.db.client.clone(),
|
||||
))
|
||||
.with_config(AuthConfig::<String>::default()),
|
||||
);
|
||||
router = router.layer(SessionLayer::new((*self.app_state.session_store).clone()));
|
||||
|
||||
if self.compression_enabled {
|
||||
router = router.layer(compression_layer());
|
||||
}
|
||||
|
||||
router
|
||||
.layer(from_fn_with_state(
|
||||
self.app_state.clone(),
|
||||
analytics_middleware::<HtmlState>,
|
||||
))
|
||||
.layer(map_response_with_state(
|
||||
self.app_state.clone(),
|
||||
with_template_response::<HtmlState>,
|
||||
))
|
||||
.layer(
|
||||
AuthSessionLayer::<User, String, SessionSurrealPool<Any>, Surreal<Any>>::new(Some(
|
||||
self.app_state.db.client.clone(),
|
||||
))
|
||||
.with_config(AuthConfig::<String>::default()),
|
||||
)
|
||||
.layer(SessionLayer::new((*self.app_state.session_store).clone()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,10 @@ pub async fn show_account_page(
|
||||
RequireUser(user): RequireUser,
|
||||
State(state): State<HtmlState>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
|
||||
let timezones = TZ_VARIANTS
|
||||
.iter()
|
||||
.map(std::string::ToString::to_string)
|
||||
.collect();
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -102,7 +105,10 @@ pub async fn update_timezone(
|
||||
..user.clone()
|
||||
};
|
||||
|
||||
let timezones = TZ_VARIANTS.iter().map(|tz| tz.to_string()).collect();
|
||||
let timezones = TZ_VARIANTS
|
||||
.iter()
|
||||
.map(std::string::ToString::to_string)
|
||||
.collect();
|
||||
|
||||
// Render the API key section block
|
||||
Ok(TemplateResponse::new_partial(
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
use async_openai::types::ListModelResponse;
|
||||
use axum::{extract::State, response::IntoResponse, Form};
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::IntoResponse,
|
||||
Form,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use common::{
|
||||
@@ -31,44 +35,83 @@ use crate::{
|
||||
pub struct AdminPanelData {
|
||||
user: User,
|
||||
settings: SystemSettings,
|
||||
analytics: Analytics,
|
||||
users: i64,
|
||||
analytics: Option<Analytics>,
|
||||
users: Option<i64>,
|
||||
default_query_prompt: String,
|
||||
default_image_prompt: String,
|
||||
conversation_archive: Vec<Conversation>,
|
||||
available_models: ListModelResponse,
|
||||
available_models: Option<ListModelResponse>,
|
||||
current_section: AdminSection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AdminSection {
|
||||
Overview,
|
||||
Models,
|
||||
}
|
||||
|
||||
impl Default for AdminSection {
|
||||
fn default() -> Self {
|
||||
Self::Overview
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AdminPanelQuery {
|
||||
section: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn show_admin_panel(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Query(query): Query<AdminPanelQuery>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let (
|
||||
settings_res,
|
||||
analytics_res,
|
||||
user_count_res,
|
||||
conversation_archive_res,
|
||||
available_models_res,
|
||||
) = tokio::join!(
|
||||
let section = match query.section.as_deref() {
|
||||
Some("models") => AdminSection::Models,
|
||||
_ => AdminSection::Overview,
|
||||
};
|
||||
|
||||
let (settings, conversation_archive) = tokio::try_join!(
|
||||
SystemSettings::get_current(&state.db),
|
||||
Analytics::get_current(&state.db),
|
||||
Analytics::get_users_amount(&state.db),
|
||||
User::get_user_conversations(&user.id, &state.db),
|
||||
async { state.openai_client.models().list().await }
|
||||
);
|
||||
User::get_user_conversations(&user.id, &state.db)
|
||||
)?;
|
||||
|
||||
let (analytics, users) = if section == AdminSection::Overview {
|
||||
let (analytics, users) = tokio::try_join!(
|
||||
Analytics::get_current(&state.db),
|
||||
Analytics::get_users_amount(&state.db)
|
||||
)?;
|
||||
(Some(analytics), Some(users))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let available_models = if section == AdminSection::Models {
|
||||
Some(
|
||||
state
|
||||
.openai_client
|
||||
.models()
|
||||
.list()
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"admin/base.html",
|
||||
AdminPanelData {
|
||||
user,
|
||||
settings: settings_res?,
|
||||
analytics: analytics_res?,
|
||||
available_models: available_models_res
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?,
|
||||
users: user_count_res?,
|
||||
settings,
|
||||
analytics,
|
||||
available_models,
|
||||
users,
|
||||
default_query_prompt: DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
|
||||
default_image_prompt: DEFAULT_IMAGE_PROCESSING_PROMPT.to_string(),
|
||||
conversation_archive: conversation_archive_res?,
|
||||
conversation_archive,
|
||||
current_section: section,
|
||||
},
|
||||
))
|
||||
}
|
||||
@@ -103,7 +146,7 @@ pub async fn toggle_registration_status(
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
};
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
@@ -115,7 +158,7 @@ pub async fn toggle_registration_status(
|
||||
SystemSettings::update(&state.db, new_settings.clone()).await?;
|
||||
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"admin/base.html",
|
||||
"admin/sections/overview.html",
|
||||
"registration_status_input",
|
||||
RegistrationToggleData {
|
||||
settings: new_settings,
|
||||
@@ -128,6 +171,7 @@ pub struct ModelSettingsInput {
|
||||
query_model: String,
|
||||
processing_model: String,
|
||||
image_processing_model: String,
|
||||
voice_processing_model: String,
|
||||
embedding_model: String,
|
||||
embedding_dimensions: Option<u32>,
|
||||
}
|
||||
@@ -146,7 +190,7 @@ pub async fn update_model_settings(
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
};
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
@@ -159,6 +203,7 @@ pub async fn update_model_settings(
|
||||
query_model: input.query_model,
|
||||
processing_model: input.processing_model,
|
||||
image_processing_model: input.image_processing_model,
|
||||
voice_processing_model: input.voice_processing_model,
|
||||
embedding_model: input.embedding_model,
|
||||
// Use new dimensions if provided, otherwise retain the current ones.
|
||||
embedding_dimensions: input
|
||||
@@ -215,7 +260,7 @@ pub async fn update_model_settings(
|
||||
.map_err(|_e| AppError::InternalError("Failed to get models".to_string()))?;
|
||||
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"admin/base.html",
|
||||
"admin/sections/models.html",
|
||||
"model_settings_form",
|
||||
ModelSettingsData {
|
||||
settings: new_settings,
|
||||
@@ -237,7 +282,7 @@ pub async fn show_edit_system_prompt(
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
};
|
||||
}
|
||||
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
@@ -268,7 +313,7 @@ pub async fn patch_query_prompt(
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
};
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
@@ -280,7 +325,7 @@ pub async fn patch_query_prompt(
|
||||
SystemSettings::update(&state.db, new_settings.clone()).await?;
|
||||
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"admin/base.html",
|
||||
"admin/sections/overview.html",
|
||||
"system_prompt_section",
|
||||
SystemPromptSectionData {
|
||||
settings: new_settings,
|
||||
@@ -301,7 +346,7 @@ pub async fn show_edit_ingestion_prompt(
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
};
|
||||
}
|
||||
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
@@ -327,7 +372,7 @@ pub async fn patch_ingestion_prompt(
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
};
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
@@ -339,7 +384,7 @@ pub async fn patch_ingestion_prompt(
|
||||
SystemSettings::update(&state.db, new_settings.clone()).await?;
|
||||
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"admin/base.html",
|
||||
"admin/sections/overview.html",
|
||||
"system_prompt_section",
|
||||
SystemPromptSectionData {
|
||||
settings: new_settings,
|
||||
@@ -360,7 +405,7 @@ pub async fn show_edit_image_prompt(
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
};
|
||||
}
|
||||
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
@@ -386,7 +431,7 @@ pub async fn patch_image_prompt(
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
};
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
@@ -398,10 +443,10 @@ pub async fn patch_image_prompt(
|
||||
SystemSettings::update(&state.db, new_settings.clone()).await?;
|
||||
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"admin/base.html",
|
||||
"admin/sections/overview.html",
|
||||
"system_prompt_section",
|
||||
SystemPromptSectionData {
|
||||
settings: new_settings,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,13 +27,14 @@ pub async fn show_signin_form(
|
||||
if auth.is_authenticated() {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
match boosted {
|
||||
true => Ok(TemplateResponse::new_partial(
|
||||
if boosted {
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"auth/signin_base.html",
|
||||
"body",
|
||||
(),
|
||||
)),
|
||||
false => Ok(TemplateResponse::new_template("auth/signin_base.html", ())),
|
||||
))
|
||||
} else {
|
||||
Ok(TemplateResponse::new_template("auth/signin_base.html", ()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,13 +29,14 @@ pub async fn show_signup_form(
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
match boosted {
|
||||
true => Ok(TemplateResponse::new_partial(
|
||||
if boosted {
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"auth/signup_form.html",
|
||||
"body",
|
||||
(),
|
||||
)),
|
||||
false => Ok(TemplateResponse::new_template("auth/signup_form.html", ())),
|
||||
))
|
||||
} else {
|
||||
Ok(TemplateResponse::new_template("auth/signup_form.html", ()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +49,7 @@ pub async fn process_signup_and_show_verification(
|
||||
Ok(user) => user,
|
||||
Err(e) => {
|
||||
tracing::error!("{:?}", e);
|
||||
return Ok(Html(format!("<p>{}</p>", e)).into_response());
|
||||
return Ok(Html(format!("<p>{e}</p>")).into_response());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ pub async fn show_existing_chat(
|
||||
ChatPageData {
|
||||
history: messages,
|
||||
user,
|
||||
conversation: Some(conversation.clone()),
|
||||
conversation: Some(conversation),
|
||||
conversation_archive,
|
||||
},
|
||||
))
|
||||
@@ -157,7 +157,7 @@ pub async fn new_user_message(
|
||||
|
||||
if conversation.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
};
|
||||
}
|
||||
|
||||
let user_message = Message::new(conversation_id, MessageRole::User, form.content, None);
|
||||
|
||||
|
||||
@@ -9,11 +9,8 @@ use axum::{
|
||||
},
|
||||
};
|
||||
use composite_retrieval::{
|
||||
answer_retrieval::{
|
||||
create_chat_request, create_user_message_with_history, format_entities_json,
|
||||
LLMResponseFormat,
|
||||
},
|
||||
retrieve_entities,
|
||||
answer_retrieval::{create_chat_request, create_user_message_with_history, LLMResponseFormat},
|
||||
retrieve_entities, retrieved_entities_to_json,
|
||||
};
|
||||
use futures::{
|
||||
stream::{self, once},
|
||||
@@ -121,11 +118,17 @@ pub async fn get_response_stream(
|
||||
};
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let entities = match retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
rerank_lease,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -136,7 +139,7 @@ pub async fn get_response_stream(
|
||||
};
|
||||
|
||||
// 3. Create the OpenAI request
|
||||
let entities_json = format_entities_json(&entities);
|
||||
let entities_json = retrieved_entities_to_json(&entities);
|
||||
let formatted_user_message =
|
||||
create_user_message_with_history(&entities_json, &history, &user_message.content);
|
||||
let settings = match SystemSettings::get_current(&state.db).await {
|
||||
@@ -251,7 +254,7 @@ pub async fn get_response_stream(
|
||||
Err(e) => {
|
||||
yield Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {}", e)));
|
||||
.data(format!("Stream error: {e}")));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -260,7 +263,11 @@ pub async fn get_response_stream(
|
||||
.chain(stream::once(async move {
|
||||
if let Some(message) = rx_final.recv().await {
|
||||
// Don't send any event if references is empty
|
||||
if message.references.as_ref().is_some_and(|x| x.is_empty()) {
|
||||
if message
|
||||
.references
|
||||
.as_ref()
|
||||
.is_some_and(std::vec::Vec::is_empty)
|
||||
{
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
|
||||
@@ -3,11 +3,12 @@ use axum::{
|
||||
response::IntoResponse,
|
||||
Form,
|
||||
};
|
||||
use axum_htmx::{HxBoosted, HxRequest};
|
||||
use axum_htmx::{HxBoosted, HxRequest, HxTarget};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use common::storage::types::{
|
||||
conversation::Conversation, file_info::FileInfo, text_content::TextContent, user::User, knowledge_entity::KnowledgeEntity, text_chunk::TextChunk,
|
||||
conversation::Conversation, file_info::FileInfo, knowledge_entity::KnowledgeEntity,
|
||||
text_chunk::TextChunk, text_content::TextContent, user::User,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -16,7 +17,12 @@ use crate::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
utils::pagination::{paginate_items, Pagination},
|
||||
utils::text_content_preview::truncate_text_contents,
|
||||
};
|
||||
use url::form_urlencoded;
|
||||
|
||||
const CONTENTS_PER_PAGE: usize = 12;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ContentPageData {
|
||||
@@ -25,11 +31,20 @@ pub struct ContentPageData {
|
||||
categories: Vec<String>,
|
||||
selected_category: Option<String>,
|
||||
conversation_archive: Vec<Conversation>,
|
||||
pagination: Pagination,
|
||||
page_query: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct RecentTextContentData {
|
||||
pub user: User,
|
||||
pub text_contents: Vec<TextContent>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FilterParams {
|
||||
category: Option<String>,
|
||||
page: Option<usize>,
|
||||
}
|
||||
|
||||
pub async fn show_content_page(
|
||||
@@ -40,17 +55,32 @@ pub async fn show_content_page(
|
||||
HxBoosted(is_boosted): HxBoosted,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Normalize empty strings to None
|
||||
let has_category_param = params.category.is_some();
|
||||
let category_filter = params.category.as_deref().unwrap_or("").trim();
|
||||
let category_filter = params
|
||||
.category
|
||||
.as_ref()
|
||||
.map(|c| c.trim())
|
||||
.filter(|c| !c.is_empty());
|
||||
|
||||
// load categories and filtered/all contents
|
||||
let categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
let text_contents = if !category_filter.is_empty() {
|
||||
User::get_text_contents_by_category(&user.id, category_filter, &state.db).await?
|
||||
} else {
|
||||
User::get_text_contents(&user.id, &state.db).await?
|
||||
let full_contents = match category_filter {
|
||||
Some(category) => {
|
||||
User::get_text_contents_by_category(&user.id, category, &state.db).await?
|
||||
}
|
||||
None => User::get_text_contents(&user.id, &state.db).await?,
|
||||
};
|
||||
|
||||
let (page_contents, pagination) = paginate_items(full_contents, params.page, CONTENTS_PER_PAGE);
|
||||
let text_contents = truncate_text_contents(page_contents);
|
||||
|
||||
let page_query = category_filter
|
||||
.map(|category| {
|
||||
let mut serializer = form_urlencoded::Serializer::new(String::new());
|
||||
serializer.append_pair("category", category);
|
||||
format!("&{}", serializer.finish())
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||
let data = ContentPageData {
|
||||
user,
|
||||
@@ -58,19 +88,19 @@ pub async fn show_content_page(
|
||||
categories,
|
||||
selected_category: params.category.clone(),
|
||||
conversation_archive,
|
||||
pagination,
|
||||
page_query,
|
||||
};
|
||||
|
||||
if is_htmx && !is_boosted && has_category_param {
|
||||
// If HTMX partial request with filter applied, return partial content list update
|
||||
return Ok(TemplateResponse::new_partial(
|
||||
if is_htmx && !is_boosted {
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"content/base.html",
|
||||
"main",
|
||||
data,
|
||||
));
|
||||
))
|
||||
} else {
|
||||
Ok(TemplateResponse::new_template("content/base.html", data))
|
||||
}
|
||||
|
||||
// Otherwise full page response including layout
|
||||
Ok(TemplateResponse::new_template("content/base.html", data))
|
||||
}
|
||||
|
||||
pub async fn show_text_content_edit_form(
|
||||
@@ -102,13 +132,32 @@ pub async fn patch_text_content(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
HxTarget(target): HxTarget,
|
||||
Form(form): Form<PatchTextContentParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
TextContent::patch(&id, &form.context, &form.category, &form.text, &state.db).await?;
|
||||
|
||||
let text_contents = User::get_text_contents(&user.id, &state.db).await?;
|
||||
if target.as_deref() == Some("latest_content_section") {
|
||||
let text_contents =
|
||||
truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?);
|
||||
|
||||
return Ok(TemplateResponse::new_template(
|
||||
"dashboard/recent_content.html",
|
||||
RecentTextContentData {
|
||||
user,
|
||||
text_contents,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let (page_contents, pagination) = paginate_items(
|
||||
User::get_text_contents(&user.id, &state.db).await?,
|
||||
Some(1),
|
||||
CONTENTS_PER_PAGE,
|
||||
);
|
||||
let text_contents = truncate_text_contents(page_contents);
|
||||
let categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||
|
||||
@@ -121,6 +170,8 @@ pub async fn patch_text_content(
|
||||
categories,
|
||||
selected_category: None,
|
||||
conversation_archive,
|
||||
pagination,
|
||||
page_query: String::new(),
|
||||
},
|
||||
))
|
||||
}
|
||||
@@ -134,8 +185,13 @@ pub async fn delete_text_content(
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
// If it has file info, delete that too
|
||||
if let Some(file_info) = &text_content.file_info {
|
||||
FileInfo::delete_by_id(&file_info.id, &state.db).await?;
|
||||
if let Some(file_info) = text_content.file_info.as_ref() {
|
||||
let file_in_use =
|
||||
TextContent::has_other_with_file(&file_info.id, &text_content.id, &state.db).await?;
|
||||
|
||||
if !file_in_use {
|
||||
FileInfo::delete_by_id_with_storage(&file_info.id, &state.db, &state.storage).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Delete related knowledge entities and text chunks
|
||||
@@ -146,7 +202,12 @@ pub async fn delete_text_content(
|
||||
state.db.delete_item::<TextContent>(&id).await?;
|
||||
|
||||
// Get updated content, categories and return the refreshed list
|
||||
let text_contents = User::get_text_contents(&user.id, &state.db).await?;
|
||||
let (page_contents, pagination) = paginate_items(
|
||||
User::get_text_contents(&user.id, &state.db).await?,
|
||||
Some(1),
|
||||
CONTENTS_PER_PAGE,
|
||||
);
|
||||
let text_contents = truncate_text_contents(page_contents);
|
||||
let categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||
|
||||
@@ -158,6 +219,8 @@ pub async fn delete_text_content(
|
||||
categories,
|
||||
selected_category: None,
|
||||
conversation_archive,
|
||||
pagination,
|
||||
page_query: String::new(),
|
||||
},
|
||||
))
|
||||
}
|
||||
@@ -185,13 +248,8 @@ pub async fn show_recent_content(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let text_contents = User::get_latest_text_contents(&user.id, &state.db).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct RecentTextContentData {
|
||||
pub user: User,
|
||||
pub text_contents: Vec<TextContent>,
|
||||
}
|
||||
let text_contents =
|
||||
truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?);
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"dashboard/recent_content.html",
|
||||
|
||||
@@ -4,17 +4,20 @@ use axum::{
|
||||
http::{header, HeaderMap, HeaderValue, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::try_join;
|
||||
use serde::Serialize;
|
||||
use tokio::{fs::File, join};
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
utils::text_content_preview::truncate_text_contents,
|
||||
AuthSessionType,
|
||||
};
|
||||
use common::storage::types::user::DashboardStats;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{
|
||||
@@ -24,12 +27,11 @@ use common::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct IndexPageData {
|
||||
user: Option<User>,
|
||||
text_contents: Vec<TextContent>,
|
||||
stats: DashboardStats,
|
||||
active_jobs: Vec<IngestionTask>,
|
||||
conversation_archive: Vec<Conversation>,
|
||||
}
|
||||
@@ -42,26 +44,30 @@ pub async fn index_handler(
|
||||
return Ok(TemplateResponse::redirect("/signin"));
|
||||
};
|
||||
|
||||
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
|
||||
let (text_contents, conversation_archive, stats, active_jobs) = try_join!(
|
||||
User::get_latest_text_contents(&user.id, &state.db),
|
||||
User::get_user_conversations(&user.id, &state.db),
|
||||
User::get_dashboard_stats(&user.id, &state.db),
|
||||
User::get_unfinished_ingestion_tasks(&user.id, &state.db)
|
||||
)?;
|
||||
|
||||
let text_contents = User::get_latest_text_contents(&user.id, &state.db).await?;
|
||||
|
||||
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
|
||||
let text_contents = truncate_text_contents(text_contents);
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"dashboard/base.html",
|
||||
IndexPageData {
|
||||
user: Some(user),
|
||||
text_contents,
|
||||
active_jobs,
|
||||
stats,
|
||||
conversation_archive,
|
||||
active_jobs,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LatestTextContentData {
|
||||
latest_text_contents: Vec<TextContent>,
|
||||
text_contents: Vec<TextContent>,
|
||||
user: User,
|
||||
}
|
||||
|
||||
@@ -73,30 +79,35 @@ pub async fn delete_text_content(
|
||||
// Get and validate TextContent
|
||||
let text_content = get_and_validate_text_content(&state, &id, &user).await?;
|
||||
|
||||
// Perform concurrent deletions
|
||||
let (_res1, _res2, _res3, _res4, _res5) = join!(
|
||||
async {
|
||||
if let Some(file_info) = text_content.file_info {
|
||||
FileInfo::delete_by_id(&file_info.id, &state.db).await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
state.db.delete_item::<TextContent>(&text_content.id),
|
||||
TextChunk::delete_by_source_id(&text_content.id, &state.db),
|
||||
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db),
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db)
|
||||
);
|
||||
// Remove stored assets before deleting the text content record
|
||||
if let Some(file_info) = text_content.file_info.as_ref() {
|
||||
let file_in_use =
|
||||
TextContent::has_other_with_file(&file_info.id, &text_content.id, &state.db).await?;
|
||||
|
||||
if !file_in_use {
|
||||
FileInfo::delete_by_id_with_storage(&file_info.id, &state.db, &state.storage).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the text content and any related data
|
||||
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
|
||||
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db).await?;
|
||||
state
|
||||
.db
|
||||
.delete_item::<TextContent>(&text_content.id)
|
||||
.await?;
|
||||
|
||||
// Render updated content
|
||||
let latest_text_contents = User::get_latest_text_contents(&user.id, &state.db).await?;
|
||||
let text_contents =
|
||||
truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?);
|
||||
|
||||
Ok(TemplateResponse::new_partial(
|
||||
"index/signed_in/recent_content.html",
|
||||
"dashboard/recent_content.html",
|
||||
"latest_content_section",
|
||||
LatestTextContentData {
|
||||
user: user.to_owned(),
|
||||
latest_text_contents,
|
||||
user: user.clone(),
|
||||
text_contents,
|
||||
},
|
||||
))
|
||||
}
|
||||
@@ -128,6 +139,32 @@ pub struct ActiveJobsData {
|
||||
pub user: User,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TaskArchiveEntry {
|
||||
id: String,
|
||||
state_label: String,
|
||||
state_raw: String,
|
||||
attempts: u32,
|
||||
max_attempts: u32,
|
||||
created_at: DateTime<Utc>,
|
||||
updated_at: DateTime<Utc>,
|
||||
scheduled_at: DateTime<Utc>,
|
||||
locked_at: Option<DateTime<Utc>>,
|
||||
last_error_at: Option<DateTime<Utc>>,
|
||||
error_message: Option<String>,
|
||||
worker_id: Option<String>,
|
||||
priority: i32,
|
||||
lease_duration_secs: i64,
|
||||
content_kind: String,
|
||||
content_summary: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TaskArchiveData {
|
||||
user: User,
|
||||
tasks: Vec<TaskArchiveEntry>,
|
||||
}
|
||||
|
||||
pub async fn delete_job(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
@@ -153,9 +190,8 @@ pub async fn show_active_jobs(
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let active_jobs = User::get_unfinished_ingestion_tasks(&user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_partial(
|
||||
Ok(TemplateResponse::new_template(
|
||||
"dashboard/active_jobs.html",
|
||||
"active_jobs_section",
|
||||
ActiveJobsData {
|
||||
user: user.clone(),
|
||||
active_jobs,
|
||||
@@ -163,6 +199,70 @@ pub async fn show_active_jobs(
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn show_task_archive(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let tasks = User::get_all_ingestion_tasks(&user.id, &state.db).await?;
|
||||
|
||||
let entries: Vec<TaskArchiveEntry> = tasks
|
||||
.into_iter()
|
||||
.map(|task| {
|
||||
let (content_kind, content_summary) = summarize_task_content(&task);
|
||||
|
||||
TaskArchiveEntry {
|
||||
id: task.id.clone(),
|
||||
state_label: task.state.display_label().to_string(),
|
||||
state_raw: task.state.as_str().to_string(),
|
||||
attempts: task.attempts,
|
||||
max_attempts: task.max_attempts,
|
||||
created_at: task.created_at,
|
||||
updated_at: task.updated_at,
|
||||
scheduled_at: task.scheduled_at,
|
||||
locked_at: task.locked_at,
|
||||
last_error_at: task.last_error_at,
|
||||
error_message: task.error_message.clone(),
|
||||
worker_id: task.worker_id.clone(),
|
||||
priority: task.priority,
|
||||
lease_duration_secs: task.lease_duration_secs,
|
||||
content_kind,
|
||||
content_summary,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"dashboard/task_archive_modal.html",
|
||||
TaskArchiveData {
|
||||
user,
|
||||
tasks: entries,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
fn summarize_task_content(task: &IngestionTask) -> (String, String) {
|
||||
match &task.content {
|
||||
common::storage::types::ingestion_payload::IngestionPayload::Text { text, .. } => {
|
||||
("Text".to_string(), truncate_summary(text, 80))
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
|
||||
("URL".to_string(), url.to_string())
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => {
|
||||
("File".to_string(), file_info.file_name.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_summary(input: &str, max_chars: usize) -> String {
|
||||
if input.chars().count() <= max_chars {
|
||||
input.to_string()
|
||||
} else {
|
||||
let truncated: String = input.chars().take(max_chars).collect();
|
||||
format!("{truncated}…")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_file(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
@@ -177,14 +277,10 @@ pub async fn serve_file(
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
}
|
||||
|
||||
let path = std::path::Path::new(&file_info.path);
|
||||
|
||||
let file = match File::open(path).await {
|
||||
Ok(f) => f,
|
||||
Err(_e) => return Ok(TemplateResponse::server_error().into_response()),
|
||||
let stream = match state.storage.get_stream(&file_info.path).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Ok(TemplateResponse::server_error().into_response()),
|
||||
};
|
||||
|
||||
let stream = ReaderStream::new(file);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
@@ -5,7 +5,9 @@ use axum::{
|
||||
routing::{delete, get},
|
||||
Router,
|
||||
};
|
||||
use handlers::{delete_job, delete_text_content, index_handler, serve_file, show_active_jobs};
|
||||
use handlers::{
|
||||
delete_job, delete_text_content, index_handler, serve_file, show_active_jobs, show_task_archive,
|
||||
};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
@@ -24,6 +26,7 @@ where
|
||||
{
|
||||
Router::new()
|
||||
.route("/jobs/{job_id}", delete(delete_job))
|
||||
.route("/jobs/archive", get(show_task_archive))
|
||||
.route("/active-jobs", get(show_active_jobs))
|
||||
.route("/text-content/{id}", delete(delete_text_content))
|
||||
.route("/file/{id}", get(serve_file))
|
||||
|
||||