22 Commits

Author SHA1 Message Date
Per Stark
2964f1a5a5 release: 0.2.4 2025-10-15 09:09:35 +02:00
Per Stark
cb7f625b81 fix: score normalization for vector search 2025-10-14 21:13:58 +02:00
Per Stark
dc40cf7663 feat: hybrid search 2025-10-14 20:38:43 +02:00
Per Stark
aa0b1462a1 feat: task archive
fix: simplified
2025-10-14 10:38:09 +02:00
Per Stark
41fc7bb99c feat: state machine for tasks, multiple workers 2025-10-12 22:21:20 +02:00
Per Stark
61d8d7abe7 release: 0.2.3 2025-10-12 20:15:10 +02:00
Per Stark
b7344644dc fix: embedding dimension change on fresh db 2025-10-12 20:13:23 +02:00
Per Stark
3742598a6d chore: changed image in readme 2025-10-08 12:04:47 +02:00
Per Stark
c6a6080e1c release: 0.2.2 2025-10-07 11:51:33 +02:00
Per Stark
1159712724 fix: convert to surrealdb datetime before conversion 2025-10-03 15:33:28 +02:00
Per Stark
e5e1414f54 chore: clippy magic 2025-10-01 15:39:45 +02:00
Per Stark
fcc49b1954 design: new icons to match new theme 2025-10-01 10:17:43 +02:00
Per Stark
022f4d8575 fix: compliant with gpt-5 models 2025-10-01 10:17:31 +02:00
Per Stark
945a2b7f37 fix: do not log config here 2025-09-30 15:22:14 +02:00
Per Stark
ff4ea55cd5 fix: user guard on knowledge relationship deletion 2025-09-30 11:15:53 +02:00
Per Stark
c4c76efe92 test: startup smoke test 2025-09-29 21:15:34 +02:00
Per Stark
c0fcad5952 fix: deletion of items, shared files etc 2025-09-29 20:28:06 +02:00
Per Stark
b0ed69330d fix: improved concurrency 2025-09-28 22:08:08 +02:00
Per Stark
5cb15dab45 feat: pdf support 2025-09-28 20:53:51 +02:00
Per Stark
7403195df5 release: 0.2.1
chore: remove stale todo

chore: version bump
2025-09-24 10:25:56 +02:00
Per Stark
9faef31387 fix: json response in api to work with ios shortcuts
fix corrected lockfile
2025-09-23 22:01:58 +02:00
Per Stark
110f7b8a8f no graph screenshot in readme, too much image 2025-09-23 09:08:17 +02:00
55 changed files with 4553 additions and 729 deletions

50
CHANGELOG.md Normal file
View File

@@ -0,0 +1,50 @@
# Changelog
## Unreleased
## Version 0.2.4 (2025-10-15)
- Improved retrieval performance. Ingestion and chat now utilizes full text search, vector comparison and graph traversal.
- Ingestion task archive
## Version 0.2.3 (2025-10-12)
- Fix changing vector dimensions on a fresh database (#3)
## Version 0.2.2 (2025-10-07)
- Support for ingestion of PDF files
- Improved ingestion speed
- Fix deletion of items work as expected
- Fix enabling GPT-5 use via OpenAI API
## Version 0.2.1 (2025-09-24)
- Fixed API JSON responses so iOS Shortcuts integrations keep working.
## Version 0.2.0 (2025-09-23)
- Revamped the UI with a neobrutalist theme, better dark mode, and a D3-based knowledge graph.
- Added pagination for entities and content plus new observability metrics on the dashboard.
- Enabled audio ingestion and merged the new storage backend.
- Improved performance, request filtering, and journalctl/systemd compatibility.
## Version 0.1.4 (2025-07-01)
- Added image ingestion with configurable system settings and updated Docker Compose docs.
- Hardened admin flows by fixing concurrent API/database calls and normalizing task statuses.
## Version 0.1.3 (2025-06-08)
- Added support for AI providers beyond OpenAI.
- Made the HTTP port configurable for deployments.
- Smoothed graph mapper failures, long content tiles, and refreshed project documentation.
## Version 0.1.2 (2025-05-26)
- Introduced full-text search across indexed knowledge.
- Polished the UI with consistent titles, icon fallbacks, and improved markdown scrolling.
- Fixed search result links and SurrealDB vector formatting glitches.
## Version 0.1.1 (2025-05-13)
- Added streaming feedback to ingestion tasks for clearer progress updates.
- Made the data storage path configurable.
- Improved release tooling with Chromium-enabled Nix flakes, Docker builds, and migration/template fixes.
## Version 0.1.0 (2025-05-06)
- Initial release with a SurrealDB-backed ingestion pipeline, job queue, vector search, and knowledge graph storage.
- Delivered a chat experience featuring streaming responses, conversation history, markdown rendering, and customizable system prompts.
- Introduced an admin console with analytics, registration and timezone controls, and job monitoring.
- Shipped a Tailwind/daisyUI web UI with responsive layouts, modals, content viewers, and editing flows.
- Provided readability-based content ingestion, API/HTML ingress routes, and Docker/Docker Compose tooling.

299
Cargo.lock generated
View File

@@ -36,6 +36,15 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "adobe-cmap-parser"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae8abfa9a4688de8fc9f42b3f013b6fffec18ed8a554f5f113577e0b9b3212a3"
dependencies = [
"pom",
]
[[package]]
name = "aead"
version = "0.5.2"
@@ -326,15 +335,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "async-convert"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae"
dependencies = [
"async-trait",
]
[[package]]
name = "async-executor"
version = "1.13.2"
@@ -422,30 +422,41 @@ dependencies = [
[[package]]
name = "async-openai"
version = "0.24.1"
version = "0.29.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6db3286b4f52b6556ac5208fb575d035eca61a2bf40d7e75d1db2733ffc599f"
checksum = "d4fc47ec9e669d562e0755f59e1976d157546910e403f3c2da856d0a4d3cdc07"
dependencies = [
"async-convert",
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand 0.8.5",
"rand 0.9.1",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 1.0.69",
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-openai-macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "async-recursion"
version = "1.1.1"
@@ -881,6 +892,15 @@ dependencies = [
"generic-array",
]
[[package]]
name = "block-padding"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93"
dependencies = [
"generic-array",
]
[[package]]
name = "blowfish"
version = "0.9.1"
@@ -963,6 +983,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "bytecount"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e"
[[package]]
name = "bytemuck"
version = "1.23.0"
@@ -993,6 +1019,15 @@ dependencies = [
"rustversion",
]
[[package]]
name = "cbc"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6"
dependencies = [
"cipher",
]
[[package]]
name = "cc"
version = "1.2.21"
@@ -1061,6 +1096,12 @@ dependencies = [
"unicode-security",
]
[[package]]
name = "cff-parser"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31f5b6e9141c036f3ff4ce7b2f7e432b0f00dee416ddcd4f17741d189ddc2e9d"
[[package]]
name = "cfg-if"
version = "1.0.0"
@@ -1281,6 +1322,7 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"state-machines",
"surrealdb",
"surrealdb-migrations",
"tempfile",
@@ -1840,6 +1882,15 @@ dependencies = [
"num-traits",
]
[[package]]
name = "ecb"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a8bfa975b1aec2145850fcaa1c6fe269a16578c44705a532ae3edc92b8881c7"
dependencies = [
"cipher",
]
[[package]]
name = "either"
version = "1.15.0"
@@ -1892,6 +1943,15 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "euclid"
version = "0.20.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bb7ef65b3777a325d1eeefefab5b6d4959da54747e33bd6258e789640f307ad"
dependencies = [
"num-traits",
]
[[package]]
name = "event-listener"
version = "5.4.0"
@@ -2866,6 +2926,8 @@ dependencies = [
"dom_smoothie",
"futures",
"headless_chrome",
"lopdf 0.32.0",
"pdf-extract",
"reqwest",
"serde",
"serde_json",
@@ -2904,6 +2966,7 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
dependencies = [
"block-padding",
"generic-array",
]
@@ -3133,6 +3196,12 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]]
name = "linux-raw-sys"
version = "0.9.4"
@@ -3161,6 +3230,51 @@ version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "lopdf"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e775e4ee264e8a87d50a9efef7b67b4aa988cf94e75630859875fc347e6c872b"
dependencies = [
"chrono",
"encoding_rs",
"flate2",
"itoa",
"linked-hash-map",
"log",
"md5",
"nom 7.1.3",
"rayon",
"time",
"weezl",
]
[[package]]
name = "lopdf"
version = "0.36.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fa2559e99ba0f26a12458aabc754432c805bbb8cba516c427825a997af1fb7"
dependencies = [
"aes",
"bitflags 2.9.0",
"cbc",
"ecb",
"encoding_rs",
"flate2",
"indexmap 2.9.0",
"itoa",
"log",
"md-5",
"nom 8.0.0",
"nom_locate",
"rand 0.9.1",
"rangemap",
"sha2",
"stringprep",
"thiserror 2.0.12",
"weezl",
]
[[package]]
name = "lru"
version = "0.12.5"
@@ -3178,28 +3292,22 @@ checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4"
[[package]]
name = "main"
version = "0.1.4"
version = "0.2.4"
dependencies = [
"anyhow",
"api-router",
"async-openai",
"axum",
"axum_session",
"axum_session_surreal",
"common",
"cookie",
"futures",
"headless_chrome",
"html-router",
"ingestion-pipeline",
"reqwest",
"serde",
"serde_json",
"serial_test",
"surrealdb",
"tempfile",
"thiserror 1.0.69",
"tokio",
"tower",
"tracing",
"tracing-subscriber",
"uuid",
@@ -3282,6 +3390,12 @@ dependencies = [
"digest",
]
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "memchr"
version = "2.7.4"
@@ -3514,6 +3628,17 @@ dependencies = [
"memchr",
]
[[package]]
name = "nom_locate"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b577e2d69827c4740cba2b52efaad1c4cc7c73042860b199710b3575c68438d"
dependencies = [
"bytecount",
"memchr",
"nom 8.0.0",
]
[[package]]
name = "nonempty"
version = "0.7.0"
@@ -3836,6 +3961,23 @@ dependencies = [
"sha2",
]
[[package]]
name = "pdf-extract"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c2f44c6c642e359e2fe7f662bf5438db3811b6b4be60afc6de04b619ce51e1a"
dependencies = [
"adobe-cmap-parser",
"cff-parser",
"encoding_rs",
"euclid",
"log",
"lopdf 0.36.0",
"postscript",
"type1-encoding-parser",
"unicode-normalization",
]
[[package]]
name = "pem"
version = "3.0.5"
@@ -4027,6 +4169,18 @@ dependencies = [
"universal-hash",
]
[[package]]
name = "pom"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60f6ce597ecdcc9a098e7fddacb1065093a3d66446fa16c675e7e71d1b5c28e6"
[[package]]
name = "postscript"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78451badbdaebaf17f053fd9152b3ffb33b516104eacb45e7864aaa9c712f306"
[[package]]
name = "powerfmt"
version = "0.2.0"
@@ -4292,6 +4446,12 @@ dependencies = [
"getrandom 0.3.2",
]
[[package]]
name = "rangemap"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f93e7e49bb0bf967717f7bd674458b3d6b0c5f48ec7e3038166026a69fc22223"
[[package]]
name = "rawpointer"
version = "0.2.1"
@@ -4854,9 +5014,9 @@ checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b"
[[package]]
name = "secrecy"
version = "0.8.0"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
@@ -5053,31 +5213,6 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "serial_test"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e56dd856803e253c8f298af3f4d7eb0ae5e23a737252cd90bb4f3b435033b2d"
dependencies = [
"dashmap 5.5.3",
"futures",
"lazy_static",
"log",
"parking_lot",
"serial_test_derive",
]
[[package]]
name = "serial_test_derive"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91d129178576168c589c9ec973feedf7d3126c01ac2bf08795109aa35b69fb8f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "servo_arc"
version = "0.4.0"
@@ -5266,6 +5401,34 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "state-machines"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "806ba0bf43ae158b229036d8a84601649a58d9761e718b5e0e07c2953803f4c1"
dependencies = [
"state-machines-core",
"state-machines-macro",
]
[[package]]
name = "state-machines-core"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "949cc50e84bed6234117f28a0ba2980dc35e9c17984ffe4e0a3364fba3e77540"
[[package]]
name = "state-machines-macro"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8322f5aa92d31b3c05faa1ec3231b82da479a20706836867d67ae89ce74927bd"
dependencies = [
"proc-macro2",
"quote",
"state-machines-core",
"syn 2.0.101",
]
[[package]]
name = "static_assertions_next"
version = "1.1.2"
@@ -5309,6 +5472,17 @@ dependencies = [
"quote",
]
[[package]]
name = "stringprep"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1"
dependencies = [
"unicode-bidi",
"unicode-normalization",
"unicode-properties",
]
[[package]]
name = "strsim"
version = "0.11.1"
@@ -6141,6 +6315,15 @@ dependencies = [
"utf-8",
]
[[package]]
name = "type1-encoding-parser"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3d6cc09e1a99c7e01f2afe4953789311a1c50baebbdac5b477ecf78e2e92a5b"
dependencies = [
"pom",
]
[[package]]
name = "typenum"
version = "1.18.0"
@@ -6176,6 +6359,12 @@ version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-bidi"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5"
[[package]]
name = "unicode-ident"
version = "1.0.18"
@@ -6191,6 +6380,12 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-properties"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0"
[[package]]
name = "unicode-script"
version = "0.5.7"
@@ -6520,6 +6715,12 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "weezl"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
[[package]]
name = "which"
version = "7.0.3"

View File

@@ -12,7 +12,7 @@ resolver = "2"
[workspace.dependencies]
anyhow = "1.0.94"
async-openai = "0.24.1"
async-openai = "0.29.3"
async-stream = "0.3.6"
async-trait = "0.1.88"
axum-htmx = "0.7.0"
@@ -55,6 +55,7 @@ tokio-retry = "0.3.0"
base64 = "0.22.1"
object_store = { version = "0.11.2" }
bytes = "1.7.1"
state-machines = "0.2.0"
[profile.dist]
inherits = "release"

View File

@@ -6,8 +6,7 @@
[![License: AGPL v3](https://img.shields.io/badge/License-AGPL_v3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![Latest Release](https://img.shields.io/github/v/release/perstarkse/minne?sort=semver)](https://github.com/perstarkse/minne/releases/latest)
![Screenshot](screenshot-dashboard.webp)
![Graph screenshot](screenshot-graph.webp)
![Screenshot](screenshot-graph.webp)
## Demo deployment
@@ -23,11 +22,13 @@ While developing Minne, I discovered [KaraKeep](https://karakeep.com/) (formerly
Minne is designed to make it incredibly easy to save snippets of text, URLs, and other content (limited, pending demand). Simply send content along with a category tag. Minne then ingests this, leveraging AI to create relevant nodes and relationships within its graph database, alongside your manual categorization. This graph backend allows for discoverable connections between your pieces of knowledge.
You can converse with your knowledge base through an LLM-powered chat interface (via OpenAI compatible API, like Ollama or others). For those who like to see the bigger picture, Minne also includes an **experimental feature to visually explore your knowledge graph.**
You can converse with your knowledge base through an LLM-powered chat interface (via OpenAI compatible API, like Ollama or others). For those who like to see the bigger picture, Minne also includes an feature to visually explore your knowledge graph.
You may switch and choose between models used, and have the possiblity to change the prompts to your liking. There is since release **0.1.3** the option to change embeddings length, making it easy to test another embedding model.
You may switch and choose between models used, and have the possiblity to change the prompts to your liking. There is the option to change embeddings length, making it easy to test another embedding model.
The application is built for speed and efficiency using Rust with a Server-Side Rendered (SSR) frontend (HTMX and minimal JavaScript). It's fully responsive, offering a complete mobile interface for reading, editing, and managing your content, including the graph database itself. **PWA (Progressive Web App) support** means you can "install" Minne to your device for a native-like experience. For quick capture on the go on iOS, a [**Shortcut**](https://www.icloud.com/shortcuts/9aa960600ec14329837ba4169f57a166) makes sending content to your Minne instance a breeze.
The application is built for speed and efficiency using Rust with a Server-Side Rendered (SSR) frontend (HTMX and minimal JavaScript). It's fully responsive, offering a complete mobile interface for reading, editing, and managing your content, including the graph database itself. **PWA (Progressive Web App) support** means you can "install" Minne to your device for a native-like experience. For quick capture on the go on iOS, a [**Shortcut**](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) makes sending content to your Minne instance a breeze.
A hybrid retrieval layer blends embeddings, full-text search, and graph signals to surface the best context when augmenting chat responses and when building new relationships during ingestion.
Minne is open source (AGPL), self-hostable, and can be deployed flexibly: via Nix, Docker Compose, pre-built binaries, or by building from source. It can run as a single `main` binary or as separate `server` and `worker` processes for optimized resource allocation.
@@ -191,7 +192,7 @@ Minne can be configured using environment variables or a `config.yaml` file plac
- `SURREALDB_PASSWORD`: Password for SurrealDB (e.g., `root_password`).
- `SURREALDB_DATABASE`: Database name in SurrealDB (e.g., `minne_db`).
- `SURREALDB_NAMESPACE`: Namespace in SurrealDB (e.g., `minne_ns`).
- `OPENAI_API_KEY`: Your API key for OpenAI (e.g., `sk-YourActualOpenAIKeyGoesHere`).
- `OPENAI_API_KEY`: Your API key for OpenAI compatible endpoint (e.g., `sk-YourActualOpenAIKeyGoesHere`).
- `HTTP_PORT`: Port for the Minne server to listen on (Default: `3000`).
**Optional Configuration:**
@@ -210,8 +211,8 @@ surrealdb_database: "minne_db"
surrealdb_namespace: "minne_ns"
openai_api_key: "sk-YourActualOpenAIKeyGoesHere"
data_dir: "./minne_app_data"
http_port: 3000
# rust_log: "info"
# http_port: 3000
```
## Application Architecture (Binaries)
@@ -258,8 +259,8 @@ Once you have configured the `OPENAI_BASE_URL` to point to your desired provider
I've developed Minne primarily for my own use, but having been in the selfhosted space for a long time, and using the efforts by others, I thought I'd share with the community. Feature requests are welcome.
The roadmap as of now is:
- Handle uploaded images wisely.
- An updated explorer of the graph database.
~~- Handle uploaded images wisely.~~
~~- An updated explorer of the graph database.~~
- A TUI frontend which opens your system default editor for improved writing and document management.
## Contributing

View File

@@ -1,4 +1,4 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension};
use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use common::{
error::AppError,
@@ -8,6 +8,7 @@ use common::{
},
};
use futures::{future::try_join_all, TryFutureExt};
use serde_json::json;
use tempfile::NamedTempFile;
use tracing::info;
@@ -52,5 +53,5 @@ pub async fn ingest_data(
try_join_all(futures).await?;
Ok(StatusCode::OK)
Ok((StatusCode::OK, Json(json!({ "status": "success" }))))
}

View File

@@ -41,6 +41,7 @@ surrealdb-migrations = { workspace = true }
tokio-retry = { workspace = true }
object_store = { workspace = true }
bytes = { workspace = true }
state-machines = { workspace = true }
[features]

View File

@@ -0,0 +1,17 @@
-- Add FTS indexes for searching name and description on entities
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity
FIELDS name
SEARCH ANALYZER app_en_fts_analyzer BM25;
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity
FIELDS description
SEARCH ANALYZER app_en_fts_analyzer BM25;
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk
FIELDS chunk
SEARCH ANALYZER app_en_fts_analyzer BM25;

View File

@@ -0,0 +1,173 @@
-- State machine migration for ingestion_task records
DEFINE FIELD IF NOT EXISTS state ON TABLE ingestion_task TYPE option<string>;
DEFINE FIELD IF NOT EXISTS attempts ON TABLE ingestion_task TYPE option<number>;
DEFINE FIELD IF NOT EXISTS max_attempts ON TABLE ingestion_task TYPE option<number>;
DEFINE FIELD IF NOT EXISTS scheduled_at ON TABLE ingestion_task TYPE option<datetime>;
DEFINE FIELD IF NOT EXISTS locked_at ON TABLE ingestion_task TYPE option<datetime>;
DEFINE FIELD IF NOT EXISTS lease_duration_secs ON TABLE ingestion_task TYPE option<number>;
DEFINE FIELD IF NOT EXISTS worker_id ON TABLE ingestion_task TYPE option<string>;
DEFINE FIELD IF NOT EXISTS error_code ON TABLE ingestion_task TYPE option<string>;
DEFINE FIELD IF NOT EXISTS error_message ON TABLE ingestion_task TYPE option<string>;
DEFINE FIELD IF NOT EXISTS last_error_at ON TABLE ingestion_task TYPE option<datetime>;
DEFINE FIELD IF NOT EXISTS priority ON TABLE ingestion_task TYPE option<number>;
REMOVE FIELD status ON TABLE ingestion_task;
DEFINE FIELD status ON TABLE ingestion_task TYPE option<object>;
DEFINE INDEX IF NOT EXISTS idx_ingestion_task_state_sched ON TABLE ingestion_task FIELDS state, scheduled_at;
LET $needs_migration = (SELECT count() AS count FROM type::table('ingestion_task') WHERE state = NONE)[0].count;
IF $needs_migration > 0 THEN {
-- Created -> Pending
UPDATE type::table('ingestion_task')
SET
state = "Pending",
attempts = 0,
max_attempts = 3,
scheduled_at = IF created_at != NONE THEN created_at ELSE time::now() END,
locked_at = NONE,
lease_duration_secs = 300,
worker_id = NONE,
error_code = NONE,
error_message = NONE,
last_error_at = NONE,
priority = 0
WHERE state = NONE
AND status != NONE
AND status.name = "Created";
-- InProgress -> Processing
UPDATE type::table('ingestion_task')
SET
state = "Processing",
attempts = IF status.attempts != NONE THEN status.attempts ELSE 1 END,
max_attempts = 3,
scheduled_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
locked_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
lease_duration_secs = 300,
worker_id = NONE,
error_code = NONE,
error_message = NONE,
last_error_at = NONE,
priority = 0
WHERE state = NONE
AND status != NONE
AND status.name = "InProgress";
-- Completed -> Succeeded
UPDATE type::table('ingestion_task')
SET
state = "Succeeded",
attempts = 1,
max_attempts = 3,
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
locked_at = NONE,
lease_duration_secs = 300,
worker_id = NONE,
error_code = NONE,
error_message = NONE,
last_error_at = NONE,
priority = 0
WHERE state = NONE
AND status != NONE
AND status.name = "Completed";
-- Error -> DeadLetter (terminal failure)
UPDATE type::table('ingestion_task')
SET
state = "DeadLetter",
attempts = 3,
max_attempts = 3,
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
locked_at = NONE,
lease_duration_secs = 300,
worker_id = NONE,
error_code = NONE,
error_message = status.message,
last_error_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
priority = 0
WHERE state = NONE
AND status != NONE
AND status.name = "Error";
-- Cancelled -> Cancelled
UPDATE type::table('ingestion_task')
SET
state = "Cancelled",
attempts = 0,
max_attempts = 3,
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
locked_at = NONE,
lease_duration_secs = 300,
worker_id = NONE,
error_code = NONE,
error_message = NONE,
last_error_at = NONE,
priority = 0
WHERE state = NONE
AND status != NONE
AND status.name = "Cancelled";
-- Fallback for any remaining records missing state
UPDATE type::table('ingestion_task')
SET
state = "Pending",
attempts = 0,
max_attempts = 3,
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
locked_at = NONE,
lease_duration_secs = 300,
worker_id = NONE,
error_code = NONE,
error_message = NONE,
last_error_at = NONE,
priority = 0
WHERE state = NONE;
} END;
-- Ensure defaults for newly added fields
UPDATE type::table('ingestion_task')
SET max_attempts = 3
WHERE max_attempts = NONE;
UPDATE type::table('ingestion_task')
SET lease_duration_secs = 300
WHERE lease_duration_secs = NONE;
UPDATE type::table('ingestion_task')
SET attempts = 0
WHERE attempts = NONE;
UPDATE type::table('ingestion_task')
SET priority = 0
WHERE priority = NONE;
UPDATE type::table('ingestion_task')
SET scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END
WHERE scheduled_at = NONE;
UPDATE type::table('ingestion_task')
SET locked_at = NONE
WHERE locked_at = NONE;
UPDATE type::table('ingestion_task')
SET worker_id = NONE
WHERE worker_id != NONE AND worker_id = "";
UPDATE type::table('ingestion_task')
SET error_code = NONE
WHERE error_code = NONE;
UPDATE type::table('ingestion_task')
SET error_message = NONE
WHERE error_message = NONE;
UPDATE type::table('ingestion_task')
SET last_error_at = NONE
WHERE last_error_at = NONE;
UPDATE type::table('ingestion_task')
SET status = NONE
WHERE status != NONE;

View File

@@ -80,15 +80,18 @@ impl SurrealDbClient {
/// Operation to rebuild indexes
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
debug!("Rebuilding indexes");
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity")
.await?;
self.client
.query("REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content")
.await?;
let rebuild_sql = r#"
BEGIN TRANSACTION;
REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
COMMIT TRANSACTION;
"#;
self.client.query(rebuild_sql).await?;
Ok(())
}

View File

@@ -196,7 +196,7 @@ pub fn split_object_path(path: &str) -> AnyResult<(String, String)> {
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::config::StorageKind;
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
use bytes::Bytes;
use futures::TryStreamExt;
use uuid::Uuid;
@@ -213,6 +213,7 @@ mod tests {
http_port: 0,
openai_base_url: "..".into(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
}
}

View File

@@ -57,7 +57,6 @@ impl FileInfo {
user_id: &str,
config: &AppConfig,
) -> Result<Self, FileError> {
info!("Data_dir: {:?}", config);
let file = field_data.contents;
let file_name = field_data
.metadata
@@ -230,14 +229,8 @@ impl FileInfo {
config: &AppConfig,
) -> Result<(), AppError> {
// Get the FileInfo from the database
let file_info = match db_client.get_item::<FileInfo>(id).await? {
Some(info) => info,
None => {
return Err(AppError::from(FileError::FileNotFound(format!(
"File with id {} was not found",
id
))))
}
let Some(file_info) = db_client.get_item::<FileInfo>(id).await? else {
return Ok(());
};
// Remove the object's parent prefix in the object store
@@ -277,7 +270,7 @@ impl FileInfo {
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::config::StorageKind;
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
use axum::http::HeaderMap;
use axum_typed_multipart::FieldMetadata;
use std::io::Write;
@@ -332,6 +325,7 @@ mod tests {
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
// Test file creation
@@ -392,6 +386,7 @@ mod tests {
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
// Store the original file
@@ -448,6 +443,7 @@ mod tests {
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let file_info = FileInfo::new(field_data, &db, user_id, &config).await;
@@ -505,6 +501,7 @@ mod tests {
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let field_data1 = create_test_file(content, file_name);
@@ -669,6 +666,7 @@ mod tests {
http_port: 0,
openai_base_url: "".to_string(),
storage: crate::utils::config::StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
let temp = create_test_file(b"test content", "test_file.txt");
let file_info = FileInfo::new(temp, &db, user_id, &cfg)
@@ -723,18 +721,13 @@ mod tests {
http_port: 0,
openai_base_url: "".to_string(),
storage: crate::utils::config::StorageKind::Local,
pdf_ingest_mode: LlmFirst,
},
)
.await;
// Should fail with FileNotFound error
assert!(result.is_err());
match result {
Err(AppError::File(_)) => {
// Expected error
}
_ => panic!("Expected FileNotFound error"),
}
// Should succeed even if the file record does not exist
assert!(result.is_ok());
}
#[tokio::test]
async fn test_get_by_id() {
@@ -831,6 +824,7 @@ mod tests {
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
};
// Test file creation

View File

@@ -1,116 +1,529 @@
use futures::Stream;
use surrealdb::{opt::PatchOp, Notification};
use std::time::Duration;
use chrono::Duration as ChronoDuration;
use state_machines::state_machine;
use surrealdb::sql::Datetime as SurrealDatetime;
use uuid::Uuid;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use super::ingestion_payload::IngestionPayload;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "name")]
pub enum IngestionTaskStatus {
Created,
InProgress {
attempts: u32,
last_attempt: DateTime<Utc>,
},
Completed,
Error {
message: String,
},
pub const MAX_ATTEMPTS: u32 = 3;
pub const DEFAULT_LEASE_SECS: i64 = 300;
pub const DEFAULT_PRIORITY: i32 = 0;
#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
pub enum TaskState {
#[serde(rename = "Pending")]
#[default]
Pending,
#[serde(rename = "Reserved")]
Reserved,
#[serde(rename = "Processing")]
Processing,
#[serde(rename = "Succeeded")]
Succeeded,
#[serde(rename = "Failed")]
Failed,
#[serde(rename = "Cancelled")]
Cancelled,
#[serde(rename = "DeadLetter")]
DeadLetter,
}
impl TaskState {
pub fn as_str(&self) -> &'static str {
match self {
TaskState::Pending => "Pending",
TaskState::Reserved => "Reserved",
TaskState::Processing => "Processing",
TaskState::Succeeded => "Succeeded",
TaskState::Failed => "Failed",
TaskState::Cancelled => "Cancelled",
TaskState::DeadLetter => "DeadLetter",
}
}
pub fn is_terminal(&self) -> bool {
matches!(
self,
TaskState::Succeeded | TaskState::Cancelled | TaskState::DeadLetter
)
}
pub fn display_label(&self) -> &'static str {
match self {
TaskState::Pending => "Pending",
TaskState::Reserved => "Reserved",
TaskState::Processing => "Processing",
TaskState::Succeeded => "Completed",
TaskState::Failed => "Retrying",
TaskState::Cancelled => "Cancelled",
TaskState::DeadLetter => "Dead Letter",
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Default)]
pub struct TaskErrorInfo {
pub code: Option<String>,
pub message: String,
}
#[derive(Debug, Clone, Copy)]
enum TaskTransition {
StartProcessing,
Succeed,
Fail,
Cancel,
DeadLetter,
Release,
}
impl TaskTransition {
fn as_str(&self) -> &'static str {
match self {
TaskTransition::StartProcessing => "start_processing",
TaskTransition::Succeed => "succeed",
TaskTransition::Fail => "fail",
TaskTransition::Cancel => "cancel",
TaskTransition::DeadLetter => "deadletter",
TaskTransition::Release => "release",
}
}
}
mod lifecycle {
use super::state_machine;
state_machine! {
name: TaskLifecycleMachine,
initial: Pending,
states: [Pending, Reserved, Processing, Succeeded, Failed, Cancelled, DeadLetter],
events {
reserve {
transition: { from: Pending, to: Reserved }
transition: { from: Failed, to: Reserved }
}
start_processing {
transition: { from: Reserved, to: Processing }
}
succeed {
transition: { from: Processing, to: Succeeded }
}
fail {
transition: { from: Processing, to: Failed }
}
cancel {
transition: { from: Pending, to: Cancelled }
transition: { from: Reserved, to: Cancelled }
transition: { from: Processing, to: Cancelled }
}
deadletter {
transition: { from: Failed, to: DeadLetter }
}
release {
transition: { from: Reserved, to: Pending }
}
}
}
pub(super) fn pending() -> TaskLifecycleMachine<(), Pending> {
TaskLifecycleMachine::new(())
}
pub(super) fn reserved() -> TaskLifecycleMachine<(), Reserved> {
pending()
.reserve()
.expect("reserve transition from Pending should exist")
}
pub(super) fn processing() -> TaskLifecycleMachine<(), Processing> {
reserved()
.start_processing()
.expect("start_processing transition from Reserved should exist")
}
pub(super) fn failed() -> TaskLifecycleMachine<(), Failed> {
processing()
.fail()
.expect("fail transition from Processing should exist")
}
}
fn invalid_transition(state: &TaskState, event: TaskTransition) -> AppError {
AppError::Validation(format!(
"Invalid task transition: {} -> {}",
state.as_str(),
event.as_str()
))
}
stored_object!(IngestionTask, "ingestion_task", {
content: IngestionPayload,
status: IngestionTaskStatus,
user_id: String
state: TaskState,
user_id: String,
attempts: u32,
max_attempts: u32,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime")]
scheduled_at: chrono::DateTime<chrono::Utc>,
#[serde(
serialize_with = "serialize_option_datetime",
deserialize_with = "deserialize_option_datetime",
default
)]
locked_at: Option<chrono::DateTime<chrono::Utc>>,
lease_duration_secs: i64,
worker_id: Option<String>,
error_code: Option<String>,
error_message: Option<String>,
#[serde(
serialize_with = "serialize_option_datetime",
deserialize_with = "deserialize_option_datetime",
default
)]
last_error_at: Option<chrono::DateTime<chrono::Utc>>,
priority: i32
});
pub const MAX_ATTEMPTS: u32 = 3;
impl IngestionTask {
pub async fn new(content: IngestionPayload, user_id: String) -> Self {
let now = Utc::now();
let now = chrono::Utc::now();
Self {
id: Uuid::new_v4().to_string(),
content,
status: IngestionTaskStatus::Created,
state: TaskState::Pending,
user_id,
attempts: 0,
max_attempts: MAX_ATTEMPTS,
scheduled_at: now,
locked_at: None,
lease_duration_secs: DEFAULT_LEASE_SECS,
worker_id: None,
error_code: None,
error_message: None,
last_error_at: None,
priority: DEFAULT_PRIORITY,
created_at: now,
updated_at: now,
user_id,
}
}
/// Creates a new job and stores it in the database
pub fn can_retry(&self) -> bool {
self.attempts < self.max_attempts
}
pub fn lease_duration(&self) -> Duration {
Duration::from_secs(self.lease_duration_secs.max(0) as u64)
}
pub async fn create_and_add_to_db(
content: IngestionPayload,
user_id: String,
db: &SurrealDbClient,
) -> Result<IngestionTask, AppError> {
let task = Self::new(content, user_id).await;
db.store_item(task.clone()).await?;
Ok(task)
}
// Update job status
pub async fn update_status(
id: &str,
status: IngestionTaskStatus,
pub async fn claim_next_ready(
db: &SurrealDbClient,
) -> Result<(), AppError> {
let _job: Option<Self> = db
.update((Self::table_name(), id))
.patch(PatchOp::replace("/status", status))
.patch(PatchOp::replace(
"/updated_at",
surrealdb::Datetime::from(Utc::now()),
worker_id: &str,
now: chrono::DateTime<chrono::Utc>,
lease_duration: Duration,
) -> Result<Option<IngestionTask>, AppError> {
debug_assert!(lifecycle::pending().reserve().is_ok());
debug_assert!(lifecycle::failed().reserve().is_ok());
const CLAIM_QUERY: &str = r#"
UPDATE (
SELECT * FROM type::table($table)
WHERE state IN $candidate_states
AND scheduled_at <= $now
AND (
attempts < max_attempts
OR state IN $sticky_states
)
AND (
locked_at = NONE
OR time::unix($now) - time::unix(locked_at) >= lease_duration_secs
)
ORDER BY priority DESC, scheduled_at ASC, created_at ASC
LIMIT 1
)
SET state = $reserved_state,
attempts = if state IN $increment_states THEN
if attempts + 1 > max_attempts THEN max_attempts ELSE attempts + 1 END
ELSE
attempts
END,
locked_at = $now,
worker_id = $worker_id,
lease_duration_secs = $lease_secs,
updated_at = $now
RETURN *;
"#;
let mut result = db
.client
.query(CLAIM_QUERY)
.bind(("table", Self::table_name()))
.bind((
"candidate_states",
vec![
TaskState::Pending.as_str(),
TaskState::Failed.as_str(),
TaskState::Reserved.as_str(),
TaskState::Processing.as_str(),
],
))
.bind((
"sticky_states",
vec![TaskState::Reserved.as_str(), TaskState::Processing.as_str()],
))
.bind((
"increment_states",
vec![TaskState::Pending.as_str(), TaskState::Failed.as_str()],
))
.bind(("reserved_state", TaskState::Reserved.as_str()))
.bind(("now", SurrealDatetime::from(now)))
.bind(("worker_id", worker_id.to_string()))
.bind(("lease_secs", lease_duration.as_secs() as i64))
.await?;
Ok(())
let task: Option<IngestionTask> = result.take(0)?;
Ok(task)
}
/// Listen for new jobs
pub async fn listen_for_tasks(
pub async fn mark_processing(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
const START_PROCESSING_QUERY: &str = r#"
UPDATE type::thing($table, $id)
SET state = $processing,
updated_at = $now,
locked_at = $now
WHERE state = $reserved AND worker_id = $worker_id
RETURN *;
"#;
let now = chrono::Utc::now();
let mut result = db
.client
.query(START_PROCESSING_QUERY)
.bind(("table", Self::table_name()))
.bind(("id", self.id.clone()))
.bind(("processing", TaskState::Processing.as_str()))
.bind(("reserved", TaskState::Reserved.as_str()))
.bind(("now", SurrealDatetime::from(now)))
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
.await?;
let updated: Option<IngestionTask> = result.take(0)?;
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::StartProcessing))
}
pub async fn mark_succeeded(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
const COMPLETE_QUERY: &str = r#"
UPDATE type::thing($table, $id)
SET state = $succeeded,
updated_at = $now,
locked_at = NONE,
worker_id = NONE,
scheduled_at = $now,
error_code = NONE,
error_message = NONE,
last_error_at = NONE
WHERE state = $processing AND worker_id = $worker_id
RETURN *;
"#;
let now = chrono::Utc::now();
let mut result = db
.client
.query(COMPLETE_QUERY)
.bind(("table", Self::table_name()))
.bind(("id", self.id.clone()))
.bind(("succeeded", TaskState::Succeeded.as_str()))
.bind(("processing", TaskState::Processing.as_str()))
.bind(("now", SurrealDatetime::from(now)))
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
.await?;
let updated: Option<IngestionTask> = result.take(0)?;
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Succeed))
}
pub async fn mark_failed(
&self,
error: TaskErrorInfo,
retry_delay: Duration,
db: &SurrealDbClient,
) -> Result<impl Stream<Item = Result<Notification<Self>, surrealdb::Error>>, surrealdb::Error>
{
db.listen::<Self>().await
) -> Result<IngestionTask, AppError> {
let now = chrono::Utc::now();
let retry_at = now
+ ChronoDuration::from_std(retry_delay).unwrap_or_else(|_| ChronoDuration::seconds(30));
const FAIL_QUERY: &str = r#"
UPDATE type::thing($table, $id)
SET state = $failed,
updated_at = $now,
locked_at = NONE,
worker_id = NONE,
scheduled_at = $retry_at,
error_code = $error_code,
error_message = $error_message,
last_error_at = $now
WHERE state = $processing AND worker_id = $worker_id
RETURN *;
"#;
let mut result = db
.client
.query(FAIL_QUERY)
.bind(("table", Self::table_name()))
.bind(("id", self.id.clone()))
.bind(("failed", TaskState::Failed.as_str()))
.bind(("processing", TaskState::Processing.as_str()))
.bind(("now", SurrealDatetime::from(now)))
.bind(("retry_at", SurrealDatetime::from(retry_at)))
.bind(("error_code", error.code.clone()))
.bind(("error_message", error.message.clone()))
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
.await?;
let updated: Option<IngestionTask> = result.take(0)?;
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Fail))
}
/// Get all unfinished tasks, ie newly created and in progress up two times
pub async fn get_unfinished_tasks(db: &SurrealDbClient) -> Result<Vec<Self>, AppError> {
let jobs: Vec<Self> = db
pub async fn mark_dead_letter(
&self,
error: TaskErrorInfo,
db: &SurrealDbClient,
) -> Result<IngestionTask, AppError> {
const DEAD_LETTER_QUERY: &str = r#"
UPDATE type::thing($table, $id)
SET state = $dead,
updated_at = $now,
locked_at = NONE,
worker_id = NONE,
scheduled_at = $now,
error_code = $error_code,
error_message = $error_message,
last_error_at = $now
WHERE state = $failed
RETURN *;
"#;
let now = chrono::Utc::now();
let mut result = db
.client
.query(DEAD_LETTER_QUERY)
.bind(("table", Self::table_name()))
.bind(("id", self.id.clone()))
.bind(("dead", TaskState::DeadLetter.as_str()))
.bind(("failed", TaskState::Failed.as_str()))
.bind(("now", SurrealDatetime::from(now)))
.bind(("error_code", error.code.clone()))
.bind(("error_message", error.message.clone()))
.await?;
let updated: Option<IngestionTask> = result.take(0)?;
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::DeadLetter))
}
pub async fn mark_cancelled(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
const CANCEL_QUERY: &str = r#"
UPDATE type::thing($table, $id)
SET state = $cancelled,
updated_at = $now,
locked_at = NONE,
worker_id = NONE
WHERE state IN $allow_states
RETURN *;
"#;
let now = chrono::Utc::now();
let mut result = db
.client
.query(CANCEL_QUERY)
.bind(("table", Self::table_name()))
.bind(("id", self.id.clone()))
.bind(("cancelled", TaskState::Cancelled.as_str()))
.bind((
"allow_states",
vec![
TaskState::Pending.as_str(),
TaskState::Reserved.as_str(),
TaskState::Processing.as_str(),
],
))
.bind(("now", SurrealDatetime::from(now)))
.await?;
let updated: Option<IngestionTask> = result.take(0)?;
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Cancel))
}
pub async fn release(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
const RELEASE_QUERY: &str = r#"
UPDATE type::thing($table, $id)
SET state = $pending,
updated_at = $now,
locked_at = NONE,
worker_id = NONE
WHERE state = $reserved
RETURN *;
"#;
let now = chrono::Utc::now();
let mut result = db
.client
.query(RELEASE_QUERY)
.bind(("table", Self::table_name()))
.bind(("id", self.id.clone()))
.bind(("pending", TaskState::Pending.as_str()))
.bind(("reserved", TaskState::Reserved.as_str()))
.bind(("now", SurrealDatetime::from(now)))
.await?;
let updated: Option<IngestionTask> = result.take(0)?;
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Release))
}
pub async fn get_unfinished_tasks(
db: &SurrealDbClient,
) -> Result<Vec<IngestionTask>, AppError> {
let tasks: Vec<IngestionTask> = db
.query(
"SELECT * FROM type::table($table)
WHERE
status.name = 'Created'
OR (
status.name = 'InProgress'
AND status.attempts < $max_attempts
)
ORDER BY created_at ASC",
"SELECT * FROM type::table($table)
WHERE state IN $active_states
ORDER BY scheduled_at ASC, created_at ASC",
)
.bind(("table", Self::table_name()))
.bind(("max_attempts", MAX_ATTEMPTS))
.bind((
"active_states",
vec![
TaskState::Pending.as_str(),
TaskState::Reserved.as_str(),
TaskState::Processing.as_str(),
TaskState::Failed.as_str(),
],
))
.await?
.take(0)?;
Ok(jobs)
Ok(tasks)
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use crate::storage::types::ingestion_payload::IngestionPayload;
// Helper function to create a test ingestion payload
fn create_test_payload(user_id: &str) -> IngestionPayload {
fn create_payload(user_id: &str) -> IngestionPayload {
IngestionPayload::Text {
text: "Test content".to_string(),
context: "Test context".to_string(),
@@ -119,182 +532,197 @@ mod tests {
}
}
#[tokio::test]
async fn test_new_ingestion_task() {
let user_id = "user123";
let payload = create_test_payload(user_id);
async fn memory_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
SurrealDbClient::memory(namespace, &database)
.await
.expect("in-memory surrealdb")
}
#[tokio::test]
async fn test_new_task_defaults() {
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
// Verify task properties
assert_eq!(task.user_id, user_id);
assert_eq!(task.content, payload);
assert!(matches!(task.status, IngestionTaskStatus::Created));
assert!(!task.id.is_empty());
assert_eq!(task.state, TaskState::Pending);
assert_eq!(task.attempts, 0);
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
assert!(task.locked_at.is_none());
assert!(task.worker_id.is_none());
}
#[tokio::test]
async fn test_create_and_add_to_db() {
// Setup in-memory database
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
async fn test_create_and_store_task() {
let db = memory_db().await;
let user_id = "user123";
let payload = create_test_payload(user_id);
let payload = create_payload(user_id);
// Create and store task
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
let created =
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
.await
.expect("store");
let stored: Option<IngestionTask> = db
.get_item::<IngestionTask>(&created.id)
.await
.expect("Failed to create and add task to db");
.expect("fetch");
// Query to verify task was stored
let query = format!(
"SELECT * FROM {} WHERE user_id = '{}'",
IngestionTask::table_name(),
user_id
);
let mut result = db.query(query).await.expect("Query failed");
let tasks: Vec<IngestionTask> = result.take(0).unwrap_or_default();
// Verify task is in the database
assert!(!tasks.is_empty(), "Task should exist in the database");
let stored_task = &tasks[0];
assert_eq!(stored_task.user_id, user_id);
assert!(matches!(stored_task.status, IngestionTaskStatus::Created));
let stored = stored.expect("task exists");
assert_eq!(stored.id, created.id);
assert_eq!(stored.state, TaskState::Pending);
assert_eq!(stored.attempts, 0);
}
#[tokio::test]
async fn test_update_status() {
// Setup in-memory database
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
async fn test_claim_and_transition() {
let db = memory_db().await;
let user_id = "user123";
let payload = create_test_payload(user_id);
let payload = create_payload(user_id);
let task = IngestionTask::new(payload, user_id.to_string()).await;
db.store_item(task.clone()).await.expect("store");
// Create task manually
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
let task_id = task.id.clone();
let worker_id = "worker-1";
let now = chrono::Utc::now();
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
.await
.expect("claim");
// Store task
db.store_item(task).await.expect("Failed to store task");
let claimed = claimed.expect("task claimed");
assert_eq!(claimed.state, TaskState::Reserved);
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
// Update status to InProgress
let now = Utc::now();
let new_status = IngestionTaskStatus::InProgress {
attempts: 1,
last_attempt: now,
let processing = claimed.mark_processing(&db).await.expect("processing");
assert_eq!(processing.state, TaskState::Processing);
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded");
assert_eq!(succeeded.state, TaskState::Succeeded);
assert!(succeeded.worker_id.is_none());
assert!(succeeded.locked_at.is_none());
}
#[tokio::test]
async fn test_fail_and_dead_letter() {
let db = memory_db().await;
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload, user_id.to_string()).await;
db.store_item(task.clone()).await.expect("store");
let worker_id = "worker-dead";
let now = chrono::Utc::now();
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
.await
.expect("claim")
.expect("claimed");
let processing = claimed.mark_processing(&db).await.expect("processing");
let error_info = TaskErrorInfo {
code: Some("pipeline_error".into()),
message: "failed".into(),
};
IngestionTask::update_status(&task_id, new_status.clone(), &db)
let failed = processing
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
.await
.expect("Failed to update status");
.expect("failed update");
assert_eq!(failed.state, TaskState::Failed);
assert_eq!(failed.error_message.as_deref(), Some("failed"));
assert!(failed.worker_id.is_none());
assert!(failed.locked_at.is_none());
assert!(failed.scheduled_at > now);
// Verify status updated
let updated_task: Option<IngestionTask> = db
.get_item::<IngestionTask>(&task_id)
let dead = failed
.mark_dead_letter(error_info.clone(), &db)
.await
.expect("Failed to get updated task");
.expect("dead letter");
assert_eq!(dead.state, TaskState::DeadLetter);
assert_eq!(dead.error_message.as_deref(), Some("failed"));
}
assert!(updated_task.is_some());
let updated_task = updated_task.unwrap();
#[tokio::test]
async fn test_mark_processing_requires_reservation() {
let db = memory_db().await;
let user_id = "user123";
let payload = create_payload(user_id);
match updated_task.status {
IngestionTaskStatus::InProgress { attempts, .. } => {
assert_eq!(attempts, 1);
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
db.store_item(task.clone()).await.expect("store");
let err = task
.mark_processing(&db)
.await
.expect_err("processing should fail without reservation");
match err {
AppError::Validation(message) => {
assert!(
message.contains("Pending -> start_processing"),
"unexpected message: {message}"
);
}
_ => panic!("Expected InProgress status"),
other => panic!("expected validation error, got {other:?}"),
}
}
#[tokio::test]
async fn test_get_unfinished_tasks() {
// Setup in-memory database
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
async fn test_mark_failed_requires_processing() {
let db = memory_db().await;
let user_id = "user123";
let payload = create_test_payload(user_id);
let payload = create_payload(user_id);
// Create tasks with different statuses
let created_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
db.store_item(task.clone()).await.expect("store");
let mut in_progress_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
in_progress_task.status = IngestionTaskStatus::InProgress {
attempts: 1,
last_attempt: Utc::now(),
};
let mut max_attempts_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
max_attempts_task.status = IngestionTaskStatus::InProgress {
attempts: MAX_ATTEMPTS,
last_attempt: Utc::now(),
};
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
completed_task.status = IngestionTaskStatus::Completed;
let mut error_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
error_task.status = IngestionTaskStatus::Error {
message: "Test error".to_string(),
};
// Store all tasks
db.store_item(created_task)
let err = task
.mark_failed(
TaskErrorInfo {
code: None,
message: "boom".into(),
},
Duration::from_secs(30),
&db,
)
.await
.expect("Failed to store created task");
db.store_item(in_progress_task)
.await
.expect("Failed to store in-progress task");
db.store_item(max_attempts_task)
.await
.expect("Failed to store max-attempts task");
db.store_item(completed_task)
.await
.expect("Failed to store completed task");
db.store_item(error_task)
.await
.expect("Failed to store error task");
.expect_err("failing should require processing state");
// Get unfinished tasks
let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db)
match err {
AppError::Validation(message) => {
assert!(
message.contains("Pending -> fail"),
"unexpected message: {message}"
);
}
other => panic!("expected validation error, got {other:?}"),
}
}
#[tokio::test]
async fn test_release_requires_reservation() {
let db = memory_db().await;
let user_id = "user123";
let payload = create_payload(user_id);
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
db.store_item(task.clone()).await.expect("store");
let err = task
.release(&db)
.await
.expect("Failed to get unfinished tasks");
.expect_err("release should require reserved state");
// Verify only Created and InProgress with attempts < MAX_ATTEMPTS are returned
assert_eq!(unfinished_tasks.len(), 2);
let statuses: Vec<_> = unfinished_tasks
.iter()
.map(|task| match &task.status {
IngestionTaskStatus::Created => "Created",
IngestionTaskStatus::InProgress { attempts, .. } => {
if *attempts < MAX_ATTEMPTS {
"InProgress<MAX"
} else {
"InProgress>=MAX"
}
}
IngestionTaskStatus::Completed => "Completed",
IngestionTaskStatus::Error { .. } => "Error",
IngestionTaskStatus::Cancelled => "Cancelled",
})
.collect();
assert!(statuses.contains(&"Created"));
assert!(statuses.contains(&"InProgress<MAX"));
assert!(!statuses.contains(&"InProgress>=MAX"));
assert!(!statuses.contains(&"Completed"));
assert!(!statuses.contains(&"Error"));
assert!(!statuses.contains(&"Cancelled"));
match err {
AppError::Validation(message) => {
assert!(
message.contains("Pending -> release"),
"unexpected message: {message}"
);
}
other => panic!("expected validation error, got {other:?}"),
}
}
}

View File

@@ -150,7 +150,18 @@ impl KnowledgeEntity {
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
let total_entities = all_entities.len();
if total_entities == 0 {
info!("No knowledge entities to update. Skipping.");
info!("No knowledge entities to update. Just updating the idx");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
transaction_query
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
db.query(transaction_query).await?;
return Ok(());
}
info!("Found {} entities to process.", total_entities);

View File

@@ -75,13 +75,36 @@ impl KnowledgeRelationship {
pub async fn delete_relationship_by_id(
id: &str,
user_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!("DELETE relates_to:`{}`", id);
let mut authorized_result = db_client
.query(format!(
"SELECT * FROM relates_to WHERE id = relates_to:`{}` AND metadata.user_id = '{}'",
id, user_id
))
.await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
db_client.query(query).await?;
if authorized.is_empty() {
let mut exists_result = db_client
.query(format!("SELECT * FROM relates_to:`{}`", id))
.await?;
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
Ok(())
if existing.is_some() {
Err(AppError::Auth(
"Not authorized to delete relationship".into(),
))
} else {
Err(AppError::NotFound(format!("Relationship {} not found", id)))
}
} else {
db_client
.query(format!("DELETE relates_to:`{}`", id))
.await?;
Ok(())
}
}
}
@@ -161,7 +184,7 @@ mod tests {
let relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
user_id,
user_id.clone(),
source_id.clone(),
relationship_type,
);
@@ -209,7 +232,7 @@ mod tests {
let relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
user_id,
user_id.clone(),
source_id.clone(),
relationship_type,
);
@@ -220,20 +243,107 @@ mod tests {
.await
.expect("Failed to store relationship");
// Ensure relationship exists before deletion attempt
let mut existing_before_delete = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
user_id, source_id
))
.await
.expect("Query failed");
let before_results: Vec<KnowledgeRelationship> =
existing_before_delete.take(0).unwrap_or_default();
assert!(
!before_results.is_empty(),
"Relationship should exist before deletion"
);
// Delete the relationship by ID
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &db)
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
.await
.expect("Failed to delete relationship by ID");
// Query to verify the relationship was deleted
let query = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship.id);
let mut result = db.query(query).await.expect("Query failed");
let mut result = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
user_id, source_id
))
.await
.expect("Query failed");
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
// Verify the relationship no longer exists
assert!(results.is_empty(), "Relationship should be deleted");
}
#[tokio::test]
async fn test_delete_relationship_by_id_unauthorized() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let owner_user_id = "owner-user".to_string();
let source_id = "source123".to_string();
let relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
owner_user_id.clone(),
source_id,
"references".to_string(),
);
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
let mut before_attempt = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
owner_user_id
))
.await
.expect("Query failed");
let before_results: Vec<KnowledgeRelationship> = before_attempt.take(0).unwrap_or_default();
assert!(
!before_results.is_empty(),
"Relationship should exist before unauthorized delete attempt"
);
let result = KnowledgeRelationship::delete_relationship_by_id(
&relationship.id,
"different-user",
&db,
)
.await;
match result {
Err(AppError::Auth(_)) => {}
_ => panic!("Expected authorization error when deleting someone else's relationship"),
}
let mut after_attempt = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}'",
owner_user_id
))
.await
.expect("Query failed");
let results: Vec<KnowledgeRelationship> = after_attempt.take(0).unwrap_or_default();
assert!(
!results.is_empty(),
"Relationship should still exist after unauthorized delete attempt"
);
}
#[tokio::test]
async fn test_delete_relationships_by_source_id() {
// Setup in-memory database for testing

View File

@@ -83,6 +83,32 @@ macro_rules! stored_object {
Ok(DateTime::<Utc>::from(dt))
}
#[allow(dead_code)]
fn serialize_option_datetime<S>(
date: &Option<DateTime<Utc>>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match date {
Some(dt) => serializer
.serialize_some(&Into::<surrealdb::sql::Datetime>::into(*dt)),
None => serializer.serialize_none(),
}
}
#[allow(dead_code)]
fn deserialize_option_datetime<'de, D>(
deserializer: D,
) -> Result<Option<DateTime<Utc>>, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = Option::<surrealdb::sql::Datetime>::deserialize(deserializer)?;
Ok(value.map(DateTime::<Utc>::from))
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct $name {
@@ -92,7 +118,7 @@ macro_rules! stored_object {
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub updated_at: DateTime<Utc>,
$(pub $field: $ty),*
$( $(#[$attr])* pub $field: $ty),*
}
impl StoredObject for $name {

View File

@@ -53,11 +53,60 @@ impl SystemSettings {
#[cfg(test)]
mod tests {
use crate::storage::types::text_chunk::TextChunk;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
use async_openai::Client;
use super::*;
use uuid::Uuid;
async fn get_hnsw_index_dimension(
db: &SurrealDbClient,
table_name: &str,
index_name: &str,
) -> u32 {
let query = format!("INFO FOR TABLE {table_name};");
let mut response = db
.client
.query(query)
.await
.expect("Failed to fetch table info");
let info: Option<serde_json::Value> = response
.take(0)
.expect("Failed to extract table info response");
let info = info.expect("Table info result missing");
let indexes = info
.get("indexes")
.or_else(|| {
info.get("tables")
.and_then(|tables| tables.get(table_name))
.and_then(|table| table.get("indexes"))
})
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info:#?}"));
let definition = indexes
.get(index_name)
.and_then(|definition| definition.as_str())
.unwrap_or_else(|| panic!("Index definition not found in table info: {info:#?}"));
let dimension_part = definition
.split("DIMENSION")
.nth(1)
.expect("Index definition missing DIMENSION clause");
let dimension_token = dimension_part
.split_whitespace()
.next()
.expect("Dimension value missing in definition")
.trim_end_matches(';');
dimension_token
.parse::<u32>()
.expect("Dimension value is not a valid number")
}
#[tokio::test]
async fn test_settings_initialization() {
// Setup in-memory database for testing
@@ -255,4 +304,74 @@ mod tests {
assert!(migration_result.is_ok(), "Migrations should not fail");
}
#[tokio::test]
async fn test_should_change_embedding_length_on_indexes_when_switching_length() {
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
.await
.expect("Failed to start DB");
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
db.apply_migrations()
.await
.expect("Initial migration failed");
let mut current_settings = SystemSettings::get_current(&db)
.await
.expect("Failed to load current settings");
let initial_chunk_dimension =
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
assert_eq!(
initial_chunk_dimension, current_settings.embedding_dimensions,
"embedding size should match initial system settings"
);
let new_dimension = 768;
let new_model = "new-test-embedding-model".to_string();
current_settings.embedding_dimensions = new_dimension;
current_settings.embedding_model = new_model.clone();
let updated_settings = SystemSettings::update(&db, current_settings)
.await
.expect("Failed to update settings");
assert_eq!(
updated_settings.embedding_dimensions, new_dimension,
"Settings should reflect the new embedding dimension"
);
let openai_client = Client::new();
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
.await
.expect("TextChunk re-embedding should succeed on fresh DB");
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
.await
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
let text_chunk_dimension =
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
let knowledge_dimension =
get_hnsw_index_dimension(&db, "knowledge_entity", "idx_embedding_entities").await;
assert_eq!(
text_chunk_dimension, new_dimension,
"text_chunk index dimension should update"
);
assert_eq!(
knowledge_dimension, new_dimension,
"knowledge_entity index dimension should update"
);
let persisted_settings = SystemSettings::get_current(&db)
.await
.expect("Failed to reload updated settings");
assert_eq!(
persisted_settings.embedding_dimensions, new_dimension,
"Settings should persist new embedding dimension"
);
}
}

View File

@@ -68,7 +68,17 @@ impl TextChunk {
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
let total_chunks = all_chunks.len();
if total_chunks == 0 {
info!("No text chunks to update. Skipping.");
info!("No text chunks to update. Just updating the idx");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
new_dimensions));
transaction_query.push_str("COMMIT TRANSACTION;");
db.query(transaction_query).await?;
return Ok(());
}
info!("Found {} chunks to process.", total_chunks);

View File

@@ -110,6 +110,26 @@ impl TextContent {
Ok(())
}
pub async fn has_other_with_file(
file_id: &str,
exclude_id: &str,
db: &SurrealDbClient,
) -> Result<bool, AppError> {
let mut response = db
.client
.query(
"SELECT VALUE id FROM type::table($table_name) WHERE file_info.id = $file_id AND id != type::thing($table_name, $exclude_id) LIMIT 1",
)
.bind(("table_name", TextContent::table_name()))
.bind(("file_id", file_id.to_owned()))
.bind(("exclude_id", exclude_id.to_owned()))
.await?;
let existing: Option<surrealdb::sql::Thing> = response.take(0)?;
Ok(existing.is_some())
}
pub async fn search(
db: &SurrealDbClient,
search_terms: &str,
@@ -276,4 +296,64 @@ mod tests {
assert_eq!(updated_content.text, new_text);
assert!(updated_content.updated_at > text_content.updated_at);
}
#[tokio::test]
async fn test_has_other_with_file_detects_shared_usage() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let user_id = "user123".to_string();
let file_info = FileInfo {
id: "file-1".to_string(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
sha256: "sha-test".to_string(),
path: "user123/file-1/test.txt".to_string(),
file_name: "test.txt".to_string(),
mime_type: "text/plain".to_string(),
user_id: user_id.clone(),
};
let content_a = TextContent::new(
"First".to_string(),
Some("ctx-a".to_string()),
"category".to_string(),
Some(file_info.clone()),
None,
user_id.clone(),
);
let content_b = TextContent::new(
"Second".to_string(),
Some("ctx-b".to_string()),
"category".to_string(),
Some(file_info.clone()),
None,
user_id.clone(),
);
db.store_item(content_a.clone())
.await
.expect("Failed to store first content");
db.store_item(content_b.clone())
.await
.expect("Failed to store second content");
let has_other = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
.await
.expect("Failed to check for shared file usage");
assert!(has_other);
let _removed: Option<TextContent> = db
.delete_item(&content_b.id)
.await
.expect("Failed to delete second content");
let has_other_after = TextContent::has_other_with_file(&file_info.id, &content_a.id, &db)
.await
.expect("Failed to check shared usage after delete");
assert!(!has_other_after);
}
}

View File

@@ -8,7 +8,7 @@ use uuid::Uuid;
use super::text_chunk::TextChunk;
use super::{
conversation::Conversation,
ingestion_task::{IngestionTask, MAX_ATTEMPTS},
ingestion_task::{IngestionTask, TaskState},
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
system_settings::SystemSettings,
@@ -109,7 +109,7 @@ impl User {
)
.bind(("table", T::table_name()))
.bind(("user_id", user_id.to_string()))
.bind(("since", since))
.bind(("since", surrealdb::Datetime::from(since)))
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
@@ -535,19 +535,43 @@ impl User {
let jobs: Vec<IngestionTask> = db
.query(
"SELECT * FROM type::table($table)
WHERE user_id = $user_id
AND (
status.name = 'Created'
OR (
status.name = 'InProgress'
AND status.attempts < $max_attempts
)
)
ORDER BY created_at DESC",
WHERE user_id = $user_id
AND (
state IN $active_states
OR (state = $failed_state AND attempts < max_attempts)
)
ORDER BY scheduled_at ASC, created_at DESC",
)
.bind(("table", IngestionTask::table_name()))
.bind(("user_id", user_id.to_owned()))
.bind((
"active_states",
vec![
TaskState::Pending.as_str(),
TaskState::Reserved.as_str(),
TaskState::Processing.as_str(),
],
))
.bind(("failed_state", TaskState::Failed.as_str()))
.await?
.take(0)?;
Ok(jobs)
}
/// Gets all ingestion tasks for the specified user ordered by newest first
pub async fn get_all_ingestion_tasks(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<IngestionTask>, AppError> {
let jobs: Vec<IngestionTask> = db
.query(
"SELECT * FROM type::table($table)
WHERE user_id = $user_id
ORDER BY created_at DESC",
)
.bind(("table", IngestionTask::table_name()))
.bind(("user_id", user_id.to_owned()))
.bind(("max_attempts", MAX_ATTEMPTS))
.await?
.take(0)?;
@@ -605,7 +629,7 @@ impl User {
mod tests {
use super::*;
use crate::storage::types::ingestion_payload::IngestionPayload;
use crate::storage::types::ingestion_task::{IngestionTask, IngestionTaskStatus, MAX_ATTEMPTS};
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
use std::collections::HashSet;
// Helper function to set up a test database with SystemSettings
@@ -710,28 +734,32 @@ mod tests {
.await
.expect("Failed to store created task");
let mut in_progress_allowed =
IngestionTask::new(payload.clone(), user_id.to_string()).await;
in_progress_allowed.status = IngestionTaskStatus::InProgress {
attempts: 1,
last_attempt: chrono::Utc::now(),
};
db.store_item(in_progress_allowed.clone())
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
processing_task.state = TaskState::Processing;
processing_task.attempts = 1;
db.store_item(processing_task.clone())
.await
.expect("Failed to store in-progress task");
.expect("Failed to store processing task");
let mut in_progress_blocked =
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
failed_retry_task.state = TaskState::Failed;
failed_retry_task.attempts = 1;
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
db.store_item(failed_retry_task.clone())
.await
.expect("Failed to store retryable failed task");
let mut failed_blocked_task =
IngestionTask::new(payload.clone(), user_id.to_string()).await;
in_progress_blocked.status = IngestionTaskStatus::InProgress {
attempts: MAX_ATTEMPTS,
last_attempt: chrono::Utc::now(),
};
db.store_item(in_progress_blocked.clone())
failed_blocked_task.state = TaskState::Failed;
failed_blocked_task.attempts = MAX_ATTEMPTS;
failed_blocked_task.error_message = Some("Too many failures".into());
db.store_item(failed_blocked_task.clone())
.await
.expect("Failed to store blocked task");
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
completed_task.status = IngestionTaskStatus::Completed;
completed_task.state = TaskState::Succeeded;
db.store_item(completed_task.clone())
.await
.expect("Failed to store completed task");
@@ -755,10 +783,54 @@ mod tests {
unfinished.iter().map(|task| task.id.clone()).collect();
assert!(unfinished_ids.contains(&created_task.id));
assert!(unfinished_ids.contains(&in_progress_allowed.id));
assert!(!unfinished_ids.contains(&in_progress_blocked.id));
assert!(unfinished_ids.contains(&processing_task.id));
assert!(unfinished_ids.contains(&failed_retry_task.id));
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
assert!(!unfinished_ids.contains(&completed_task.id));
assert_eq!(unfinished_ids.len(), 2);
assert_eq!(unfinished_ids.len(), 3);
}
#[tokio::test]
async fn test_get_all_ingestion_tasks_returns_sorted() {
let db = setup_test_db().await;
let user_id = "archive_user";
let other_user_id = "other_user";
let payload = IngestionPayload::Text {
text: "One".to_string(),
context: "Context".to_string(),
category: "Category".to_string(),
user_id: user_id.to_string(),
};
// Oldest task
let mut first = IngestionTask::new(payload.clone(), user_id.to_string()).await;
first.created_at = first.created_at - chrono::Duration::minutes(1);
first.updated_at = first.created_at;
first.state = TaskState::Succeeded;
db.store_item(first.clone()).await.expect("store first");
// Latest task
let mut second = IngestionTask::new(payload.clone(), user_id.to_string()).await;
second.state = TaskState::Processing;
db.store_item(second.clone()).await.expect("store second");
let other_payload = IngestionPayload::Text {
text: "Other".to_string(),
context: "Context".to_string(),
category: "Category".to_string(),
user_id: other_user_id.to_string(),
};
let other_task = IngestionTask::new(other_payload, other_user_id.to_string()).await;
db.store_item(other_task).await.expect("store other");
let tasks = User::get_all_ingestion_tasks(user_id, &db)
.await
.expect("fetch all tasks");
assert_eq!(tasks.len(), 2);
assert_eq!(tasks[0].id, second.id); // newest first
assert_eq!(tasks[1].id, first.id);
}
#[tokio::test]

View File

@@ -11,6 +11,20 @@ fn default_storage_kind() -> StorageKind {
StorageKind::Local
}
/// Selects the strategy used for PDF ingestion.
#[derive(Clone, Deserialize, Debug)]
#[serde(rename_all = "kebab-case")]
pub enum PdfIngestMode {
/// Only rely on classic text extraction (no LLM fallbacks).
Classic,
/// Prefer fast text extraction, but fall back to the LLM rendering path when needed.
LlmFirst,
}
fn default_pdf_ingest_mode() -> PdfIngestMode {
PdfIngestMode::LlmFirst
}
#[derive(Clone, Deserialize, Debug)]
pub struct AppConfig {
pub openai_api_key: String,
@@ -26,6 +40,8 @@ pub struct AppConfig {
pub openai_base_url: String,
#[serde(default = "default_storage_kind")]
pub storage: StorageKind,
#[serde(default = "default_pdf_ingest_mode")]
pub pdf_ingest_mode: PdfIngestMode,
}
fn default_data_dir() -> String {

View File

@@ -11,7 +11,6 @@ use common::{
storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity,
message::{format_history, Message},
system_settings::SystemSettings,
},
@@ -20,7 +19,7 @@ use common::{
use serde::Deserialize;
use serde_json::{json, Value};
use crate::retrieve_entities;
use crate::{retrieve_entities, RetrievedEntity};
use super::answer_retrieval_helper::get_query_response_schema;
@@ -84,21 +83,31 @@ pub async fn get_answer_with_references(
})
}
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
pub fn format_entities_json(entities: &[RetrievedEntity]) -> Value {
json!(entities
.iter()
.map(|entity| {
.map(|entry| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
fn round_score(value: f32) -> f64 {
((value as f64) * 1000.0).round() / 1000.0
}
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
format!(
r#"
@@ -154,8 +163,6 @@ pub fn create_chat_request(
CreateChatCompletionRequestArgs::default()
.model(&settings.query_model)
.temperature(0.2)
.max_tokens(3048u32)
.messages([
ChatCompletionRequestSystemMessage::from(settings.query_system_prompt.clone()).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),

View File

@@ -0,0 +1,265 @@
use std::collections::HashMap;
use serde::Deserialize;
use tracing::debug;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
};
use crate::scoring::Scored;
use common::storage::types::file_info::deserialize_flexible_id;
use surrealdb::sql::Thing;
#[derive(Debug, Deserialize)]
struct FtsScoreRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
fts_score: Option<f32>,
}
/// Executes a full-text search query against SurrealDB and returns scored results.
///
/// The function expects FTS indexes to exist for the provided table. Currently supports
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
pub async fn find_items_by_fts<T>(
take: usize,
query: &str,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let (filter_clause, score_clause) = match table {
"knowledge_entity" => (
"(name @0@ $terms OR description @1@ $terms)",
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
),
"text_chunk" => (
"(chunk @0@ $terms)",
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
),
_ => {
return Err(AppError::Validation(format!(
"FTS not configured for table '{table}'"
)))
}
};
let sql = format!(
"SELECT id, {score_clause} AS fts_score \
FROM {table} \
WHERE {filter_clause} \
AND user_id = $user_id \
ORDER BY fts_score DESC \
LIMIT $limit",
table = table,
filter_clause = filter_clause,
score_clause = score_clause
);
debug!(
table = table,
limit = take,
"Executing FTS query with filter clause: {}",
filter_clause
);
let mut response = db_client
.query(sql)
.bind(("terms", query.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
if score_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut results = Vec::with_capacity(score_rows.len());
for row in score_rows {
if let Some(item) = item_map.remove(&row.id) {
let score = row.fts_score.unwrap_or_default();
results.push(Scored::new(item).with_fts_score(score));
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
text_chunk::TextChunk,
StoredObject,
};
use uuid::Uuid;
fn dummy_embedding() -> Vec<f32> {
vec![0.0; 1536]
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_name() {
let namespace = "fts_test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts";
let entity = KnowledgeEntity::new(
"source_a".into(),
"Rustacean handbook".into(),
"completely unrelated description".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"rustacean",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the name matched"
);
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_description() {
let namespace = "fts_test_ns_desc";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts_desc";
let entity = KnowledgeEntity::new(
"source_b".into(),
"neutral name".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"async",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the description matched"
);
}
#[tokio::test]
async fn fts_preserves_scores_for_text_chunks() {
let namespace = "fts_test_ns_chunks";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts_chunk";
let chunk = TextChunk::new(
"source_chunk".into(),
"GraphQL documentation reference".into(),
dummy_embedding(),
user_id.into(),
);
db.store_item(chunk.clone())
.await
.expect("failed to insert chunk");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results =
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when chunk field matched"
);
}
}

View File

@@ -1,7 +1,14 @@
use surrealdb::Error;
use tracing::debug;
use std::collections::{HashMap, HashSet};
use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
use surrealdb::{sql::Thing, Error};
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
StoredObject,
},
};
/// Retrieves database entries that match a specific source identifier.
///
@@ -31,18 +38,21 @@ use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEnt
/// * The database query fails to execute
/// * The results cannot be deserialized into type `T`
pub async fn find_entities_by_source_ids<T>(
source_id: Vec<String>,
table_name: String,
source_ids: Vec<String>,
table_name: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<T>, Error>
where
T: for<'de> serde::Deserialize<'de>,
{
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
let query =
"SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id";
db.query(query)
.bind(("table", table_name))
.bind(("source_ids", source_id))
.bind(("table", table_name.to_owned()))
.bind(("source_ids", source_ids))
.bind(("user_id", user_id.to_owned()))
.await?
.take(0)
}
@@ -50,16 +60,92 @@ where
/// Find entities by their relationship to the id
pub async fn find_entities_by_relationship_by_id(
db: &SurrealDbClient,
entity_id: String,
entity_id: &str,
user_id: &str,
limit: usize,
) -> Result<Vec<KnowledgeEntity>, Error> {
let query = format!(
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
entity_id
);
let mut relationships_response = db
.query(
"
SELECT * FROM relates_to
WHERE metadata.user_id = $user_id
AND (in = type::thing('knowledge_entity', $entity_id)
OR out = type::thing('knowledge_entity', $entity_id))
",
)
.bind(("entity_id", entity_id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
debug!("{}", query);
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
if relationships.is_empty() {
return Ok(Vec::new());
}
db.query(query).await?.take(0)
let mut neighbor_ids: Vec<String> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
for rel in relationships {
if rel.in_ == entity_id {
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
} else if rel.out == entity_id {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_);
}
} else {
if seen.insert(rel.in_.clone()) {
neighbor_ids.push(rel.in_.clone());
}
if seen.insert(rel.out.clone()) {
neighbor_ids.push(rel.out);
}
}
}
neighbor_ids.retain(|id| id != entity_id);
if neighbor_ids.is_empty() {
return Ok(Vec::new());
}
if limit > 0 && neighbor_ids.len() > limit {
neighbor_ids.truncate(limit);
}
let thing_ids: Vec<Thing> = neighbor_ids
.iter()
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
.collect();
let mut neighbors_response = db
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", KnowledgeEntity::table_name().to_owned()))
.bind(("things", thing_ids))
.bind(("user_id", user_id.to_owned()))
.await?;
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
if neighbors.is_empty() {
return Ok(Vec::new());
}
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
.into_iter()
.map(|entity| (entity.id.clone(), entity))
.collect();
let mut ordered = Vec::new();
for id in neighbor_ids {
if let Some(entity) = neighbor_map.remove(&id) {
ordered.push(entity);
}
if limit > 0 && ordered.len() >= limit {
break;
}
}
Ok(ordered)
}
#[cfg(test)]
@@ -149,7 +235,7 @@ mod tests {
// Test finding entities by multiple source_ids
let source_ids = vec![source_id1.clone(), source_id2.clone()];
let found_entities: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name().to_string(), &db)
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db)
.await
.expect("Failed to find entities by source_ids");
@@ -180,7 +266,8 @@ mod tests {
let single_source_id = vec![source_id1.clone()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
single_source_id,
KnowledgeEntity::table_name().to_string(),
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
@@ -205,7 +292,8 @@ mod tests {
let non_existent_source_id = vec!["non_existent_source".to_string()];
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
non_existent_source_id,
KnowledgeEntity::table_name().to_string(),
KnowledgeEntity::table_name(),
&user_id,
&db,
)
.await
@@ -330,11 +418,15 @@ mod tests {
.expect("Failed to store relationship 2");
// Test finding entities related to the central entity
let related_entities = find_entities_by_relationship_by_id(&db, central_entity.id.clone())
.await
.expect("Failed to find entities by relationship");
let related_entities =
find_entities_by_relationship_by_id(&db, &central_entity.id, &user_id, usize::MAX)
.await
.expect("Failed to find entities by relationship");
// Check that we found relationships
assert!(related_entities.len() > 0, "Should find related entities");
assert!(
related_entities.len() >= 2,
"Should find related entities in both directions"
);
}
}

View File

@@ -1,90 +1,721 @@
pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod fts;
pub mod graph;
pub mod scoring;
pub mod vector;
use std::collections::{HashMap, HashSet};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
utils::embedding::generate_embedding,
};
use futures::future::{try_join, try_join_all};
use futures::{stream::FuturesUnordered, StreamExt};
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
use std::collections::HashMap;
use vector::find_items_by_vector_similarity;
use scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
};
use tracing::{debug, instrument, trace};
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
/// to find the most relevant entities for a given query.
///
/// # Strategy
/// The function employs a three-pronged approach to knowledge retrieval:
/// 1. Direct vector similarity search on knowledge entities
/// 2. Text chunk similarity search with source entity lookup
/// 3. Graph relationship traversal from related entities
///
/// This combined approach ensures both semantic similarity matches and structurally
/// related content are included in the results.
///
/// # Arguments
/// * `db_client` - SurrealDB client for database operations
/// * `openai_client` - OpenAI client for vector embeddings generation
/// * `query` - The search query string to find relevant knowledge entities
/// * 'user_id' - The user id of the current user
///
/// # Returns
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
/// knowledge entities, or an error if the retrieval process fails
use crate::{fts::find_items_by_fts, vector::find_items_by_vector_similarity_with_embedding};
// Tunable knobs controlling first-pass recall, graph expansion, and answer shaping.
const ENTITY_VECTOR_TAKE: usize = 15;
const CHUNK_VECTOR_TAKE: usize = 20;
const ENTITY_FTS_TAKE: usize = 10;
const CHUNK_FTS_TAKE: usize = 20;
const SCORE_THRESHOLD: f32 = 0.35;
const FALLBACK_MIN_RESULTS: usize = 10;
const TOKEN_BUDGET_ESTIMATE: usize = 2800;
const AVG_CHARS_PER_TOKEN: usize = 4;
const MAX_CHUNKS_PER_ENTITY: usize = 4;
const GRAPH_TRAVERSAL_SEED_LIMIT: usize = 5;
const GRAPH_NEIGHBOR_LIMIT: usize = 6;
const GRAPH_SCORE_DECAY: f32 = 0.75;
const GRAPH_SEED_MIN_SCORE: f32 = 0.4;
const GRAPH_VECTOR_INHERITANCE: f32 = 0.6;
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
pub struct RetrievedChunk {
pub chunk: TextChunk,
pub score: f32,
}
// Final entity representation returned to callers, enriched with ranked chunks.
#[derive(Debug, Clone)]
pub struct RetrievedEntity {
pub entity: KnowledgeEntity,
pub score: f32,
pub chunks: Vec<RetrievedChunk>,
}
#[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities(
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
query: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
find_items_by_vector_similarity(
10,
) -> Result<Vec<RetrievedEntity>, AppError> {
trace!("Generating query embedding for hybrid retrieval");
let query_embedding = generate_embedding(openai_client, query, db_client).await?;
retrieve_entities_with_embedding(db_client, query_embedding, query, user_id).await
}
pub(crate) async fn retrieve_entities_with_embedding(
db_client: &SurrealDbClient,
query_embedding: Vec<f32>,
query: &str,
user_id: &str,
) -> Result<Vec<RetrievedEntity>, AppError> {
// 1) Gather first-pass candidates from vector search and BM25.
let weights = FusionWeights::default();
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
find_items_by_vector_similarity_with_embedding(
ENTITY_VECTOR_TAKE,
query_embedding.clone(),
db_client,
"knowledge_entity",
user_id,
),
find_items_by_vector_similarity_with_embedding(
CHUNK_VECTOR_TAKE,
query_embedding,
db_client,
"text_chunk",
user_id,
),
find_items_by_fts(
ENTITY_FTS_TAKE,
query,
db_client,
"knowledge_entity",
openai_client,
user_id,
user_id
),
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
find_items_by_fts(CHUNK_FTS_TAKE, query, db_client, "text_chunk", user_id),
)?;
debug!(
vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
fts_entities = fts_entities.len(),
fts_chunks = fts_chunks.len(),
"Hybrid retrieval initial candidate counts"
);
normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
let mut entity_candidates: HashMap<String, Scored<KnowledgeEntity>> = HashMap::new();
let mut chunk_candidates: HashMap<String, Scored<TextChunk>> = HashMap::new();
// Collate raw retrieval results so each ID accumulates all available signals.
merge_scored_by_id(&mut entity_candidates, vector_entities);
merge_scored_by_id(&mut entity_candidates, fts_entities);
merge_scored_by_id(&mut chunk_candidates, vector_chunks);
merge_scored_by_id(&mut chunk_candidates, fts_chunks);
// 2) Normalize scores, fuse them, and allow high-confidence entities to pull neighbors from the graph.
apply_fusion(&mut entity_candidates, weights);
apply_fusion(&mut chunk_candidates, weights);
enrich_entities_from_graph(&mut entity_candidates, db_client, user_id, weights).await?;
// 3) Track high-signal chunk sources so we can backfill missing entities.
let chunk_by_source = group_chunks_by_source(&chunk_candidates);
let mut missing_sources = Vec::new();
for source_id in chunk_by_source.keys() {
if !entity_candidates
.values()
.any(|entity| entity.item.source_id == *source_id)
{
missing_sources.push(source_id.clone());
}
}
if !missing_sources.is_empty() {
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
missing_sources.clone(),
"knowledge_entity",
user_id,
db_client,
)
.await
.unwrap_or_default();
for entity in related_entities {
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
entity_candidates.insert(entity.id.clone(), scored);
}
}
}
// Boost entities with evidence from high scoring chunks.
for entity in entity_candidates.values_mut() {
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
if best_chunk_score > 0.0 {
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
entity.scores.vector = Some(boosted);
let fused = fuse_scores(&entity.scores, weights);
entity.update_fused(fused);
}
}
}
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
entity_candidates.into_values().collect();
sort_by_fused_desc(&mut entity_results);
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
.iter()
.filter(|candidate| candidate.fused >= SCORE_THRESHOLD)
.cloned()
.collect();
if filtered_entities.len() < FALLBACK_MIN_RESULTS {
// Low recall scenarios still benefit from some context; take the top N regardless of score.
filtered_entities = entity_results
.into_iter()
.take(FALLBACK_MIN_RESULTS)
.collect();
}
// 4) Re-rank chunks and prepare for attachment to surviving entities.
let mut chunk_results: Vec<Scored<TextChunk>> = chunk_candidates.into_values().collect();
sort_by_fused_desc(&mut chunk_results);
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
for chunk in chunk_results {
chunk_by_id.insert(chunk.item.id.clone(), chunk);
}
enrich_chunks_from_entities(
&mut chunk_by_id,
&filtered_entities,
db_client,
user_id,
weights,
)
.await?;
let source_ids = closest_chunks
.iter()
.map(|chunk: &TextChunk| chunk.source_id.clone())
.collect::<Vec<String>>();
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
.clone()
.into_iter()
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
.collect();
let items_from_relationships = try_join_all(items_from_relationships_futures)
.await?
.into_iter()
.flatten()
.collect::<Vec<KnowledgeEntity>>();
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
.into_iter()
.chain(items_from_text_chunk_similarity.into_iter())
.chain(items_from_relationships.into_iter())
.fold(HashMap::new(), |mut map, entity| {
map.insert(entity.id.clone(), entity);
map
})
.into_values()
.collect();
Ok(entities)
Ok(assemble_results(filtered_entities, chunk_values))
}
// Minimal record used while seeding graph expansion so we can retain the original fused score.
#[derive(Clone)]
struct GraphSeed {
id: String,
fused: f32,
}
async fn enrich_entities_from_graph(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
if entity_candidates.is_empty() {
return Ok(());
}
// Select a small frontier of high-confidence entities to seed the relationship walk.
let mut seeds: Vec<GraphSeed> = entity_candidates
.values()
.filter(|entity| entity.fused >= GRAPH_SEED_MIN_SCORE)
.map(|entity| GraphSeed {
id: entity.item.id.clone(),
fused: entity.fused,
})
.collect();
if seeds.is_empty() {
return Ok(());
}
// Prioritise the strongest seeds so we explore the most grounded context first.
seeds.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
seeds.truncate(GRAPH_TRAVERSAL_SEED_LIMIT);
let mut futures = FuturesUnordered::new();
for seed in seeds.clone() {
let user_id = user_id.to_owned();
futures.push(async move {
// Fetch neighbors concurrently to avoid serial graph round trips.
let neighbors = find_entities_by_relationship_by_id(
db_client,
&seed.id,
&user_id,
GRAPH_NEIGHBOR_LIMIT,
)
.await;
(seed, neighbors)
});
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
// Fold neighbors back into the candidate map and let them inherit attenuated signal.
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * GRAPH_SCORE_DECAY);
let entry = entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector = clamp_unit(graph_score * GRAPH_VECTOR_INHERITANCE);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph {
entry.scores.graph = Some(graph_score);
} else if entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
Ok(())
}
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
// Scale BM25 outputs into [0,1] to keep fusion weights predictable.
let raw_scores: Vec<f32> = results
.iter()
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
.collect();
let normalized = min_max_normalize(&raw_scores);
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
candidate.scores.fts = Some(normalized_score);
candidate.update_fused(0.0);
}
}
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
where
T: StoredObject,
{
// Collapse individual signals into a single fused score used for ranking.
for candidate in candidates.values_mut() {
let fused = fuse_scores(&candidate.scores, weights);
candidate.update_fused(fused);
}
}
fn group_chunks_by_source(
chunks: &HashMap<String, Scored<TextChunk>>,
) -> HashMap<String, Vec<Scored<TextChunk>>> {
// Preserve chunk candidates keyed by their originating source entity.
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.values() {
by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk.clone());
}
by_source
}
async fn enrich_chunks_from_entities(
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
entities: &[Scored<KnowledgeEntity>],
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
// Fetch additional chunks referenced by entities that survived the fusion stage.
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
source_ids.insert(entity.item.source_id.clone());
}
if source_ids.is_empty() {
return Ok(());
}
let chunks = find_entities_by_source_ids::<TextChunk>(
source_ids.into_iter().collect(),
"text_chunk",
user_id,
db_client,
)
.await?;
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
// Cache fused scores per source so chunks inherit the strength of their parent entity.
for entity in entities {
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
}
for chunk in chunks {
// Ensure each chunk is represented so downstream selection sees the latest content.
let entry = chunk_candidates
.entry(chunk.id.clone())
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
let entity_score = entity_score_lookup
.get(&chunk.source_id)
.copied()
.unwrap_or(0.0);
// Lift chunk score toward the entity score so supporting evidence is prioritised.
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
}
Ok(())
}
fn assemble_results(
entities: Vec<Scored<KnowledgeEntity>>,
mut chunks: Vec<Scored<TextChunk>>,
) -> Vec<RetrievedEntity> {
// Re-associate chunk candidates with their parent entity for ranked selection.
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.drain(..) {
chunk_by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk);
}
for chunk_list in chunk_by_source.values_mut() {
sort_by_fused_desc(chunk_list);
}
let mut token_budget_remaining = TOKEN_BUDGET_ESTIMATE;
let mut results = Vec::new();
for entity in entities {
// Attach best chunks first while respecting per-entity and global token caps.
let mut selected_chunks = Vec::new();
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
let mut per_entity_count = 0;
candidates.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
for candidate in candidates.iter() {
if per_entity_count >= MAX_CHUNKS_PER_ENTITY {
break;
}
let estimated_tokens = estimate_tokens(&candidate.item.chunk);
if estimated_tokens > token_budget_remaining {
continue;
}
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
per_entity_count += 1;
selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(),
score: candidate.fused,
});
}
}
results.push(RetrievedEntity {
entity: entity.item.clone(),
score: entity.fused,
chunks: selected_chunks,
});
if token_budget_remaining == 0 {
break;
}
}
results
}
fn estimate_tokens(text: &str) -> usize {
// Simple heuristic to avoid calling a tokenizer in hot code paths.
let chars = text.chars().count().max(1);
(chars / AVG_CHARS_PER_TOKEN).max(1)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::types::{
knowledge_entity::KnowledgeEntityType, knowledge_relationship::KnowledgeRelationship,
};
use uuid::Uuid;
fn test_embedding() -> Vec<f32> {
vec![0.9, 0.1, 0.0]
}
fn entity_embedding_high() -> Vec<f32> {
vec![0.8, 0.2, 0.0]
}
fn entity_embedding_low() -> Vec<f32> {
vec![0.1, 0.9, 0.0]
}
fn chunk_embedding_primary() -> Vec<f32> {
vec![0.85, 0.15, 0.0]
}
fn chunk_embedding_secondary() -> Vec<f32> {
vec![0.2, 0.8, 0.0]
}
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db.query(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
COMMIT TRANSACTION;",
)
.await
.expect("Failed to redefine vector indexes for tests");
db
}
async fn seed_test_data(db: &SurrealDbClient, user_id: &str) {
let entity_relevant = KnowledgeEntity::new(
"source_a".into(),
"Rust Concurrency Patterns".into(),
"Discussion about async concurrency in Rust.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let entity_irrelevant = KnowledgeEntity::new(
"source_b".into(),
"Python Tips".into(),
"General Python programming tips.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_low(),
user_id.into(),
);
db.store_item(entity_relevant.clone())
.await
.expect("Failed to store relevant entity");
db.store_item(entity_irrelevant.clone())
.await
.expect("Failed to store irrelevant entity");
let chunk_primary = TextChunk::new(
entity_relevant.source_id.clone(),
"Tokio enables async concurrency with lightweight tasks.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let chunk_secondary = TextChunk::new(
entity_irrelevant.source_id.clone(),
"Python focuses on readability and dynamic typing.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(chunk_primary)
.await
.expect("Failed to store primary chunk");
db.store_item(chunk_secondary)
.await
.expect("Failed to store secondary chunk");
}
#[tokio::test]
async fn test_hybrid_retrieval_prioritises_relevant_entity() {
let db = setup_test_db().await;
let user_id = "user123";
seed_test_data(&db, user_id).await;
let results = retrieve_entities_with_embedding(
&db,
test_embedding(),
"Rust concurrency async tasks",
user_id,
)
.await
.expect("Hybrid retrieval failed");
assert!(
!results.is_empty(),
"Expected at least one retrieval result"
);
let top = &results[0];
assert!(
top.entity.name.contains("Rust"),
"Expected Rust entity to be ranked first"
);
assert!(
!top.chunks.is_empty(),
"Expected Rust entity to include supporting chunks"
);
let chunk_texts: Vec<&str> = top
.chunks
.iter()
.map(|chunk| chunk.chunk.chunk.as_str())
.collect();
assert!(
chunk_texts.iter().any(|text| text.contains("Tokio")),
"Expected chunk discussing Tokio to be included"
);
}
#[tokio::test]
async fn test_graph_relationship_enriches_results() {
let db = setup_test_db().await;
let user_id = "graph_user";
let primary = KnowledgeEntity::new(
"primary_source".into(),
"Async Rust patterns".into(),
"Explores async runtimes and scheduling strategies.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let neighbor = KnowledgeEntity::new(
"neighbor_source".into(),
"Tokio Scheduler Deep Dive".into(),
"Details on Tokio's cooperative scheduler.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_low(),
user_id.into(),
);
db.store_item(primary.clone())
.await
.expect("Failed to store primary entity");
db.store_item(neighbor.clone())
.await
.expect("Failed to store neighbor entity");
let primary_chunk = TextChunk::new(
primary.source_id.clone(),
"Rust async tasks use Tokio's cooperative scheduler.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let neighbor_chunk = TextChunk::new(
neighbor.source_id.clone(),
"Tokio's scheduler manages task fairness across executors.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(primary_chunk)
.await
.expect("Failed to store primary chunk");
db.store_item(neighbor_chunk)
.await
.expect("Failed to store neighbor chunk");
let relationship = KnowledgeRelationship::new(
primary.id.clone(),
neighbor.id.clone(),
user_id.into(),
"relationship_source".into(),
"references".into(),
);
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
let results = retrieve_entities_with_embedding(
&db,
test_embedding(),
"Rust concurrency async tasks",
user_id,
)
.await
.expect("Hybrid retrieval failed");
let mut neighbor_entry = None;
for entity in &results {
if entity.entity.id == neighbor.id {
neighbor_entry = Some(entity.clone());
}
}
let neighbor_entry =
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
assert!(
neighbor_entry.score > 0.2,
"Graph-enriched entity should have a meaningful fused score"
);
assert!(
neighbor_entry
.chunks
.iter()
.all(|chunk| chunk.chunk.source_id == neighbor.source_id),
"Neighbor entity should surface its own supporting chunks"
);
}
}

View File

@@ -0,0 +1,180 @@
use std::cmp::Ordering;
use common::storage::types::StoredObject;
/// Holds optional subscores gathered from different retrieval signals.
#[derive(Debug, Clone, Copy, Default)]
pub struct Scores {
pub fts: Option<f32>,
pub vector: Option<f32>,
pub graph: Option<f32>,
}
/// Generic wrapper combining an item with its accumulated retrieval scores.
#[derive(Debug, Clone)]
pub struct Scored<T> {
pub item: T,
pub scores: Scores,
pub fused: f32,
}
impl<T> Scored<T> {
pub fn new(item: T) -> Self {
Self {
item,
scores: Scores::default(),
fused: 0.0,
}
}
pub fn with_vector_score(mut self, score: f32) -> Self {
self.scores.vector = Some(score);
self
}
pub fn with_fts_score(mut self, score: f32) -> Self {
self.scores.fts = Some(score);
self
}
pub fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score);
self
}
pub fn update_fused(&mut self, fused: f32) {
self.fused = fused;
}
}
/// Weights used for linear score fusion.
#[derive(Debug, Clone, Copy)]
pub struct FusionWeights {
pub vector: f32,
pub fts: f32,
pub graph: f32,
pub multi_bonus: f32,
}
impl Default for FusionWeights {
fn default() -> Self {
Self {
vector: 0.5,
fts: 0.3,
graph: 0.2,
multi_bonus: 0.02,
}
}
}
pub fn clamp_unit(value: f32) -> f32 {
value.max(0.0).min(1.0)
}
pub fn distance_to_similarity(distance: f32) -> f32 {
if !distance.is_finite() {
return 0.0;
}
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
}
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return Vec::new();
}
let mut min = f32::MAX;
let mut max = f32::MIN;
for s in scores {
if !s.is_finite() {
continue;
}
if *s < min {
min = *s;
}
if *s > max {
max = *s;
}
}
if !min.is_finite() || !max.is_finite() {
return scores.iter().map(|_| 0.0).collect();
}
if (max - min).abs() < f32::EPSILON {
return vec![1.0; scores.len()];
}
scores
.iter()
.map(|score| {
if !score.is_finite() {
0.0
} else {
clamp_unit((score - min) / (max - min))
}
})
.collect()
}
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
let vector = scores.vector.unwrap_or(0.0);
let fts = scores.fts.unwrap_or(0.0);
let graph = scores.graph.unwrap_or(0.0);
let mut fused = vector * weights.vector + fts * weights.fts + graph * weights.graph;
let signals_present = scores
.vector
.iter()
.chain(scores.fts.iter())
.chain(scores.graph.iter())
.count();
if signals_present >= 2 {
fused += weights.multi_bonus;
}
clamp_unit(fused)
}
pub fn merge_scored_by_id<T>(
target: &mut std::collections::HashMap<String, Scored<T>>,
incoming: Vec<Scored<T>>,
) where
T: StoredObject + Clone,
{
for scored in incoming {
let id = scored.item.get_id().to_owned();
target
.entry(id)
.and_modify(|existing| {
if let Some(score) = scored.scores.vector {
existing.scores.vector = Some(score);
}
if let Some(score) = scored.scores.fts {
existing.scores.fts = Some(score);
}
if let Some(score) = scored.scores.graph {
existing.scores.graph = Some(score);
}
})
.or_insert_with(|| Scored {
item: scored.item.clone(),
scores: scored.scores,
fused: scored.fused,
});
}
}
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where
T: StoredObject,
{
items.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
}

View File

@@ -1,4 +1,15 @@
use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::generate_embedding};
use std::collections::HashMap;
use common::storage::types::file_info::deserialize_flexible_id;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
utils::embedding::generate_embedding,
};
use serde::Deserialize;
use surrealdb::sql::Thing;
use crate::scoring::{clamp_unit, distance_to_similarity, Scored};
/// Compares vectors and retrieves a number of items from the specified table.
///
@@ -22,24 +33,125 @@ use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::ge
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: u8,
take: usize,
input_text: &str,
db_client: &SurrealDbClient,
table: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<T>, AppError>
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de>,
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
// Construct the query
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE user_id = '{}' AND embedding <|{},40|> {:?} ORDER BY distance", table, user_id, take, input_embedding);
// Perform query and deserialize to struct
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
Ok(closest_entities)
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
.await
}
#[derive(Debug, Deserialize)]
struct DistanceRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
distance: Option<f32>,
}
pub async fn find_items_by_vector_similarity_with_embedding<T>(
take: usize,
query_embedding: Vec<f32>,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, vector::distance::knn() AS distance \
FROM {table} \
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
LIMIT $limit",
table = table,
take = take,
embedding = embedding_literal
);
let mut response = db_client
.query(closest_query)
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let distance_rows: Vec<DistanceRow> = response.take(0)?;
if distance_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut min_distance = f32::MAX;
let mut max_distance = f32::MIN;
for row in &distance_rows {
if let Some(distance) = row.distance {
if distance.is_finite() {
if distance < min_distance {
min_distance = distance;
}
if distance > max_distance {
max_distance = distance;
}
}
}
}
let normalize = min_distance.is_finite()
&& max_distance.is_finite()
&& (max_distance - min_distance).abs() > f32::EPSILON;
let mut scored = Vec::with_capacity(distance_rows.len());
for row in distance_rows {
if let Some(item) = item_map.remove(&row.id) {
let similarity = row
.distance
.map(|distance| {
if normalize {
let span = max_distance - min_distance;
if span.abs() < f32::EPSILON {
1.0
} else {
let normalized = 1.0 - ((distance - min_distance) / span);
clamp_unit(normalized)
}
} else {
distance_to_similarity(distance)
}
})
.unwrap_or_default();
scored.push(Scored::new(item).with_vector_score(similarity));
}
}
Ok(scored)
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 47 KiB

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 252 KiB

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 42 KiB

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 790 B

After

Width:  |  Height:  |  Size: 963 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

File diff suppressed because one or more lines are too long

View File

@@ -185,8 +185,13 @@ pub async fn delete_text_content(
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
// If it has file info, delete that too
if let Some(file_info) = &text_content.file_info {
FileInfo::delete_by_id(&file_info.id, &state.db, &state.config).await?;
if let Some(file_info) = text_content.file_info.as_ref() {
let file_in_use =
TextContent::has_other_with_file(&file_info.id, &text_content.id, &state.db).await?;
if !file_in_use {
FileInfo::delete_by_id(&file_info.id, &state.db, &state.config).await?;
}
}
// Delete related knowledge entities and text chunks

View File

@@ -4,9 +4,9 @@ use axum::{
http::{header, HeaderMap, HeaderValue, StatusCode},
response::IntoResponse,
};
use chrono::{DateTime, Utc};
use futures::try_join;
use serde::Serialize;
use tokio::join;
use crate::{
html_state::HtmlState,
@@ -68,7 +68,7 @@ pub async fn index_handler(
#[derive(Serialize)]
pub struct LatestTextContentData {
latest_text_contents: Vec<TextContent>,
text_contents: Vec<TextContent>,
user: User,
}
@@ -80,31 +80,35 @@ pub async fn delete_text_content(
// Get and validate TextContent
let text_content = get_and_validate_text_content(&state, &id, &user).await?;
// Perform concurrent deletions
let (_res1, _res2, _res3, _res4, _res5) = join!(
async {
if let Some(file_info) = text_content.file_info {
FileInfo::delete_by_id(&file_info.id, &state.db, &state.config).await
} else {
Ok(())
}
},
state.db.delete_item::<TextContent>(&text_content.id),
TextChunk::delete_by_source_id(&text_content.id, &state.db),
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db),
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db)
);
// Remove stored assets before deleting the text content record
if let Some(file_info) = text_content.file_info.as_ref() {
let file_in_use =
TextContent::has_other_with_file(&file_info.id, &text_content.id, &state.db).await?;
if !file_in_use {
FileInfo::delete_by_id(&file_info.id, &state.db, &state.config).await?;
}
}
// Delete the text content and any related data
TextChunk::delete_by_source_id(&text_content.id, &state.db).await?;
KnowledgeEntity::delete_by_source_id(&text_content.id, &state.db).await?;
KnowledgeRelationship::delete_relationships_by_source_id(&text_content.id, &state.db).await?;
state
.db
.delete_item::<TextContent>(&text_content.id)
.await?;
// Render updated content
let latest_text_contents =
let text_contents =
truncate_text_contents(User::get_latest_text_contents(&user.id, &state.db).await?);
Ok(TemplateResponse::new_partial(
"index/signed_in/recent_content.html",
"dashboard/recent_content.html",
"latest_content_section",
LatestTextContentData {
user: user.to_owned(),
latest_text_contents,
text_contents,
},
))
}
@@ -136,6 +140,32 @@ pub struct ActiveJobsData {
pub user: User,
}
#[derive(Serialize)]
struct TaskArchiveEntry {
id: String,
state_label: String,
state_raw: String,
attempts: u32,
max_attempts: u32,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
scheduled_at: DateTime<Utc>,
locked_at: Option<DateTime<Utc>>,
last_error_at: Option<DateTime<Utc>>,
error_message: Option<String>,
worker_id: Option<String>,
priority: i32,
lease_duration_secs: i64,
content_kind: String,
content_summary: String,
}
#[derive(Serialize)]
struct TaskArchiveData {
user: User,
tasks: Vec<TaskArchiveEntry>,
}
pub async fn delete_job(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
@@ -170,6 +200,70 @@ pub async fn show_active_jobs(
))
}
pub async fn show_task_archive(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let tasks = User::get_all_ingestion_tasks(&user.id, &state.db).await?;
let entries: Vec<TaskArchiveEntry> = tasks
.into_iter()
.map(|task| {
let (content_kind, content_summary) = summarize_task_content(&task);
TaskArchiveEntry {
id: task.id.clone(),
state_label: task.state.display_label().to_string(),
state_raw: task.state.as_str().to_string(),
attempts: task.attempts,
max_attempts: task.max_attempts,
created_at: task.created_at,
updated_at: task.updated_at,
scheduled_at: task.scheduled_at,
locked_at: task.locked_at,
last_error_at: task.last_error_at,
error_message: task.error_message.clone(),
worker_id: task.worker_id.clone(),
priority: task.priority,
lease_duration_secs: task.lease_duration_secs,
content_kind,
content_summary,
}
})
.collect();
Ok(TemplateResponse::new_template(
"dashboard/task_archive_modal.html",
TaskArchiveData {
user,
tasks: entries,
},
))
}
fn summarize_task_content(task: &IngestionTask) -> (String, String) {
match &task.content {
common::storage::types::ingestion_payload::IngestionPayload::Text { text, .. } => {
("Text".to_string(), truncate_summary(text, 80))
}
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
("URL".to_string(), url.to_string())
}
common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => {
("File".to_string(), file_info.file_name.clone())
}
}
}
fn truncate_summary(input: &str, max_chars: usize) -> String {
if input.chars().count() <= max_chars {
input.to_string()
} else {
let truncated: String = input.chars().take(max_chars).collect();
format!("{truncated}")
}
}
pub async fn serve_file(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,

View File

@@ -5,7 +5,9 @@ use axum::{
routing::{delete, get},
Router,
};
use handlers::{delete_job, delete_text_content, index_handler, serve_file, show_active_jobs};
use handlers::{
delete_job, delete_text_content, index_handler, serve_file, show_active_jobs, show_task_archive,
};
use crate::html_state::HtmlState;
@@ -24,6 +26,7 @@ where
{
Router::new()
.route("/jobs/{job_id}", delete(delete_job))
.route("/jobs/archive", get(show_task_archive))
.route("/active-jobs", get(show_active_jobs))
.route("/text-content/{id}", delete(delete_text_content))
.route("/file/{id}", get(serve_file))

View File

@@ -20,7 +20,7 @@ use common::{
storage::types::{
file_info::FileInfo,
ingestion_payload::IngestionPayload,
ingestion_task::{IngestionTask, IngestionTaskStatus},
ingestion_task::{IngestionTask, TaskState},
user::User,
},
};
@@ -178,40 +178,54 @@ pub async fn get_task_updates_stream(
Ok(Some(updated_task)) => {
consecutive_db_errors = 0; // Reset error count on success
// Format the status message based on IngestionTaskStatus
let status_message = match &updated_task.status {
IngestionTaskStatus::Created => "Created".to_string(),
IngestionTaskStatus::InProgress { attempts, .. } => {
// Following your template's current display
format!("In progress, attempt {}", attempts)
let status_message = match updated_task.state {
TaskState::Pending => "Pending".to_string(),
TaskState::Reserved => format!(
"Reserved (attempt {} of {})",
updated_task.attempts,
updated_task.max_attempts
),
TaskState::Processing => format!(
"Processing (attempt {} of {})",
updated_task.attempts,
updated_task.max_attempts
),
TaskState::Succeeded => "Completed".to_string(),
TaskState::Failed => {
let mut base = format!(
"Retry scheduled (attempt {} of {})",
updated_task.attempts,
updated_task.max_attempts
);
if let Some(message) = updated_task.error_message.as_ref() {
base.push_str(": ");
base.push_str(message);
}
base
}
IngestionTaskStatus::Completed => "Completed".to_string(),
IngestionTaskStatus::Error { message } => {
// Providing a user-friendly error message from the status
format!("Error: {}", message)
TaskState::Cancelled => "Cancelled".to_string(),
TaskState::DeadLetter => {
let mut base = "Failed permanently".to_string();
if let Some(message) = updated_task.error_message.as_ref() {
base.push_str(": ");
base.push_str(message);
}
base
}
IngestionTaskStatus::Cancelled => "Cancelled".to_string(),
};
yield Ok(Event::default().event("status").data(status_message));
// Check for terminal states to close the stream
match updated_task.status {
IngestionTaskStatus::Completed
| IngestionTaskStatus::Error { .. }
| IngestionTaskStatus::Cancelled => {
// Send a specific event that HTMX uses to close the connection
// Send a event to reload the recent content
// Send a event to remove the loading indicatior
let check_icon = state.templates.render("icons/check_icon.html", &context!{}).unwrap_or("Ok".to_string());
yield Ok(Event::default().event("stop_loading").data(check_icon));
yield Ok(Event::default().event("update_latest_content").data("Update latest content"));
yield Ok(Event::default().event("close_stream").data("Stream complete"));
break; // Exit loop on terminal states
}
_ => {
// Not a terminal state, continue polling
}
if updated_task.state.is_terminal() {
// Send a specific event that HTMX uses to close the connection
// Send a event to reload the recent content
// Send a event to remove the loading indicatior
let check_icon = state.templates.render("icons/check_icon.html", &context!{}).unwrap_or("Ok".to_string());
yield Ok(Event::default().event("stop_loading").data(check_icon));
yield Ok(Event::default().event("update_latest_content").data("Update latest content"));
yield Ok(Event::default().event("close_stream").data("Stream complete"));
break; // Exit loop on terminal states
}
},
Ok(None) => {

View File

@@ -385,9 +385,7 @@ pub async fn delete_knowledge_relationship(
RequireUser(user): RequireUser,
Path(id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
// GOTTA ADD AUTH VALIDATION
KnowledgeRelationship::delete_relationship_by_id(&id, &state.db).await?;
KnowledgeRelationship::delete_relationship_by_id(&id, &user.id, &state.db).await?;
let entities = User::get_knowledge_entities(&user.id, &state.db).await?;

View File

@@ -3,14 +3,10 @@ use common::storage::types::text_content::TextContent;
const TEXT_PREVIEW_LENGTH: usize = 50;
fn maybe_truncate(value: &str) -> Option<String> {
let mut char_count = 0;
for (idx, _) in value.char_indices() {
for (char_count, (idx, _)) in value.char_indices().enumerate() {
if char_count == TEXT_PREVIEW_LENGTH {
return Some(value[..idx].to_string());
}
char_count += 1;
}
None

View File

@@ -2,10 +2,16 @@
<section id="active_jobs_section" class="nb-panel p-4 space-y-4 mt-6 sm:mt-8">
<header class="flex flex-wrap items-center justify-between gap-3">
<h2 class="text-xl font-extrabold tracking-tight">Active Tasks</h2>
<button class="nb-btn btn-square btn-sm" hx-get="/active-jobs" hx-target="#active_jobs_section" hx-swap="outerHTML"
aria-label="Refresh active tasks">
{% include "icons/refresh_icon.html" %}
</button>
<div class="flex gap-2">
<button class="nb-btn btn-square btn-sm" hx-get="/active-jobs" hx-target="#active_jobs_section" hx-swap="outerHTML"
aria-label="Refresh active tasks">
{% include "icons/refresh_icon.html" %}
</button>
<button class="nb-btn btn-sm" hx-get="/jobs/archive" hx-target="#modal" hx-swap="innerHTML"
aria-label="View task archive">
View Archive
</button>
</div>
</header>
{% if active_jobs %}
<ul class="flex flex-col gap-3 list-none p-0 m-0">
@@ -23,12 +29,18 @@
</div>
<div class="space-y-1">
<div class="text-sm font-semibold">
{% if item.status.name == "InProgress" %}
In progress, attempt {{ item.status.attempts }}
{% elif item.status.name == "Error" %}
Error: {{ item.status.message }}
{% if item.state == "Processing" %}
Processing, attempt {{ item.attempts }} of {{ item.max_attempts }}
{% elif item.state == "Reserved" %}
Reserved, attempt {{ item.attempts }} of {{ item.max_attempts }}
{% elif item.state == "Failed" %}
Retry scheduled (attempt {{ item.attempts }} of {{ item.max_attempts }}){% if item.error_message %}: {{ item.error_message }}{% endif %}
{% elif item.state == "DeadLetter" %}
Failed permanently{% if item.error_message %}: {{ item.error_message }}{% endif %}
{% elif item.state == "Succeeded" %}
Completed
{% else %}
{{ item.status.name }}
{{ item.state }}
{% endif %}
</div>
<div class="text-xs font-semibold opacity-60">
@@ -60,4 +72,4 @@
</ul>
{% endif %}
</section>
{% endblock %}
{% endblock %}

View File

@@ -8,7 +8,7 @@
</div>
<div class="space-y-1">
<div class="text-sm font-semibold flex gap-2 items-center">
<span sse-swap="status" hx-swap="innerHTML">Created</span>
<span sse-swap="status" hx-swap="innerHTML">Pending</span>
<div hx-get="/content/recent" hx-target="#latest_content_section" hx-swap="outerHTML"
hx-trigger="sse:update_latest_content"></div>
</div>

View File

@@ -1,6 +1,6 @@
{% block latest_content_section %}
<div id="latest_content_section" class="list">
<h2 class="text-2xl mb-2 font-extrabold">Recent content</h2>
{% include "content/content_list.html" %}
{% include "dashboard/recent_content_list.html" %}
</div>
{% endblock %}
{% endblock %}

View File

@@ -0,0 +1,65 @@
<div id="latest_text_content_cards" class="space-y-6">
{% if text_contents|length > 0 %}
<div class="nb-masonry w-full">
{% for text_content in text_contents %}
<article class="nb-card cursor-pointer mx-auto mb-4 w-full max-w-[92vw] space-y-3 sm:max-w-none"
hx-get="/content/{{ text_content.id }}/read" hx-target="#modal" hx-swap="innerHTML">
{% if text_content.url_info %}
<figure class="-mx-4 -mt-4 border-b-2 border-neutral bg-base-200">
<img class="w-full h-auto" src="/file/{{ text_content.url_info.image_id }}" alt="website screenshot" />
</figure>
{% endif %}
{% if text_content.file_info and (text_content.file_info.mime_type == "image/png" or text_content.file_info.mime_type == "image/jpeg") %}
<figure class="-mx-4 -mt-4 border-b-2 border-neutral bg-base-200">
<img class="w-full h-auto" src="/file/{{ text_content.file_info.id }}" alt="{{ text_content.file_info.file_name }}" />
</figure>
{% endif %}
<div class="space-y-3 break-words">
<h2 class="text-lg font-extrabold tracking-tight truncate">
{% if text_content.url_info %}
{{ text_content.url_info.title }}
{% elif text_content.file_info %}
{{ text_content.file_info.file_name }}
{% else %}
{{ text_content.text }}
{% endif %}
</h2>
<div class="flex flex-wrap items-center justify-between gap-3">
<p class="text-xs opacity-60 shrink-0">
{{ text_content.created_at | datetimeformat(format="short", tz=user.timezone) }}
</p>
<span class="nb-badge">{{ text_content.category }}</span>
<div class="flex gap-2" hx-on:click="event.stopPropagation()">
{% if text_content.url_info %}
<a href="{{ text_content.url_info.url }}" target="_blank" rel="noopener noreferrer"
class="nb-btn btn-square btn-sm" aria-label="Open source link">
{% include "icons/link_icon.html" %}
</a>
{% endif %}
<button hx-get="/content/{{ text_content.id }}/read" hx-target="#modal" hx-swap="innerHTML"
class="nb-btn btn-square btn-sm" aria-label="Read content">
{% include "icons/read_icon.html" %}
</button>
<button hx-get="/content/{{ text_content.id }}" hx-target="#modal" hx-swap="innerHTML"
class="nb-btn btn-square btn-sm" aria-label="Edit content">
{% include "icons/edit_icon.html" %}
</button>
<button hx-delete="/text-content/{{ text_content.id }}" hx-target="#latest_content_section"
hx-swap="outerHTML" class="nb-btn btn-square btn-sm" aria-label="Delete content">
{% include "icons/delete_icon.html" %}
</button>
</div>
</div>
<p class="text-sm leading-relaxed">
{{ text_content.instructions }}
</p>
</div>
</article>
{% endfor %}
</div>
{% else %}
<div class="nb-card p-8 text-center text-sm opacity-70">
No content found.
</div>
{% endif %}
</div>

View File

@@ -0,0 +1,152 @@
{% extends "modal_base.html" %}
{% block modal_class %}w-11/12 max-w-[90ch] max-h-[95%] overflow-y-auto{% endblock %}
{% block form_attributes %}onsubmit="event.preventDefault();"{% endblock %}
{% block modal_content %}
<h3 class="text-xl font-extrabold tracking-tight flex items-center gap-2">
Ingestion Task Archive
<span class="badge badge-neutral text-xs font-normal">{{ tasks|length }} total</span>
</h3>
<p class="text-sm opacity-70">A history of all ingestion tasks for {{ user.email }}.</p>
{% if tasks %}
<div class="hidden lg:block overflow-x-auto nb-card mt-4">
<table class="nb-table">
<thead>
<tr>
<th class="text-left">Content</th>
<th class="text-left">State</th>
<th class="text-left">Attempts</th>
<th class="text-left">Scheduled</th>
<th class="text-left">Updated</th>
<th class="text-left">Worker</th>
<th class="text-left">Error</th>
</tr>
</thead>
<tbody>
{% for task in tasks %}
<tr>
<td>
<div class="flex flex-col gap-1">
<div class="text-sm font-semibold">{{ task.content_kind }}</div>
<div class="text-xs opacity-70 break-words">{{ task.content_summary }}</div>
<div class="text-[11px] opacity-60 lowercase tracking-wider">{{ task.id }}</div>
</div>
</td>
<td>
<span class="badge badge-primary badge-outline tracking-wide">{{ task.state_label }}</span>
</td>
<td>
<div class="text-sm font-semibold">{{ task.attempts }} / {{ task.max_attempts }}</div>
<div class="text-xs opacity-60">Priority {{ task.priority }}</div>
</td>
<td>
<div class="text-sm">
{{ task.scheduled_at|datetimeformat(format="short", tz=user.timezone) }}
</div>
{% if task.locked_at %}
<div class="text-xs opacity-60">Locked {{ task.locked_at|datetimeformat(format="short", tz=user.timezone) }}
</div>
{% endif %}
</td>
<td>
<div class="text-sm">
{{ task.updated_at|datetimeformat(format="short", tz=user.timezone) }}
</div>
<div class="text-xs opacity-60">Created {{ task.created_at|datetimeformat(format="short", tz=user.timezone) }}
</div>
</td>
<td>
{% if task.worker_id %}
<span class="text-sm font-semibold">{{ task.worker_id }}</span>
<div class="text-xs opacity-60">Lease {{ task.lease_duration_secs }}s</div>
{% else %}
<span class="text-xs opacity-60">Not assigned</span>
{% endif %}
</td>
<td>
{% if task.error_message %}
<div class="text-sm text-error font-semibold">{{ task.error_message }}</div>
{% if task.last_error_at %}
<div class="text-xs opacity-60">{{ task.last_error_at|datetimeformat(format="short", tz=user.timezone) }}
</div>
{% endif %}
{% else %}
<span class="text-xs opacity-60"></span>
{% endif %}
</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<div class="lg:hidden flex flex-col gap-3 mt-4">
{% for task in tasks %}
<details class="nb-panel p-3 space-y-3">
<summary class="flex items-center justify-between gap-2 text-sm font-semibold cursor-pointer">
<span>{{ task.content_kind }}</span>
<span class="badge badge-primary badge-outline tracking-wide">{{ task.state_label }}</span>
</summary>
<div class="text-xs opacity-70 break-words">{{ task.content_summary }}</div>
<div class="text-[11px] opacity-60 lowercase tracking-wider">{{ task.id }}</div>
<div class="grid grid-cols-1 gap-2 text-xs">
<div class="flex justify-between">
<span class="opacity-60 uppercase tracking-wide">Attempts</span>
<span class="text-sm font-semibold">{{ task.attempts }} / {{ task.max_attempts }}</span>
</div>
<div class="flex justify-between">
<span class="opacity-60 uppercase tracking-wide">Priority</span>
<span class="text-sm font-semibold">{{ task.priority }}</span>
</div>
<div class="flex justify-between">
<span class="opacity-60 uppercase tracking-wide">Scheduled</span>
<span>{{ task.scheduled_at|datetimeformat(format="short", tz=user.timezone) }}</span>
</div>
<div class="flex justify-between">
<span class="opacity-60 uppercase tracking-wide">Updated</span>
<span>{{ task.updated_at|datetimeformat(format="short", tz=user.timezone) }}</span>
</div>
<div class="flex justify-between">
<span class="opacity-60 uppercase tracking-wide">Created</span>
<span>{{ task.created_at|datetimeformat(format="short", tz=user.timezone) }}</span>
</div>
<div class="flex justify-between">
<span class="opacity-60 uppercase tracking-wide">Worker</span>
{% if task.worker_id %}
<span class="text-sm font-semibold">{{ task.worker_id }}</span>
{% else %}
<span class="opacity-60">Unassigned</span>
{% endif %}
</div>
<div class="flex justify-between">
<span class="opacity-60 uppercase tracking-wide">Lease</span>
<span>{{ task.lease_duration_secs }}s</span>
</div>
{% if task.locked_at %}
<div class="flex justify-between">
<span class="opacity-60 uppercase tracking-wide">Locked</span>
<span>{{ task.locked_at|datetimeformat(format="short", tz=user.timezone) }}</span>
</div>
{% endif %}
</div>
{% if task.error_message or task.last_error_at %}
<div class="border-t border-base-200 pt-2 text-xs space-y-1">
{% if task.error_message %}
<div class="text-sm text-error font-semibold">{{ task.error_message }}</div>
{% endif %}
{% if task.last_error_at %}
<div class="opacity-60">Last error {{ task.last_error_at|datetimeformat(format="short", tz=user.timezone) }}</div>
{% endif %}
</div>
{% endif %}
</details>
{% endfor %}
</div>
{% else %}
<p class="text-sm opacity-70 mt-4">No tasks yet. Start an ingestion to populate the archive.</p>
{% endif %}
{% endblock %}
{% block primary_actions %}{% endblock %}

View File

@@ -23,6 +23,8 @@ url = { workspace = true }
uuid = { workspace = true }
headless_chrome = { workspace = true }
base64 = { workspace = true }
pdf-extract = "0.9"
lopdf = "0.32"
common = { path = "../common" }
composite-retrieval = { path = "../composite-retrieval" }

View File

@@ -7,13 +7,11 @@ use async_openai::types::{
};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, system_settings::SystemSettings},
},
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
use composite_retrieval::{
answer_retrieval::format_entities_json, retrieve_entities, RetrievedEntity,
};
use composite_retrieval::retrieve_entities;
use serde_json::json;
use tracing::{debug, info};
use crate::{
@@ -61,7 +59,7 @@ impl IngestionEnricher {
context: Option<&str>,
text: &str,
user_id: &str,
) -> Result<Vec<KnowledgeEntity>, AppError> {
) -> Result<Vec<RetrievedEntity>, AppError> {
let input_text = format!(
"content: {}, category: {}, user_context: {:?}",
text, category, context
@@ -75,22 +73,11 @@ impl IngestionEnricher {
category: &str,
context: Option<&str>,
text: &str,
similar_entities: &[KnowledgeEntity],
similar_entities: &[RetrievedEntity],
) -> Result<CreateChatCompletionRequest, AppError> {
let settings = SystemSettings::get_current(&self.db_client).await?;
let entities_json = json!(similar_entities
.iter()
.map(|entity| {
json!({
"KnowledgeEntity": {
"id": entity.id,
"name": entity.name,
"description": entity.description
}
})
})
.collect::<Vec<_>>());
let entities_json = format_entities_json(similar_entities);
let user_message = format!(
"Category:\n{}\ncontext:\n{:?}\nContent:\n{}\nExisting KnowledgeEntities in database:\n{}",
@@ -110,8 +97,6 @@ impl IngestionEnricher {
let request = CreateChatCompletionRequestArgs::default()
.model(&settings.processing_model)
.temperature(0.2)
.max_tokens(6048u32)
.messages([
ChatCompletionRequestSystemMessage::from(INGRESS_ANALYSIS_SYSTEM_MESSAGE).into(),
ChatCompletionRequestUserMessage::from(user_message).into(),

View File

@@ -3,101 +3,47 @@ pub mod pipeline;
pub mod types;
pub mod utils;
use chrono::Utc;
use common::storage::{
db::SurrealDbClient,
types::ingestion_task::{IngestionTask, IngestionTaskStatus},
types::ingestion_task::{IngestionTask, DEFAULT_LEASE_SECS},
};
use futures::StreamExt;
use pipeline::IngestionPipeline;
use std::sync::Arc;
use surrealdb::Action;
use tracing::{error, info};
use tokio::time::{sleep, Duration};
use tracing::{error, info, warn};
use uuid::Uuid;
pub async fn run_worker_loop(
db: Arc<SurrealDbClient>,
ingestion_pipeline: Arc<IngestionPipeline>,
) -> Result<(), Box<dyn std::error::Error>> {
let worker_id = format!("ingestion-worker-{}", Uuid::new_v4());
let lease_duration = Duration::from_secs(DEFAULT_LEASE_SECS as u64);
let idle_backoff = Duration::from_millis(500);
loop {
// First, check for any unfinished tasks
let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db).await?;
if !unfinished_tasks.is_empty() {
info!("Found {} unfinished jobs", unfinished_tasks.len());
for task in unfinished_tasks {
ingestion_pipeline.process_task(task).await?;
}
}
// If no unfinished jobs, start listening for new ones
info!("Listening for new jobs...");
let mut job_stream = IngestionTask::listen_for_tasks(&db).await?;
while let Some(notification) = job_stream.next().await {
match notification {
Ok(notification) => {
info!("Received notification: {:?}", notification);
match notification.action {
Action::Create => {
if let Err(e) = ingestion_pipeline.process_task(notification.data).await
{
error!("Error processing task: {}", e);
}
}
Action::Update => {
match notification.data.status {
IngestionTaskStatus::Completed
| IngestionTaskStatus::Error { .. }
| IngestionTaskStatus::Cancelled => {
info!(
"Skipping already completed/error/cancelled task: {}",
notification.data.id
);
continue;
}
IngestionTaskStatus::InProgress { attempts, .. } => {
// Only process if this is a retry after an error, not our own update
if let Ok(Some(current_task)) =
db.get_item::<IngestionTask>(&notification.data.id).await
{
match current_task.status {
IngestionTaskStatus::Error { .. }
if attempts
< common::storage::types::ingestion_task::MAX_ATTEMPTS =>
{
// This is a retry after an error
if let Err(e) =
ingestion_pipeline.process_task(current_task).await
{
error!("Error processing task retry: {}", e);
}
}
_ => {
info!(
"Skipping in-progress update for task: {}",
notification.data.id
);
continue;
}
}
}
}
IngestionTaskStatus::Created => {
// Shouldn't happen with Update action, but process if it does
if let Err(e) =
ingestion_pipeline.process_task(notification.data).await
{
error!("Error processing task: {}", e);
}
}
}
}
_ => {} // Ignore other actions
}
match IngestionTask::claim_next_ready(&db, &worker_id, Utc::now(), lease_duration).await {
Ok(Some(task)) => {
let task_id = task.id.clone();
info!(
%worker_id,
%task_id,
attempt = task.attempts,
"claimed ingestion task"
);
if let Err(err) = ingestion_pipeline.process_task(task).await {
error!(%worker_id, %task_id, error = %err, "ingestion task failed");
}
Err(e) => error!("Error in job notification: {}", e),
}
Ok(None) => {
sleep(idle_backoff).await;
}
Err(err) => {
error!(%worker_id, error = %err, "failed to claim ingestion task");
warn!("Backing off for 1s after claim error");
sleep(Duration::from_secs(1)).await;
}
}
// If we reach here, the stream has ended (connection lost?)
error!("Database stream ended unexpectedly, reconnecting...");
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
}

View File

@@ -1,16 +1,15 @@
use std::{sync::Arc, time::Instant};
use chrono::Utc;
use futures::future::try_join_all;
use text_splitter::TextSplitter;
use tracing::info;
use tokio::time::{sleep, Duration};
use tracing::{info, info_span, warn};
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{
ingestion_task::{IngestionTask, IngestionTaskStatus, MAX_ATTEMPTS},
ingestion_task::{IngestionTask, TaskErrorInfo},
knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
@@ -44,47 +43,81 @@ impl IngestionPipeline {
})
}
pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> {
let current_attempts = match task.status {
IngestionTaskStatus::InProgress { attempts, .. } => attempts + 1,
_ => 1,
};
let task_id = task.id.clone();
let attempt = task.attempts;
let worker_label = task
.worker_id
.clone()
.unwrap_or_else(|| "unknown-worker".to_string());
let span = info_span!(
"ingestion_task",
%task_id,
attempt,
worker_id = %worker_label,
state = %task.state.as_str()
);
let _enter = span.enter();
let processing_task = task.mark_processing(&self.db).await?;
// Update status to InProgress with attempt count
IngestionTask::update_status(
&task.id,
IngestionTaskStatus::InProgress {
attempts: current_attempts,
last_attempt: Utc::now(),
},
let text_content = to_text_content(
processing_task.content.clone(),
&self.db,
&self.config,
&self.openai_client,
)
.await?;
let text_content =
to_text_content(task.content, &self.db, &self.config, &self.openai_client).await?;
match self.process(&text_content).await {
Ok(_) => {
IngestionTask::update_status(&task.id, IngestionTaskStatus::Completed, &self.db)
.await?;
processing_task.mark_succeeded(&self.db).await?;
info!(%task_id, attempt, "ingestion task succeeded");
Ok(())
}
Err(e) => {
if current_attempts >= MAX_ATTEMPTS {
IngestionTask::update_status(
&task.id,
IngestionTaskStatus::Error {
message: format!("Max attempts reached: {}", e),
},
&self.db,
)
.await?;
Err(err) => {
let reason = err.to_string();
let error_info = TaskErrorInfo {
code: None,
message: reason.clone(),
};
if processing_task.can_retry() {
let delay = Self::retry_delay(processing_task.attempts);
processing_task
.mark_failed(error_info, delay, &self.db)
.await?;
warn!(
%task_id,
attempt = processing_task.attempts,
retry_in_secs = delay.as_secs(),
"ingestion task failed; scheduled retry"
);
} else {
processing_task
.mark_dead_letter(error_info, &self.db)
.await?;
warn!(
%task_id,
attempt = processing_task.attempts,
"ingestion task failed; moved to dead letter queue"
);
}
Err(AppError::Processing(e.to_string()))
Err(AppError::Processing(reason))
}
}
}
fn retry_delay(attempt: u32) -> Duration {
const BASE_SECONDS: u64 = 30;
const MAX_SECONDS: u64 = 15 * 60;
let capped_attempt = attempt.saturating_sub(1).min(5) as u32;
let multiplier = 2_u64.pow(capped_attempt);
let delay = BASE_SECONDS * multiplier;
Duration::from_secs(delay.min(MAX_SECONDS))
}
pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
let now = Instant::now();
@@ -135,17 +168,73 @@ impl IngestionPipeline {
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
) -> Result<(), AppError> {
let entities = Arc::new(entities);
let relationships = Arc::new(relationships);
let entity_count = entities.len();
let relationship_count = relationships.len();
let entity_futures = entities
.iter()
.map(|entitity| self.db.store_item(entitity.to_owned()));
const STORE_GRAPH_MUTATION: &str = r#"
BEGIN TRANSACTION;
LET $entities = $entities;
LET $relationships = $relationships;
try_join_all(entity_futures).await?;
FOR $entity IN $entities {
CREATE type::thing('knowledge_entity', $entity.id) CONTENT $entity;
};
for relationship in &relationships {
relationship.store_relationship(&self.db).await?;
FOR $relationship IN $relationships {
LET $in_node = type::thing('knowledge_entity', $relationship.in);
LET $out_node = type::thing('knowledge_entity', $relationship.out);
RELATE $in_node->relates_to->$out_node CONTENT {
id: type::thing('relates_to', $relationship.id),
metadata: $relationship.metadata
};
};
COMMIT TRANSACTION;
"#;
const MAX_ATTEMPTS: usize = 3;
const INITIAL_BACKOFF_MS: u64 = 50;
const MAX_BACKOFF_MS: u64 = 800;
let mut backoff_ms = INITIAL_BACKOFF_MS;
let mut success = false;
for attempt in 0..MAX_ATTEMPTS {
let result = self
.db
.client
.query(STORE_GRAPH_MUTATION)
.bind(("entities", entities.clone()))
.bind(("relationships", relationships.clone()))
.await;
match result {
Ok(_) => {
success = true;
break;
}
Err(err) => {
if Self::is_retryable_conflict(&err) && attempt + 1 < MAX_ATTEMPTS {
warn!(
attempt = attempt + 1,
"Transient SurrealDB conflict while storing graph data; retrying"
);
sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
continue;
}
return Err(AppError::from(err));
}
}
}
if !success {
return Err(AppError::InternalError(
"Failed to store graph entities after retries".to_string(),
));
}
info!(
@@ -173,4 +262,10 @@ impl IngestionPipeline {
Ok(())
}
fn is_retryable_conflict(error: &surrealdb::Error) -> bool {
error
.to_string()
.contains("Failed to commit transaction due to a read or write conflict")
}
}

View File

@@ -9,10 +9,13 @@ use chrono::Utc;
use common::storage::db::SurrealDbClient;
use common::{
error::AppError,
storage::types::{
file_info::FileInfo,
ingestion_payload::IngestionPayload,
text_content::{TextContent, UrlInfo},
storage::{
store,
types::{
file_info::FileInfo,
ingestion_payload::IngestionPayload,
text_content::{TextContent, UrlInfo},
},
},
utils::config::AppConfig,
};
@@ -24,6 +27,7 @@ use tracing::{error, info};
use crate::utils::{
audio_transcription::transcribe_audio_file, image_parsing::extract_text_from_image,
pdf_ingestion::extract_pdf_content,
};
pub async fn to_text_content(
@@ -72,7 +76,7 @@ pub async fn to_text_content(
category,
user_id,
} => {
let text = extract_text_from_file(&file_info, db, openai_client).await?;
let text = extract_text_from_file(&file_info, db, openai_client, config).await?;
Ok(TextContent::new(
text,
Some(context),
@@ -199,43 +203,55 @@ async fn fetch_article_from_url(
Ok((article, file_info))
}
/// Extracts text from a file based on its MIME type.
/// Extracts text from a stored file by MIME type.
async fn extract_text_from_file(
file_info: &FileInfo,
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
config: &AppConfig,
) -> Result<String, AppError> {
let base_path = store::resolve_base_dir(config);
let absolute_path = base_path.join(&file_info.path);
match file_info.mime_type.as_str() {
"text/plain" => {
// Read the file and return its content
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"text/markdown" => {
// Read the file and return its content
let content = tokio::fs::read_to_string(&file_info.path).await?;
"text/plain" | "text/markdown" | "application/octet-stream" | "text/x-rust" => {
let content = tokio::fs::read_to_string(&absolute_path).await?;
Ok(content)
}
"application/pdf" => {
// TODO: Implement PDF text extraction using a crate like `pdf-extract` or `lopdf`
Err(AppError::NotFound(file_info.mime_type.clone()))
extract_pdf_content(
&absolute_path,
db_client,
openai_client,
&config.pdf_ingest_mode,
)
.await
}
"image/png" | "image/jpeg" => {
let content =
extract_text_from_image(&file_info.path, db_client, openai_client).await?;
Ok(content)
}
"application/octet-stream" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
Ok(content)
}
"text/x-rust" => {
let content = tokio::fs::read_to_string(&file_info.path).await?;
let path_str = absolute_path
.to_str()
.ok_or_else(|| {
AppError::Processing(format!(
"Encountered a non-UTF8 path while reading image {}",
file_info.id
))
})?
.to_string();
let content = extract_text_from_image(&path_str, db_client, openai_client).await?;
Ok(content)
}
"audio/mpeg" | "audio/mp3" | "audio/wav" | "audio/x-wav" | "audio/webm" | "audio/mp4"
| "audio/ogg" | "audio/flac" => {
transcribe_audio_file(&file_info.path, db_client, openai_client).await
let path_str = absolute_path
.to_str()
.ok_or_else(|| {
AppError::Processing(format!(
"Encountered a non-UTF8 path while reading audio {}",
file_info.id
))
})?
.to_string();
transcribe_audio_file(&path_str, db_client, openai_client).await
}
// Handle other MIME types as needed
_ => Err(AppError::NotFound(file_info.mime_type.clone())),

View File

@@ -23,7 +23,6 @@ pub async fn extract_text_from_image(
let request = CreateChatCompletionRequestArgs::default()
.model(system_settings.image_processing_model)
.max_tokens(6400_u32)
.messages([ChatCompletionRequestUserMessageArgs::default()
.content(vec![
ChatCompletionRequestMessageContentPartTextArgs::default()

View File

@@ -1,6 +1,7 @@
pub mod audio_transcription;
pub mod image_parsing;
pub mod llm_instructions;
pub mod pdf_ingestion;
use common::error::AppError;
use std::collections::HashMap;

View File

@@ -0,0 +1,793 @@
use std::{
path::{Path, PathBuf},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use async_openai::types::{
ChatCompletionRequestMessageContentPartImageArgs,
ChatCompletionRequestMessageContentPartTextArgs, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequestArgs, ImageDetail, ImageUrlArgs,
};
use base64::{engine::general_purpose::STANDARD, Engine as _};
use headless_chrome::{
protocol::cdp::{Emulation, Page, DOM},
Browser,
};
use lopdf::Document;
use serde_json::Value;
use tokio::time::sleep;
use tracing::{debug, warn};
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
utils::config::PdfIngestMode,
};
const FAST_PATH_MIN_LEN: usize = 150;
const FAST_PATH_MIN_ASCII_RATIO: f64 = 0.7;
const MAX_VISION_PAGES: usize = 50;
const PAGES_PER_VISION_CHUNK: usize = 4;
const MAX_VISION_ATTEMPTS: usize = 2;
const PDF_MARKDOWN_PROMPT: &str = "Convert these PDF pages to clean Markdown. Preserve headings, lists, tables, blockquotes, code fences, and inline formatting. Keep the original reading order, avoid commentary, and do NOT wrap the entire response in a Markdown code block.";
const PDF_MARKDOWN_PROMPT_RETRY: &str = "You must transcribe the provided PDF page images into accurate Markdown. The images are already supplied, so do not respond that you cannot view them. Extract all visible text, tables, and structure, and do NOT wrap the overall response in a Markdown code block.";
const NAVIGATION_RETRY_INTERVAL_MS: u64 = 120;
const NAVIGATION_RETRY_ATTEMPTS: usize = 10;
const MIN_PAGE_IMAGE_BYTES: usize = 1_024;
const DEFAULT_VIEWPORT_WIDTH: u32 = 1_248; // generous width to reduce horizontal clipping
const DEFAULT_VIEWPORT_HEIGHT: u32 = 1_800; // tall enough to capture full page at fit-to-width scale
const DEFAULT_DEVICE_SCALE_FACTOR: f64 = 1.0;
const CANVAS_VIEWPORT_ATTEMPTS: usize = 12;
const CANVAS_VIEWPORT_WAIT_MS: u64 = 200;
const DEBUG_IMAGE_ENV_VAR: &str = "MINNE_PDF_DEBUG_DIR";
/// Attempts to extract PDF content, using a fast text layer first and falling back to
/// rendering the document for a vision-enabled LLM when needed.
pub async fn extract_pdf_content(
file_path: &Path,
db: &SurrealDbClient,
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
mode: &PdfIngestMode,
) -> Result<String, AppError> {
let pdf_bytes = tokio::fs::read(file_path).await?;
if let Some(candidate) = try_fast_path(pdf_bytes.clone()).await? {
return Ok(candidate);
}
if matches!(mode, PdfIngestMode::Classic) {
return Err(AppError::Processing(
"PDF text extraction failed and LLM-first mode is disabled".into(),
));
}
let page_numbers = load_page_numbers(pdf_bytes.clone()).await?;
if page_numbers.is_empty() {
return Err(AppError::Processing("PDF appears to have no pages".into()));
}
if page_numbers.len() > MAX_VISION_PAGES {
return Err(AppError::Processing(format!(
"PDF has {} pages which exceeds the configured vision processing limit of {}",
page_numbers.len(),
MAX_VISION_PAGES
)));
}
let rendered_pages = render_pdf_pages(file_path, &page_numbers).await?;
let combined_markdown = vision_markdown(rendered_pages, db, client).await?;
Ok(post_process(&combined_markdown))
}
/// Runs `pdf-extract` on the PDF bytes and validates the result with simple heuristics.
/// Returns `Ok(None)` when the text layer is missing or too noisy.
async fn try_fast_path(pdf_bytes: Vec<u8>) -> Result<Option<String>, AppError> {
let extraction = tokio::task::spawn_blocking(move || {
pdf_extract::extract_text_from_mem(&pdf_bytes).map(|s| s.trim().to_string())
})
.await?
.map_err(|err| AppError::Processing(format!("Failed to extract text from PDF: {err}")))?;
if extraction.is_empty() {
return Ok(None);
}
if !looks_good_enough(&extraction) {
return Ok(None);
}
Ok(Some(normalize_fast_text(&extraction)))
}
/// Parses the PDF structure to discover the available page numbers while keeping work off
/// the async executor.
async fn load_page_numbers(pdf_bytes: Vec<u8>) -> Result<Vec<u32>, AppError> {
let pages = tokio::task::spawn_blocking(move || -> Result<Vec<u32>, AppError> {
let document = Document::load_mem(&pdf_bytes)
.map_err(|err| AppError::Processing(format!("Failed to parse PDF: {err}")))?;
let mut page_numbers: Vec<u32> = document.get_pages().keys().copied().collect();
page_numbers.sort_unstable();
Ok(page_numbers)
})
.await??;
Ok(pages)
}
/// Uses the existing headless Chrome dependency to rasterize the requested PDF pages into PNGs.
async fn render_pdf_pages(file_path: &Path, pages: &[u32]) -> Result<Vec<Vec<u8>>, AppError> {
let file_url = url::Url::from_file_path(file_path)
.map_err(|_| AppError::Processing("Unable to construct PDF file URL".into()))?;
let browser = create_browser()?;
let tab = browser
.new_tab()
.map_err(|err| AppError::Processing(format!("Failed to create Chrome tab: {err}")))?;
tab.set_default_timeout(Duration::from_secs(10));
configure_tab(&tab)?;
set_pdf_viewport(&tab)?;
let mut captures = Vec::with_capacity(pages.len());
for (idx, page) in pages.iter().enumerate() {
let target = format!(
"{}#page={}&toolbar=0&statusbar=0&zoom=page-fit",
file_url, page
);
tab.navigate_to(&target)
.map_err(|err| AppError::Processing(format!("Failed to navigate to PDF page: {err}")))?
.wait_until_navigated()
.map_err(|err| AppError::Processing(format!("Navigation to PDF page failed: {err}")))?;
let mut loaded = false;
for attempt in 0..NAVIGATION_RETRY_ATTEMPTS {
if tab
.wait_for_element("embed, canvas, body")
.map(|_| ())
.is_ok()
{
loaded = true;
break;
}
if attempt + 1 < NAVIGATION_RETRY_ATTEMPTS {
sleep(Duration::from_millis(NAVIGATION_RETRY_INTERVAL_MS)).await;
}
}
if !loaded {
return Err(AppError::Processing(
"Timed out waiting for Chrome to render PDF page".into(),
));
}
wait_for_pdf_ready(&tab, *page)?;
tokio::time::sleep(Duration::from_millis(350)).await;
prepare_pdf_viewer(&tab, *page);
let mut viewport: Option<Page::Viewport> = None;
for attempt in 0..CANVAS_VIEWPORT_ATTEMPTS {
match canvas_viewport_for_page(&tab, *page) {
Ok(Some(vp)) => {
viewport = Some(vp);
break;
}
Ok(None) => {
if attempt + 1 < CANVAS_VIEWPORT_ATTEMPTS {
tokio::time::sleep(Duration::from_millis(CANVAS_VIEWPORT_WAIT_MS)).await;
}
}
Err(err) => {
warn!(page = *page, error = %err, "Failed to derive canvas viewport");
break;
}
}
}
let png = if let Some(clip) = viewport {
match tab.call_method(Page::CaptureScreenshot {
format: Some(Page::CaptureScreenshotFormatOption::Png),
quality: None,
clip: Some(clip),
from_surface: Some(true),
capture_beyond_viewport: Some(true),
optimize_for_speed: Some(false),
}) {
Ok(data) => match STANDARD.decode(data.data) {
Ok(bytes) => bytes,
Err(err) => {
warn!(error = %err, page = *page, "Failed to decode clipped screenshot; falling back to full page capture");
capture_full_page_png(&tab)?
}
},
Err(err) => {
warn!(error = %err, page = *page, "Clipped screenshot failed; falling back to full page capture");
capture_full_page_png(&tab)?
}
}
} else {
warn!(
page = *page,
"Unable to determine canvas viewport; capturing full page"
);
capture_full_page_png(&tab)?
};
debug!(
page = *page,
bytes = png.len(),
page_index = idx,
"Captured PDF page screenshot"
);
if is_suspicious_image(png.len()) {
warn!(
page = *page,
bytes = png.len(),
"Screenshot size below threshold; check rendering output"
);
}
if let Err(err) = maybe_dump_debug_image(*page, &png).await {
warn!(
page = *page,
error = %err,
"Failed to write debug screenshot to disk"
);
}
captures.push(png);
}
Ok(captures)
}
/// Launches a headless Chrome instance that respects the existing feature flags.
fn create_browser() -> Result<Browser, AppError> {
#[cfg(feature = "docker")]
{
let options = headless_chrome::LaunchOptionsBuilder::default()
.sandbox(false)
.build()
.map_err(|err| AppError::Processing(format!("Failed to launch Chrome: {err}")))?;
Browser::new(options)
.map_err(|err| AppError::Processing(format!("Failed to start Chrome: {err}")))
}
#[cfg(not(feature = "docker"))]
{
Browser::default()
.map_err(|err| AppError::Processing(format!("Failed to start Chrome: {err}")))
}
}
/// Sends one or more rendered pages to the configured multimodal model and stitches the resulting Markdown chunks together.
async fn vision_markdown(
rendered_pages: Vec<Vec<u8>>,
db: &SurrealDbClient,
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
) -> Result<String, AppError> {
let settings = SystemSettings::get_current(db).await?;
let prompt = PDF_MARKDOWN_PROMPT;
debug!(
pages = rendered_pages.len(),
"Preparing vision batches for PDF conversion"
);
let mut markdown_sections = Vec::with_capacity(rendered_pages.len());
for (batch_idx, chunk) in rendered_pages.chunks(PAGES_PER_VISION_CHUNK).enumerate() {
let total_image_bytes: usize = chunk.iter().map(|bytes| bytes.len()).sum();
debug!(
batch = batch_idx,
pages = chunk.len(),
bytes = total_image_bytes,
"Encoding PDF images for vision batch"
);
let encoded_images: Vec<String> = chunk
.iter()
.enumerate()
.map(|(idx, png_bytes)| {
let encoded = STANDARD.encode(png_bytes);
if encoded.len() < 80 {
warn!(
batch = batch_idx,
page_index = idx,
encoded_bytes = encoded.len(),
"Encoded PDF image payload unusually small"
);
}
encoded
})
.collect();
let mut batch_markdown: Option<String> = None;
for attempt in 0..MAX_VISION_ATTEMPTS {
let prompt_text = prompt_for_attempt(attempt, prompt);
let mut content_parts = Vec::with_capacity(encoded_images.len() + 1);
content_parts.push(
ChatCompletionRequestMessageContentPartTextArgs::default()
.text(prompt_text)
.build()?
.into(),
);
for encoded in &encoded_images {
let image_url = format!("data:image/png;base64,{}", encoded);
content_parts.push(
ChatCompletionRequestMessageContentPartImageArgs::default()
.image_url(
ImageUrlArgs::default()
.url(image_url)
.detail(ImageDetail::High)
.build()?,
)
.build()?
.into(),
);
}
let request = CreateChatCompletionRequestArgs::default()
.model(settings.image_processing_model.clone())
.messages([ChatCompletionRequestUserMessageArgs::default()
.content(content_parts)
.build()?
.into()])
.build()?;
let response = client.chat().create(request).await?;
let Some(choice) = response.choices.first() else {
warn!(
batch = batch_idx,
attempt, "Vision response contained zero choices"
);
continue;
};
let Some(content) = choice.message.content.as_ref() else {
warn!(
batch = batch_idx,
attempt, "Vision response missing content field"
);
continue;
};
debug!(
batch = batch_idx,
attempt,
response_chars = content.len(),
"Received Markdown response for PDF batch"
);
let preview: String = if content.len() > 500 {
let mut snippet = content.chars().take(500).collect::<String>();
snippet.push('…');
snippet
} else {
content.clone()
};
debug!(batch = batch_idx, attempt, preview = %preview, "Vision response content preview");
if is_low_quality_response(content) {
warn!(
batch = batch_idx,
attempt, "Vision model returned low quality response"
);
if attempt + 1 == MAX_VISION_ATTEMPTS {
return Err(AppError::Processing(
"Vision model failed to transcribe PDF page contents".into(),
));
}
continue;
}
batch_markdown = Some(content.trim().to_string());
break;
}
if let Some(markdown) = batch_markdown {
markdown_sections.push(markdown);
} else {
return Err(AppError::Processing(
"Vision model did not return usable Markdown".into(),
));
}
}
Ok(markdown_sections.join("\n\n"))
}
/// Heuristic that determines whether the fast-path text looks like well-formed prose.
fn looks_good_enough(text: &str) -> bool {
if text.len() < FAST_PATH_MIN_LEN {
return false;
}
let total_chars = text.chars().count() as f64;
if total_chars == 0.0 {
return false;
}
let ascii_chars = text.chars().filter(|c| c.is_ascii()).count() as f64;
let ascii_ratio = ascii_chars / total_chars;
if ascii_ratio < FAST_PATH_MIN_ASCII_RATIO {
return false;
}
let letters = text.chars().filter(|c| c.is_alphabetic()).count() as f64;
let letter_ratio = letters / total_chars;
letter_ratio > 0.3
}
/// Normalizes fast-path output so downstream consumers see consistent Markdown.
fn normalize_fast_text(text: &str) -> String {
reflow_markdown(text)
}
/// Cleans, trims, and reflows Markdown created by the LLM path.
fn post_process(markdown: &str) -> String {
let cleaned = markdown.replace('\r', "");
let trimmed = cleaned.trim();
reflow_markdown(trimmed)
}
/// Joins hard-wrapped paragraph text while preserving structural Markdown lines.
fn reflow_markdown(input: &str) -> String {
let mut paragraphs = Vec::new();
let mut buffer: Vec<String> = Vec::new();
for line in input.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
buffer.clear();
}
continue;
}
if is_structural_line(trimmed) {
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
buffer.clear();
}
paragraphs.push(trimmed.to_string());
continue;
}
buffer.push(trimmed.to_string());
}
if !buffer.is_empty() {
paragraphs.push(buffer.join(" "));
}
paragraphs.join("\n\n")
}
/// Detects whether a line is structural Markdown that should remain on its own.
fn is_structural_line(line: &str) -> bool {
let lowered = line.to_ascii_lowercase();
line.starts_with('#')
|| line.starts_with('-')
|| line.starts_with('*')
|| line.starts_with('>')
|| line.starts_with("```")
|| line.starts_with('~')
|| line.starts_with("| ")
|| line.starts_with("+-")
|| lowered
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
&& lowered.contains('.')
}
fn debug_dump_directory() -> Option<PathBuf> {
std::env::var(DEBUG_IMAGE_ENV_VAR)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.map(PathBuf::from)
}
fn configure_tab(tab: &headless_chrome::Tab) -> Result<(), AppError> {
tab.call_method(Emulation::SetDefaultBackgroundColorOverride {
color: Some(DOM::RGBA {
r: 255,
g: 255,
b: 255,
a: Some(1.0),
}),
})
.map_err(|err| {
AppError::Processing(format!("Failed to configure Chrome page background: {err}"))
})?;
Ok(())
}
fn set_pdf_viewport(tab: &headless_chrome::Tab) -> Result<(), AppError> {
tab.call_method(Emulation::SetDeviceMetricsOverride {
width: DEFAULT_VIEWPORT_WIDTH,
height: DEFAULT_VIEWPORT_HEIGHT,
device_scale_factor: DEFAULT_DEVICE_SCALE_FACTOR,
mobile: false,
scale: None,
screen_width: Some(DEFAULT_VIEWPORT_WIDTH),
screen_height: Some(DEFAULT_VIEWPORT_HEIGHT),
position_x: None,
position_y: None,
dont_set_visible_size: Some(false),
screen_orientation: None,
viewport: None,
display_feature: None,
device_posture: None,
})
.map_err(|err| AppError::Processing(format!("Failed to configure Chrome viewport: {err}")))?;
tab.call_method(Emulation::SetVisibleSize {
width: DEFAULT_VIEWPORT_WIDTH,
height: DEFAULT_VIEWPORT_HEIGHT,
})
.map_err(|err| AppError::Processing(format!("Failed to apply Chrome visible size: {err}")))?;
Ok(())
}
fn wait_for_pdf_ready(
tab: &headless_chrome::Tab,
page_number: u32,
) -> Result<headless_chrome::Element<'_>, AppError> {
let embed_selector = "embed[type='application/pdf']";
let element = tab
.wait_for_element_with_custom_timeout(embed_selector, Duration::from_secs(8))
.or_else(|_| tab.wait_for_element_with_custom_timeout("embed", Duration::from_secs(8)))
.map_err(|err| AppError::Processing(format!("Timed out waiting for PDF content: {err}")))?;
if let Err(err) = element.scroll_into_view() {
debug!("Failed to scroll PDF element into view: {err}");
}
debug!(page = page_number, "PDF viewer element located");
Ok(element)
}
fn prepare_pdf_viewer(tab: &headless_chrome::Tab, page_number: u32) {
let script = format!(
r#"(function() {{
const embed = document.querySelector('embed[type="application/pdf"]') || document.querySelector('embed');
if (!embed || !embed.shadowRoot) return false;
const viewer = embed.shadowRoot.querySelector('pdf-viewer');
if (!viewer || !viewer.shadowRoot) return false;
const app = viewer.shadowRoot.querySelector('viewer-app');
if (app && app.shadowRoot) {{
const toolbar = app.shadowRoot.querySelector('#toolbar');
if (toolbar) {{ toolbar.style.display = 'none'; }}
}}
const page = viewer.shadowRoot.querySelector('viewer-page:nth-of-type({page})');
if (page && page.scrollIntoView) {{
page.scrollIntoView({{ block: 'start', inline: 'center' }});
}}
const canvas = viewer.shadowRoot.querySelector('canvas[aria-label="Page {page}"]');
return !!canvas;
}})()"#,
page = page_number
);
match tab.evaluate(&script, false) {
Ok(result) => {
let ready = result
.value
.as_ref()
.and_then(Value::as_bool)
.unwrap_or(false);
debug!(page = page_number, ready, "Prepared PDF viewer page");
}
Err(err) => {
debug!(page = page_number, error = %err, "Unable to run PDF viewer preparation script");
}
}
}
fn canvas_viewport_for_page(
tab: &headless_chrome::Tab,
page_number: u32,
) -> Result<Option<Page::Viewport>, AppError> {
let script = format!(
r#"(function() {{
const embed = document.querySelector('embed[type="application/pdf"]') || document.querySelector('embed');
if (!embed || !embed.shadowRoot) return null;
const viewer = embed.shadowRoot.querySelector('pdf-viewer');
if (!viewer || !viewer.shadowRoot) return null;
const canvas = viewer.shadowRoot.querySelector('canvas[aria-label="Page {page}"]');
if (!canvas) return null;
const rect = canvas.getBoundingClientRect();
return {{ x: rect.x, y: rect.y, width: rect.width, height: rect.height }};
}})()"#,
page = page_number
);
let result = tab
.evaluate(&script, false)
.map_err(|err| AppError::Processing(format!("Failed to inspect PDF canvas: {err}")))?;
let Some(value) = result.value else {
return Ok(None);
};
if value.is_null() {
return Ok(None);
}
let x = value
.get("x")
.and_then(Value::as_f64)
.unwrap_or_default()
.max(0.0);
let y = value
.get("y")
.and_then(Value::as_f64)
.unwrap_or_default()
.max(0.0);
let width = value
.get("width")
.and_then(Value::as_f64)
.unwrap_or_default();
let height = value
.get("height")
.and_then(Value::as_f64)
.unwrap_or_default();
if width <= 0.0 || height <= 0.0 {
return Ok(None);
}
debug!(
page = page_number,
x, y, width, height, "Derived canvas viewport"
);
Ok(Some(Page::Viewport {
x,
y,
width,
height,
scale: 1.0,
}))
}
fn capture_full_page_png(tab: &headless_chrome::Tab) -> Result<Vec<u8>, AppError> {
let screenshot = tab
.call_method(Page::CaptureScreenshot {
format: Some(Page::CaptureScreenshotFormatOption::Png),
quality: None,
clip: None,
from_surface: Some(true),
capture_beyond_viewport: Some(true),
optimize_for_speed: Some(false),
})
.map_err(|err| {
AppError::Processing(format!("Failed to capture PDF page (fallback): {err}"))
})?;
STANDARD.decode(screenshot.data).map_err(|err| {
AppError::Processing(format!("Failed to decode PDF screenshot (fallback): {err}"))
})
}
fn is_suspicious_image(len: usize) -> bool {
len < MIN_PAGE_IMAGE_BYTES
}
async fn maybe_dump_debug_image(page_index: u32, bytes: &[u8]) -> Result<(), AppError> {
if let Some(dir) = debug_dump_directory() {
tokio::fs::create_dir_all(&dir).await?;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let file_path = dir.join(format!("page-{page_index:04}-{timestamp}.png"));
tokio::fs::write(&file_path, bytes).await?;
debug!(?file_path, size = bytes.len(), "Wrote PDF debug screenshot");
}
Ok(())
}
fn is_low_quality_response(content: &str) -> bool {
let trimmed = content.trim();
if trimmed.is_empty() {
return true;
}
let lowered = trimmed.to_ascii_lowercase();
lowered.contains("unable to") || lowered.contains("cannot")
}
fn prompt_for_attempt(attempt: usize, base_prompt: &str) -> &str {
if attempt == 0 {
base_prompt
} else {
PDF_MARKDOWN_PROMPT_RETRY
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_looks_good_enough_short_text() {
assert!(!looks_good_enough("too short"));
}
#[test]
fn test_looks_good_enough_ascii_text() {
let text = "This is a reasonably long ASCII text that should pass the heuristic. \
It contains multiple sentences and a decent amount of letters to satisfy the threshold.";
assert!(looks_good_enough(text));
}
#[test]
fn test_reflow_markdown_preserves_lists() {
let input = "Item one\nItem two\n\n- Bullet\n- Another";
let output = reflow_markdown(input);
assert!(output.contains("Item one Item two"));
assert!(output.contains("- Bullet"));
}
#[test]
fn test_debug_dump_directory_env_var() {
std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
assert!(debug_dump_directory().is_none());
std::env::set_var(DEBUG_IMAGE_ENV_VAR, "/tmp/minne_pdf_debug");
let dir = debug_dump_directory().expect("expected debug directory");
assert_eq!(dir, PathBuf::from("/tmp/minne_pdf_debug"));
std::env::remove_var(DEBUG_IMAGE_ENV_VAR);
}
#[test]
fn test_is_suspicious_image_threshold() {
assert!(is_suspicious_image(0));
assert!(is_suspicious_image(MIN_PAGE_IMAGE_BYTES - 1));
assert!(!is_suspicious_image(MIN_PAGE_IMAGE_BYTES + 1));
}
#[test]
fn test_is_low_quality_response_detection() {
assert!(is_low_quality_response(""));
assert!(is_low_quality_response("I'm unable to help."));
assert!(is_low_quality_response("I cannot read this."));
assert!(!is_low_quality_response("# Heading\nValid content"));
}
#[test]
fn test_prompt_for_attempt_variants() {
assert_eq!(
prompt_for_attempt(0, PDF_MARKDOWN_PROMPT),
PDF_MARKDOWN_PROMPT
);
assert_eq!(
prompt_for_attempt(1, PDF_MARKDOWN_PROMPT),
PDF_MARKDOWN_PROMPT_RETRY
);
assert_eq!(
prompt_for_attempt(5, PDF_MARKDOWN_PROMPT),
PDF_MARKDOWN_PROMPT_RETRY
);
}
#[test]
fn test_markdown_prompts_discourage_code_blocks() {
assert!(!PDF_MARKDOWN_PROMPT.contains("```"));
assert!(!PDF_MARKDOWN_PROMPT_RETRY.contains("```"));
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "main"
version = "0.2.0"
version = "0.2.4"
edition = "2021"
repository = "https://github.com/perstarkse/minne"
license = "AGPL-3.0-or-later"
@@ -23,6 +23,10 @@ api-router = { path = "../api-router" }
html-router = { path = "../html-router" }
common = { path = "../common" }
[dev-dependencies]
tower = "0.5"
uuid = { workspace = true }
[[bin]]
name = "server"
path = "src/server.rs"
@@ -34,4 +38,3 @@ path = "src/worker.rs"
[[bin]]
name = "main"
path = "src/main.rs"

View File

@@ -129,3 +129,98 @@ struct AppState {
api_state: ApiState,
html_state: HtmlState,
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, http::StatusCode, Router};
use common::utils::config::{AppConfig, PdfIngestMode, StorageKind};
use std::{path::Path, sync::Arc};
use tower::ServiceExt;
use uuid::Uuid;
fn smoke_test_config(namespace: &str, database: &str, data_dir: &Path) -> AppConfig {
AppConfig {
openai_api_key: "test-key".into(),
surrealdb_address: "mem://".into(),
surrealdb_username: "root".into(),
surrealdb_password: "root".into(),
surrealdb_namespace: namespace.into(),
surrealdb_database: database.into(),
data_dir: data_dir.to_string_lossy().into_owned(),
http_port: 0,
openai_base_url: "https://example.com".into(),
storage: StorageKind::Local,
pdf_ingest_mode: PdfIngestMode::LlmFirst,
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn smoke_startup_with_in_memory_surrealdb() {
let namespace = "test_ns";
let database = format!("test_db_{}", Uuid::new_v4());
let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4()));
tokio::fs::create_dir_all(&data_dir)
.await
.expect("failed to create temp data directory");
let config = smoke_test_config(namespace, &database, &data_dir);
let db = Arc::new(
SurrealDbClient::memory(namespace, &database)
.await
.expect("failed to start in-memory surrealdb"),
);
db.apply_migrations()
.await
.expect("failed to apply migrations");
let session_store = Arc::new(db.create_session_store().await.expect("session store"));
let openai_client = Arc::new(async_openai::Client::with_config(
async_openai::config::OpenAIConfig::new()
.with_api_key(&config.openai_api_key)
.with_api_base(&config.openai_base_url),
));
let html_state =
HtmlState::new_with_resources(db.clone(), openai_client, session_store, config.clone())
.expect("failed to build html state");
let api_state = ApiState {
db: html_state.db.clone(),
config: config.clone(),
};
let app = Router::new()
.nest("/api/v1", api_routes_v1(&api_state))
.merge(html_routes(&html_state))
.with_state(AppState {
api_state,
html_state,
});
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/api/v1/live")
.body(Body::empty())
.expect("request"),
)
.await
.expect("router response");
assert_eq!(response.status(), StatusCode::OK);
let ready_response = app
.oneshot(
Request::builder()
.uri("/api/v1/ready")
.body(Body::empty())
.expect("request"),
)
.await
.expect("ready response");
assert_eq!(ready_response.status(), StatusCode::OK);
tokio::fs::remove_dir_all(&data_dir).await.ok();
}
}

52
todo.md
View File

@@ -1,52 +0,0 @@
[] implement prompt and model choice for image processing?
[x] ollama and changing of openai_base_url
[x] allow changing of port the server listens to
[] archive ingressed webpage, pdf would be easy
[] embed surrealdb for the main binary
[] three js graph explorer
[x] add user_id to ingress objects
[x] admin controls re registration
[x] allow setting of data storage folder, via envs and config
[x] build docker container on release plan
[x] change to smoothie dom
[x] chat functionality
[x] chat history
[x] chat styling overhaul
[x] configs primarily get envs
[x] debug vector search
[x] debug why not automatic retrieval of chrome binary works
[x] filtering on categories
[x] fix card image in content
[x] fix patch_text_content
[x] fix redirect for non hx
[x] full text search
[x] html ingression
[x] hx-redirect
[x] implement migrations
[x] integrate assets folder in release build
[x] integrate templates in release build
[x] ios shortcut generation
[x] job queue
[x] link to ingressed urls or archives
[x] macro for pagedata?
[x] make sure error messages render correctly
[x] markdown rendering in client
[x] marked loading conditions
[x] on updates of knowledgeentity create new embeddings
[x] openai api key in config
[x] option to set models, query and processing
[x] page screenshot?
[x] redirects
[x] rename ingestion instructions to context
[x] restrict retrieval to users own objects
[x] scroll reading window to top
[x] smoothie_dom test
[x] sse ingestion updates
[x] store page title
[x] template customization?
[x] templating
[x] testing core functions
[x] user id to fileinfo and data path?
[x] view content
[x] view graph map
[x] view latest