43 Commits

Author SHA1 Message Date
Per Stark
8fe4ac9fec release: 1.0.0 2026-01-11 18:37:07 +01:00
Per Stark
db43be1606 fix: schemafull and textcontent 2026-01-02 15:41:22 +01:00
Per Stark
8e8370b080 docs: more complete and correct 2025-12-24 23:36:58 +01:00
Per Stark
84695fa0cc chore: wording 2025-12-22 23:03:33 +01:00
Per Stark
654add98bc fix: never block fts, rely on rrf 2025-12-22 22:56:57 +01:00
Per Stark
244ec0ea25 fix: migrating embeddings to new dimensions
changing order
2025-12-22 22:39:14 +01:00
Per Stark
d8416ac711 fix: ordering of index creation 2025-12-22 21:59:35 +01:00
Per Stark
f9f48d1046 docs: evaluations instructions and readme refactoring 2025-12-22 18:55:47 +01:00
Per Stark
30b8a65377 fix: migrations
schemafull
2025-12-22 18:32:08 +01:00
Per Stark
04faa38ee6 fix: admin page sorted 2025-12-21 21:35:52 +01:00
Per Stark
cdc62dda30 Merge branch 'main' into benchmarks 2025-12-20 23:09:16 +01:00
Per Stark
ab8ff8b07a changelog 2025-12-20 23:03:06 +01:00
Per Stark
79ea007b0a tidying stuff up, dto for search 2025-12-20 22:30:31 +01:00
Per Stark
a5bc72aedf passed wide smoke check 2025-12-10 13:54:08 +01:00
Per Stark
2e2ea0c4ff faster index creation 2025-12-09 21:32:23 +01:00
Per Stark
a090a8c76e retrieval simplfied 2025-12-09 20:35:42 +01:00
Per Stark
a8d10f265c benchmarks: fin 2025-12-08 21:57:53 +01:00
Per Stark
0cb1abc6db beir-rff 2025-12-08 20:39:12 +01:00
Per Stark
d1a6d9abdf dataset: beir 2025-12-04 17:50:35 +01:00
Per Stark
d3fa3be3e5 retrieval: hybrid search, linear fusion 2025-12-04 12:48:59 +01:00
Per Stark
a2c9bb848d release: 0.2.7 2025-12-04 12:25:46 +01:00
Per Stark
dd881efbf9 benchmarks: ready for hybrid revised 2025-12-03 11:38:07 +01:00
Per Stark
2939e4c2a4 fix: removed stale embeddings handler 2025-11-29 20:07:48 +01:00
Per Stark
1039ec32a4 fix: all tests now in sync 2025-11-29 18:59:08 +01:00
Per Stark
cb906c5b53 ndcg fix 2025-11-29 16:24:09 +01:00
Per Stark
08b1612fcb refactored to clap, mrr and ndcg 2025-11-28 21:26:51 +01:00
Per Stark
67004c9646 fix: index creation at init 2025-11-26 21:49:20 +01:00
Per Stark
030f0fc17d evals: v3, ebeddings at the side
additional indexes
2025-11-26 15:15:10 +01:00
Per Stark
226b2db43a retrieval-pipeline: v1 2025-11-19 12:58:27 +01:00
Per Stark
6f88d87e74 fix: add dockerfile changes related to retrieval-pipeline 2025-11-18 22:51:48 +01:00
Per Stark
bd519ab269 benchmarks: v2
Minor refactor
2025-11-18 22:51:06 +01:00
Per Stark
f535df7e61 retrieval-pipeline: v0 2025-11-18 22:46:35 +01:00
Per Stark
6b7befbd04 upsert relationship and creation 2025-11-18 21:18:09 +01:00
Per Stark
0eda65b07e benchmarks: v1
Benchmarking ingestion, retrieval precision and performance
2025-11-18 11:50:15 +01:00
Per Stark
04ee225732 design: improved admin page, new structure 2025-11-04 20:42:24 +01:00
Per Stark
13b7ad6f3a fix: added cargo lock to crane build 2025-11-04 12:59:32 +01:00
Per Stark
112a6965a4 Merge branch 'main' into development 2025-11-03 12:48:04 +01:00
Per Stark
911e830be5 Merge branch 'development' of github.com:perstarkse/minne into development 2025-11-03 12:40:36 +01:00
Per Stark
3196e65172 fix: improved storage manager, prep for s3 2025-11-03 12:39:15 +01:00
Per Stark
f13791cfcf fix: better default naming of relationships 2025-10-27 20:46:00 +01:00
Per Stark
75c200b2ba fix: update graph view when changes in knowledge store 2025-10-27 18:22:15 +01:00
Per Stark
1b7c24747a fix: in memory object store handler for testing 2025-10-27 17:03:03 +01:00
Per Stark
241ad9a089 fix: scratchpad tz aware datetime 2025-10-27 14:00:22 +01:00
180 changed files with 20936 additions and 4019 deletions

2
.cargo/config.toml Normal file
View File

@@ -0,0 +1,2 @@
[alias]
eval = "run -p evaluations --"

4
.gitignore vendored
View File

@@ -10,6 +10,9 @@ result
data
database
evaluations/cache/
evaluations/reports/
# Devenv
.devenv*
devenv.local.nix
@@ -21,3 +24,4 @@ devenv.local.nix
.pre-commit-config.yaml
# html-router/assets/style.css
html-router/node_modules
.fastembed_cache/

View File

@@ -1,8 +1,18 @@
# Changelog
## Unreleased
## 1.0.0 (2026-01-02)
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
- Added a benchmarks create for evaluating the retrieval process
- Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms.
- Embeddings stored on own table.
- Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details.
## Version 0.2.7 (2025-12-04)
- Improved admin page, now only loads models when specifically requested. Groundwork for coming configuration features.
- Fix: timezone aware info in scratchpad
## Version 0.2.6 (2025-10-29)
- Added an opt-in FastEmbed-based reranking stage behind `reranking_enabled`. It improves retrieval accuracy by re-scoring hybrid results.
- Fix: default name for relationships harmonized across application
## Version 0.2.5 (2025-10-24)
- Added manual knowledge entity creation flows using a modal, with the option for suggested relationships

390
Cargo.lock generated
View File

@@ -184,6 +184,12 @@ dependencies = [
"libc",
]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "anstream"
version = "0.6.18"
@@ -1090,6 +1096,12 @@ dependencies = [
"serde",
]
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "castaway"
version = "0.2.3"
@@ -1405,6 +1417,7 @@ dependencies = [
"chrono-tz",
"config",
"dom_smoothie",
"fastembed",
"futures",
"include_dir",
"mime",
@@ -1445,26 +1458,6 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "composite-retrieval"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"axum",
"common",
"fastembed",
"futures",
"serde",
"serde_json",
"state-machines",
"surrealdb",
"thiserror 1.0.69",
"tokio",
"tracing",
"uuid",
]
[[package]]
name = "compression-codecs"
version = "0.4.30"
@@ -1626,6 +1619,42 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "criterion"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"is-terminal",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
@@ -1900,13 +1929,22 @@ dependencies = [
"subtle",
]
[[package]]
name = "dirs"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
dependencies = [
"dirs-sys 0.4.1",
]
[[package]]
name = "dirs"
version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e"
dependencies = [
"dirs-sys",
"dirs-sys 0.5.0",
]
[[package]]
@@ -1919,6 +1957,18 @@ dependencies = [
"dirs-sys-next",
]
[[package]]
name = "dirs-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
dependencies = [
"libc",
"option-ext",
"redox_users 0.4.6",
"windows-sys 0.48.0",
]
[[package]]
name = "dirs-sys"
version = "0.5.0"
@@ -2130,6 +2180,9 @@ name = "esaxx-rs"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
dependencies = [
"cc",
]
[[package]]
name = "euclid"
@@ -2140,6 +2193,39 @@ dependencies = [
"num-traits",
]
[[package]]
name = "evaluations"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-trait",
"chrono",
"clap",
"common",
"criterion",
"fastembed",
"futures",
"ingestion-pipeline",
"object_store 0.11.2",
"once_cell",
"rand 0.8.5",
"retrieval-pipeline",
"serde",
"serde_json",
"serde_yaml",
"sha2",
"state-machines",
"surrealdb",
"tempfile",
"text-splitter",
"tokio",
"tracing",
"tracing-subscriber",
"unicode-normalization",
"uuid",
]
[[package]]
name = "event-listener"
version = "5.4.0"
@@ -2217,12 +2303,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2e9bf3ea201e5d338450555088e02cff23b00be92bead3eff7ed341c68f5ac6"
dependencies = [
"anyhow",
"hf-hub",
"hf-hub 0.4.3",
"image",
"ndarray 0.16.1",
"ort",
"serde_json",
"tokenizers",
"tokenizers 0.22.1",
]
[[package]]
@@ -2735,19 +2821,42 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hf-hub"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
dependencies = [
"dirs 5.0.1",
"indicatif",
"log",
"native-tls",
"rand 0.8.5",
"serde",
"serde_json",
"thiserror 1.0.69",
"ureq",
]
[[package]]
name = "hf-hub"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
dependencies = [
"dirs",
"dirs 6.0.0",
"http",
"indicatif",
"libc",
@@ -2795,7 +2904,6 @@ dependencies = [
"chrono",
"chrono-tz",
"common",
"composite-retrieval",
"futures",
"include_dir",
"json-stream-parser",
@@ -2803,6 +2911,7 @@ dependencies = [
"minijinja-autoreload",
"minijinja-contrib",
"minijinja-embed",
"retrieval-pipeline",
"serde",
"serde_json",
"surrealdb",
@@ -3248,26 +3357,29 @@ dependencies = [
name = "ingestion-pipeline"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-trait",
"axum",
"axum_typed_multipart",
"base64 0.22.1",
"bytes",
"chrono",
"common",
"composite-retrieval",
"dom_smoothie",
"futures",
"headless_chrome",
"lopdf 0.32.0",
"pdf-extract",
"reqwest",
"retrieval-pipeline",
"serde",
"serde_json",
"state-machines",
"surrealdb",
"tempfile",
"text-splitter",
"tokenizers 0.20.4",
"tokio",
"tracing",
"url",
@@ -3330,6 +3442,17 @@ version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is-terminal"
version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [
"hermit-abi 0.5.2",
"libc",
"windows-sys 0.60.2",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
@@ -3697,17 +3820,17 @@ checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
[[package]]
name = "main"
version = "0.2.6"
version = "1.0.0"
dependencies = [
"anyhow",
"api-router",
"async-openai",
"axum",
"common",
"composite-retrieval",
"futures",
"html-router",
"ingestion-pipeline",
"retrieval-pipeline",
"serde",
"serde_json",
"surrealdb",
@@ -4240,7 +4363,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi",
"hermit-abi 0.3.9",
"libc",
]
@@ -4330,6 +4453,12 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "opaque-debug"
version = "0.3.1"
@@ -4705,6 +4834,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
name = "png"
version = "0.18.0"
@@ -4872,6 +5029,17 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "pulldown-cmark"
version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f86ba2052aebccc42cbbb3ed234b8b13ce76f75c3551a303cb2bcffcff12bb14"
dependencies = [
"bitflags 2.9.0",
"memchr",
"unicase",
]
[[package]]
name = "pxfm"
version = "0.1.25"
@@ -5137,6 +5305,17 @@ dependencies = [
"rayon-core",
]
[[package]]
name = "rayon-cond"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9"
dependencies = [
"either",
"itertools 0.11.0",
"rayon",
]
[[package]]
name = "rayon-cond"
version = "0.4.0"
@@ -5343,6 +5522,27 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "retrieval-pipeline"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-trait",
"axum",
"clap",
"common",
"fastembed",
"futures",
"serde",
"serde_json",
"surrealdb",
"thiserror 1.0.69",
"tokio",
"tracing",
"uuid",
]
[[package]]
name = "revision"
version = "0.10.0"
@@ -5920,6 +6120,19 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap 2.9.0",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "servo_arc"
version = "0.4.0"
@@ -6598,9 +6811,11 @@ dependencies = [
"either",
"itertools 0.13.0",
"once_cell",
"pulldown-cmark",
"regex",
"strum",
"thiserror 1.0.69",
"tokenizers 0.20.4",
"unicode-segmentation",
]
@@ -6735,6 +6950,16 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.9.0"
@@ -6750,6 +6975,39 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.20.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom 0.2.16",
"hf-hub 0.3.2",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand 0.8.5",
"rayon",
"rayon-cond 0.3.0",
"regex",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror 1.0.69",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokenizers"
version = "0.22.1"
@@ -6771,7 +7029,7 @@ dependencies = [
"paste",
"rand 0.9.1",
"rayon",
"rayon-cond",
"rayon-cond 0.4.0",
"regex",
"regex-syntax 0.8.5",
"serde",
@@ -7263,6 +7521,12 @@ dependencies = [
"subtle",
]
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -7755,6 +8019,15 @@ dependencies = [
"windows-link 0.1.1",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
@@ -7782,6 +8055,21 @@ dependencies = [
"windows-targets 0.53.5",
]
[[package]]
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows_aarch64_gnullvm 0.48.5",
"windows_aarch64_msvc 0.48.5",
"windows_i686_gnu 0.48.5",
"windows_i686_msvc 0.48.5",
"windows_x86_64_gnu 0.48.5",
"windows_x86_64_gnullvm 0.48.5",
"windows_x86_64_msvc 0.48.5",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
@@ -7815,6 +8103,12 @@ dependencies = [
"windows_x86_64_msvc 0.53.0",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
@@ -7827,6 +8121,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
@@ -7839,6 +8139,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
@@ -7863,6 +8169,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
@@ -7875,6 +8187,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
@@ -7887,6 +8205,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
@@ -7899,6 +8223,12 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"

View File

@@ -5,8 +5,9 @@ members = [
"api-router",
"html-router",
"ingestion-pipeline",
"composite-retrieval",
"json-stream-parser"
"retrieval-pipeline",
"json-stream-parser",
"evaluations"
]
resolver = "2"
@@ -41,7 +42,9 @@ sha2 = "0.10.8"
surrealdb-migrations = "2.2.2"
surrealdb = { version = "2", features = ["kv-mem"] }
tempfile = "3.12.0"
text-splitter = "0.18.1"
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
tokenizers = { version = "0.20.4", features = ["http"] }
unicode-normalization = "0.1.24"
thiserror = "1.0.63"
tokio-util = { version = "0.7.15", features = ["io"] }
tokio = { version = "1", features = ["full"] }
@@ -77,7 +80,7 @@ implicit_clone = "warn"
redundant_clone = "warn"
# Security-focused lints
integer_arithmetic = "warn"
arithmetic_side_effects = "warn"
indexing_slicing = "warn"
unwrap_used = "warn"
expect_used = "warn"
@@ -87,7 +90,7 @@ todo = "warn"
# Async/Network lints
async_yields_async = "warn"
await_holding_invalid_state = "warn"
await_holding_invalid_type = "warn"
rc_buffer = "warn"
# Maintainability-focused lints
@@ -106,6 +109,8 @@ wildcard_dependencies = "warn"
missing_docs_in_private_items = "warn"
# Allow noisy lints that don't add value for this project
manual_must_use = "allow"
needless_raw_string_hashes = "allow"
multiple_bound_locations = "allow"
cargo_common_metadata = "allow"
multiple-crate-versions = "allow"
module_name_repetition = "allow"

View File

@@ -6,10 +6,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
# Cache deps
COPY Cargo.toml Cargo.lock ./
RUN mkdir -p api-router common composite-retrieval html-router ingestion-pipeline json-stream-parser main worker
RUN mkdir -p api-router common retrieval-pipeline html-router ingestion-pipeline json-stream-parser main worker
COPY api-router/Cargo.toml ./api-router/
COPY common/Cargo.toml ./common/
COPY composite-retrieval/Cargo.toml ./composite-retrieval/
COPY retrieval-pipeline/Cargo.toml ./retrieval-pipeline/
COPY html-router/Cargo.toml ./html-router/
COPY ingestion-pipeline/Cargo.toml ./ingestion-pipeline/
COPY json-stream-parser/Cargo.toml ./json-stream-parser/

261
README.md
View File

@@ -1,265 +1,66 @@
# Minne - A Graph-Powered Personal Knowledge Base
# Minne
**Minne (Swedish for "memory")** is a personal knowledge management system and save-for-later application for capturing, organizing, and accessing your information. Inspired by the Zettelkasten method, it uses a graph database to automatically create connections between your notes without manual linking overhead.
**A graph-powered personal knowledge base that makes storing easy.**
Capture content effortlessly, let AI discover connections, and explore your knowledge visually. Self-hosted and privacy-focused.
[![Release Status](https://github.com/perstarkse/minne/actions/workflows/release.yml/badge.svg)](https://github.com/perstarkse/minne/actions/workflows/release.yml)
[![License: AGPL v3](https://img.shields.io/badge/License-AGPL_v3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![Latest Release](https://img.shields.io/github/v/release/perstarkse/minne?sort=semver)](https://github.com/perstarkse/minne/releases/latest)
![Screenshot](screenshot-graph.webp)
![Screenshot](./screenshot-graph.webp)
## Demo deployment
## Try It
To test _Minne_ out, enter [this](https://minne-demo.stark.pub) read-only demo deployment to view and test functionality out.
## Noteworthy Features
- **Search & Chat Interface** - Find content or knowledge instantly with full-text search, or use the chat mode and conversational AI to find and reason about content
- **Manual and AI-assisted connections** - Build entities and relationships manually with full control, let AI create entities and relationships automatically, or blend both approaches with AI suggestions for manual approval
- **Hybrid Retrieval System** - Search combining vector similarity, full-text search, and graph traversal for highly relevant results
- **Scratchpad Feature** - Quickly capture thoughts and convert them to permanent content when ready
- **Visual Graph Explorer** - Interactive D3-based navigation of your knowledge entities and connections
- **Multi-Format Support** - Ingest text, URLs, PDFs, audio files, and images into your knowledge base
- **Performance Focus** - Built with Rust and server-side rendering for speed and efficiency
- **Self-Hosted & Privacy-Focused** - Full control over your data, and compatible with any OpenAI-compatible API that supports structured outputs
## The "Why" Behind Minne
For a while I've been fascinated by personal knowledge management systems. I wanted something that made it incredibly easy to capture content - snippets of text, URLs, and other media - while automatically discovering connections between ideas. But I also wanted to maintain control over my knowledge structure.
Traditional tools like Logseq and Obsidian are excellent, but the manual linking process often became a hindrance. Meanwhile, fully automated systems sometimes miss important context or create relationships I wouldn't have chosen myself.
So I built Minne to offer the best of both worlds: effortless content capture with AI-assisted relationship discovery, but with the flexibility to manually curate, edit, or override any connections. You can let AI handle the heavy lifting of extracting entities and finding relationships, take full control yourself, or use a hybrid approach where AI suggests connections that you can approve or modify.
While developing Minne, I discovered [KaraKeep](https://github.com/karakeep-app/karakeep) (formerly Hoarder), which is an excellent application in a similar space you probably want to check it out! However, if you're interested in a PKM that offers both intelligent automation and manual curation, with the ability to chat with your knowledge base, then Minne might be worth testing.
## Table of Contents
- [Quick Start](#quick-start)
- [Features in Detail](#features-in-detail)
- [Configuration](#configuration)
- [Tech Stack](#tech-stack)
- [Application Architecture](#application-architecture)
- [AI Configuration](#ai-configuration--model-selection)
- [Roadmap](#roadmap)
- [Development](#development)
- [Contributing](#contributing)
- [License](#license)
**[Live Demo](https://minne-demo.stark.pub)** — Read-only demo deployment
## Quick Start
The fastest way to get Minne running is with Docker Compose:
```bash
# Clone the repository
git clone https://github.com/perstarkse/minne.git
cd minne
# Start Minne and its database
# Set your OpenAI API key in docker-compose.yml, then:
docker compose up -d
# Access at http://localhost:3000
# Open http://localhost:3000
```
**Required Setup:**
- Replace `your_openai_api_key_here` in `docker-compose.yml` with your actual API key
- Configure `OPENAI_BASE_URL` if using a custom AI provider (like Ollama)
For detailed installation options, see [Configuration](#configuration).
## Features in Detail
### Search vs. Chat mode
**Search** - Use when you know roughly what you're looking for. Full-text search finds items quickly by matching your query terms.
**Chat Mode** - Use when you want to explore concepts, find connections, or reason about your knowledge. The AI analyzes your query and finds relevant context across your entire knowledge base.
### Content Processing
Minne automatically processes content you save:
1. **Web scraping** extracts readable text from URLs
2. **Text analysis** identifies key concepts and relationships
3. **Graph creation** builds connections between related content
4. **Embedding generation** enables semantic search capabilities
### Visual Knowledge Graph
Explore your knowledge as an interactive network with flexible curation options:
**Manual Curation** - Create knowledge entities and relationships yourself with full control over your graph structure
**AI Automation** - Let AI automatically extract entities and discover relationships from your content
**Hybrid Approach** - Get AI-suggested relationships and entities that you can manually review, edit, or approve
The graph visualization shows:
- Knowledge entities as nodes (manually created or AI-extracted)
- Relationships as connections (manually defined, AI-discovered, or suggested)
- Interactive navigation for discovery and editing
### Optional FastEmbed Reranking
Minne ships with an opt-in reranking stage powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs). When enabled, the hybrid retrieval results are rescored with a lightweight cross-encoder before being returned to chat or ingestion flows. In practice this often means more relevant results, boosting answer quality and downstream enrichment.
⚠️ **Resource notes**
- Enabling reranking downloads and caches ~1.1GB of model data on first startup (cached under `<data_dir>/fastembed/reranker` by default).
- Initialization takes longer while warming the cache, and each query consumes extra CPU. The default pool size (2) is tuned for a singe user setup, but could work with a pool size on 1 as well.
- The feature is disabled by default. Set `reranking_enabled: true` (or `RERANKING_ENABLED=true`) if youre comfortable with the additional footprint.
Example configuration:
```yaml
reranking_enabled: true
reranking_pool_size: 2
fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults to .fastembed_cache
```
## Tech Stack
- **Backend:** Rust with Axum framework and Server-Side Rendering (SSR)
- **Frontend:** HTML with HTMX and minimal JavaScript for interactivity
- **Database:** SurrealDB (graph, document, and vector search)
- **AI Integration:** OpenAI-compatible API with structured outputs
- **Web Processing:** Headless Chrome for robust webpage content extraction
## Configuration
Minne can be configured using environment variables or a `config.yaml` file. Environment variables take precedence over `config.yaml`.
### Required Configuration
- `SURREALDB_ADDRESS`: WebSocket address of your SurrealDB instance (e.g., `ws://127.0.0.1:8000`)
- `SURREALDB_USERNAME`: Username for SurrealDB (e.g., `root_user`)
- `SURREALDB_PASSWORD`: Password for SurrealDB (e.g., `root_password`)
- `SURREALDB_DATABASE`: Database name in SurrealDB (e.g., `minne_db`)
- `SURREALDB_NAMESPACE`: Namespace in SurrealDB (e.g., `minne_ns`)
- `OPENAI_API_KEY`: Your API key for OpenAI compatible endpoint
- `HTTP_PORT`: Port for the Minne server (Default: `3000`)
### Optional Configuration
- `RUST_LOG`: Controls logging level (e.g., `minne=info,tower_http=debug`)
- `DATA_DIR`: Directory to store local data (e.g., `./data`)
- `OPENAI_BASE_URL`: Base URL for custom AI providers (like Ollama)
- `RERANKING_ENABLED` / `reranking_enabled`: Set to `true` to enable the FastEmbed reranking stage (default `false`)
- `RERANKING_POOL_SIZE` / `reranking_pool_size`: Maximum concurrent reranker workers (defaults to `2`)
- `FASTEMBED_CACHE_DIR` / `fastembed_cache_dir`: Directory for cached FastEmbed models (defaults to `<data_dir>/fastembed/reranker`)
- `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` / `fastembed_show_download_progress`: Show model download progress when warming the cache (default `true`)
### Example config.yaml
```yaml
surrealdb_address: "ws://127.0.0.1:8000"
surrealdb_username: "root_user"
surrealdb_password: "root_password"
surrealdb_database: "minne_db"
surrealdb_namespace: "minne_ns"
openai_api_key: "sk-YourActualOpenAIKeyGoesHere"
data_dir: "./minne_app_data"
http_port: 3000
# rust_log: "info"
```
## Installation Options
### 1. Docker Compose (Recommended)
```bash
# Clone and run
git clone https://github.com/perstarkse/minne.git
cd minne
docker compose up -d
```
The included `docker-compose.yml` handles SurrealDB and Chromium dependencies automatically.
### 2. Nix
Or with Nix (with environment variables set):
```bash
nix run 'github:perstarkse/minne#main'
```
This fetches Minne and all dependencies, including Chromium.
Pre-built binaries for Windows, macOS, and Linux are available on the [Releases](https://github.com/perstarkse/minne/releases/latest) page.
### 3. Pre-built Binaries
## Features
Download binaries for Windows, macOS, and Linux from the [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
- **Fast** — Rust backend with server-side rendering and HTMX for snappy interactions
- **Search & Chat** — Search or use conversational AI to find and reason about content
- **Knowledge Graph** — Visual exploration with automatic or manual relationship curation
- **Hybrid Retrieval** — Vector similarity + full-text for relevant results
- **Multi-Format** — Ingest text, URLs, PDFs, audio, and images
- **Self-Hosted** — Your data, your server, any OpenAI-compatible API
**Requirements:** You'll need to provide SurrealDB and Chromium separately.
## Documentation
### 4. Build from Source
| Guide | Description |
|-------|-------------|
| [Installation](docs/installation.md) | Docker, Nix, binaries, source builds |
| [Configuration](docs/configuration.md) | Environment variables, config.yaml, AI setup |
| [Features](docs/features.md) | Search, Chat, Graph, Reranking, Ingestion |
| [Architecture](docs/architecture.md) | Tech stack, crate structure, data flow |
| [Vision](docs/vision.md) | Philosophy, roadmap, related projects |
```bash
git clone https://github.com/perstarkse/minne.git
cd minne
cargo run --release --bin main
```
## Tech Stack
**Requirements:** SurrealDB and Chromium must be installed and accessible in your PATH.
## Application Architecture
Minne offers flexible deployment options:
- **`main`**: Combined server and worker in one process (recommended for most users)
- **`server`**: Web interface and API only
- **`worker`**: Background processing only (for resource optimization)
## Usage
Once Minne is running at `http://localhost:3000`:
1. **Web Interface**: Full-featured experience for desktop and mobile
2. **iOS Shortcut**: Use the [Minne iOS Shortcut](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) for quick content capture
3. **Content Types**: Save notes, URLs, audio files, and more
4. **Knowledge Graph**: Explore automatic connections between your content
5. **Chat Interface**: Query your knowledge base conversationally
## AI Configuration & Model Selection
### Setting Up AI Providers
Minne uses OpenAI-compatible APIs. Configure via environment variables or `config.yaml`:
- `OPENAI_API_KEY` (required): Your API key
- `OPENAI_BASE_URL` (optional): Custom provider URL (e.g., Ollama: `http://localhost:11434/v1`)
### Model Selection
1. Access the `/admin` page in your Minne instance
2. Select models for content processing and chat from your configured provider
3. **Content Processing Requirements**: The model must support structured outputs
4. **Embedding Dimensions**: Update this setting when changing embedding models (e.g., 1536 for `text-embedding-3-small`, 768 for `nomic-embed-text`)
## Roadmap
Current development focus:
- TUI frontend with system editor integration
- Enhanced reranking for improved retrieval recall
- Additional content type support
Feature requests and contributions are welcome!
## Development
```bash
# Run tests
cargo test
# Development build
cargo build
# Comprehensive linting
cargo clippy --workspace --all-targets --all-features
```
The codebase includes extensive unit tests. Integration tests and additional contributions are welcome.
Rust • Axum • HTMX • SurrealDB • FastEmbed
## Contributing
I've developed Minne primarily for my own use, but having been in the selfhosted space for a long time, and using the efforts by others, I thought I'd share with the community. Feature requests are welcome.
Feature requests and contributions welcome. See [Vision](docs/vision.md) for roadmap.
## License
Minne is licensed under the **GNU Affero General Public License v3.0 (AGPL-3.0)**. See the [LICENSE](LICENSE) file for details.
[AGPL-3.0](LICENSE)

View File

@@ -1,15 +1,22 @@
use std::sync::Arc;
use common::{storage::db::SurrealDbClient, utils::config::AppConfig};
use common::{
storage::{db::SurrealDbClient, store::StorageManager},
utils::config::AppConfig,
};
#[derive(Clone)]
pub struct ApiState {
pub db: Arc<SurrealDbClient>,
pub config: AppConfig,
pub storage: StorageManager,
}
impl ApiState {
pub async fn new(config: &AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
pub async fn new(
config: &AppConfig,
storage: StorageManager,
) -> Result<Self, Box<dyn std::error::Error>> {
let surreal_db_client = Arc::new(
SurrealDbClient::new(
&config.surrealdb_address,
@@ -26,6 +33,7 @@ impl ApiState {
let app_state = Self {
db: surreal_db_client.clone(),
config: config.clone(),
storage,
};
Ok(app_state)

View File

@@ -30,9 +30,11 @@ pub async fn ingest_data(
TypedMultipart(input): TypedMultipart<IngestParams>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", input);
let user_id = user.id;
let file_infos = try_join_all(input.files.into_iter().map(|file| {
FileInfo::new(file, &state.db, &user.id, &state.config).map_err(AppError::from)
FileInfo::new_with_storage(file, &state.db, &user_id, &state.storage)
.map_err(AppError::from)
}))
.await?;
@@ -41,12 +43,12 @@ pub async fn ingest_data(
input.context,
input.category,
file_infos,
user.id.as_str(),
&user_id,
)?;
let futures: Vec<_> = payloads
.into_iter()
.map(|object| IngestionTask::create_and_add_to_db(object, user.id.clone(), &state.db))
.map(|object| IngestionTask::create_and_add_to_db(object, user_id.clone(), &state.db))
.collect();
try_join_all(futures).await?;

View File

@@ -45,6 +45,7 @@ tokio-retry = { workspace = true }
object_store = { workspace = true }
bytes = { workspace = true }
state-machines = { workspace = true }
fastembed = { workspace = true }
[features]

View File

@@ -14,6 +14,9 @@ CREATE system_settings:current CONTENT {
query_model: "gpt-4o-mini",
processing_model: "gpt-4o-mini",
embedding_model: "text-embedding-3-small",
voice_processing_model: "whisper-1",
image_processing_model: "gpt-4o-mini",
image_processing_prompt: "Analyze this image and respond based on its primary content:\n - If the image is mainly text (document, screenshot, sign), transcribe the text verbatim.\n - If the image is mainly visual (photograph, art, landscape), provide a concise description of the scene.\n - For hybrid images (diagrams, ads), briefly describe the visual, then transcribe the text under a Text: heading.\n\n Respond directly with the analysis.",
embedding_dimensions: 1536,
query_system_prompt: "You are a knowledgeable assistant with access to a specialized knowledge base. You will be provided with relevant knowledge entities from the database as context. Each knowledge entity contains a name, description, and type, representing different concepts, ideas, and information.\nYour task is to:\n1. Carefully analyze the provided knowledge entities in the context\n2. Answer user questions based on this information\n3. Provide clear, concise, and accurate responses\n4. When referencing information, briefly mention which knowledge entity it came from\n5. If the provided context doesn't contain enough information to answer the question confidently, clearly state this\n6. If only partial information is available, explain what you can answer and what information is missing\n7. Avoid making assumptions or providing information not supported by the context\n8. Output the references to the documents. Use the UUIDs and make sure they are correct!\nRemember:\n- Be direct and honest about the limitations of your knowledge\n- Cite the relevant knowledge entities when providing information, but only provide the UUIDs in the reference array\n- If you need to combine information from multiple entities, explain how they connect\n- Don't speculate beyond what's provided in the context\nExample response formats:\n\"Based on [Entity Name], [answer...]\"\n\"I found relevant information in multiple entries: [explanation...]\"\n\"I apologize, but the provided context doesn't contain information about [topic]\"",
ingestion_system_prompt: "You are an AI assistant. You will receive a text content, along with user context and a category. Your task is to provide a structured JSON object representing the content in a graph format suitable for a graph database. You will also be presented with some existing knowledge_entities from the database, do not replicate these! Your task is to create meaningful knowledge entities from the submitted content. Try and infer as much as possible from the users context and category when creating these. If the user submits a large content, create more general entities. If the user submits a narrow and precise content, try and create precise knowledge entities.\nThe JSON should have the following structure:\n{\n\"knowledge_entities\": [\n{\n\"key\": \"unique-key-1\",\n\"name\": \"Entity Name\",\n\"description\": \"A detailed description of the entity.\",\n\"entity_type\": \"TypeOfEntity\"\n},\n// More entities...\n],\n\"relationships\": [\n{\n\"type\": \"RelationshipType\",\n\"source\": \"unique-key-1 or UUID from existing database\",\n\"target\": \"unique-key-1 or UUID from existing database\"\n},\n// More relationships...\n]\n}\nGuidelines:\n1. Do NOT generate any IDs or UUIDs. Use a unique `key` for each knowledge entity.\n2. Each KnowledgeEntity should have a unique `key`, a meaningful `name`, and a descriptive `description`.\n3. Define the type of each KnowledgeEntity using the following categories: Idea, Project, Document, Page, TextSnippet.\n4. Establish relationships between entities using types like RelatedTo, RelevantTo, SimilarTo.\n5. Use the `source` key to indicate the originating entity and the `target` key to indicate the related entity\"\n6. You will be presented with a few existing KnowledgeEntities that are similar to the current ones. They will have an existing UUID. When creating relationships to these entities, use their UUID.\n7. Only create relationships between existing KnowledgeEntities.\n8. Entities that exist already in the database should NOT be created again. If there is only a minor overlap, skip creating a new entity.\n9. A new relationship MUST include a newly created KnowledgeEntity."

View File

@@ -1,27 +1,2 @@
DEFINE ANALYZER IF NOT EXISTS app_default_fts_analyzer
TOKENIZERS class
FILTERS lowercase, ascii;
DEFINE INDEX IF NOT EXISTS text_content_fts_text_idx ON TABLE text_content
FIELDS text
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
DEFINE INDEX IF NOT EXISTS text_content_fts_category_idx ON TABLE text_content
FIELDS category
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
DEFINE INDEX IF NOT EXISTS text_content_fts_context_idx ON TABLE text_content
FIELDS context
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
DEFINE INDEX IF NOT EXISTS text_content_fts_file_name_idx ON TABLE text_content
FIELDS file_info.file_name
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
DEFINE INDEX IF NOT EXISTS text_content_fts_url_idx ON TABLE text_content
FIELDS url_info.url
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
DEFINE INDEX IF NOT EXISTS text_content_fts_url_title_idx ON TABLE text_content
FIELDS url_info.title
SEARCH ANALYZER app_default_fts_analyzer BM25 HIGHLIGHTS;
-- Runtime-managed: text_content FTS indexes now created at startup via the shared Surreal helper.
-- This migration is intentionally left as a no-op to avoid heavy index builds during migration.

View File

@@ -1 +1 @@
REMOVE TABLE job;
-- No-op: legacy `job` table was superseded by `ingestion_task`; kept for migration order compatibility.

View File

@@ -1,17 +1 @@
-- Add FTS indexes for searching name and description on entities
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity
FIELDS name
SEARCH ANALYZER app_en_fts_analyzer BM25;
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity
FIELDS description
SEARCH ANALYZER app_en_fts_analyzer BM25;
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk
FIELDS chunk
SEARCH ANALYZER app_en_fts_analyzer BM25;
-- Runtime-managed: FTS indexes now built at startup; migration retained as a no-op.

View File

@@ -0,0 +1,18 @@
-- Remove HNSW indexes from base tables (now created at runtime on *_embedding tables)
REMOVE INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
REMOVE INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
-- Remove FTS indexes (now created at runtime via indexes.rs)
REMOVE INDEX IF EXISTS text_content_fts_text_idx ON text_content;
REMOVE INDEX IF EXISTS text_content_fts_category_idx ON text_content;
REMOVE INDEX IF EXISTS text_content_fts_context_idx ON text_content;
REMOVE INDEX IF EXISTS text_content_fts_file_name_idx ON text_content;
REMOVE INDEX IF EXISTS text_content_fts_url_idx ON text_content;
REMOVE INDEX IF EXISTS text_content_fts_url_title_idx ON text_content;
REMOVE INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
REMOVE INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
REMOVE INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
-- Remove legacy analyzers (recreated at runtime with updated configuration)
REMOVE ANALYZER IF EXISTS app_default_fts_analyzer;
REMOVE ANALYZER IF EXISTS app_en_fts_analyzer;

View File

@@ -0,0 +1,23 @@
-- Move chunk/entity embeddings to dedicated tables for index efficiency.
-- Text chunk embeddings table
DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;
DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;
-- Knowledge entity embeddings table
DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;

View File

@@ -0,0 +1,23 @@
-- Copy embeddings from base tables to dedicated tables
-- This runs BEFORE the field removal migration
FOR $chunk IN (SELECT * FROM text_chunk WHERE embedding != NONE AND array::len(embedding) > 0) {
CREATE text_chunk_embedding CONTENT {
chunk_id: $chunk.id,
embedding: $chunk.embedding,
user_id: $chunk.user_id,
source_id: $chunk.source_id,
created_at: $chunk.created_at,
updated_at: $chunk.updated_at
};
};
FOR $entity IN (SELECT * FROM knowledge_entity WHERE embedding != NONE AND array::len(embedding) > 0) {
CREATE knowledge_entity_embedding CONTENT {
entity_id: $entity.id,
embedding: $entity.embedding,
user_id: $entity.user_id,
created_at: $entity.created_at,
updated_at: $entity.updated_at
};
};

View File

@@ -0,0 +1,3 @@
-- Drop legacy embedding fields from base tables; embeddings now live in *_embedding tables.
REMOVE FIELD IF EXISTS embedding ON TABLE text_chunk;
REMOVE FIELD IF EXISTS embedding ON TABLE knowledge_entity;

View File

@@ -0,0 +1,8 @@
-- Add embedding_backend field to system_settings for visibility of active backend
DEFINE FIELD IF NOT EXISTS embedding_backend ON system_settings TYPE option<string>;
-- Set default to 'openai' for existing installs to preserve backward compatibility
UPDATE system_settings:current SET
embedding_backend = 'openai'
WHERE embedding_backend == NONE;

View File

@@ -0,0 +1,97 @@
-- Enforce SCHEMAFULL on all tables and define missing fields
-- 1. Define missing fields for ingestion_task (formerly job, but now ingestion_task)
DEFINE TABLE OVERWRITE ingestion_task SCHEMAFULL;
-- Core Fields
DEFINE FIELD IF NOT EXISTS id ON ingestion_task TYPE record<ingestion_task>;
DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime DEFAULT time::now();
DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime DEFAULT time::now();
DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;
-- State Machine Fields
DEFINE FIELD IF NOT EXISTS state ON ingestion_task TYPE string ASSERT $value IN ['Pending', 'Reserved', 'Processing', 'Succeeded', 'Failed', 'Cancelled', 'DeadLetter'];
DEFINE FIELD IF NOT EXISTS attempts ON ingestion_task TYPE int DEFAULT 0;
DEFINE FIELD IF NOT EXISTS max_attempts ON ingestion_task TYPE int DEFAULT 3;
DEFINE FIELD IF NOT EXISTS scheduled_at ON ingestion_task TYPE datetime DEFAULT time::now();
DEFINE FIELD IF NOT EXISTS locked_at ON ingestion_task TYPE option<datetime>;
DEFINE FIELD IF NOT EXISTS lease_duration_secs ON ingestion_task TYPE int DEFAULT 300;
DEFINE FIELD IF NOT EXISTS worker_id ON ingestion_task TYPE option<string>;
DEFINE FIELD IF NOT EXISTS error_code ON ingestion_task TYPE option<string>;
DEFINE FIELD IF NOT EXISTS error_message ON ingestion_task TYPE option<string>;
DEFINE FIELD IF NOT EXISTS last_error_at ON ingestion_task TYPE option<datetime>;
DEFINE FIELD IF NOT EXISTS priority ON ingestion_task TYPE int DEFAULT 0;
-- Content Payload (IngestionPayload Enum)
DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;
DEFINE FIELD IF NOT EXISTS content.Url ON ingestion_task TYPE option<object>;
DEFINE FIELD IF NOT EXISTS content.Text ON ingestion_task TYPE option<object>;
DEFINE FIELD IF NOT EXISTS content.File ON ingestion_task TYPE option<object>;
-- Content: Url Variant
DEFINE FIELD IF NOT EXISTS content.Url.url ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.Url.context ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.Url.category ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.Url.user_id ON ingestion_task TYPE string;
-- Content: Text Variant
DEFINE FIELD IF NOT EXISTS content.Text.text ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.Text.context ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.Text.category ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.Text.user_id ON ingestion_task TYPE string;
-- Content: File Variant
DEFINE FIELD IF NOT EXISTS content.File.context ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.File.category ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.File.user_id ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.File.file_info ON ingestion_task TYPE object;
-- Content: File.file_info (FileInfo Struct)
DEFINE FIELD IF NOT EXISTS content.File.file_info.id ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.File.file_info.created_at ON ingestion_task TYPE datetime;
DEFINE FIELD IF NOT EXISTS content.File.file_info.updated_at ON ingestion_task TYPE datetime;
DEFINE FIELD IF NOT EXISTS content.File.file_info.sha256 ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.File.file_info.path ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.File.file_info.file_name ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.File.file_info.mime_type ON ingestion_task TYPE string;
DEFINE FIELD IF NOT EXISTS content.File.file_info.user_id ON ingestion_task TYPE string;
-- 2. Enforce SCHEMAFULL on all other tables
DEFINE TABLE OVERWRITE analytics SCHEMAFULL;
DEFINE TABLE OVERWRITE conversation SCHEMAFULL;
DEFINE TABLE OVERWRITE file SCHEMAFULL;
DEFINE TABLE OVERWRITE knowledge_entity SCHEMAFULL;
DEFINE TABLE OVERWRITE message SCHEMAFULL;
DEFINE TABLE OVERWRITE relates_to SCHEMAFULL TYPE RELATION;
DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
DEFINE TABLE OVERWRITE scratchpad SCHEMAFULL;
DEFINE TABLE OVERWRITE system_settings SCHEMAFULL;
DEFINE TABLE OVERWRITE text_chunk SCHEMAFULL;
-- text_content must have fields defined before enforcing SCHEMAFULL
DEFINE TABLE OVERWRITE text_content SCHEMAFULL;
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;
DEFINE TABLE OVERWRITE user SCHEMAFULL;

View File

@@ -1 +0,0 @@
{"schemas":"--- original\n+++ modified\n@@ -98,7 +98,7 @@\n DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;\n DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;\n-DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at; # For get_latest_knowledge_entities\n+DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;\n\n # Defines the schema for the 'message' table.\n\n@@ -157,6 +157,8 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n\n","events":null}

View File

@@ -1 +0,0 @@
{"schemas":"--- original\n+++ modified\n@@ -51,23 +51,23 @@\n\n # Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n\n-DEFINE TABLE IF NOT EXISTS job SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON job TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON job TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n\n # Custom fields from the IngestionTask struct\n # IngestionPayload is complex, store as object\n-DEFINE FIELD IF NOT EXISTS content ON job TYPE object;\n+DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n # IngestionTaskStatus can hold data (InProgress), store as object\n-DEFINE FIELD IF NOT EXISTS status ON job TYPE object;\n-DEFINE FIELD IF NOT EXISTS user_id ON job TYPE string;\n+DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n+DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n # Indexes explicitly defined in build_indexes and useful for get_unfinished_tasks\n-DEFINE INDEX IF NOT EXISTS idx_job_status ON job FIELDS status;\n-DEFINE INDEX IF NOT EXISTS idx_job_user ON job FIELDS user_id;\n-DEFINE INDEX IF NOT EXISTS idx_job_created ON job FIELDS created_at;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_status ON ingestion_task FIELDS status;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_user ON ingestion_task FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS idx_ingestion_task_created ON ingestion_task FIELDS created_at;\n\n # Defines the schema for the 'knowledge_entity' table.\n\n","events":null}

View File

@@ -1 +0,0 @@
{"schemas":"--- original\n+++ modified\n@@ -57,10 +57,7 @@\n DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n\n-# Custom fields from the IngestionTask struct\n-# IngestionPayload is complex, store as object\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n-# IngestionTaskStatus can hold data (InProgress), store as object\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS user_id ON ingestion_task TYPE string;\n\n@@ -157,10 +154,12 @@\n DEFINE FIELD IF NOT EXISTS require_email_verification ON system_settings TYPE bool;\n DEFINE FIELD IF NOT EXISTS query_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS processing_model ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS image_processing_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS embedding_model ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS embedding_dimensions ON system_settings TYPE int;\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}

View File

@@ -1 +0,0 @@
{"schemas":"--- original\n+++ modified\n@@ -160,6 +160,7 @@\n DEFINE FIELD IF NOT EXISTS query_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS ingestion_system_prompt ON system_settings TYPE string;\n DEFINE FIELD IF NOT EXISTS image_processing_prompt ON system_settings TYPE string;\n+DEFINE FIELD IF NOT EXISTS voice_processing_model ON system_settings TYPE string;\n\n # Defines the schema for the 'text_chunk' table.\n\n","events":null}

View File

@@ -1 +0,0 @@
{"schemas":"--- original\n+++ modified\n@@ -18,8 +18,8 @@\n DEFINE TABLE IF NOT EXISTS conversation SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON conversation TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON conversation TYPE datetime;\n\n # Custom fields from the Conversation struct\n DEFINE FIELD IF NOT EXISTS user_id ON conversation TYPE string;\n@@ -34,8 +34,8 @@\n DEFINE TABLE IF NOT EXISTS file SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON file TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON file TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON file TYPE datetime;\n\n # Custom fields from the FileInfo struct\n DEFINE FIELD IF NOT EXISTS sha256 ON file TYPE string;\n@@ -54,8 +54,8 @@\n DEFINE TABLE IF NOT EXISTS ingestion_task SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON ingestion_task TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON ingestion_task TYPE datetime;\n\n DEFINE FIELD IF NOT EXISTS content ON ingestion_task TYPE object;\n DEFINE FIELD IF NOT EXISTS status ON ingestion_task TYPE object;\n@@ -71,8 +71,8 @@\n DEFINE TABLE IF NOT EXISTS knowledge_entity SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity TYPE datetime;\n\n # Custom fields from the KnowledgeEntity struct\n DEFINE FIELD IF NOT EXISTS source_id ON knowledge_entity TYPE string;\n@@ -102,8 +102,8 @@\n DEFINE TABLE IF NOT EXISTS message SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON message TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON message TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON message TYPE datetime;\n\n # Custom fields from the Message struct\n DEFINE FIELD IF NOT EXISTS conversation_id ON message TYPE string;\n@@ -167,8 +167,8 @@\n DEFINE TABLE IF NOT EXISTS text_chunk SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_chunk TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;\n\n # Custom fields from the TextChunk struct\n DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;\n@@ -191,8 +191,8 @@\n DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;\n\n # Custom fields from the TextContent struct\n DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;\n@@ -215,8 +215,8 @@\n DEFINE TABLE IF NOT EXISTS user SCHEMALESS;\n\n # Standard fields\n-DEFINE FIELD IF NOT EXISTS created_at ON user TYPE string;\n-DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE string;\n+DEFINE FIELD IF NOT EXISTS created_at ON user TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON user TYPE datetime;\n\n # Custom fields from the User struct\n DEFINE FIELD IF NOT EXISTS email ON user TYPE string;\n","events":null}

View File

@@ -1 +0,0 @@
{"schemas":"--- original\n+++ modified\n@@ -137,6 +137,30 @@\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_source_id_idx ON relates_to FIELDS metadata.source_id;\n DEFINE INDEX IF NOT EXISTS relates_to_metadata_user_id_idx ON relates_to FIELDS metadata.user_id;\n\n+# Defines the schema for the 'scratchpad' table.\n+\n+DEFINE TABLE IF NOT EXISTS scratchpad SCHEMALESS;\n+\n+# Standard fields from stored_object! macro\n+DEFINE FIELD IF NOT EXISTS created_at ON scratchpad TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS updated_at ON scratchpad TYPE datetime;\n+\n+# Custom fields from the Scratchpad struct\n+DEFINE FIELD IF NOT EXISTS user_id ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS title ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS content ON scratchpad TYPE string;\n+DEFINE FIELD IF NOT EXISTS last_saved_at ON scratchpad TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS is_dirty ON scratchpad TYPE bool DEFAULT false;\n+DEFINE FIELD IF NOT EXISTS is_archived ON scratchpad TYPE bool DEFAULT false;\n+DEFINE FIELD IF NOT EXISTS archived_at ON scratchpad TYPE option<datetime>;\n+DEFINE FIELD IF NOT EXISTS ingested_at ON scratchpad TYPE option<datetime>;\n+\n+# Indexes based on query patterns\n+DEFINE INDEX IF NOT EXISTS scratchpad_user_idx ON scratchpad FIELDS user_id;\n+DEFINE INDEX IF NOT EXISTS scratchpad_user_archived_idx ON scratchpad FIELDS user_id, is_archived;\n+DEFINE INDEX IF NOT EXISTS scratchpad_updated_idx ON scratchpad FIELDS updated_at;\n+DEFINE INDEX IF NOT EXISTS scratchpad_archived_idx ON scratchpad FIELDS archived_at;\n+\n DEFINE TABLE OVERWRITE script_migration SCHEMAFULL\n PERMISSIONS\n FOR select FULL\n","events":null}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"schemas":"--- original\n+++ modified\n@@ -242,7 +242,7 @@\n\n # Defines the schema for the 'text_content' table.\n\n-DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS text_content SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n@@ -254,10 +254,24 @@\n DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;\n # UrlInfo is a struct, store as object\n DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;\n+DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;\n+\n DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;\n DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;\n\n+# FileInfo fields\n+DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;\n+\n # Indexes based on query patterns\n DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;\n","events":null}

File diff suppressed because one or more lines are too long

View File

@@ -15,16 +15,12 @@ DEFINE FIELD IF NOT EXISTS entity_type ON knowledge_entity TYPE string;
# metadata is Option<serde_json::Value>, store as object
DEFINE FIELD IF NOT EXISTS metadata ON knowledge_entity TYPE option<object>;
# Define embedding as a standard array of floats for schema definition
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity TYPE array<float>;
# The specific vector nature is handled by the index definition below
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity TYPE string;
# Indexes based on build_indexes and query patterns
# The INDEX definition correctly specifies the vector properties
DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
-- Indexes based on build_indexes and query patterns
-- HNSW index now defined on knowledge_entity_embedding table for better memory usage
-- DEFINE INDEX IF NOT EXISTS idx_embedding_entities ON knowledge_entity FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS knowledge_entity_source_id_idx ON knowledge_entity FIELDS source_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_user_id_idx ON knowledge_entity FIELDS user_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_entity_type_idx ON knowledge_entity FIELDS entity_type;
DEFINE INDEX IF NOT EXISTS knowledge_entity_created_at_idx ON knowledge_entity FIELDS created_at;

View File

@@ -0,0 +1,18 @@
-- Defines the schema for the 'knowledge_entity_embedding' table.
-- Separate table to optimize HNSW index creation memory usage
DEFINE TABLE IF NOT EXISTS knowledge_entity_embedding SCHEMAFULL;
-- Standard fields
DEFINE FIELD IF NOT EXISTS created_at ON knowledge_entity_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS updated_at ON knowledge_entity_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS user_id ON knowledge_entity_embedding TYPE string;
-- Custom fields
DEFINE FIELD IF NOT EXISTS entity_id ON knowledge_entity_embedding TYPE record<knowledge_entity>;
DEFINE FIELD IF NOT EXISTS embedding ON knowledge_entity_embedding TYPE array<float>;
-- Indexes
-- DEFINE INDEX IF NOT EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_entity_id_idx ON knowledge_entity_embedding FIELDS entity_id;
DEFINE INDEX IF NOT EXISTS knowledge_entity_embedding_user_id_idx ON knowledge_entity_embedding FIELDS user_id;

View File

@@ -10,14 +10,8 @@ DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk TYPE datetime;
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk TYPE string;
DEFINE FIELD IF NOT EXISTS chunk ON text_chunk TYPE string;
# Define embedding as a standard array of floats for schema definition
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk TYPE array<float>;
# The specific vector nature is handled by the index definition below
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk TYPE string;
# Indexes based on build_indexes and query patterns (delete_by_source_id)
# The INDEX definition correctly specifies the vector properties
DEFINE INDEX IF NOT EXISTS idx_embedding_chunks ON text_chunk FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS text_chunk_source_id_idx ON text_chunk FIELDS source_id;
DEFINE INDEX IF NOT EXISTS text_chunk_user_id_idx ON text_chunk FIELDS user_id;

View File

@@ -0,0 +1,20 @@
-- Defines the schema for the 'text_chunk_embedding' table.
-- Separate table to optimize HNSW index creation memory usage
DEFINE TABLE IF NOT EXISTS text_chunk_embedding SCHEMAFULL;
# Standard fields
DEFINE FIELD IF NOT EXISTS created_at ON text_chunk_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS updated_at ON text_chunk_embedding TYPE datetime;
DEFINE FIELD IF NOT EXISTS user_id ON text_chunk_embedding TYPE string;
DEFINE FIELD IF NOT EXISTS source_id ON text_chunk_embedding TYPE string;
# Custom fields
DEFINE FIELD IF NOT EXISTS chunk_id ON text_chunk_embedding TYPE record<text_chunk>;
DEFINE FIELD IF NOT EXISTS embedding ON text_chunk_embedding TYPE array<float>;
-- Indexes
-- DEFINE INDEX IF NOT EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536;
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_chunk_id_idx ON text_chunk_embedding FIELDS chunk_id;
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_user_id_idx ON text_chunk_embedding FIELDS user_id;
DEFINE INDEX IF NOT EXISTS text_chunk_embedding_source_id_idx ON text_chunk_embedding FIELDS source_id;

View File

@@ -1,6 +1,6 @@
# Defines the schema for the 'text_content' table.
DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;
DEFINE TABLE IF NOT EXISTS text_content SCHEMAFULL;
# Standard fields
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
@@ -12,10 +12,24 @@ DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
# UrlInfo is a struct, store as object
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
# FileInfo fields
DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;
# Indexes based on query patterns
DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;
DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;

View File

@@ -5,6 +5,7 @@ use tokio::task::JoinError;
use crate::storage::types::file_info::FileError;
// Core internal errors
#[allow(clippy::module_name_repetitions)]
#[derive(Error, Debug)]
pub enum AppError {
#[error("Database error: {0}")]

View File

@@ -1,3 +1,5 @@
#![allow(clippy::doc_markdown)]
//! Shared utilities and storage helpers for the workspace crates.
pub mod error;
pub mod storage;
pub mod utils;

View File

@@ -7,18 +7,20 @@ use include_dir::{include_dir, Dir};
use std::{ops::Deref, sync::Arc};
use surrealdb::{
engine::any::{connect, Any},
opt::auth::Root,
opt::auth::{Namespace, Root},
Error, Notification, Surreal,
};
use surrealdb_migrations::MigrationRunner;
use tracing::debug;
/// Embedded SurrealDB migration directory packaged with the crate.
static MIGRATIONS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/");
#[derive(Clone)]
pub struct SurrealDbClient {
pub client: Surreal<Any>,
}
#[allow(clippy::module_name_repetitions)]
pub trait ProvidesDb {
fn db(&self) -> &Arc<SurrealDbClient>;
}
@@ -48,6 +50,24 @@ impl SurrealDbClient {
Ok(SurrealDbClient { client: db })
}
pub async fn new_with_namespace_user(
address: &str,
namespace: &str,
username: &str,
password: &str,
database: &str,
) -> Result<Self, Error> {
let db = connect(address).await?;
db.signin(Namespace {
namespace,
username,
password,
})
.await?;
db.use_ns(namespace).use_db(database).await?;
Ok(SurrealDbClient { client: db })
}
pub async fn create_session_store(
&self,
) -> Result<SessionStore<SessionSurrealPool<Any>>, SessionError> {
@@ -77,24 +97,6 @@ impl SurrealDbClient {
Ok(())
}
/// Operation to rebuild indexes
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
debug!("Rebuilding indexes");
let rebuild_sql = r#"
BEGIN TRANSACTION;
REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
COMMIT TRANSACTION;
"#;
self.client.query(rebuild_sql).await?;
Ok(())
}
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
///
/// # Arguments
@@ -112,6 +114,19 @@ impl SurrealDbClient {
.await
}
/// Operation to upsert an object in SurrealDB, replacing any existing record
/// with the same ID. Useful for idempotent ingestion flows.
pub async fn upsert_item<T>(&self, item: T) -> Result<Option<T>, Error>
where
T: StoredObject + Send + Sync + 'static,
{
let id = item.get_id().to_string();
self.client
.upsert((T::table_name(), id))
.content(item)
.await
}
/// Operation to retrieve all objects from a certain table, requires the struct to implement StoredObject
///
/// # Returns
@@ -250,6 +265,56 @@ mod tests {
assert!(fetch_post.is_none());
}
#[tokio::test]
async fn upsert_item_overwrites_existing_records() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to initialize schema");
let mut dummy = Dummy {
id: "abc".to_string(),
name: "first".to_string(),
created_at: Utc::now(),
updated_at: Utc::now(),
};
db.store_item(dummy.clone())
.await
.expect("Failed to store initial record");
dummy.name = "updated".to_string();
let upserted = db
.upsert_item(dummy.clone())
.await
.expect("Failed to upsert record");
assert!(upserted.is_some());
let fetched: Option<Dummy> = db.get_item(&dummy.id).await.expect("fetch after upsert");
assert_eq!(fetched.unwrap().name, "updated");
let new_record = Dummy {
id: "def".to_string(),
name: "brand-new".to_string(),
created_at: Utc::now(),
updated_at: Utc::now(),
};
db.upsert_item(new_record.clone())
.await
.expect("Failed to upsert new record");
let fetched_new: Option<Dummy> = db
.get_item(&new_record.id)
.await
.expect("fetch inserted via upsert");
assert_eq!(fetched_new, Some(new_record));
}
#[tokio::test]
async fn test_applying_migrations() {
let namespace = "test_ns";

View File

@@ -0,0 +1,795 @@
use std::time::Duration;
use anyhow::{Context, Result};
use futures::future::try_join_all;
use serde::Deserialize;
use serde_json::{Map, Value};
use tracing::{debug, info, warn};
use crate::{error::AppError, storage::db::SurrealDbClient};
const INDEX_POLL_INTERVAL: Duration = Duration::from_millis(50);
const FTS_ANALYZER_NAME: &str = "app_en_fts_analyzer";
#[derive(Clone, Copy)]
struct HnswIndexSpec {
index_name: &'static str,
table: &'static str,
options: &'static str,
}
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
[
HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
HnswIndexSpec {
index_name: "idx_embedding_knowledge_entity_embedding",
table: "knowledge_entity_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
]
}
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
[
FtsIndexSpec {
index_name: "text_content_fts_idx",
table: "text_content",
field: "text",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_context_fts_idx",
table: "text_content",
field: "context",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_file_name_fts_idx",
table: "text_content",
field: "file_info.file_name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_fts_idx",
table: "text_content",
field: "url_info.url",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_title_fts_idx",
table: "text_content",
field: "url_info.title",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_name_idx",
table: "knowledge_entity",
field: "name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_description_idx",
table: "knowledge_entity",
field: "description",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_chunk_fts_chunk_idx",
table: "text_chunk",
field: "chunk",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
]
}
impl HnswIndexSpec {
fn definition_if_not_exists(&self, dimension: usize) -> String {
format!(
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} \
FIELDS embedding HNSW DIMENSION {dimension} {options};",
index = self.index_name,
table = self.table,
dimension = dimension,
options = self.options,
)
}
fn definition_overwrite(&self, dimension: usize) -> String {
format!(
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} \
FIELDS embedding HNSW DIMENSION {dimension} {options};",
index = self.index_name,
table = self.table,
dimension = dimension,
options = self.options,
)
}
}
#[derive(Clone, Copy)]
struct FtsIndexSpec {
index_name: &'static str,
table: &'static str,
field: &'static str,
analyzer: Option<&'static str>,
method: &'static str,
}
impl FtsIndexSpec {
fn definition(&self) -> String {
let analyzer_clause = self
.analyzer
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
.unwrap_or_default();
format!(
"DEFINE INDEX IF NOT EXISTS {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
index = self.index_name,
table = self.table,
field = self.field,
)
}
fn overwrite_definition(&self) -> String {
let analyzer_clause = self
.analyzer
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
.unwrap_or_default();
format!(
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
index = self.index_name,
table = self.table,
field = self.field,
)
}
}
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
/// Idempotent: safe to call multiple times and will overwrite HNSW definitions when the dimension changes.
pub async fn ensure_runtime_indexes(
db: &SurrealDbClient,
embedding_dimension: usize,
) -> Result<(), AppError> {
ensure_runtime_indexes_inner(db, embedding_dimension)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
}
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_indexes_inner(db)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
}
async fn ensure_runtime_indexes_inner(
db: &SurrealDbClient,
embedding_dimension: usize,
) -> Result<()> {
create_fts_analyzer(db).await?;
for spec in fts_index_specs() {
if index_exists(db, spec.table, spec.index_name).await? {
continue;
}
// We need to create these sequentially otherwise SurrealDB errors with read/write clash
create_index_with_polling(
db,
spec.definition(),
spec.index_name,
spec.table,
Some(spec.table),
)
.await?;
}
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
match hnsw_index_state(db, &spec, embedding_dimension).await? {
HnswIndexState::Missing => {
create_index_with_polling(
db,
spec.definition_if_not_exists(embedding_dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
}
HnswIndexState::Matches => {
let status = get_index_status(db, spec.index_name, spec.table).await?;
if status.eq_ignore_ascii_case("error") {
warn!(
index = spec.index_name,
table = spec.table,
"HNSW index found in error state; triggering rebuild"
);
create_index_with_polling(
db,
spec.definition_overwrite(embedding_dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
} else {
Ok(())
}
}
HnswIndexState::Different(existing) => {
info!(
index = spec.index_name,
table = spec.table,
existing_dimension = existing,
target_dimension = embedding_dimension,
"Overwriting HNSW index to match new embedding dimension"
);
create_index_with_polling(
db,
spec.definition_overwrite(embedding_dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
}
}
});
try_join_all(hnsw_tasks).await.map(|_| ())?;
Ok(())
}
async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -> Result<String> {
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
let mut info_res = db
.client
.query(info_query)
.await
.context("checking index status")?;
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
let info = match info {
Some(i) => i,
None => return Ok("unknown".to_string()),
};
let building = info.get("building");
let status = building
.and_then(|b| b.get("status"))
.and_then(|s| s.as_str())
.unwrap_or("ready")
.to_string();
Ok(status)
}
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
debug!("Rebuilding indexes with concurrent definitions");
create_fts_analyzer(db).await?;
for spec in fts_index_specs() {
if !index_exists(db, spec.table, spec.index_name).await? {
debug!(
index = spec.index_name,
table = spec.table,
"Skipping FTS rebuild because index is missing"
);
continue;
}
create_index_with_polling(
db,
spec.overwrite_definition(),
spec.index_name,
spec.table,
Some(spec.table),
)
.await?;
}
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
if !index_exists(db, spec.table, spec.index_name).await? {
debug!(
index = spec.index_name,
table = spec.table,
"Skipping HNSW rebuild because index is missing"
);
return Ok(());
}
let Some(dimension) = existing_hnsw_dimension(db, &spec).await? else {
warn!(
index = spec.index_name,
table = spec.table,
"HNSW index missing dimension; skipping rebuild"
);
return Ok(());
};
create_index_with_polling(
db,
spec.definition_overwrite(dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
});
try_join_all(hnsw_tasks).await.map(|_| ())
}
async fn existing_hnsw_dimension(
db: &SurrealDbClient,
spec: &HnswIndexSpec,
) -> Result<Option<usize>> {
let Some(indexes) = table_index_definitions(db, spec.table).await? else {
return Ok(None);
};
let Some(definition) = indexes
.get(spec.index_name)
.and_then(|details| details.get("Strand"))
.and_then(|v| v.as_str())
else {
return Ok(None);
};
Ok(extract_dimension(definition).and_then(|d| usize::try_from(d).ok()))
}
async fn hnsw_index_state(
db: &SurrealDbClient,
spec: &HnswIndexSpec,
expected_dimension: usize,
) -> Result<HnswIndexState> {
match existing_hnsw_dimension(db, spec).await? {
None => Ok(HnswIndexState::Missing),
Some(current_dimension) if current_dimension == expected_dimension => {
Ok(HnswIndexState::Matches)
}
Some(current_dimension) => Ok(HnswIndexState::Different(current_dimension as u64)),
}
}
enum HnswIndexState {
Missing,
Matches,
Different(u64),
}
fn extract_dimension(definition: &str) -> Option<u64> {
definition
.split("DIMENSION")
.nth(1)
.and_then(|rest| rest.split_whitespace().next())
.and_then(|token| token.trim_end_matches(';').parse::<u64>().ok())
}
async fn create_fts_analyzer(db: &SurrealDbClient) -> Result<()> {
// Prefer snowball stemming when supported; fall back to ascii-only when the filter
// is unavailable in the running Surreal build. Use IF NOT EXISTS to avoid clobbering
// an existing analyzer definition.
let snowball_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
TOKENIZERS class
FILTERS lowercase, ascii, snowball(english);",
analyzer = FTS_ANALYZER_NAME
);
match db.client.query(snowball_query).await {
Ok(res) => {
if res.check().is_ok() {
return Ok(());
}
warn!(
"Snowball analyzer check failed; attempting ascii fallback definition (analyzer: {})",
FTS_ANALYZER_NAME
);
}
Err(err) => {
warn!(
error = %err,
"Snowball analyzer creation errored; attempting ascii fallback definition"
);
}
}
let fallback_query = format!(
"DEFINE ANALYZER IF NOT EXISTS {analyzer}
TOKENIZERS class
FILTERS lowercase, ascii;",
analyzer = FTS_ANALYZER_NAME
);
let res = db
.client
.query(fallback_query)
.await
.context("creating fallback FTS analyzer")?;
if let Err(err) = res.check() {
warn!(
error = %err,
"Fallback analyzer creation failed; FTS will run without snowball/ascii analyzer ({})",
FTS_ANALYZER_NAME
);
return Err(err).context("failed to create fallback FTS analyzer");
}
warn!(
"Snowball analyzer unavailable; using fallback analyzer ({}) with lowercase+ascii only",
FTS_ANALYZER_NAME
);
Ok(())
}
async fn create_index_with_polling(
db: &SurrealDbClient,
definition: String,
index_name: &str,
table: &str,
progress_table: Option<&str>,
) -> Result<()> {
let expected_total = match progress_table {
Some(table) => Some(count_table_rows(db, table).await.with_context(|| {
format!("counting rows in {table} for index {index_name} progress")
})?),
None => None,
};
let mut attempts = 0;
const MAX_ATTEMPTS: usize = 3;
loop {
attempts += 1;
let res = db
.client
.query(definition.clone())
.await
.with_context(|| format!("creating index {index_name} on table {table}"))?;
match res.check() {
Ok(_) => break,
Err(err) => {
let msg = err.to_string();
let conflict = msg.contains("read or write conflict");
warn!(
index = %index_name,
table = %table,
error = ?err,
attempt = attempts,
definition = %definition,
"Index definition failed"
);
if conflict && attempts < MAX_ATTEMPTS {
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
return Err(err).with_context(|| {
format!("index definition failed for {index_name} on {table}")
});
}
}
}
debug!(
index = %index_name,
table = %table,
expected_rows = ?expected_total,
"Index definition submitted; waiting for build to finish"
);
poll_index_build_status(db, index_name, table, expected_total, INDEX_POLL_INTERVAL).await
}
async fn poll_index_build_status(
db: &SurrealDbClient,
index_name: &str,
table: &str,
total_rows: Option<u64>,
poll_every: Duration,
) -> Result<()> {
let started_at = std::time::Instant::now();
loop {
tokio::time::sleep(poll_every).await;
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
let mut info_res =
db.client.query(info_query).await.with_context(|| {
format!("checking index build status for {index_name} on {table}")
})?;
let info: Option<Value> = info_res
.take(0)
.context("failed to deserialize INFO FOR INDEX result")?;
let Some(snapshot) = parse_index_build_info(info, total_rows) else {
warn!(
index = %index_name,
table = %table,
"INFO FOR INDEX returned no data; assuming index definition might be missing"
);
break;
};
match snapshot.progress_pct {
Some(pct) => debug!(
index = %index_name,
table = %table,
status = snapshot.status,
initial = snapshot.initial,
pending = snapshot.pending,
updated = snapshot.updated,
processed = snapshot.processed,
total = snapshot.total_rows,
progress_pct = format_args!("{pct:.1}"),
"Index build status"
),
None => debug!(
index = %index_name,
table = %table,
status = snapshot.status,
initial = snapshot.initial,
pending = snapshot.pending,
updated = snapshot.updated,
processed = snapshot.processed,
"Index build status"
),
}
if snapshot.is_ready() {
debug!(
index = %index_name,
table = %table,
elapsed = ?started_at.elapsed(),
processed = snapshot.processed,
total = snapshot.total_rows,
"Index is ready"
);
break;
}
if snapshot.status.eq_ignore_ascii_case("error") {
warn!(
index = %index_name,
table = %table,
status = snapshot.status,
"Index build reported error status; stopping polling"
);
break;
}
}
Ok(())
}
#[derive(Debug, PartialEq)]
struct IndexBuildSnapshot {
status: String,
initial: u64,
pending: u64,
updated: u64,
processed: u64,
total_rows: Option<u64>,
progress_pct: Option<f64>,
}
impl IndexBuildSnapshot {
fn is_ready(&self) -> bool {
self.status.eq_ignore_ascii_case("ready")
}
}
fn parse_index_build_info(
info: Option<Value>,
total_rows: Option<u64>,
) -> Option<IndexBuildSnapshot> {
let info = info?;
let building = info.get("building");
let status = building
.and_then(|b| b.get("status"))
.and_then(|s| s.as_str())
// If there's no `building` block at all, treat as "ready" (index not building anymore)
.unwrap_or("ready")
.to_string();
let initial = building
.and_then(|b| b.get("initial"))
.and_then(|v| v.as_u64())
.unwrap_or(0);
let pending = building
.and_then(|b| b.get("pending"))
.and_then(|v| v.as_u64())
.unwrap_or(0);
let updated = building
.and_then(|b| b.get("updated"))
.and_then(|v| v.as_u64())
.unwrap_or(0);
// `initial` is the number of rows seen when the build started; `updated` accounts for later writes.
let processed = initial.saturating_add(updated);
let progress_pct = total_rows.map(|total| {
if total == 0 {
0.0
} else {
((processed as f64 / total as f64).min(1.0)) * 100.0
}
});
Some(IndexBuildSnapshot {
status,
initial,
pending,
updated,
processed,
total_rows,
progress_pct,
})
}
#[derive(Debug, Deserialize)]
struct CountRow {
count: u64,
}
async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
let query = format!("SELECT count() AS count FROM {table} GROUP ALL;");
let mut response = db
.client
.query(query)
.await
.with_context(|| format!("counting rows in {table}"))?;
let rows: Vec<CountRow> = response
.take(0)
.context("failed to deserialize count() response")?;
Ok(rows.first().map_or(0, |r| r.count))
}
async fn table_index_definitions(
db: &SurrealDbClient,
table: &str,
) -> Result<Option<Map<String, Value>>> {
let info_query = format!("INFO FOR TABLE {table};");
let mut response = db
.client
.query(info_query)
.await
.with_context(|| format!("fetching table info for {}", table))?;
let info: surrealdb::Value = response
.take(0)
.context("failed to take table info response")?;
let info_json: Value =
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
Ok(info_json
.get("Object")
.and_then(|o| o.get("indexes"))
.and_then(|i| i.get("Object"))
.and_then(|i| i.as_object())
.cloned())
}
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
let Some(indexes) = table_index_definitions(db, table).await? else {
return Ok(false);
};
Ok(indexes.contains_key(index_name))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use uuid::Uuid;
#[test]
fn parse_index_build_info_reports_progress() {
let info = json!({
"building": {
"initial": 56894,
"pending": 0,
"status": "indexing",
"updated": 0
}
});
let snapshot = parse_index_build_info(Some(info), Some(61081)).expect("snapshot");
assert_eq!(
snapshot,
IndexBuildSnapshot {
status: "indexing".to_string(),
initial: 56894,
pending: 0,
updated: 0,
processed: 56894,
total_rows: Some(61081),
progress_pct: Some((56894_f64 / 61081_f64) * 100.0),
}
);
assert!(!snapshot.is_ready());
}
#[test]
fn parse_index_build_info_defaults_to_ready_when_no_building_block() {
// Surreal returns `{}` when the index exists but isn't building.
let info = json!({});
let snapshot = parse_index_build_info(Some(info), Some(10)).expect("snapshot");
assert!(snapshot.is_ready());
assert_eq!(snapshot.processed, 0);
assert_eq!(snapshot.progress_pct, Some(0.0));
}
#[test]
fn extract_dimension_parses_value() {
let definition = "DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32 EFC 100 M 8;";
assert_eq!(extract_dimension(definition), Some(1536));
}
#[tokio::test]
async fn ensure_runtime_indexes_is_idempotent() {
let namespace = "indexes_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("in-memory db");
db.apply_migrations()
.await
.expect("migrations should succeed");
// First run creates everything
ensure_runtime_indexes(&db, 1536)
.await
.expect("initial index creation");
// Second run should be a no-op and still succeed
ensure_runtime_indexes(&db, 1536)
.await
.expect("second index creation");
}
#[tokio::test]
async fn ensure_hnsw_index_overwrites_dimension() {
let namespace = "indexes_dim";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("in-memory db");
db.apply_migrations()
.await
.expect("migrations should succeed");
// Create initial index with default dimension
ensure_runtime_indexes(&db, 1536)
.await
.expect("initial index creation");
// Change dimension and ensure overwrite path is exercised
ensure_runtime_indexes(&db, 128)
.await
.expect("overwritten index creation");
}
}

View File

@@ -1,3 +1,4 @@
pub mod db;
pub mod indexes;
pub mod store;
pub mod types;

View File

@@ -1,4 +1,5 @@
use std::path::{Path, PathBuf};
use std::io::ErrorKind;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use anyhow::{anyhow, Result as AnyResult};
@@ -6,36 +7,424 @@ use bytes::Bytes;
use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt};
use object_store::local::LocalFileSystem;
use object_store::memory::InMemory;
use object_store::{path::Path as ObjPath, ObjectStore};
use crate::utils::config::{AppConfig, StorageKind};
pub type DynStore = Arc<dyn ObjectStore>;
/// Build an object store instance anchored at the given filesystem `prefix`.
/// Storage manager with persistent state and proper lifecycle management.
#[derive(Clone)]
pub struct StorageManager {
// Store from objectstore wrapped as dyn
store: DynStore,
// Simple enum to track which kind
backend_kind: StorageKind,
// Where on disk
local_base: Option<PathBuf>,
}
impl StorageManager {
/// Create a new StorageManager with the specified configuration.
///
/// This method validates the configuration and creates the appropriate
/// storage backend with proper initialization.
pub async fn new(cfg: &AppConfig) -> object_store::Result<Self> {
let backend_kind = cfg.storage.clone();
let (store, local_base) = create_storage_backend(cfg).await?;
Ok(Self {
store,
backend_kind,
local_base,
})
}
/// Create a StorageManager with a custom storage backend.
///
/// This method is useful for testing scenarios where you want to inject
/// a specific storage backend.
pub fn with_backend(store: DynStore, backend_kind: StorageKind) -> Self {
Self {
store,
backend_kind,
local_base: None,
}
}
/// Get the storage backend kind.
pub fn backend_kind(&self) -> &StorageKind {
&self.backend_kind
}
/// Access the resolved local base directory when using the local backend.
pub fn local_base_path(&self) -> Option<&Path> {
self.local_base.as_deref()
}
/// Resolve an object location to a filesystem path when using the local backend.
///
/// Returns `None` when the backend is not local or when the provided location includes
/// unsupported components (absolute paths or parent traversals).
pub fn resolve_local_path(&self, location: &str) -> Option<PathBuf> {
let base = self.local_base_path()?;
let relative = Path::new(location);
if relative.is_absolute()
|| relative
.components()
.any(|component| matches!(component, Component::ParentDir | Component::Prefix(_)))
{
return None;
}
Some(base.join(relative))
}
/// Store bytes at the specified location.
///
/// This operation persists data using the underlying storage backend.
/// For memory backends, data persists for the lifetime of the StorageManager.
pub async fn put(&self, location: &str, data: Bytes) -> object_store::Result<()> {
let path = ObjPath::from(location);
let payload = object_store::PutPayload::from_bytes(data);
self.store.put(&path, payload).await.map(|_| ())
}
/// Retrieve bytes from the specified location.
///
/// Returns the full contents buffered in memory.
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
let path = ObjPath::from(location);
let result = self.store.get(&path).await?;
result.bytes().await
}
/// Get a streaming handle for large objects.
///
/// Returns a fallible stream of Bytes chunks suitable for large file processing.
pub async fn get_stream(
&self,
location: &str,
) -> object_store::Result<BoxStream<'static, object_store::Result<Bytes>>> {
let path = ObjPath::from(location);
let result = self.store.get(&path).await?;
Ok(result.into_stream())
}
/// Delete all objects below the specified prefix.
///
/// For local filesystem backends, this also attempts to clean up empty directories.
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
let prefix_path = ObjPath::from(prefix);
let locations = self
.store
.list(Some(&prefix_path))
.map_ok(|m| m.location)
.boxed();
self.store
.delete_stream(locations)
.try_collect::<Vec<_>>()
.await?;
// Cleanup filesystem directories only for local backend
if matches!(self.backend_kind, StorageKind::Local) {
self.cleanup_filesystem_directories(prefix).await?;
}
Ok(())
}
/// List all objects below the specified prefix.
pub async fn list(
&self,
prefix: Option<&str>,
) -> object_store::Result<Vec<object_store::ObjectMeta>> {
let prefix_path = prefix.map(ObjPath::from);
self.store.list(prefix_path.as_ref()).try_collect().await
}
/// Check if an object exists at the specified location.
pub async fn exists(&self, location: &str) -> object_store::Result<bool> {
let path = ObjPath::from(location);
self.store
.head(&path)
.await
.map(|_| true)
.or_else(|e| match e {
object_store::Error::NotFound { .. } => Ok(false),
_ => Err(e),
})
}
/// Cleanup filesystem directories for local backend.
///
/// This is a best-effort cleanup and ignores errors.
async fn cleanup_filesystem_directories(&self, prefix: &str) -> object_store::Result<()> {
if !matches!(self.backend_kind, StorageKind::Local) {
return Ok(());
}
let Some(base) = &self.local_base else {
return Ok(());
};
let relative = Path::new(prefix);
if relative.is_absolute()
|| relative
.components()
.any(|component| matches!(component, Component::ParentDir | Component::Prefix(_)))
{
tracing::warn!(
prefix = %prefix,
"Skipping directory cleanup for unsupported prefix components"
);
return Ok(());
}
let mut current = base.join(relative);
while current.starts_with(base) && current.as_path() != base.as_path() {
match tokio::fs::remove_dir(&current).await {
Ok(()) => {}
Err(err) => match err.kind() {
ErrorKind::NotFound => {}
ErrorKind::DirectoryNotEmpty => break,
_ => tracing::debug!(
error = %err,
path = %current.display(),
"Failed to remove directory during cleanup"
),
},
}
if let Some(parent) = current.parent() {
current = parent.to_path_buf();
} else {
break;
}
}
Ok(())
}
}
/// Create a storage backend based on configuration.
///
/// - For the `Local` backend, `prefix` is the absolute directory on disk that
/// serves as the root for all object paths passed to the store.
/// - `prefix` must already exist; this function will create it if missing.
///
/// Example (Local):
/// - prefix: `/var/data`
/// - object location: `user/uuid/file.txt`
/// - absolute path: `/var/data/user/uuid/file.txt`
pub async fn build_store(prefix: &Path, cfg: &AppConfig) -> object_store::Result<DynStore> {
/// This factory function handles the creation and initialization of different
/// storage backends with proper error handling and validation.
async fn create_storage_backend(
cfg: &AppConfig,
) -> object_store::Result<(DynStore, Option<PathBuf>)> {
match cfg.storage {
StorageKind::Local => {
if !prefix.exists() {
tokio::fs::create_dir_all(prefix).await.map_err(|e| {
let base = resolve_base_dir(cfg);
if !base.exists() {
tokio::fs::create_dir_all(&base).await.map_err(|e| {
object_store::Error::Generic {
store: "LocalFileSystem",
source: e.into(),
}
})?;
}
let store = LocalFileSystem::new_with_prefix(prefix)?;
Ok(Arc::new(store))
let store = LocalFileSystem::new_with_prefix(base.clone())?;
Ok((Arc::new(store), Some(base)))
}
StorageKind::Memory => {
let store = InMemory::new();
Ok((Arc::new(store), None))
}
}
}
/// Testing utilities for storage operations.
///
/// This module provides specialized utilities for testing scenarios with
/// automatic memory backend setup and proper test isolation.
#[cfg(test)]
pub mod testing {
use super::*;
use crate::utils::config::{AppConfig, PdfIngestMode};
use uuid;
/// Create a test configuration with memory storage.
///
/// This provides a ready-to-use configuration for testing scenarios
/// that don't require filesystem persistence.
pub fn test_config_memory() -> AppConfig {
AppConfig {
openai_api_key: "test".into(),
surrealdb_address: "test".into(),
surrealdb_username: "test".into(),
surrealdb_password: "test".into(),
surrealdb_namespace: "test".into(),
surrealdb_database: "test".into(),
data_dir: "/tmp/unused".into(), // Ignored for memory storage
http_port: 0,
openai_base_url: "..".into(),
storage: StorageKind::Memory,
pdf_ingest_mode: PdfIngestMode::LlmFirst,
..Default::default()
}
}
/// Create a test configuration with local storage.
///
/// This provides a ready-to-use configuration for testing scenarios
/// that require actual filesystem operations.
pub fn test_config_local() -> AppConfig {
let base = format!("/tmp/minne_test_storage_{}", uuid::Uuid::new_v4());
AppConfig {
openai_api_key: "test".into(),
surrealdb_address: "test".into(),
surrealdb_username: "test".into(),
surrealdb_password: "test".into(),
surrealdb_namespace: "test".into(),
surrealdb_database: "test".into(),
data_dir: base.into(),
http_port: 0,
openai_base_url: "..".into(),
storage: StorageKind::Local,
pdf_ingest_mode: PdfIngestMode::LlmFirst,
..Default::default()
}
}
/// A specialized StorageManager for testing scenarios.
///
/// This provides automatic setup for memory storage with proper isolation
/// and cleanup capabilities for test environments.
#[derive(Clone)]
pub struct TestStorageManager {
storage: StorageManager,
_temp_dir: Option<(String, std::path::PathBuf)>, // For local storage cleanup
}
impl TestStorageManager {
/// Create a new TestStorageManager with memory backend.
///
/// This is the preferred method for unit tests as it provides
/// fast execution and complete isolation.
pub async fn new_memory() -> object_store::Result<Self> {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg).await?;
Ok(Self {
storage,
_temp_dir: None,
})
}
/// Create a new TestStorageManager with local filesystem backend.
///
/// This method creates a temporary directory that will be automatically
/// cleaned up when the TestStorageManager is dropped.
pub async fn new_local() -> object_store::Result<Self> {
let cfg = test_config_local();
let storage = StorageManager::new(&cfg).await?;
let resolved = storage
.local_base_path()
.map(|path| (cfg.data_dir.clone(), path.to_path_buf()));
Ok(Self {
storage,
_temp_dir: resolved,
})
}
/// Create a TestStorageManager with custom configuration.
pub async fn with_config(cfg: &AppConfig) -> object_store::Result<Self> {
let storage = StorageManager::new(cfg).await?;
let temp_dir = if matches!(cfg.storage, StorageKind::Local) {
storage
.local_base_path()
.map(|path| (cfg.data_dir.clone(), path.to_path_buf()))
} else {
None
};
Ok(Self {
storage,
_temp_dir: temp_dir,
})
}
/// Get a reference to the underlying StorageManager.
pub fn storage(&self) -> &StorageManager {
&self.storage
}
/// Clone the underlying StorageManager.
pub fn clone_storage(&self) -> StorageManager {
self.storage.clone()
}
/// Store test data at the specified location.
pub async fn put(&self, location: &str, data: &[u8]) -> object_store::Result<()> {
self.storage.put(location, Bytes::from(data.to_vec())).await
}
/// Retrieve test data from the specified location.
pub async fn get(&self, location: &str) -> object_store::Result<Bytes> {
self.storage.get(location).await
}
/// Delete test data below the specified prefix.
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
self.storage.delete_prefix(prefix).await
}
/// Check if test data exists at the specified location.
pub async fn exists(&self, location: &str) -> object_store::Result<bool> {
self.storage.exists(location).await
}
/// List all test objects below the specified prefix.
pub async fn list(
&self,
prefix: Option<&str>,
) -> object_store::Result<Vec<object_store::ObjectMeta>> {
self.storage.list(prefix).await
}
}
impl Drop for TestStorageManager {
fn drop(&mut self) {
// Clean up temporary directories for local storage
if let Some((_, path)) = &self._temp_dir {
if path.exists() {
let _ = std::fs::remove_dir_all(path);
}
}
}
}
/// Convenience macro for creating memory storage tests.
///
/// This macro simplifies the creation of test storage with memory backend.
#[macro_export]
macro_rules! test_storage_memory {
() => {{
async move {
$crate::storage::store::testing::TestStorageManager::new_memory()
.await
.expect("Failed to create test memory storage")
}
}};
}
/// Convenience macro for creating local storage tests.
///
/// This macro simplifies the creation of test storage with local filesystem backend.
#[macro_export]
macro_rules! test_storage_local {
() => {{
async move {
$crate::storage::store::testing::TestStorageManager::new_local()
.await
.expect("Failed to create test local storage")
}
}};
}
}
@@ -52,124 +441,6 @@ pub fn resolve_base_dir(cfg: &AppConfig) -> PathBuf {
}
}
/// Build an object store rooted at the configured data directory.
///
/// This is the recommended way to obtain a store for logical object operations
/// such as `put_bytes_at`, `get_bytes_at`, and `delete_prefix_at`.
pub async fn build_store_root(cfg: &AppConfig) -> object_store::Result<DynStore> {
let base = resolve_base_dir(cfg);
build_store(&base, cfg).await
}
/// Write bytes to `file_name` within a filesystem `prefix` using the configured store.
///
/// Prefer [`put_bytes_at`] for location-based writes that do not need to compute
/// a separate filesystem prefix.
pub async fn put_bytes(
prefix: &Path,
file_name: &str,
data: Bytes,
cfg: &AppConfig,
) -> object_store::Result<()> {
let store = build_store(prefix, cfg).await?;
let payload = object_store::PutPayload::from_bytes(data);
store.put(&ObjPath::from(file_name), payload).await?;
Ok(())
}
/// Write bytes to the provided logical object `location`, e.g. `"user/uuid/file"`.
///
/// The store root is taken from `AppConfig::data_dir` for the local backend.
/// This performs an atomic write as guaranteed by `object_store`.
pub async fn put_bytes_at(
location: &str,
data: Bytes,
cfg: &AppConfig,
) -> object_store::Result<()> {
let store = build_store_root(cfg).await?;
let payload = object_store::PutPayload::from_bytes(data);
store.put(&ObjPath::from(location), payload).await?;
Ok(())
}
/// Read bytes from `file_name` within a filesystem `prefix` using the configured store.
///
/// Prefer [`get_bytes_at`] for location-based reads.
pub async fn get_bytes(
prefix: &Path,
file_name: &str,
cfg: &AppConfig,
) -> object_store::Result<Bytes> {
let store = build_store(prefix, cfg).await?;
let r = store.get(&ObjPath::from(file_name)).await?;
let b = r.bytes().await?;
Ok(b)
}
/// Read bytes from the provided logical object `location`.
///
/// Returns the full contents buffered in memory.
pub async fn get_bytes_at(location: &str, cfg: &AppConfig) -> object_store::Result<Bytes> {
let store = build_store_root(cfg).await?;
let r = store.get(&ObjPath::from(location)).await?;
r.bytes().await
}
/// Get a streaming body for the provided logical object `location`.
///
/// Returns a fallible `BoxStream` of `Bytes`, suitable for use with
/// `axum::body::Body::from_stream` to stream responses without buffering.
pub async fn get_stream_at(
location: &str,
cfg: &AppConfig,
) -> object_store::Result<BoxStream<'static, object_store::Result<Bytes>>> {
let store = build_store_root(cfg).await?;
let r = store.get(&ObjPath::from(location)).await?;
Ok(r.into_stream())
}
/// Delete all objects below the provided filesystem `prefix`.
///
/// This is a low-level variant for when a dedicated on-disk prefix is used for a
/// particular object grouping. Prefer [`delete_prefix_at`] for location-based stores.
pub async fn delete_prefix(prefix: &Path, cfg: &AppConfig) -> object_store::Result<()> {
let store = build_store(prefix, cfg).await?;
// list everything and delete
let locations = store.list(None).map_ok(|m| m.location).boxed();
store
.delete_stream(locations)
.try_collect::<Vec<_>>()
.await?;
// Best effort remove the directory itself
if tokio::fs::try_exists(prefix).await.unwrap_or(false) {
let _ = tokio::fs::remove_dir_all(prefix).await;
}
Ok(())
}
/// Delete all objects below the provided logical object `prefix`, e.g. `"user/uuid/"`.
///
/// After deleting, attempts a best-effort cleanup of the now-empty directory on disk
/// when using the local backend.
pub async fn delete_prefix_at(prefix: &str, cfg: &AppConfig) -> object_store::Result<()> {
let store = build_store_root(cfg).await?;
let prefix_path = ObjPath::from(prefix);
let locations = store
.list(Some(&prefix_path))
.map_ok(|m| m.location)
.boxed();
store
.delete_stream(locations)
.try_collect::<Vec<_>>()
.await?;
// Best effort remove empty directory on disk for local storage
let base_dir = resolve_base_dir(cfg).join(prefix);
if tokio::fs::try_exists(&base_dir).await.unwrap_or(false) {
let _ = tokio::fs::remove_dir_all(&base_dir).await;
}
Ok(())
}
/// Split an absolute filesystem path into `(parent_dir, file_name)`.
pub fn split_abs_path(path: &str) -> AnyResult<(PathBuf, String)> {
let pb = PathBuf::from(path);
@@ -198,7 +469,6 @@ mod tests {
use super::*;
use crate::utils::config::{PdfIngestMode::LlmFirst, StorageKind};
use bytes::Bytes;
use futures::TryStreamExt;
use uuid::Uuid;
fn test_config(root: &str) -> AppConfig {
@@ -218,68 +488,353 @@ mod tests {
}
}
fn test_config_memory() -> AppConfig {
AppConfig {
openai_api_key: "test".into(),
surrealdb_address: "test".into(),
surrealdb_username: "test".into(),
surrealdb_password: "test".into(),
surrealdb_namespace: "test".into(),
surrealdb_database: "test".into(),
data_dir: "/tmp/unused".into(), // Ignored for memory storage
http_port: 0,
openai_base_url: "..".into(),
storage: StorageKind::Memory,
pdf_ingest_mode: LlmFirst,
..Default::default()
}
}
#[tokio::test]
async fn test_build_store_root_creates_base() {
let base = format!("/tmp/minne_store_test_{}", Uuid::new_v4());
async fn test_storage_manager_memory_basic_operations() {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
assert!(storage.local_base_path().is_none());
let location = "test/data/file.txt";
let data = b"test data for storage manager";
// Test put and get
storage
.put(location, Bytes::from(data.to_vec()))
.await
.expect("put");
let retrieved = storage.get(location).await.expect("get");
assert_eq!(retrieved.as_ref(), data);
// Test exists
assert!(storage.exists(location).await.expect("exists check"));
// Test delete
storage.delete_prefix("test/data/").await.expect("delete");
assert!(!storage
.exists(location)
.await
.expect("exists check after delete"));
}
#[tokio::test]
async fn test_storage_manager_local_basic_operations() {
let base = format!("/tmp/minne_storage_test_{}", Uuid::new_v4());
let cfg = test_config(&base);
let _ = build_store_root(&cfg).await.expect("build store root");
assert!(tokio::fs::try_exists(&base).await.unwrap_or(false));
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
let resolved_base = storage
.local_base_path()
.expect("resolved base dir")
.to_path_buf();
assert_eq!(resolved_base, PathBuf::from(&base));
let location = "test/data/file.txt";
let data = b"test data for local storage";
// Test put and get
storage
.put(location, Bytes::from(data.to_vec()))
.await
.expect("put");
let retrieved = storage.get(location).await.expect("get");
assert_eq!(retrieved.as_ref(), data);
let object_dir = resolved_base.join("test/data");
tokio::fs::metadata(&object_dir)
.await
.expect("object directory exists after write");
// Test exists
assert!(storage.exists(location).await.expect("exists check"));
// Test delete
storage.delete_prefix("test/data/").await.expect("delete");
assert!(!storage
.exists(location)
.await
.expect("exists check after delete"));
assert!(
tokio::fs::metadata(&object_dir).await.is_err(),
"object directory should be removed"
);
tokio::fs::metadata(&resolved_base)
.await
.expect("base directory remains intact");
// Clean up
let _ = tokio::fs::remove_dir_all(&base).await;
}
#[tokio::test]
async fn test_put_get_bytes_at_and_delete_prefix_at() {
let base = format!("/tmp/minne_store_test_{}", Uuid::new_v4());
let cfg = test_config(&base);
let location_prefix = format!("{}/{}", "user1", Uuid::new_v4());
let file_name = "file.txt";
let location = format!("{}/{}", &location_prefix, file_name);
let payload = Bytes::from_static(b"hello world");
put_bytes_at(&location, payload.clone(), &cfg)
async fn test_storage_manager_memory_persistence() {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("put");
let got = get_bytes_at(&location, &cfg).await.expect("get");
assert_eq!(got.as_ref(), payload.as_ref());
.expect("create storage manager");
// Delete the whole prefix and ensure retrieval fails
delete_prefix_at(&location_prefix, &cfg)
let location = "persistence/test.txt";
let data1 = b"first data";
let data2 = b"second data";
// Put first data
storage
.put(location, Bytes::from(data1.to_vec()))
.await
.expect("delete prefix");
assert!(get_bytes_at(&location, &cfg).await.is_err());
.expect("put first");
let _ = tokio::fs::remove_dir_all(&base).await;
// Retrieve and verify first data
let retrieved1 = storage.get(location).await.expect("get first");
assert_eq!(retrieved1.as_ref(), data1);
// Overwrite with second data
storage
.put(location, Bytes::from(data2.to_vec()))
.await
.expect("put second");
// Retrieve and verify second data
let retrieved2 = storage.get(location).await.expect("get second");
assert_eq!(retrieved2.as_ref(), data2);
// Data persists across multiple operations using the same StorageManager
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
}
#[tokio::test]
async fn test_get_stream_at() {
let base = format!("/tmp/minne_store_test_{}", Uuid::new_v4());
let cfg = test_config(&base);
async fn test_storage_manager_list_operations() {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
let location = format!("{}/{}/stream.bin", "user2", Uuid::new_v4());
let content = vec![7u8; 32 * 1024]; // 32KB payload
// Create multiple files
let files = vec![
("dir1/file1.txt", b"content1"),
("dir1/file2.txt", b"content2"),
("dir2/file3.txt", b"content3"),
];
put_bytes_at(&location, Bytes::from(content.clone()), &cfg)
for (location, data) in &files {
storage
.put(location, Bytes::from(data.to_vec()))
.await
.expect("put");
}
// Test listing without prefix
let all_files = storage.list(None).await.expect("list all");
assert_eq!(all_files.len(), 3);
// Test listing with prefix
let dir1_files = storage.list(Some("dir1/")).await.expect("list dir1");
assert_eq!(dir1_files.len(), 2);
assert!(dir1_files
.iter()
.any(|meta| meta.location.as_ref().contains("file1.txt")));
assert!(dir1_files
.iter()
.any(|meta| meta.location.as_ref().contains("file2.txt")));
// Test listing non-existent prefix
let empty_files = storage
.list(Some("nonexistent/"))
.await
.expect("list nonexistent");
assert_eq!(empty_files.len(), 0);
}
#[tokio::test]
async fn test_storage_manager_stream_operations() {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("create storage manager");
let location = "stream/test.bin";
let content = vec![42u8; 1024 * 64]; // 64KB of data
// Put large data
storage
.put(location, Bytes::from(content.clone()))
.await
.expect("put large data");
// Get as stream
let mut stream = storage.get_stream(location).await.expect("get stream");
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.expect("stream chunk");
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, content);
}
#[tokio::test]
async fn test_storage_manager_with_custom_backend() {
use object_store::memory::InMemory;
// Create custom memory backend
let custom_store = InMemory::new();
let storage = StorageManager::with_backend(Arc::new(custom_store), StorageKind::Memory);
let location = "custom/test.txt";
let data = b"custom backend test";
// Test operations with custom backend
storage
.put(location, Bytes::from(data.to_vec()))
.await
.expect("put");
let retrieved = storage.get(location).await.expect("get");
assert_eq!(retrieved.as_ref(), data);
let stream = get_stream_at(&location, &cfg).await.expect("stream");
let combined: Vec<u8> = stream
.map_ok(|chunk| chunk.to_vec())
.try_fold(Vec::new(), |mut acc, mut chunk| async move {
acc.append(&mut chunk);
Ok(acc)
})
assert!(storage.exists(location).await.expect("exists"));
assert_eq!(*storage.backend_kind(), StorageKind::Memory);
}
#[tokio::test]
async fn test_storage_manager_error_handling() {
let cfg = test_config_memory();
let storage = StorageManager::new(&cfg)
.await
.expect("collect");
.expect("create storage manager");
assert_eq!(combined, content);
// Test getting non-existent file
let result = storage.get("nonexistent.txt").await;
assert!(result.is_err());
delete_prefix_at(&split_object_path(&location).unwrap().0, &cfg)
// Test checking existence of non-existent file
let exists = storage
.exists("nonexistent.txt")
.await
.ok();
.expect("exists check");
assert!(!exists);
let _ = tokio::fs::remove_dir_all(&base).await;
// Test listing with invalid location (should not panic)
let _result = storage.get("").await;
// This may or may not error depending on the backend implementation
// The important thing is that it doesn't panic
}
// TestStorageManager tests
#[tokio::test]
async fn test_test_storage_manager_memory() {
let test_storage = testing::TestStorageManager::new_memory()
.await
.expect("create test storage");
let location = "test/storage/file.txt";
let data = b"test data with TestStorageManager";
// Test put and get
test_storage.put(location, data).await.expect("put");
let retrieved = test_storage.get(location).await.expect("get");
assert_eq!(retrieved.as_ref(), data);
// Test existence check
assert!(test_storage.exists(location).await.expect("exists"));
// Test list
let files = test_storage
.list(Some("test/storage/"))
.await
.expect("list");
assert_eq!(files.len(), 1);
// Test delete
test_storage
.delete_prefix("test/storage/")
.await
.expect("delete");
assert!(!test_storage
.exists(location)
.await
.expect("exists after delete"));
}
#[tokio::test]
async fn test_test_storage_manager_local() {
let test_storage = testing::TestStorageManager::new_local()
.await
.expect("create test storage");
let location = "test/local/file.txt";
let data = b"test data with local TestStorageManager";
// Test put and get
test_storage.put(location, data).await.expect("put");
let retrieved = test_storage.get(location).await.expect("get");
assert_eq!(retrieved.as_ref(), data);
// Test existence check
assert!(test_storage.exists(location).await.expect("exists"));
// The storage should be automatically cleaned up when test_storage is dropped
}
#[tokio::test]
async fn test_test_storage_manager_isolation() {
let storage1 = testing::TestStorageManager::new_memory()
.await
.expect("create test storage 1");
let storage2 = testing::TestStorageManager::new_memory()
.await
.expect("create test storage 2");
let location = "isolation/test.txt";
let data1 = b"storage 1 data";
let data2 = b"storage 2 data";
// Put different data in each storage
storage1.put(location, data1).await.expect("put storage 1");
storage2.put(location, data2).await.expect("put storage 2");
// Verify isolation
let retrieved1 = storage1.get(location).await.expect("get storage 1");
let retrieved2 = storage2.get(location).await.expect("get storage 2");
assert_eq!(retrieved1.as_ref(), data1);
assert_eq!(retrieved2.as_ref(), data2);
assert_ne!(retrieved1.as_ref(), retrieved2.as_ref());
}
#[tokio::test]
async fn test_test_storage_manager_config() {
let cfg = testing::test_config_memory();
let test_storage = testing::TestStorageManager::with_config(&cfg)
.await
.expect("create test storage with config");
let location = "config/test.txt";
let data = b"test data with custom config";
test_storage.put(location, data).await.expect("put");
let retrieved = test_storage.get(location).await.expect("get");
assert_eq!(retrieved.as_ref(), data);
// Verify it's using memory backend
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
}
}

View File

@@ -71,6 +71,7 @@ impl Analytics {
// We need to use a direct query for COUNT aggregation
#[derive(Debug, Deserialize)]
struct CountResult {
/// Total user count.
count: i64,
}
@@ -81,7 +82,7 @@ impl Analytics {
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
Ok(result.map_or(0, |r| r.count))
}
}

View File

@@ -1,4 +1,5 @@
use axum_typed_multipart::FieldData;
use bytes;
use mime_guess::from_path;
use object_store::Error as ObjectStoreError;
use sha2::{Digest, Sha256};
@@ -8,14 +9,14 @@ use std::{
};
use tempfile::NamedTempFile;
use thiserror::Error;
use tokio::task;
use tracing::info;
use uuid::Uuid;
use crate::{
error::AppError,
storage::{db::SurrealDbClient, store},
storage::{db::SurrealDbClient, store, store::StorageManager},
stored_object,
utils::config::AppConfig,
};
#[derive(Error, Debug)]
@@ -51,54 +52,6 @@ stored_object!(FileInfo, "file", {
});
impl FileInfo {
pub async fn new(
field_data: FieldData<NamedTempFile>,
db_client: &SurrealDbClient,
user_id: &str,
config: &AppConfig,
) -> Result<Self, FileError> {
let file = field_data.contents;
let file_name = field_data
.metadata
.file_name
.ok_or(FileError::MissingFileName)?;
// Calculate SHA256
let sha256 = Self::get_sha(&file).await?;
// Early return if file already exists
match Self::get_by_sha(&sha256, db_client).await {
Ok(existing_file) => {
info!("File already exists with SHA256: {}", sha256);
return Ok(existing_file);
}
Err(FileError::FileNotFound(_)) => (), // Expected case for new files
Err(e) => return Err(e), // Propagate unexpected errors
}
// Generate UUID and prepare paths
let uuid = Uuid::new_v4();
let sanitized_file_name = Self::sanitize_file_name(&file_name);
let now = Utc::now();
// Create new FileInfo instance
let file_info = Self {
id: uuid.to_string(),
created_at: now,
updated_at: now,
file_name,
sha256,
path: Self::persist_file(&uuid, file, &sanitized_file_name, user_id, config).await?,
mime_type: Self::guess_mime_type(Path::new(&sanitized_file_name)),
user_id: user_id.to_string(),
};
// Store in database
db_client.store_item(file_info.clone()).await?;
Ok(file_info)
}
/// Guesses the MIME type based on the file extension.
///
/// # Arguments
@@ -119,21 +72,29 @@ impl FileInfo {
///
/// # Returns
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
#[allow(clippy::indexing_slicing)]
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
let mut reader = BufReader::new(file.as_file());
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192]; // 8KB buffer
let mut file_clone = file.as_file().try_clone()?;
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
let digest = task::spawn_blocking(move || -> Result<_, std::io::Error> {
let mut reader = BufReader::new(&mut file_clone);
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192]; // 8KB buffer
loop {
let n = reader.read(&mut buffer)?;
if n == 0 {
break;
}
hasher.update(&buffer[..n]);
}
hasher.update(&buffer[..n]);
}
let digest = hasher.finalize();
Ok(format!("{:x}", digest))
Ok::<_, std::io::Error>(hasher.finalize())
})
.await
.map_err(std::io::Error::other)??;
Ok(format!("{digest:x}"))
}
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
@@ -151,7 +112,7 @@ impl FileInfo {
}
})
.collect();
format!("{}{}", sanitized_name, ext)
format!("{sanitized_name}{ext}")
} else {
// No extension
file_name
@@ -167,36 +128,6 @@ impl FileInfo {
}
}
/// Persists the file under the logical location `{user_id}/{uuid}/{file_name}`.
///
/// # Arguments
/// * `uuid` - The UUID of the file.
/// * `file` - The temporary file to persist.
/// * `file_name` - The sanitized file name.
/// * `user-id` - User id
/// * `config` - Application configuration containing data directory path
///
/// # Returns
/// * `Result<String, FileError>` - The logical object location or an error.
async fn persist_file(
uuid: &Uuid,
file: NamedTempFile,
file_name: &str,
user_id: &str,
config: &AppConfig,
) -> Result<String, FileError> {
// Logical object location relative to the store root
let location = format!("{}/{}/{}", user_id, uuid, file_name);
info!("Persisting to object location: {}", location);
let bytes = tokio::fs::read(file.path()).await?;
store::put_bytes_at(&location, bytes.into(), config)
.await
.map_err(FileError::from)?;
Ok(location)
}
/// Retrieves a `FileInfo` by SHA256.
///
/// # Arguments
@@ -215,41 +146,6 @@ impl FileInfo {
.ok_or(FileError::FileNotFound(sha256.to_string()))
}
/// Removes FileInfo from database and file from disk
///
/// # Arguments
/// * `id` - Id of the FileInfo
/// * `db_client` - Reference to SurrealDbClient
///
/// # Returns
/// `Result<(), FileError>`
pub async fn delete_by_id(
id: &str,
db_client: &SurrealDbClient,
config: &AppConfig,
) -> Result<(), AppError> {
// Get the FileInfo from the database
let Some(file_info) = db_client.get_item::<FileInfo>(id).await? else {
return Ok(());
};
// Remove the object's parent prefix in the object store
let (parent_prefix, _file_name) = store::split_object_path(&file_info.path)
.map_err(|e| AppError::from(anyhow::anyhow!(e)))?;
store::delete_prefix_at(&parent_prefix, config)
.await
.map_err(|e| AppError::from(anyhow::anyhow!(e)))?;
info!(
"Removed object prefix {} and its contents via object_store",
parent_prefix
);
// Delete the FileInfo from the database
db_client.delete_item::<FileInfo>(id).await?;
Ok(())
}
/// Retrieves a `FileInfo` by its ID.
///
/// # Arguments
@@ -265,34 +161,168 @@ impl FileInfo {
Err(e) => Err(FileError::SurrealError(e)),
}
}
/// Create a new FileInfo using StorageManager for persistent storage operations.
///
/// # Arguments
/// * `field_data` - The uploaded file data
/// * `db_client` - Reference to the SurrealDbClient
/// * `user_id` - The user ID
/// * `storage` - A StorageManager instance for storage operations
///
/// # Returns
/// * `Result<Self, FileError>` - The created FileInfo or an error
pub async fn new_with_storage(
field_data: FieldData<NamedTempFile>,
db_client: &SurrealDbClient,
user_id: &str,
storage: &StorageManager,
) -> Result<Self, FileError> {
let file = field_data.contents;
let file_name = field_data
.metadata
.file_name
.ok_or(FileError::MissingFileName)?;
let original_file_name = file_name.clone();
// Calculate SHA256
let sha256 = Self::get_sha(&file).await?;
// Early return if file already exists
match Self::get_by_sha(&sha256, db_client).await {
Ok(existing_file) => {
info!("File already exists with SHA256: {}", sha256);
return Ok(existing_file);
}
Err(FileError::FileNotFound(_)) => (), // Expected case for new files
Err(e) => return Err(e), // Propagate unexpected errors
}
// Generate UUID and prepare paths
let uuid = Uuid::new_v4();
let sanitized_file_name = Self::sanitize_file_name(&file_name);
let now = Utc::now();
let path =
Self::persist_file_with_storage(&uuid, file, &sanitized_file_name, user_id, storage)
.await?;
// Create FileInfo struct
let file_info = FileInfo {
id: uuid.to_string(),
user_id: user_id.to_string(),
sha256,
file_name: original_file_name,
path,
mime_type: Self::guess_mime_type(Path::new(&file_name)),
created_at: now,
updated_at: now,
};
// Store in database
db_client
.store_item(file_info.clone())
.await
.map_err(FileError::SurrealError)?;
Ok(file_info)
}
/// Delete a FileInfo by ID using StorageManager for storage operations.
///
/// # Arguments
/// * `id` - ID of the FileInfo
/// * `db_client` - Reference to SurrealDbClient
/// * `storage` - A StorageManager instance for storage operations
///
/// # Returns
/// * `Result<(), AppError>` - Success or error
pub async fn delete_by_id_with_storage(
id: &str,
db_client: &SurrealDbClient,
storage: &StorageManager,
) -> Result<(), AppError> {
// Get the FileInfo from the database
let Some(file_info) = db_client.get_item::<FileInfo>(id).await? else {
return Ok(());
};
// Remove the object's parent prefix in the object store
let (parent_prefix, _file_name) = store::split_object_path(&file_info.path)
.map_err(|e| AppError::from(anyhow::anyhow!(e)))?;
storage
.delete_prefix(&parent_prefix)
.await
.map_err(|e| AppError::from(anyhow::anyhow!(e)))?;
info!(
"Removed object prefix {} and its contents via StorageManager",
parent_prefix
);
// Delete the FileInfo from the database
db_client.delete_item::<FileInfo>(id).await?;
Ok(())
}
/// Retrieve file content using StorageManager for storage operations.
///
/// # Arguments
/// * `storage` - A StorageManager instance for storage operations
///
/// # Returns
/// * `Result<bytes::Bytes, AppError>` - The file content or an error
pub async fn get_content_with_storage(
&self,
storage: &StorageManager,
) -> Result<bytes::Bytes, AppError> {
storage
.get(&self.path)
.await
.map_err(|e: object_store::Error| AppError::from(anyhow::anyhow!(e)))
}
/// Persist file to storage using StorageManager.
///
/// # Arguments
/// * `uuid` - The UUID for the file
/// * `file` - The temporary file to persist
/// * `file_name` - The name of the file
/// * `user_id` - The user ID
/// * `storage` - A StorageManager instance for storage operations
///
/// # Returns
/// * `Result<String, FileError>` - The logical object location or an error.
async fn persist_file_with_storage(
uuid: &Uuid,
file: NamedTempFile,
file_name: &str,
user_id: &str,
storage: &StorageManager,
) -> Result<String, FileError> {
// Logical object location relative to the store root
let location = format!("{user_id}/{uuid}/{file_name}");
info!("Persisting to object location: {}", location);
let bytes = tokio::fs::read(file.path()).await?;
storage
.put(&location, bytes.into())
.await
.map_err(FileError::from)?;
Ok(location)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::config::{AppConfig, PdfIngestMode::LlmFirst, StorageKind};
use crate::storage::store::testing::TestStorageManager;
use axum::http::HeaderMap;
use axum_typed_multipart::FieldMetadata;
use std::io::Write;
use std::{io::Write, path::Path};
use tempfile::NamedTempFile;
fn test_config(data_dir: &str) -> AppConfig {
AppConfig {
data_dir: data_dir.to_string(),
openai_api_key: "test_key".to_string(),
surrealdb_address: "test_address".to_string(),
surrealdb_username: "test_user".to_string(),
surrealdb_password: "test_pass".to_string(),
surrealdb_namespace: "test_ns".to_string(),
surrealdb_database: "test_db".to_string(),
http_port: 3000,
openai_base_url: "..".to_string(),
storage: StorageKind::Local,
pdf_ingest_mode: LlmFirst,
..Default::default()
}
}
/// Creates a test temporary file with the given content
fn create_test_file(content: &[u8], file_name: &str) -> FieldData<NamedTempFile> {
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
@@ -316,33 +346,39 @@ mod tests {
}
#[tokio::test]
async fn test_fileinfo_create_read_delete() {
// Setup in-memory database for testing
async fn test_fileinfo_create_read_delete_with_storage_manager() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.unwrap();
// Create a test file
let content = b"This is a test file for cross-filesystem operations";
let file_name = "cross_fs_test.txt";
let content = b"This is a test file for StorageManager operations";
let file_name = "storage_manager_test.txt";
let field_data = create_test_file(content, file_name);
// Create a FileInfo instance with data_dir in /tmp
// Create test storage manager (memory backend)
let test_storage = store::testing::TestStorageManager::new_memory()
.await
.expect("Failed to create test storage manager");
// Create a FileInfo instance with storage manager
let user_id = "test_user";
let config = test_config("/tmp/minne_test_data");
// Test file creation
let file_info = FileInfo::new(field_data, &db, user_id, &config)
.await
.expect("Failed to create file across filesystems");
// Test file creation with StorageManager
let file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
.await
.expect("Failed to create file with StorageManager");
assert_eq!(file_info.file_name, file_name);
// Verify the file exists via object_store and has correct content
let bytes = store::get_bytes_at(&file_info.path, &config)
// Verify the file exists via StorageManager and has correct content
let bytes = file_info
.get_content_with_storage(test_storage.storage())
.await
.expect("Failed to read file content via object_store");
assert_eq!(bytes, content.as_slice());
.expect("Failed to read file content via StorageManager");
assert_eq!(bytes.as_ref(), content);
// Test file reading
let retrieved = FileInfo::get_by_id(&file_info.id, &db)
@@ -350,51 +386,89 @@ mod tests {
.expect("Failed to retrieve file info");
assert_eq!(retrieved.id, file_info.id);
assert_eq!(retrieved.sha256, file_info.sha256);
assert_eq!(retrieved.file_name, file_name);
// Test file deletion
FileInfo::delete_by_id(&file_info.id, &db, &config)
// Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, test_storage.storage())
.await
.expect("Failed to delete file");
assert!(
store::get_bytes_at(&file_info.path, &config).await.is_err(),
"File should be deleted"
);
.expect("Failed to delete file with StorageManager");
// Clean up the test directory
let _ = tokio::fs::remove_dir_all(&config.data_dir).await;
let deleted_result = file_info
.get_content_with_storage(test_storage.storage())
.await;
assert!(deleted_result.is_err(), "File should be deleted");
// No cleanup needed - TestStorageManager handles it automatically
}
#[tokio::test]
async fn test_fileinfo_duplicate_detection() {
// Setup in-memory database for testing
async fn test_fileinfo_preserves_original_filename_and_sanitizes_path() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.unwrap();
// Create a test file
let content = b"This is a test file for cross-filesystem duplicate detection";
let file_name = "cross_fs_duplicate.txt";
let content = b"filename sanitization";
let original_name = "Complex name (1).txt";
let expected_sanitized = "Complex_name__1_.txt";
let field_data = create_test_file(content, original_name);
let test_storage = store::testing::TestStorageManager::new_memory()
.await
.expect("Failed to create test storage manager");
let file_info =
FileInfo::new_with_storage(field_data, &db, "sanitized_user", test_storage.storage())
.await
.expect("Failed to create file via storage manager");
assert_eq!(file_info.file_name, original_name);
let stored_name = Path::new(&file_info.path)
.file_name()
.and_then(|name| name.to_str())
.expect("stored name");
assert_eq!(stored_name, expected_sanitized);
}
#[tokio::test]
async fn test_fileinfo_duplicate_detection_with_storage_manager() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.unwrap();
let content = b"This is a test file for StorageManager duplicate detection";
let file_name = "storage_manager_duplicate.txt";
let field_data = create_test_file(content, file_name);
// Create a FileInfo instance with data_dir in /tmp
// Create test storage manager
let test_storage = store::testing::TestStorageManager::new_memory()
.await
.expect("Failed to create test storage manager");
// Create a FileInfo instance with storage manager
let user_id = "test_user";
let config = test_config("/tmp/minne_test_data");
// Store the original file
let original_file_info = FileInfo::new(field_data, &db, user_id, &config)
.await
.expect("Failed to create original file");
let original_file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage())
.await
.expect("Failed to create original file with StorageManager");
// Create another file with the same content but different name
let duplicate_name = "cross_fs_duplicate_2.txt";
let duplicate_name = "storage_manager_duplicate_2.txt";
let field_data2 = create_test_file(content, duplicate_name);
// The system should detect it's the same file and return the original FileInfo
let duplicate_file_info = FileInfo::new(field_data2, &db, user_id, &config)
.await
.expect("Failed to process duplicate file");
let duplicate_file_info =
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
.await
.expect("Failed to process duplicate file with StorageManager");
// Verify duplicate detection worked
assert_eq!(duplicate_file_info.id, original_file_info.id);
@@ -402,34 +476,48 @@ mod tests {
assert_eq!(duplicate_file_info.file_name, file_name);
assert_ne!(duplicate_file_info.file_name, duplicate_name);
// Clean up
FileInfo::delete_by_id(&original_file_info.id, &db, &config)
// Verify both files have the same content (they should point to the same file)
let original_content = original_file_info
.get_content_with_storage(test_storage.storage())
.await
.expect("Failed to delete file");
let _ = tokio::fs::remove_dir_all(&config.data_dir).await;
.unwrap();
let duplicate_content = duplicate_file_info
.get_content_with_storage(test_storage.storage())
.await
.unwrap();
assert_eq!(original_content.as_ref(), content);
assert_eq!(duplicate_content.as_ref(), content);
// Clean up
FileInfo::delete_by_id_with_storage(&original_file_info.id, &db, test_storage.storage())
.await
.expect("Failed to delete original file with StorageManager");
}
#[tokio::test]
async fn test_file_creation() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
// Create a test file
let content = b"This is a test file content";
let file_name = "test_file.txt";
let field_data = create_test_file(content, file_name);
// Create a FileInfo instance
// Create a FileInfo instance with StorageManager
let user_id = "test_user";
let config = test_config("./data");
let file_info = FileInfo::new(field_data, &db, user_id, &config).await;
let test_storage = TestStorageManager::new_memory()
.await
.expect("create test storage manager");
let file_info =
FileInfo::new_with_storage(field_data, &db, user_id, test_storage.storage()).await;
// We can't fully test persistence to disk in unit tests,
// but we can verify the database record was created
// Verify the FileInfo was created successfully
assert!(file_info.is_ok());
let file_info = file_info.unwrap();
@@ -459,33 +547,39 @@ mod tests {
#[tokio::test]
async fn test_file_duplicate_detection() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
// First, store a file with known content
let content = b"This is a test file for duplicate detection";
let file_name = "original.txt";
let user_id = "test_user";
let config = test_config("./data");
let test_storage = TestStorageManager::new_memory()
.await
.expect("create test storage manager");
let field_data1 = create_test_file(content, file_name);
let original_file_info = FileInfo::new(field_data1, &db, user_id, &config)
.await
.expect("Failed to create original file");
let original_file_info =
FileInfo::new_with_storage(field_data1, &db, user_id, test_storage.storage())
.await
.expect("Failed to create original file");
// Now try to store another file with the same content but different name
let duplicate_name = "duplicate.txt";
let field_data2 = create_test_file(content, duplicate_name);
// The system should detect it's the same file and return the original FileInfo
let duplicate_file_info = FileInfo::new(field_data2, &db, user_id, &config)
.await
.expect("Failed to process duplicate file");
let duplicate_file_info =
FileInfo::new_with_storage(field_data2, &db, user_id, test_storage.storage())
.await
.expect("Failed to process duplicate file");
// The returned FileInfo should match the original
assert_eq!(duplicate_file_info.id, original_file_info.id);
@@ -553,7 +647,6 @@ mod tests {
#[tokio::test]
async fn test_get_by_sha_not_found() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
@@ -574,7 +667,6 @@ mod tests {
#[tokio::test]
async fn test_manual_file_info_creation() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
@@ -615,23 +707,28 @@ mod tests {
#[tokio::test]
async fn test_delete_by_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
// Create and persist a test file via FileInfo::new
// Create and persist a test file via FileInfo::new_with_storage
let user_id = "user123";
let cfg = test_config("./data");
let test_storage = TestStorageManager::new_memory()
.await
.expect("create test storage manager");
let temp = create_test_file(b"test content", "test_file.txt");
let file_info = FileInfo::new(temp, &db, user_id, &cfg)
let file_info = FileInfo::new_with_storage(temp, &db, user_id, test_storage.storage())
.await
.expect("create file");
// Delete the file
let delete_result = FileInfo::delete_by_id(&file_info.id, &db, &cfg).await;
// Delete the file using StorageManager
let delete_result =
FileInfo::delete_by_id_with_storage(&file_info.id, &db, test_storage.storage()).await;
// Delete should be successful
assert!(
@@ -650,13 +747,12 @@ mod tests {
"FileInfo should be deleted from the database"
);
// Verify content no longer retrievable
assert!(store::get_bytes_at(&file_info.path, &cfg).await.is_err());
// Verify content no longer retrievable from storage
assert!(test_storage.storage().get(&file_info.path).await.is_err());
}
#[tokio::test]
async fn test_delete_by_id_not_found() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
@@ -664,19 +760,16 @@ mod tests {
.expect("Failed to start in-memory surrealdb");
// Try to delete a file that doesn't exist
let result = FileInfo::delete_by_id(
"nonexistent_id",
&db,
&test_config("./data"),
)
.await;
let test_storage = TestStorageManager::new_memory().await.unwrap();
let result =
FileInfo::delete_by_id_with_storage("nonexistent_id", &db, test_storage.storage())
.await;
// Should succeed even if the file record does not exist
assert!(result.is_ok());
}
#[tokio::test]
async fn test_get_by_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
@@ -717,7 +810,6 @@ mod tests {
#[tokio::test]
async fn test_get_by_id_not_found() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
@@ -741,43 +833,197 @@ mod tests {
}
}
// StorageManager-based tests
#[tokio::test]
async fn test_fileinfo_persist_with_custom_root() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
async fn test_file_info_new_with_storage_memory() {
// Setup
let db = SurrealDbClient::memory("test_ns", "test_file_storage_memory")
.await
.expect("Failed to start in-memory surrealdb");
.unwrap();
db.apply_migrations().await.unwrap();
// Create a test file
let content = b"This is a test file for data directory configuration";
let file_name = "data_dir_test.txt";
let field_data = create_test_file(content, file_name);
// Create a FileInfo instance with a custom data directory
let content = b"This is a test file for StorageManager";
let field_data = create_test_file(content, "test_storage.txt");
let user_id = "test_user";
let custom_data_dir = "/tmp/minne_custom_data_dir";
let config = test_config(custom_data_dir);
// Test file creation
let file_info = FileInfo::new(field_data, &db, user_id, &config)
// Create test storage manager
let storage = store::testing::TestStorageManager::new_memory()
.await
.expect("Failed to create file in custom data directory");
.unwrap();
// Verify the file has the correct content via object_store
let file_content = store::get_bytes_at(&file_info.path, &config)
// Test file creation with StorageManager
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await
.expect("Failed to read file content");
assert_eq!(file_content.as_ref(), content);
.expect("Failed to create file with StorageManager");
// Test file deletion
FileInfo::delete_by_id(&file_info.id, &db, &config)
// Verify the file was created correctly
assert_eq!(file_info.user_id, user_id);
assert_eq!(file_info.file_name, "test_storage.txt");
assert!(!file_info.sha256.is_empty());
assert!(!file_info.path.is_empty());
// Test content retrieval with StorageManager
let retrieved_content = file_info
.get_content_with_storage(storage.storage())
.await
.expect("Failed to delete file");
assert!(store::get_bytes_at(&file_info.path, &config).await.is_err());
.expect("Failed to get file content with StorageManager");
assert_eq!(retrieved_content.as_ref(), content);
// Clean up the test directory
let _ = tokio::fs::remove_dir_all(custom_data_dir).await;
// Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
.await
.expect("Failed to delete file with StorageManager");
// Verify file is deleted
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
assert!(deleted_content_result.is_err());
}
#[tokio::test]
async fn test_file_info_new_with_storage_local() {
// Setup
let db = SurrealDbClient::memory("test_ns", "test_file_storage_local")
.await
.unwrap();
db.apply_migrations().await.unwrap();
let content = b"This is a test file for StorageManager with local storage";
let field_data = create_test_file(content, "test_local.txt");
let user_id = "test_user";
// Create test storage manager with local backend
let storage = store::testing::TestStorageManager::new_local()
.await
.unwrap();
// Test file creation with StorageManager
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await
.expect("Failed to create file with StorageManager");
// Verify the file was created correctly
assert_eq!(file_info.user_id, user_id);
assert_eq!(file_info.file_name, "test_local.txt");
assert!(!file_info.sha256.is_empty());
assert!(!file_info.path.is_empty());
// Test content retrieval with StorageManager
let retrieved_content = file_info
.get_content_with_storage(storage.storage())
.await
.expect("Failed to get file content with StorageManager");
assert_eq!(retrieved_content.as_ref(), content);
// Test file deletion with StorageManager
FileInfo::delete_by_id_with_storage(&file_info.id, &db, storage.storage())
.await
.expect("Failed to delete file with StorageManager");
// Verify file is deleted
let deleted_content_result = file_info.get_content_with_storage(storage.storage()).await;
assert!(deleted_content_result.is_err());
}
#[tokio::test]
async fn test_file_info_storage_manager_persistence() {
// Setup
let db = SurrealDbClient::memory("test_ns", "test_file_persistence")
.await
.unwrap();
db.apply_migrations().await.unwrap();
let content = b"Test content for persistence";
let field_data = create_test_file(content, "persistence_test.txt");
let user_id = "test_user";
// Create test storage manager
let storage = store::testing::TestStorageManager::new_memory()
.await
.unwrap();
// Create file
let file_info = FileInfo::new_with_storage(field_data, &db, user_id, storage.storage())
.await
.expect("Failed to create file");
// Test that data persists across multiple operations with the same StorageManager
let retrieved_content_1 = file_info
.get_content_with_storage(storage.storage())
.await
.unwrap();
let retrieved_content_2 = file_info
.get_content_with_storage(storage.storage())
.await
.unwrap();
assert_eq!(retrieved_content_1.as_ref(), content);
assert_eq!(retrieved_content_2.as_ref(), content);
// Test that different StorageManager instances don't share data (memory storage isolation)
let storage2 = store::testing::TestStorageManager::new_memory()
.await
.unwrap();
let isolated_content_result = file_info.get_content_with_storage(storage2.storage()).await;
assert!(
isolated_content_result.is_err(),
"Different StorageManager should not have access to same data"
);
}
#[tokio::test]
async fn test_file_info_storage_manager_equivalence() {
// Setup
let db = SurrealDbClient::memory("test_ns", "test_file_equivalence")
.await
.unwrap();
db.apply_migrations().await.unwrap();
let content = b"Test content for equivalence testing";
let field_data1 = create_test_file(content, "equivalence_test_1.txt");
let field_data2 = create_test_file(content, "equivalence_test_2.txt");
let user_id = "test_user";
// Create single storage manager and reuse it
let storage_manager = store::testing::TestStorageManager::new_memory()
.await
.unwrap();
let storage = storage_manager.storage();
// Create multiple files with the same storage manager
let file_info_1 = FileInfo::new_with_storage(field_data1, &db, user_id, &storage)
.await
.expect("Failed to create file 1");
let file_info_2 = FileInfo::new_with_storage(field_data2, &db, user_id, &storage)
.await
.expect("Failed to create file 2");
// Test that both files can be retrieved with the same storage backend
let content_1 = file_info_1
.get_content_with_storage(&storage)
.await
.unwrap();
let content_2 = file_info_2
.get_content_with_storage(&storage)
.await
.unwrap();
assert_eq!(content_1.as_ref(), content);
assert_eq!(content_2.as_ref(), content);
// Test that files can be deleted with the same storage manager
FileInfo::delete_by_id_with_storage(&file_info_1.id, &db, &storage)
.await
.unwrap();
FileInfo::delete_by_id_with_storage(&file_info_2.id, &db, &storage)
.await
.unwrap();
// Verify files are deleted
let deleted_content_1 = file_info_1.get_content_with_storage(&storage).await;
let deleted_content_2 = file_info_2.get_content_with_storage(&storage).await;
assert!(deleted_content_1.is_err());
assert!(deleted_content_2.is_err());
}
}

View File

@@ -1,3 +1,9 @@
#![allow(
clippy::result_large_err,
clippy::needless_pass_by_value,
clippy::implicit_clone,
clippy::semicolon_if_nothing_returned
)]
use crate::{error::AppError, storage::types::file_info::FileInfo};
use serde::{Deserialize, Serialize};
use tracing::info;
@@ -38,6 +44,7 @@ impl IngestionPayload {
/// # Returns
/// * `Result<Vec<IngestionPayload>, AppError>` - On success, returns a vector of ingress objects
/// (one per file/content type). On failure, returns an `AppError`.
#[allow(clippy::similar_names)]
pub fn create_ingestion_payload(
content: Option<String>,
context: String,

View File

@@ -1,3 +1,12 @@
#![allow(
clippy::cast_possible_wrap,
clippy::items_after_statements,
clippy::arithmetic_side_effects,
clippy::cast_sign_loss,
clippy::missing_docs_in_private_items,
clippy::trivially_copy_pass_by_ref,
clippy::expect_used
)]
use std::time::Duration;
use chrono::Duration as ChronoDuration;

View File

@@ -1,7 +1,19 @@
#![allow(
clippy::missing_docs_in_private_items,
clippy::module_name_repetitions,
clippy::match_same_arms,
clippy::format_push_string,
clippy::uninlined_format_args,
clippy::explicit_iter_loop,
clippy::items_after_statements,
clippy::get_first,
clippy::redundant_closure_for_method_calls
)]
use std::collections::HashMap;
use crate::{
error::AppError, storage::db::SurrealDbClient, stored_object,
error::AppError, storage::db::SurrealDbClient,
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding, stored_object,
utils::embedding::generate_embedding,
};
use async_openai::{config::OpenAIConfig, Client};
@@ -78,10 +90,16 @@ stored_object!(KnowledgeEntity, "knowledge_entity", {
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String
});
/// Vector search result including hydrated entity.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct KnowledgeEntityVectorResult {
pub entity: KnowledgeEntity,
pub score: f32,
}
impl KnowledgeEntity {
pub fn new(
source_id: String,
@@ -89,7 +107,6 @@ impl KnowledgeEntity {
description: String,
entity_type: KnowledgeEntityType,
metadata: Option<serde_json::Value>,
embedding: Vec<f32>,
user_id: String,
) -> Self {
let now = Utc::now();
@@ -102,7 +119,6 @@ impl KnowledgeEntity {
description,
entity_type,
metadata,
embedding,
user_id,
}
}
@@ -165,6 +181,89 @@ impl KnowledgeEntity {
Ok(())
}
/// Atomically store a knowledge entity and its embedding.
/// Writes the entity to `knowledge_entity` and the embedding to `knowledge_entity_embedding`.
pub async fn store_with_embedding(
entity: KnowledgeEntity,
embedding: Vec<f32>,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let emb = KnowledgeEntityEmbedding::new(&entity.id, embedding, entity.user_id.clone());
let query = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{entity_table}', $entity_id) CONTENT $entity;
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
COMMIT TRANSACTION;
",
entity_table = Self::table_name(),
emb_table = KnowledgeEntityEmbedding::table_name(),
);
db.client
.query(query)
.bind(("entity_id", entity.id.clone()))
.bind(("entity", entity))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
pub async fn vector_search(
take: usize,
query_embedding: Vec<f32>,
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<KnowledgeEntityVectorResult>, AppError> {
#[derive(Deserialize)]
struct Row {
entity_id: KnowledgeEntity,
score: f32,
}
let sql = format!(
r#"
SELECT
entity_id,
vector::similarity::cosine(embedding, $embedding) AS score
FROM {emb_table}
WHERE user_id = $user_id
AND embedding <|{take},100|> $embedding
ORDER BY score DESC
LIMIT {take}
FETCH entity_id;
"#,
emb_table = KnowledgeEntityEmbedding::table_name(),
take = take
);
let mut response = db
.query(&sql)
.bind(("embedding", query_embedding))
.bind(("user_id", user_id.to_string()))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
response = response.check().map_err(AppError::Database)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
Ok(rows
.into_iter()
.map(|r| KnowledgeEntityVectorResult {
entity: r.entity_id,
score: r.score,
})
.collect())
}
pub async fn patch(
id: &str,
name: &str,
@@ -178,32 +277,55 @@ impl KnowledgeEntity {
name, description, entity_type
);
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
let user_id = Self::get_user_id_by_id(id, db_client).await?;
let emb = KnowledgeEntityEmbedding::new(id, embedding, user_id);
let now = Utc::now();
db_client
.client
.query(
"UPDATE type::thing($table, $id)
SET name = $name,
description = $description,
updated_at = $updated_at,
entity_type = $entity_type,
embedding = $embedding
RETURN AFTER",
"BEGIN TRANSACTION;
UPDATE type::thing($table, $id)
SET name = $name,
description = $description,
updated_at = $updated_at,
entity_type = $entity_type;
UPSERT type::thing($emb_table, $emb_id) CONTENT $emb;
COMMIT TRANSACTION;",
)
.bind(("table", Self::table_name()))
.bind(("emb_table", KnowledgeEntityEmbedding::table_name()))
.bind(("id", id.to_string()))
.bind(("name", name.to_string()))
.bind(("updated_at", surrealdb::Datetime::from(now)))
.bind(("entity_type", entity_type.to_owned()))
.bind(("embedding", embedding))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.bind(("description", description.to_string()))
.await?;
Ok(())
}
async fn get_user_id_by_id(id: &str, db_client: &SurrealDbClient) -> Result<String, AppError> {
let mut response = db_client
.client
.query("SELECT user_id FROM type::thing($table, $id) LIMIT 1")
.bind(("table", Self::table_name()))
.bind(("id", id.to_string()))
.await
.map_err(AppError::Database)?;
#[derive(Deserialize)]
struct Row {
user_id: String,
}
let rows: Vec<Row> = response.take(0).map_err(AppError::Database)?;
rows.get(0)
.map(|r| r.user_id.clone())
.ok_or_else(|| AppError::InternalError("user not found for entity".to_string()))
}
/// Re-creates embeddings for all knowledge entities in the database.
///
/// This is a costly operation that should be run in the background. It follows the same
@@ -228,22 +350,13 @@ impl KnowledgeEntity {
if total_entities == 0 {
info!("No knowledge entities to update. Just updating the idx");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
transaction_query
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
db.query(transaction_query).await?;
KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
return Ok(());
}
info!("Found {} entities to process.", total_entities);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
info!("Generating new embeddings for all entities...");
for entity in all_entities.iter() {
let embedding_input = format!(
@@ -271,17 +384,16 @@ impl KnowledgeEntity {
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(entity.id.clone(), embedding);
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
// We must properly serialize the vector for the SurrealQL query string
// Add all update statements to the embedding table
for (id, (embedding, user_id)) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
@@ -291,18 +403,22 @@ impl KnowledgeEntity {
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('knowledge_entity', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \
user_id = '{user_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id
));
}
// Re-create the index after updating the data that it will index
transaction_query
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
@@ -312,12 +428,146 @@ impl KnowledgeEntity {
info!("Re-embedding process for knowledge entities completed successfully.");
Ok(())
}
/// Re-creates embeddings for all knowledge entities using an `EmbeddingProvider`.
///
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
pub async fn update_all_embeddings_with_provider(
db: &SurrealDbClient,
provider: &crate::utils::embedding::EmbeddingProvider,
) -> Result<(), AppError> {
let new_dimensions = provider.dimension();
info!(
dimensions = new_dimensions,
backend = provider.backend_label(),
"Starting re-embedding process for all knowledge entities"
);
// Fetch all entities first
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
let total_entities = all_entities.len();
if total_entities == 0 {
info!("No knowledge entities to update. Just updating the index.");
KnowledgeEntityEmbedding::redefine_hnsw_index(db, new_dimensions).await?;
return Ok(());
}
info!(entities = total_entities, "Found entities to process");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String)> = HashMap::new();
info!("Generating new embeddings for all entities...");
for (i, entity) in all_entities.iter().enumerate() {
if i > 0 && i % 100 == 0 {
info!(
progress = i,
total = total_entities,
"Re-embedding progress"
);
}
let embedding_input = format!(
"name: {}, description: {}, type: {:?}",
entity.name, entity.description, entity.entity_type
);
let embedding = provider
.embed(&embedding_input)
.await
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
let err_msg = format!(
"CRITICAL: Generated embedding for entity {} has incorrect dimension ({}). Expected {}. Aborting.",
entity.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
}
info!("Successfully generated all new embeddings.");
info!("Successfully generated all new embeddings.");
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
info!("Removing old index and clearing embeddings...");
// Explicitly remove the index first. This prevents background HNSW maintenance from crashing
// when we delete/replace data, dealing with a known SurrealDB panic.
db.client
.query(format!(
"REMOVE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {};",
KnowledgeEntityEmbedding::table_name()
))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
db.client
.query(format!(
"DELETE FROM {};",
KnowledgeEntityEmbedding::table_name()
))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
// Perform DB updates in a single transaction
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id)) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.collect::<Vec<_>>()
.join(",")
);
transaction_query.push_str(&format!(
"CREATE type::thing('knowledge_entity_embedding', '{id}') SET \
entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \
user_id = '{user_id}', \
created_at = time::now(), \
updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id
));
}
transaction_query.push_str(&format!(
"DEFINE INDEX OVERWRITE idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.client
.query(transaction_query)
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
info!("Re-embedding process for knowledge entities completed successfully.");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding;
use serde_json::json;
use uuid::Uuid;
#[tokio::test]
async fn test_knowledge_entity_creation() {
@@ -327,7 +577,6 @@ mod tests {
let description = "Test Description".to_string();
let entity_type = KnowledgeEntityType::Document;
let metadata = Some(json!({"key": "value"}));
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
let entity = KnowledgeEntity::new(
@@ -336,7 +585,6 @@ mod tests {
description.clone(),
entity_type.clone(),
metadata.clone(),
embedding.clone(),
user_id.clone(),
);
@@ -346,7 +594,6 @@ mod tests {
assert_eq!(entity.description, description);
assert_eq!(entity.entity_type, entity_type);
assert_eq!(entity.metadata, metadata);
assert_eq!(entity.embedding, embedding);
assert_eq!(entity.user_id, user_id);
assert!(!entity.id.is_empty());
}
@@ -410,20 +657,25 @@ mod tests {
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
// Create two entities with the same source_id
let source_id = "source123".to_string();
let entity_type = KnowledgeEntityType::Document;
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("Failed to redefine index length");
let entity1 = KnowledgeEntity::new(
source_id.clone(),
"Entity 1".to_string(),
"Description 1".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -433,7 +685,6 @@ mod tests {
"Description 2".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -445,18 +696,18 @@ mod tests {
"Different Description".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
let emb = vec![0.1, 0.2, 0.3, 0.4, 0.5];
// Store the entities
db.store_item(entity1)
KnowledgeEntity::store_with_embedding(entity1.clone(), emb.clone(), &db)
.await
.expect("Failed to store entity 1");
db.store_item(entity2)
KnowledgeEntity::store_with_embedding(entity2.clone(), emb.clone(), &db)
.await
.expect("Failed to store entity 2");
db.store_item(different_entity.clone())
KnowledgeEntity::store_with_embedding(different_entity.clone(), emb.clone(), &db)
.await
.expect("Failed to store different entity");
@@ -505,6 +756,162 @@ mod tests {
assert_eq!(different_remaining[0].id, different_entity.id);
}
// Note: We can't easily test the patch method without mocking the OpenAI client
// and the generate_embedding function. This would require more complex setup.
#[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await
.expect("vector search");
assert!(results.is_empty());
}
#[tokio::test]
async fn test_vector_search_single_result() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let user_id = "user".to_string();
let source_id = "src".to_string();
let entity = KnowledgeEntity::new(
source_id.clone(),
"hello".to_string(),
"world".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store entity with embedding");
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
assert!(stored_entity.is_some());
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {}",
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("query embeddings")
.take(0)
.expect("take embeddings");
assert_eq!(stored_embeddings.len(), 1);
let rid = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched_emb = KnowledgeEntityEmbedding::get_by_entity_id(&rid, &db)
.await
.expect("fetch embedding");
assert!(fetched_emb.is_some());
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.expect("vector search");
assert_eq!(results.len(), 1);
let res = &results[0];
assert_eq!(res.entity.id, entity.id);
assert_eq!(res.entity.source_id, source_id);
assert_eq!(res.entity.name, "hello");
}
#[tokio::test]
async fn test_vector_search_orders_by_similarity() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let user_id = "user".to_string();
let e1 = KnowledgeEntity::new(
"s1".to_string(),
"entity one".to_string(),
"desc".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
let e2 = KnowledgeEntity::new(
"s2".to_string(),
"entity two".to_string(),
"desc".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
KnowledgeEntity::store_with_embedding(e1.clone(), vec![1.0, 0.0, 0.0], &db)
.await
.expect("store e1");
KnowledgeEntity::store_with_embedding(e2.clone(), vec![0.0, 1.0, 0.0], &db)
.await
.expect("store e2");
let stored_e1: Option<KnowledgeEntity> = db.get_item(&e1.id).await.unwrap();
let stored_e2: Option<KnowledgeEntity> = db.get_item(&e2.id).await.unwrap();
assert!(stored_e1.is_some() && stored_e2.is_some());
let stored_embeddings: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {}",
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("query embeddings")
.take(0)
.expect("take embeddings");
assert_eq!(stored_embeddings.len(), 2);
let rid_e1 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e1.id);
let rid_e2 = surrealdb::RecordId::from_table_key(KnowledgeEntity::table_name(), &e2.id);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e1, &db)
.await
.unwrap()
.is_some());
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&rid_e2, &db)
.await
.unwrap()
.is_some());
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await
.expect("vector search");
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity.id, e2.id);
assert_eq!(results[1].entity.id, e1.id);
}
}

View File

@@ -0,0 +1,387 @@
use std::collections::HashMap;
use surrealdb::RecordId;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
stored_object!(KnowledgeEntityEmbedding, "knowledge_entity_embedding", {
entity_id: RecordId,
embedding: Vec<f32>,
/// Denormalized user id for query scoping
user_id: String
});
impl KnowledgeEntityEmbedding {
/// Recreate the HNSW index with a new embedding dimension.
pub async fn redefine_hnsw_index(
db: &SurrealDbClient,
dimension: usize,
) -> Result<(), AppError> {
let query = format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE {table};
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
COMMIT TRANSACTION;",
table = Self::table_name(),
);
let res = db.client.query(query).await.map_err(AppError::Database)?;
res.check().map_err(AppError::Database)?;
Ok(())
}
/// Create a new knowledge entity embedding
pub fn new(entity_id: &str, embedding: Vec<f32>, user_id: String) -> Self {
let now = Utc::now();
Self {
id: uuid::Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
entity_id: RecordId::from_table_key("knowledge_entity", entity_id),
embedding,
user_id,
}
}
/// Get embedding by entity ID
pub async fn get_by_entity_id(
entity_id: &RecordId,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let query = format!(
"SELECT * FROM {} WHERE entity_id = $entity_id LIMIT 1",
Self::table_name()
);
let mut result = db
.client
.query(query)
.bind(("entity_id", entity_id.clone()))
.await
.map_err(AppError::Database)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
Ok(embeddings.into_iter().next())
}
/// Get embeddings for multiple entities in batch
pub async fn get_by_entity_ids(
entity_ids: &[RecordId],
db: &SurrealDbClient,
) -> Result<HashMap<String, Vec<f32>>, AppError> {
if entity_ids.is_empty() {
return Ok(HashMap::new());
}
let ids_list: Vec<RecordId> = entity_ids.to_vec();
let query = format!(
"SELECT * FROM {} WHERE entity_id INSIDE $entity_ids",
Self::table_name()
);
let mut result = db
.client
.query(query)
.bind(("entity_ids", ids_list))
.await
.map_err(AppError::Database)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
Ok(embeddings
.into_iter()
.map(|e| (e.entity_id.key().to_string(), e.embedding))
.collect())
}
/// Delete embedding by entity ID
pub async fn delete_by_entity_id(
entity_id: &RecordId,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE FROM {} WHERE entity_id = $entity_id",
Self::table_name()
);
db.client
.query(query)
.bind(("entity_id", entity_id.clone()))
.await
.map_err(AppError::Database)?;
Ok(())
}
/// Delete embeddings by source_id (via joining to knowledge_entity table)
#[allow(clippy::items_after_statements)]
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = "SELECT id FROM knowledge_entity WHERE source_id = $source_id";
let mut res = db
.client
.query(query)
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?;
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
for row in ids {
Self::delete_by_entity_id(&row.id, db).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::db::SurrealDbClient;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
use chrono::Utc;
use surrealdb::Value as SurrealValue;
use uuid::Uuid;
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
}
fn build_knowledge_entity_with_id(
key: &str,
source_id: &str,
user_id: &str,
) -> KnowledgeEntity {
KnowledgeEntity {
id: key.to_owned(),
created_at: Utc::now(),
updated_at: Utc::now(),
source_id: source_id.to_owned(),
name: "Test entity".to_owned(),
description: "Desc".to_owned(),
entity_type: KnowledgeEntityType::Document,
metadata: None,
user_id: user_id.to_owned(),
}
}
#[tokio::test]
async fn test_create_and_get_by_entity_id() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
let user_id = "user_ke";
let entity_key = "entity-1";
let source_id = "source-ke";
let embedding_vec = vec![0.11_f32, 0.22, 0.33];
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), embedding_vec.clone(), &db)
.await
.expect("Failed to store entity with embedding");
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let fetched = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding by entity_id")
.expect("Expected embedding to exist");
assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.entity_id, entity_rid);
assert_eq!(fetched.embedding, embedding_vec);
}
#[tokio::test]
async fn test_delete_by_entity_id() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
let user_id = "user_ke";
let entity_key = "entity-delete";
let source_id = "source-del";
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.5_f32, 0.6, 0.7], &db)
.await
.expect("Failed to store entity with embedding");
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let existing = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding before delete");
assert!(existing.is_some());
KnowledgeEntityEmbedding::delete_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to delete by entity_id");
let after = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to get embedding after delete");
assert!(after.is_none());
}
#[tokio::test]
async fn test_store_with_embedding_creates_entity_and_embedding() {
let db = setup_test_db().await;
let user_id = "user_store";
let source_id = "source_store";
let embedding = vec![0.2_f32, 0.3, 0.4];
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, embedding.len())
.await
.expect("set test index dimension");
let entity = build_knowledge_entity_with_id("entity-store", source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), embedding.clone(), &db)
.await
.expect("Failed to store entity with embedding");
let stored_entity: Option<KnowledgeEntity> = db.get_item(&entity.id).await.unwrap();
assert!(stored_entity.is_some());
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
let stored_embedding = KnowledgeEntityEmbedding::get_by_entity_id(&entity_rid, &db)
.await
.expect("Failed to fetch embedding");
assert!(stored_embedding.is_some());
let stored_embedding = stored_embedding.unwrap();
assert_eq!(stored_embedding.user_id, user_id);
assert_eq!(stored_embedding.entity_id, entity_rid);
}
#[tokio::test]
async fn test_delete_by_source_id() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
let user_id = "user_ke";
let source_id = "shared-ke";
let other_source = "other-ke";
let entity1 = build_knowledge_entity_with_id("entity-s1", source_id, user_id);
let entity2 = build_knowledge_entity_with_id("entity-s2", source_id, user_id);
let entity_other = build_knowledge_entity_with_id("entity-other", other_source, user_id);
KnowledgeEntity::store_with_embedding(entity1.clone(), vec![1.0_f32, 1.1, 1.2], &db)
.await
.expect("Failed to store entity with embedding");
KnowledgeEntity::store_with_embedding(entity2.clone(), vec![2.0_f32, 2.1, 2.2], &db)
.await
.expect("Failed to store entity with embedding");
KnowledgeEntity::store_with_embedding(entity_other.clone(), vec![3.0_f32, 3.1, 3.2], &db)
.await
.expect("Failed to store entity with embedding");
let entity1_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity1.id);
let entity2_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity2.id);
let other_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity_other.id);
KnowledgeEntityEmbedding::delete_by_source_id(source_id, &db)
.await
.expect("Failed to delete by source_id");
assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity1_rid, &db)
.await
.unwrap()
.is_none()
);
assert!(
KnowledgeEntityEmbedding::get_by_entity_id(&entity2_rid, &db)
.await
.unwrap()
.is_none()
);
assert!(KnowledgeEntityEmbedding::get_by_entity_id(&other_rid, &db)
.await
.unwrap()
.is_some());
}
#[tokio::test]
async fn test_redefine_hnsw_index_updates_dimension() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 16)
.await
.expect("failed to redefine index");
let mut info_res = db
.client
.query("INFO FOR TABLE knowledge_entity_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_knowledge_entity_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
assert!(
idx_sql.contains("DIMENSION 16"),
"expected index definition to contain new dimension, got: {idx_sql}"
);
}
#[tokio::test]
async fn test_fetch_entity_via_record_id() {
let db = setup_test_db().await;
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("set test index dimension");
let user_id = "user_ke";
let entity_key = "entity-fetch";
let source_id = "source-fetch";
let entity = build_knowledge_entity_with_id(entity_key, source_id, user_id);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.7_f32, 0.8, 0.9], &db)
.await
.expect("Failed to store entity with embedding");
let entity_rid = RecordId::from_table_key(KnowledgeEntity::table_name(), &entity.id);
#[derive(Deserialize)]
struct Row {
entity_id: KnowledgeEntity,
}
let mut res = db
.client
.query(
"SELECT entity_id FROM knowledge_entity_embedding WHERE entity_id = $id FETCH entity_id;",
)
.bind(("id", entity_rid.clone()))
.await
.expect("failed to fetch embedding with FETCH");
let rows: Vec<Row> = res.take(0).expect("failed to deserialize fetch rows");
assert_eq!(rows.len(), 1);
let fetched_entity = &rows[0].entity_id;
assert_eq!(fetched_entity.id, entity_key);
assert_eq!(fetched_entity.name, "Test entity");
assert_eq!(fetched_entity.user_id, user_id);
}
}

View File

@@ -41,20 +41,21 @@ impl KnowledgeRelationship {
}
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
let query = format!(
r#"RELATE knowledge_entity:`{}`->relates_to:`{}`->knowledge_entity:`{}`
r#"DELETE relates_to:`{rel_id}`;
RELATE knowledge_entity:`{in_id}`->relates_to:`{rel_id}`->knowledge_entity:`{out_id}`
SET
metadata.user_id = '{}',
metadata.source_id = '{}',
metadata.relationship_type = '{}'"#,
self.in_,
self.id,
self.out,
self.metadata.user_id,
self.metadata.source_id,
self.metadata.relationship_type
metadata.user_id = '{user_id}',
metadata.source_id = '{source_id}',
metadata.relationship_type = '{relationship_type}'"#,
rel_id = self.id,
in_id = self.in_,
out_id = self.out,
user_id = self.metadata.user_id.as_str(),
source_id = self.metadata.source_id.as_str(),
relationship_type = self.metadata.relationship_type.as_str()
);
db_client.query(query).await?;
db_client.query(query).await?.check()?;
Ok(())
}
@@ -64,8 +65,7 @@ impl KnowledgeRelationship {
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{}'",
source_id
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'"
);
db_client.query(query).await?;
@@ -80,15 +80,14 @@ impl KnowledgeRelationship {
) -> Result<(), AppError> {
let mut authorized_result = db_client
.query(format!(
"SELECT * FROM relates_to WHERE id = relates_to:`{}` AND metadata.user_id = '{}'",
id, user_id
"SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'"
))
.await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
if authorized.is_empty() {
let mut exists_result = db_client
.query(format!("SELECT * FROM relates_to:`{}`", id))
.query(format!("SELECT * FROM relates_to:`{id}`"))
.await?;
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
@@ -97,12 +96,10 @@ impl KnowledgeRelationship {
"Not authorized to delete relationship".into(),
))
} else {
Err(AppError::NotFound(format!("Relationship {} not found", id)))
Err(AppError::NotFound(format!("Relationship {id} not found")))
}
} else {
db_client
.query(format!("DELETE relates_to:`{}`", id))
.await?;
db_client.query(format!("DELETE relates_to:`{id}`")).await?;
Ok(())
}
}
@@ -118,7 +115,6 @@ mod tests {
let source_id = "source123".to_string();
let description = format!("Description for {}", name);
let entity_type = KnowledgeEntityType::Document;
let embedding = vec![0.1, 0.2, 0.3];
let user_id = "user123".to_string();
let entity = KnowledgeEntity::new(
@@ -127,7 +123,6 @@ mod tests {
description,
entity_type,
None,
embedding,
user_id,
);
@@ -164,7 +159,7 @@ mod tests {
}
#[tokio::test]
async fn test_store_relationship() {
async fn test_store_and_verify_by_source_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
@@ -172,6 +167,10 @@ mod tests {
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
// Create two entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
@@ -212,7 +211,7 @@ mod tests {
}
#[tokio::test]
async fn test_delete_relationship_by_id() {
async fn test_store_and_delete_relationship() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
@@ -237,7 +236,7 @@ mod tests {
relationship_type,
);
// Store the relationship
// Store relationship
relationship
.store_relationship(&db)
.await
@@ -258,12 +257,12 @@ mod tests {
"Relationship should exist before deletion"
);
// Delete the relationship by ID
// Delete relationship by ID
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
.await
.expect("Failed to delete relationship by ID");
// Query to verify the relationship was deleted
// Query to verify relationship was deleted
let mut result = db
.query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
@@ -273,7 +272,7 @@ mod tests {
.expect("Query failed");
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
// Verify the relationship no longer exists
// Verify relationship no longer exists
assert!(results.is_empty(), "Relationship should be deleted");
}
@@ -345,7 +344,7 @@ mod tests {
}
#[tokio::test]
async fn test_delete_relationships_by_source_id() {
async fn test_store_relationship_exists() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();

View File

@@ -1,3 +1,4 @@
#![allow(clippy::module_name_repetitions)]
use uuid::Uuid;
use crate::stored_object;
@@ -56,7 +57,7 @@ impl fmt::Display for Message {
pub fn format_history(history: &[Message]) -> String {
history
.iter()
.map(|msg| format!("{}", msg))
.map(|msg| format!("{msg}"))
.collect::<Vec<String>>()
.join("\n")
}

View File

@@ -1,3 +1,4 @@
#![allow(clippy::unsafe_derive_deserialize)]
use serde::{Deserialize, Serialize};
pub mod analytics;
pub mod conversation;
@@ -5,12 +6,14 @@ pub mod file_info;
pub mod ingestion_payload;
pub mod ingestion_task;
pub mod knowledge_entity;
pub mod knowledge_entity_embedding;
pub mod knowledge_relationship;
pub mod message;
pub mod scratchpad;
pub mod system_prompts;
pub mod system_settings;
pub mod text_chunk;
pub mod text_chunk_embedding;
pub mod text_content;
pub mod user;
@@ -21,7 +24,7 @@ pub trait StoredObject: Serialize + for<'de> Deserialize<'de> {
#[macro_export]
macro_rules! stored_object {
($name:ident, $table:expr, {$($(#[$attr:meta])* $field:ident: $ty:ty),*}) => {
($(#[$struct_attr:meta])* $name:ident, $table:expr, {$($(#[$field_attr:meta])* $field:ident: $ty:ty),*}) => {
use serde::{Deserialize, Deserializer, Serialize};
use surrealdb::sql::Thing;
use $crate::storage::types::StoredObject;
@@ -85,6 +88,7 @@ macro_rules! stored_object {
}
#[allow(dead_code)]
#[allow(clippy::ref_option)]
fn serialize_option_datetime<S>(
date: &Option<DateTime<Utc>>,
serializer: S,
@@ -100,6 +104,7 @@ macro_rules! stored_object {
}
#[allow(dead_code)]
#[allow(clippy::ref_option)]
fn deserialize_option_datetime<'de, D>(
deserializer: D,
) -> Result<Option<DateTime<Utc>>, D::Error>
@@ -111,6 +116,7 @@ macro_rules! stored_object {
}
$(#[$struct_attr])*
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct $name {
#[serde(deserialize_with = "deserialize_flexible_id")]
@@ -119,7 +125,7 @@ macro_rules! stored_object {
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
pub updated_at: DateTime<Utc>,
$( $(#[$attr])* pub $field: $ty),*
$( $(#[$field_attr])* pub $field: $ty),*
}
impl StoredObject for $name {

View File

@@ -459,4 +459,44 @@ mod tests {
let retrieved: Option<Scratchpad> = db.get_item(&scratchpad_id).await.unwrap();
assert!(retrieved.is_some());
}
#[tokio::test]
async fn test_timezone_aware_scratchpad_conversion() {
let db = SurrealDbClient::memory("test_ns", &Uuid::new_v4().to_string())
.await
.expect("Failed to create test database");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
let user_id = "test_user_123";
let scratchpad =
Scratchpad::new(user_id.to_string(), "Test Timezone Scratchpad".to_string());
let scratchpad_id = scratchpad.id.clone();
db.store_item(scratchpad).await.unwrap();
let retrieved = Scratchpad::get_by_id(&scratchpad_id, user_id, &db)
.await
.unwrap();
// Test that datetime fields are preserved and can be used for timezone formatting
assert!(retrieved.created_at.timestamp() > 0);
assert!(retrieved.updated_at.timestamp() > 0);
assert!(retrieved.last_saved_at.timestamp() > 0);
// Test that optional datetime fields work correctly
assert!(retrieved.archived_at.is_none());
assert!(retrieved.ingested_at.is_none());
// Archive the scratchpad to test optional datetime handling
let archived = Scratchpad::archive(&scratchpad_id, user_id, &db, false)
.await
.unwrap();
assert!(archived.archived_at.is_some());
assert!(archived.archived_at.unwrap().timestamp() > 0);
assert!(archived.ingested_at.is_none());
}
}

View File

@@ -13,6 +13,9 @@ pub struct SystemSettings {
pub processing_model: String,
pub embedding_model: String,
pub embedding_dimensions: u32,
/// Active embedding backend ("openai", "fastembed", "hashed"). Read-only, synced from config.
#[serde(default)]
pub embedding_backend: Option<String>,
pub query_system_prompt: String,
pub ingestion_system_prompt: String,
pub image_processing_model: String,
@@ -49,10 +52,62 @@ impl SystemSettings {
"Something went wrong updating the settings".into(),
))
}
/// Syncs SystemSettings with the active embedding provider's properties.
/// Updates embedding_backend, embedding_model, and embedding_dimensions if they differ.
/// Returns true if any settings were changed.
pub async fn sync_from_embedding_provider(
db: &SurrealDbClient,
provider: &crate::utils::embedding::EmbeddingProvider,
) -> Result<(Self, bool), AppError> {
let mut settings = Self::get_current(db).await?;
let mut needs_update = false;
let backend_label = provider.backend_label().to_string();
let provider_dimensions = provider.dimension() as u32;
let provider_model = provider.model_code();
// Sync backend label
if settings.embedding_backend.as_deref() != Some(&backend_label) {
settings.embedding_backend = Some(backend_label);
needs_update = true;
}
// Sync dimensions
if settings.embedding_dimensions != provider_dimensions {
tracing::info!(
old_dimensions = settings.embedding_dimensions,
new_dimensions = provider_dimensions,
"Embedding dimensions changed, updating SystemSettings"
);
settings.embedding_dimensions = provider_dimensions;
needs_update = true;
}
// Sync model if provider has one
if let Some(model) = provider_model {
if settings.embedding_model != model {
tracing::info!(
old_model = %settings.embedding_model,
new_model = %model,
"Embedding model changed, updating SystemSettings"
);
settings.embedding_model = model;
needs_update = true;
}
}
if needs_update {
settings = Self::update(db, settings).await?;
}
Ok((settings, needs_update))
}
}
#[cfg(test)]
mod tests {
use crate::storage::indexes::ensure_runtime_indexes;
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
use async_openai::Client;
@@ -71,25 +126,22 @@ mod tests {
.await
.expect("Failed to fetch table info");
let info: Option<serde_json::Value> = response
let info: surrealdb::Value = response
.take(0)
.expect("Failed to extract table info response");
let info = info.expect("Table info result missing");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("Failed to convert info to json");
let indexes = info
.get("indexes")
.or_else(|| {
info.get("tables")
.and_then(|tables| tables.get(table_name))
.and_then(|table| table.get("indexes"))
})
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info:#?}"));
let indexes = info_json["Object"]["indexes"]["Object"]
.as_object()
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info_json:#?}"));
let definition = indexes
.get(index_name)
.and_then(|definition| definition.as_str())
.unwrap_or_else(|| panic!("Index definition not found in table info: {info:#?}"));
.and_then(|definition| definition.get("Strand"))
.and_then(|v| v.as_str())
.unwrap_or_else(|| panic!("Index definition not found in table info: {info_json:#?}"));
let dimension_part = definition
.split("DIMENSION")
@@ -261,48 +313,56 @@ mod tests {
let initial_chunk = TextChunk::new(
"source1".into(),
"This chunk has the original dimension".into(),
vec![0.1; 1536],
"user1".into(),
);
db.store_item(initial_chunk.clone())
TextChunk::store_with_embedding(initial_chunk.clone(), vec![0.1; 1536], &db)
.await
.expect("Failed to store initial chunk");
.expect("Failed to store initial chunk with embedding");
async fn simulate_reembedding(
db: &SurrealDbClient,
target_dimension: usize,
initial_chunk: TextChunk,
) {
db.query("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;")
.await
.unwrap();
db.query(
"REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;",
)
.await
.unwrap();
let define_index_query = format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
"DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
target_dimension
);
db.query(define_index_query)
.await
.expect("Re-defining index should succeed");
let new_embedding = vec![0.5; target_dimension];
let sql = "UPDATE type::thing('text_chunk', $id) SET embedding = $embedding;";
let sql = "UPSERT type::thing('text_chunk_embedding', $id) SET chunk_id = type::thing('text_chunk', $id), embedding = $embedding, user_id = $user_id;";
let update_result = db
.client
.query(sql)
.bind(("id", initial_chunk.id.clone()))
.bind(("user_id", initial_chunk.user_id.clone()))
.bind(("embedding", new_embedding))
.await;
assert!(update_result.is_ok());
}
simulate_reembedding(&db, 768, initial_chunk).await;
// Re-embed with the existing configured dimension to ensure migrations remain idempotent.
let target_dimension = 1536usize;
simulate_reembedding(&db, target_dimension, initial_chunk).await;
let migration_result = db.apply_migrations().await;
assert!(migration_result.is_ok(), "Migrations should not fail");
assert!(
migration_result.is_ok(),
"Migrations should not fail: {:?}",
migration_result.err()
);
}
#[tokio::test]
@@ -320,8 +380,17 @@ mod tests {
.await
.expect("Failed to load current settings");
let initial_chunk_dimension =
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
// Ensure runtime indexes exist with the current embedding dimension so INFO queries succeed.
ensure_runtime_indexes(&db, current_settings.embedding_dimensions as usize)
.await
.expect("failed to build runtime indexes");
let initial_chunk_dimension = get_hnsw_index_dimension(
&db,
"text_chunk_embedding",
"idx_embedding_text_chunk_embedding",
)
.await;
assert_eq!(
initial_chunk_dimension, current_settings.embedding_dimensions,
@@ -352,10 +421,18 @@ mod tests {
.await
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
let text_chunk_dimension =
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
let knowledge_dimension =
get_hnsw_index_dimension(&db, "knowledge_entity", "idx_embedding_entities").await;
let text_chunk_dimension = get_hnsw_index_dimension(
&db,
"text_chunk_embedding",
"idx_embedding_text_chunk_embedding",
)
.await;
let knowledge_dimension = get_hnsw_index_dimension(
&db,
"knowledge_entity_embedding",
"idx_embedding_knowledge_entity_embedding",
)
.await;
assert_eq!(
text_chunk_dimension, new_dimension,

View File

@@ -1,5 +1,8 @@
#![allow(clippy::missing_docs_in_private_items, clippy::uninlined_format_args)]
use std::collections::HashMap;
use std::fmt::Write;
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use async_openai::{config::OpenAIConfig, Client};
use tokio_retry::{
@@ -13,12 +16,19 @@ use uuid::Uuid;
stored_object!(TextChunk, "text_chunk", {
source_id: String,
chunk: String,
embedding: Vec<f32>,
user_id: String
});
/// Search result including hydrated chunk.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct TextChunkSearchResult {
pub chunk: TextChunk,
pub score: f32,
}
impl TextChunk {
pub fn new(source_id: String, chunk: String, embedding: Vec<f32>, user_id: String) -> Self {
pub fn new(source_id: String, chunk: String, user_id: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
@@ -26,7 +36,6 @@ impl TextChunk {
updated_at: now,
source_id,
chunk,
embedding,
user_id,
}
}
@@ -45,6 +54,167 @@ impl TextChunk {
Ok(())
}
/// Atomically store a text chunk and its embedding.
/// Writes the chunk to `text_chunk` and the embedding to `text_chunk_embedding`.
pub async fn store_with_embedding(
chunk: TextChunk,
embedding: Vec<f32>,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let chunk_id = chunk.id.clone();
let source_id = chunk.source_id.clone();
let user_id = chunk.user_id.clone();
let emb = TextChunkEmbedding::new(&chunk_id, source_id.clone(), embedding, user_id.clone());
// Create both records in a single transaction so we don't orphan embeddings or chunks
let response = db
.client
.query("BEGIN TRANSACTION;")
.query(format!(
"CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;",
chunk_table = Self::table_name(),
))
.query(format!(
"CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;",
emb_table = TextChunkEmbedding::table_name(),
))
.query("COMMIT TRANSACTION;")
.bind(("chunk_id", chunk_id.clone()))
.bind(("chunk", chunk))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.await
.map_err(AppError::Database)?;
response.check().map_err(AppError::Database)?;
Ok(())
}
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
pub async fn vector_search(
take: usize,
query_embedding: Vec<f32>,
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<TextChunkSearchResult>, AppError> {
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct Row {
chunk_id: TextChunk,
score: f32,
}
let sql = format!(
r#"
SELECT
chunk_id,
embedding,
vector::similarity::cosine(embedding, $embedding) AS score
FROM {emb_table}
WHERE user_id = $user_id
AND embedding <|{take},100|> $embedding
ORDER BY score DESC
LIMIT {take}
FETCH chunk_id;
"#,
emb_table = TextChunkEmbedding::table_name(),
take = take
);
let mut response = db
.query(&sql)
.bind(("embedding", query_embedding))
.bind(("user_id", user_id.to_string()))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).unwrap_or_default();
Ok(rows
.into_iter()
.map(|r| TextChunkSearchResult {
chunk: r.chunk_id,
score: r.score,
})
.collect())
}
/// Full-text search over text chunks using the BM25 FTS index.
pub async fn fts_search(
take: usize,
terms: &str,
db: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<TextChunkSearchResult>, AppError> {
#[derive(Deserialize)]
struct Row {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
#[serde(deserialize_with = "deserialize_datetime")]
created_at: DateTime<Utc>,
#[serde(deserialize_with = "deserialize_datetime")]
updated_at: DateTime<Utc>,
source_id: String,
chunk: String,
user_id: String,
score: f32,
}
let limit = i64::try_from(take).unwrap_or(i64::MAX);
let sql = format!(
r#"
SELECT
id,
created_at,
updated_at,
source_id,
chunk,
user_id,
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END AS score
FROM {chunk_table}
WHERE chunk @0@ $terms
AND user_id = $user_id
ORDER BY score DESC
LIMIT $limit;
"#,
chunk_table = Self::table_name(),
);
let mut response = db
.query(&sql)
.bind(("terms", terms.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", limit))
.await
.map_err(|e| AppError::InternalError(format!("Surreal query failed: {e}")))?;
response = response.check().map_err(AppError::Database)?;
let rows: Vec<Row> = response.take::<Vec<Row>>(0).map_err(AppError::Database)?;
Ok(rows
.into_iter()
.map(|r| {
let chunk = TextChunk {
id: r.id,
created_at: r.created_at,
updated_at: r.updated_at,
source_id: r.source_id,
chunk: r.chunk,
user_id: r.user_id,
};
TextChunkSearchResult {
chunk,
score: r.score,
}
})
.collect())
}
/// Re-creates embeddings for all text chunks using a safe, atomic transaction.
///
/// This is a costly operation that should be run in the background. It performs these steps:
@@ -70,23 +240,16 @@ impl TextChunk {
if total_chunks == 0 {
info!("No text chunks to update. Just updating the idx");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
new_dimensions));
transaction_query.push_str("COMMIT TRANSACTION;");
db.query(transaction_query).await?;
TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions as usize).await?;
return Ok(());
}
info!("Found {} chunks to process.", total_chunks);
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all chunks...");
for chunk in all_chunks.iter() {
for chunk in &all_chunks {
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let embedding = Retry::spawn(retry_strategy, || {
@@ -108,215 +271,615 @@ impl TextChunk {
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(chunk.id.clone(), embedding);
new_embeddings.insert(
chunk.id.clone(),
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
);
}
info!("Successfully generated all new embeddings.");
// Perform DB updates in a single transaction
info!("Applying schema and data changes in a transaction...");
// Perform DB updates in a single transaction against the embedding table
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
// Add all update statements
for (id, embedding) in new_embeddings {
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|f| f.to_string())
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
transaction_query.push_str(&format!(
"UPDATE type::thing('text_chunk', '{}') SET embedding = {}, updated_at = time::now();",
id, embedding_str
));
// Use the chunk id as the embedding record id to keep a 1:1 mapping
write!(
&mut transaction_query,
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \
chunk_id = type::thing('text_chunk', '{id}'), \
source_id = '{source_id}', \
embedding = {embedding}, \
user_id = '{user_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \
updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id,
source_id = source_id
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
}
// Re-create the index inside the same transaction
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
transaction_query.push_str(&format!(
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
write!(
&mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
));
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation
db.query(transaction_query).await?;
info!("Re-embedding process for text chunks completed successfully.");
Ok(())
}
/// Re-creates embeddings for all text chunks using an `EmbeddingProvider`.
///
/// This variant uses the application's configured embedding provider (FastEmbed, OpenAI, etc.)
/// instead of directly calling OpenAI. Used during startup when embedding configuration changes.
pub async fn update_all_embeddings_with_provider(
db: &SurrealDbClient,
provider: &crate::utils::embedding::EmbeddingProvider,
) -> Result<(), AppError> {
let new_dimensions = provider.dimension();
info!(
dimensions = new_dimensions,
backend = provider.backend_label(),
"Starting re-embedding process for all text chunks"
);
// Fetch all chunks first
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
let total_chunks = all_chunks.len();
if total_chunks == 0 {
info!("No text chunks to update. Just updating the index.");
TextChunkEmbedding::redefine_hnsw_index(db, new_dimensions).await?;
return Ok(());
}
info!(chunks = total_chunks, "Found chunks to process");
// Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all chunks...");
for (i, chunk) in all_chunks.iter().enumerate() {
if i > 0 && i % 100 == 0 {
info!(progress = i, total = total_chunks, "Re-embedding progress");
}
let embedding = provider
.embed(&chunk.chunk)
.await
.map_err(|e| AppError::InternalError(format!("Embedding failed: {e}")))?;
// Safety check: ensure the generated embedding has the correct dimension.
if embedding.len() != new_dimensions {
let err_msg = format!(
"CRITICAL: Generated embedding for chunk {} has incorrect dimension ({}). Expected {}. Aborting.",
chunk.id, embedding.len(), new_dimensions
);
error!("{}", err_msg);
return Err(AppError::InternalError(err_msg));
}
new_embeddings.insert(
chunk.id.clone(),
(embedding, chunk.user_id.clone(), chunk.source_id.clone()),
);
}
info!("Successfully generated all new embeddings.");
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
info!("Removing old index and clearing embeddings...");
// Explicitly remove the index first. This prevents background HNSW maintenance from crashing
// when we delete/replace data, dealing with a known SurrealDB panic.
db.client
.query(format!(
"REMOVE INDEX idx_embedding_text_chunk_embedding ON TABLE {};",
TextChunkEmbedding::table_name()
))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
db.client
.query(format!("DELETE FROM {};", TextChunkEmbedding::table_name()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
// Perform DB updates in a single transaction against the embedding table
info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;");
for (id, (embedding, user_id, source_id)) in new_embeddings {
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
write!(
&mut transaction_query,
"CREATE type::thing('text_chunk_embedding', '{id}') SET \
chunk_id = type::thing('text_chunk', '{id}'), \
source_id = '{source_id}', \
embedding = {embedding}, \
user_id = '{user_id}', \
created_at = time::now(), \
updated_at = time::now();",
id = id,
embedding = embedding_str,
user_id = user_id,
source_id = source_id
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
}
write!(
&mut transaction_query,
"DEFINE INDEX OVERWRITE idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION {};",
new_dimensions
)
.map_err(|e| AppError::InternalError(e.to_string()))?;
transaction_query.push_str("COMMIT TRANSACTION;");
db.client
.query(transaction_query)
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
info!("Re-embedding process for text chunks completed successfully.");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes};
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use surrealdb::RecordId;
use uuid::Uuid;
async fn ensure_chunk_fts_index(db: &SurrealDbClient) {
let snowball_sql = r#"
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english);
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
"#;
if let Err(err) = db.client.query(snowball_sql).await {
// Fall back to ascii-only analyzer when snowball is unavailable in the build.
let fallback_sql = r#"
DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii;
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk FIELDS chunk SEARCH ANALYZER app_en_fts_analyzer BM25;
"#;
db.client
.query(fallback_sql)
.await
.unwrap_or_else(|_| panic!("define chunk fts index fallback: {err}"));
}
}
#[tokio::test]
async fn test_text_chunk_creation() {
// Test basic object creation
let source_id = "source123".to_string();
let chunk = "This is a text chunk for testing embeddings".to_string();
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
let text_chunk = TextChunk::new(
source_id.clone(),
chunk.clone(),
embedding.clone(),
user_id.clone(),
);
let text_chunk = TextChunk::new(source_id.clone(), chunk.clone(), user_id.clone());
// Check that the fields are set correctly
assert_eq!(text_chunk.source_id, source_id);
assert_eq!(text_chunk.chunk, chunk);
assert_eq!(text_chunk.embedding, embedding);
assert_eq!(text_chunk.user_id, user_id);
assert!(!text_chunk.id.is_empty());
}
#[tokio::test]
async fn test_delete_by_source_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
// Create test data
let source_id = "source123".to_string();
let chunk1 = "First chunk from the same source".to_string();
let chunk2 = "Second chunk from the same source".to_string();
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
// Create two chunks with the same source_id
let text_chunk1 = TextChunk::new(
let chunk1 = TextChunk::new(
source_id.clone(),
chunk1,
embedding.clone(),
"First chunk from the same source".to_string(),
user_id.clone(),
);
let text_chunk2 = TextChunk::new(
let chunk2 = TextChunk::new(
source_id.clone(),
chunk2,
embedding.clone(),
"Second chunk from the same source".to_string(),
user_id.clone(),
);
// Create a chunk with a different source_id
let different_source_id = "different_source".to_string();
let different_chunk = TextChunk::new(
different_source_id.clone(),
"different_source".to_string(),
"Different source chunk".to_string(),
embedding.clone(),
user_id.clone(),
);
// Store the chunks
db.store_item(text_chunk1)
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("Failed to store text chunk 1");
db.store_item(text_chunk2)
.expect("store chunk1");
TextChunk::store_with_embedding(chunk2.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("Failed to store text chunk 2");
db.store_item(different_chunk.clone())
.await
.expect("Failed to store different chunk");
.expect("store chunk2");
TextChunk::store_with_embedding(
different_chunk.clone(),
vec![0.1, 0.2, 0.3, 0.4, 0.5],
&db,
)
.await
.expect("store different chunk");
// Delete by source_id
TextChunk::delete_by_source_id(&source_id, &db)
.await
.expect("Failed to delete chunks by source_id");
// Verify all chunks with the original source_id are deleted
let query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
source_id
);
let remaining: Vec<TextChunk> = db
.client
.query(query)
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
source_id
))
.await
.expect("Query failed")
.take(0)
.expect("Failed to get query results");
assert_eq!(
remaining.len(),
0,
"All chunks with the source_id should be deleted"
);
assert_eq!(remaining.len(), 0);
// Verify the different source_id chunk still exists
let different_query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
different_source_id
);
let different_remaining: Vec<TextChunk> = db
.client
.query(different_query)
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
"different_source"
))
.await
.expect("Query failed")
.take(0)
.expect("Failed to get query results");
assert_eq!(
different_remaining.len(),
1,
"Chunk with different source_id should still exist"
);
assert_eq!(different_remaining.len(), 1);
assert_eq!(different_remaining[0].id, different_chunk.id);
}
#[tokio::test]
async fn test_delete_by_nonexistent_source_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
// Create a chunk with a real source_id
let real_source_id = "real_source".to_string();
let chunk = "Test chunk".to_string();
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let user_id = "user123".to_string();
let text_chunk = TextChunk::new(real_source_id.clone(), chunk, embedding, user_id);
// Store the chunk
db.store_item(text_chunk)
.await
.expect("Failed to store text chunk");
// Delete using nonexistent source_id
let nonexistent_source_id = "nonexistent_source";
TextChunk::delete_by_source_id(nonexistent_source_id, &db)
.await
.expect("Delete operation with nonexistent source_id should not fail");
// Verify the real chunk still exists
let query = format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
real_source_id
let chunk = TextChunk::new(
real_source_id.clone(),
"Test chunk".to_string(),
"user123".to_string(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("store chunk");
TextChunk::delete_by_source_id("nonexistent_source", &db)
.await
.expect("Delete should succeed");
let remaining: Vec<TextChunk> = db
.client
.query(query)
.query(format!(
"SELECT * FROM {} WHERE source_id = '{}'",
TextChunk::table_name(),
real_source_id
))
.await
.expect("Query failed")
.take(0)
.expect("Failed to get query results");
assert_eq!(remaining.len(), 1);
}
#[tokio::test]
async fn test_store_with_embedding_creates_both_records() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
let source_id = "store-src".to_string();
let user_id = "user_store".to_string();
let chunk = TextChunk::new(source_id.clone(), "chunk body".to_string(), user_id.clone());
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store with embedding");
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
assert!(stored_chunk.is_some());
let stored_chunk = stored_chunk.unwrap();
assert_eq!(stored_chunk.source_id, source_id);
assert_eq!(stored_chunk.user_id, user_id);
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await
.expect("get embedding");
assert!(embedding.is_some());
let embedding = embedding.unwrap();
assert_eq!(embedding.chunk_id, rid);
assert_eq!(embedding.user_id, user_id);
assert_eq!(embedding.source_id, source_id);
}
#[tokio::test]
async fn test_store_with_embedding_with_runtime_indexes() {
let namespace = "test_ns_runtime";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
// Ensure runtime indexes are built with the expected dimension.
let embedding_dimension = 3usize;
ensure_runtime_indexes(&db, embedding_dimension)
.await
.expect("ensure runtime indexes");
let chunk = TextChunk::new(
"runtime_src".to_string(),
"runtime chunk body".to_string(),
"runtime_user".to_string(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store with embedding");
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
assert!(stored_chunk.is_some(), "chunk should be stored");
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await
.expect("get embedding");
assert!(embedding.is_some(), "embedding should exist");
assert_eq!(
remaining.len(),
1,
"Chunk with real source_id should still exist"
embedding.unwrap().embedding.len(),
embedding_dimension,
"embedding dimension should match runtime index"
);
}
#[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
.await
.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_vector_search_single_result() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
let source_id = "src".to_string();
let user_id = "user".to_string();
let chunk = TextChunk::new(
source_id.clone(),
"hello world".to_string(),
user_id.clone(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store");
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.unwrap();
assert_eq!(results.len(), 1);
let res = &results[0];
assert_eq!(res.chunk.id, chunk.id);
assert_eq!(res.chunk.source_id, source_id);
assert_eq!(res.chunk.chunk, "hello world");
}
#[tokio::test]
async fn test_vector_search_orders_by_similarity() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("redefine index");
let user_id = "user".to_string();
let chunk1 = TextChunk::new("s1".to_string(), "chunk one".to_string(), user_id.clone());
let chunk2 = TextChunk::new("s2".to_string(), "chunk two".to_string(), user_id.clone());
TextChunk::store_with_embedding(chunk1.clone(), vec![1.0, 0.0, 0.0], &db)
.await
.expect("store chunk1");
TextChunk::store_with_embedding(chunk2.clone(), vec![0.0, 1.0, 0.0], &db)
.await
.expect("store chunk2");
let results: Vec<TextChunkSearchResult> =
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
.await
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].chunk.id, chunk2.id);
assert_eq!(results[1].chunk.id, chunk1.id);
assert!(results[0].score >= results[1].score);
}
#[tokio::test]
async fn test_fts_search_returns_empty_when_no_chunks() {
let namespace = "fts_chunk_ns_empty";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(5, "hello", &db, "user")
.await
.expect("fts search");
assert!(results.is_empty());
}
#[tokio::test]
async fn test_fts_search_single_result() {
let namespace = "fts_chunk_ns_single";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
let user_id = "fts_user";
let chunk = TextChunk::new(
"fts_src".to_string(),
"rustaceans love rust".to_string(),
user_id.to_string(),
);
db.store_item(chunk.clone()).await.expect("store chunk");
rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(3, "rust", &db, user_id)
.await
.expect("fts search");
assert_eq!(results.len(), 1);
assert_eq!(results[0].chunk.id, chunk.id);
assert!(results[0].score.is_finite(), "expected a finite FTS score");
}
#[tokio::test]
async fn test_fts_search_orders_by_score_and_filters_user() {
let namespace = "fts_chunk_ns_order";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
let user_id = "fts_user_order";
let high_score_chunk = TextChunk::new(
"src1".to_string(),
"apple apple apple pie recipe".to_string(),
user_id.to_string(),
);
let low_score_chunk = TextChunk::new(
"src2".to_string(),
"apple tart".to_string(),
user_id.to_string(),
);
let other_user_chunk = TextChunk::new(
"src3".to_string(),
"apple orchard guide".to_string(),
"other_user".to_string(),
);
db.store_item(high_score_chunk.clone())
.await
.expect("store high score chunk");
db.store_item(low_score_chunk.clone())
.await
.expect("store low score chunk");
db.store_item(other_user_chunk)
.await
.expect("store other user chunk");
rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(3, "apple", &db, user_id)
.await
.expect("fts search");
assert_eq!(results.len(), 2);
let ids: Vec<_> = results.iter().map(|r| r.chunk.id.as_str()).collect();
assert!(
ids.contains(&high_score_chunk.id.as_str())
&& ids.contains(&low_score_chunk.id.as_str()),
"expected only the two chunks for the same user"
);
assert!(
results[0].score >= results[1].score,
"expected results ordered by descending score"
);
}
}

View File

@@ -0,0 +1,436 @@
use surrealdb::RecordId;
use crate::storage::types::text_chunk::TextChunk;
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
stored_object!(TextChunkEmbedding, "text_chunk_embedding", {
/// Record link to the owning text_chunk
chunk_id: RecordId,
/// Denormalized source id for bulk deletes
source_id: String,
/// Embedding vector
embedding: Vec<f32>,
/// Denormalized user id (for scoping + permissions)
user_id: String
});
impl TextChunkEmbedding {
/// Recreate the HNSW index with a new embedding dimension.
///
/// This is useful when the embedding length changes; Surreal requires the
/// index definition to be recreated with the updated dimension.
pub async fn redefine_hnsw_index(
db: &SurrealDbClient,
dimension: usize,
) -> Result<(), AppError> {
let query = format!(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE {table};
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE {table} FIELDS embedding HNSW DIMENSION {dimension};
COMMIT TRANSACTION;",
table = Self::table_name(),
);
let res = db.client.query(query).await.map_err(AppError::Database)?;
res.check().map_err(AppError::Database)?;
Ok(())
}
/// Create a new text chunk embedding
///
/// `chunk_id` is the **key** part of the text_chunk id (e.g. the UUID),
/// not "text_chunk:uuid".
pub fn new(chunk_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
let now = Utc::now();
Self {
// NOTE: `stored_object!` macro defines `id` as `String`
id: uuid::Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
// Create a record<text_chunk> link: text_chunk:<chunk_id>
chunk_id: RecordId::from_table_key(TextChunk::table_name(), chunk_id),
source_id,
embedding,
user_id,
}
}
/// Get a single embedding by its chunk RecordId
pub async fn get_by_chunk_id(
chunk_id: &RecordId,
db: &SurrealDbClient,
) -> Result<Option<Self>, AppError> {
let query = format!(
"SELECT * FROM {} WHERE chunk_id = $chunk_id LIMIT 1",
Self::table_name()
);
let mut result = db
.client
.query(query)
.bind(("chunk_id", chunk_id.clone()))
.await
.map_err(AppError::Database)?;
let embeddings: Vec<Self> = result.take(0).map_err(AppError::Database)?;
Ok(embeddings.into_iter().next())
}
/// Delete embeddings for a given chunk RecordId
pub async fn delete_by_chunk_id(
chunk_id: &RecordId,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE FROM {} WHERE chunk_id = $chunk_id",
Self::table_name()
);
db.client
.query(query)
.bind(("chunk_id", chunk_id.clone()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
/// Delete all embeddings that belong to chunks with a given `source_id`
///
/// This uses a subquery to the `text_chunk` table:
///
/// DELETE FROM text_chunk_embedding
/// WHERE chunk_id IN (SELECT id FROM text_chunk WHERE source_id = $source_id)
pub async fn delete_by_source_id(
source_id: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)]
struct IdRow {
id: RecordId,
}
let ids_query = format!(
"SELECT id FROM {} WHERE source_id = $source_id",
TextChunk::table_name()
);
let mut res = db
.client
.query(ids_query)
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?;
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
if ids.is_empty() {
return Ok(());
}
let delete_query = format!(
"DELETE FROM {} WHERE chunk_id IN $chunk_ids",
Self::table_name()
);
db.client
.query(delete_query)
.bind((
"chunk_ids",
ids.into_iter().map(|row| row.id).collect::<Vec<_>>(),
))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::db::SurrealDbClient;
use surrealdb::Value as SurrealValue;
use uuid::Uuid;
/// Helper to create an in-memory DB and apply migrations
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
}
/// Helper: create a text_chunk with a known key, return its RecordId
async fn create_text_chunk_with_id(
db: &SurrealDbClient,
key: &str,
source_id: &str,
user_id: &str,
) -> RecordId {
let chunk = TextChunk {
id: key.to_owned(),
created_at: Utc::now(),
updated_at: Utc::now(),
source_id: source_id.to_owned(),
chunk: "Some test chunk text".to_owned(),
user_id: user_id.to_owned(),
};
db.store_item(chunk)
.await
.expect("Failed to create text_chunk");
RecordId::from_table_key(TextChunk::table_name(), key)
}
#[tokio::test]
async fn test_create_and_get_by_chunk_id() {
let db = setup_test_db().await;
let user_id = "user_a";
let chunk_key = "chunk-123";
let source_id = "source-1";
// 1) Create a text_chunk with a known key
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
// 2) Create and store an embedding for that chunk
let embedding_vec = vec![0.1_f32, 0.2, 0.3];
let emb = TextChunkEmbedding::new(
chunk_key,
source_id.to_string(),
embedding_vec.clone(),
user_id.to_string(),
);
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
.await
.expect("Failed to redefine index length");
let _: Option<TextChunkEmbedding> = db
.client
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
// 3) Fetch it via get_by_chunk_id
let fetched = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding by chunk_id");
assert!(fetched.is_some(), "Expected an embedding to be found");
let fetched = fetched.unwrap();
assert_eq!(fetched.user_id, user_id);
assert_eq!(fetched.chunk_id, chunk_rid);
assert_eq!(fetched.embedding, embedding_vec);
}
#[tokio::test]
async fn test_delete_by_chunk_id() {
let db = setup_test_db().await;
let user_id = "user_b";
let chunk_key = "chunk-delete";
let source_id = "source-del";
let chunk_rid = create_text_chunk_with_id(&db, chunk_key, source_id, user_id).await;
let emb = TextChunkEmbedding::new(
chunk_key,
source_id.to_string(),
vec![0.4_f32, 0.5, 0.6],
user_id.to_string(),
);
TextChunkEmbedding::redefine_hnsw_index(&db, emb.embedding.len())
.await
.expect("Failed to redefine index length");
let _: Option<TextChunkEmbedding> = db
.client
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
// Ensure it exists
let existing = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding before delete");
assert!(existing.is_some(), "Embedding should exist before delete");
// Delete by chunk_id
TextChunkEmbedding::delete_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to delete by chunk_id");
// Ensure it no longer exists
let after = TextChunkEmbedding::get_by_chunk_id(&chunk_rid, &db)
.await
.expect("Failed to get embedding after delete");
assert!(after.is_none(), "Embedding should have been deleted");
}
#[tokio::test]
async fn test_delete_by_source_id() {
let db = setup_test_db().await;
let user_id = "user_c";
let source_id = "shared-source";
let other_source = "other-source";
// Two chunks with the same source_id
let chunk1_rid = create_text_chunk_with_id(&db, "chunk-s1", source_id, user_id).await;
let chunk2_rid = create_text_chunk_with_id(&db, "chunk-s2", source_id, user_id).await;
// One chunk with a different source_id
let chunk_other_rid =
create_text_chunk_with_id(&db, "chunk-other", other_source, user_id).await;
// Create embeddings for all three
let emb1 = TextChunkEmbedding::new(
"chunk-s1",
source_id.to_string(),
vec![0.1],
user_id.to_string(),
);
let emb2 = TextChunkEmbedding::new(
"chunk-s2",
source_id.to_string(),
vec![0.2],
user_id.to_string(),
);
let emb3 = TextChunkEmbedding::new(
"chunk-other",
other_source.to_string(),
vec![0.3],
user_id.to_string(),
);
// Update length on index
TextChunkEmbedding::redefine_hnsw_index(&db, emb1.embedding.len())
.await
.expect("Failed to redefine index length");
for emb in [emb1, emb2, emb3] {
let _: Option<TextChunkEmbedding> = db
.client
.create(TextChunkEmbedding::table_name())
.content(emb)
.await
.expect("Failed to store embedding")
.take()
.expect("Failed to deserialize stored embedding");
}
// Sanity check: they all exist
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
.await
.unwrap()
.is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
.await
.unwrap()
.is_some());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
.await
.unwrap()
.is_some());
// Delete embeddings by source_id (shared-source)
TextChunkEmbedding::delete_by_source_id(source_id, &db)
.await
.expect("Failed to delete by source_id");
// Chunks from shared-source should have no embeddings
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk1_rid, &db)
.await
.unwrap()
.is_none());
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk2_rid, &db)
.await
.unwrap()
.is_none());
// The other chunk should still have its embedding
assert!(TextChunkEmbedding::get_by_chunk_id(&chunk_other_rid, &db)
.await
.unwrap()
.is_some());
}
#[tokio::test]
async fn test_redefine_hnsw_index_updates_dimension() {
let db = setup_test_db().await;
// Change the index dimension from default (1536) to a smaller test value.
TextChunkEmbedding::redefine_hnsw_index(&db, 8)
.await
.expect("failed to redefine index");
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_text_chunk_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
assert!(
idx_sql.contains("DIMENSION 8"),
"expected index definition to contain new dimension, got: {idx_sql}"
);
}
#[tokio::test]
async fn test_redefine_hnsw_index_is_idempotent() {
let db = setup_test_db().await;
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
.await
.expect("first redefine failed");
TextChunkEmbedding::redefine_hnsw_index(&db, 4)
.await
.expect("second redefine failed");
let mut info_res = db
.client
.query("INFO FOR TABLE text_chunk_embedding;")
.await
.expect("info query failed");
let info: SurrealValue = info_res.take(0).expect("failed to take info result");
let info_json: serde_json::Value =
serde_json::to_value(info).expect("failed to convert info to json");
let idx_sql = info_json["Object"]["indexes"]["Object"]
["idx_embedding_text_chunk_embedding"]["Strand"]
.as_str()
.unwrap_or_default();
assert!(
idx_sql.contains("DIMENSION 4"),
"expected index definition to retain dimension 4, got: {idx_sql}"
);
}
}

View File

@@ -5,6 +5,7 @@ use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use super::file_info::FileInfo;
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Deserialize, Serialize)]
pub struct TextContentSearchResult {
#[serde(deserialize_with = "deserialize_flexible_id")]
@@ -50,8 +51,11 @@ pub struct TextContentSearchResult {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct UrlInfo {
#[serde(default)]
pub url: String,
#[serde(default)]
pub title: String,
#[serde(default)]
pub image_id: String,
}
@@ -146,12 +150,12 @@ impl TextContent {
search::highlight('<b>', '</b>', 4) AS highlighted_url,
search::highlight('<b>', '</b>', 5) AS highlighted_url_title,
(
search::score(0) +
search::score(1) +
search::score(2) +
search::score(3) +
search::score(4) +
search::score(5)
IF search::score(0) != NONE THEN search::score(0) ELSE 0 END +
IF search::score(1) != NONE THEN search::score(1) ELSE 0 END +
IF search::score(2) != NONE THEN search::score(2) ELSE 0 END +
IF search::score(3) != NONE THEN search::score(3) ELSE 0 END +
IF search::score(4) != NONE THEN search::score(4) ELSE 0 END +
IF search::score(5) != NONE THEN search::score(5) ELSE 0 END
) AS score
FROM text_content
WHERE

View File

@@ -1,4 +1,5 @@
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
use anyhow::anyhow;
use async_trait::async_trait;
use axum_session_auth::Authentication;
use chrono_tz::Tz;
@@ -17,12 +18,16 @@ use super::{
use chrono::Duration;
use futures::try_join;
/// Result row for returning user category.
#[derive(Deserialize)]
pub struct CategoryResponse {
/// Category name tied to the user.
category: String,
}
stored_object!(User, "user", {
stored_object!(
#[allow(clippy::unsafe_derive_deserialize)]
User, "user", {
email: String,
password: String,
anonymous: bool,
@@ -35,11 +40,11 @@ stored_object!(User, "user", {
#[async_trait]
impl Authentication<User, String, Surreal<Any>> for User {
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
let db = db.unwrap();
let db = db.ok_or_else(|| anyhow!("Database handle missing"))?;
Ok(db
.select((Self::table_name(), userid.as_str()))
.await?
.unwrap())
.ok_or_else(|| anyhow!("User {userid} not found"))?)
}
fn is_authenticated(&self) -> bool {
@@ -55,14 +60,14 @@ impl Authentication<User, String, Surreal<Any>> for User {
}
}
/// Ensures a timezone string parses, defaulting to UTC when invalid.
fn validate_timezone(input: &str) -> String {
match input.parse::<Tz>() {
Ok(_) => input.to_owned(),
Err(_) => {
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
}
if input.parse::<Tz>().is_ok() {
return input.to_owned();
}
tracing::warn!("Invalid timezone '{}' received, defaulting to UTC", input);
"UTC".to_owned()
}
#[derive(Serialize, Deserialize, Debug, Clone)]
@@ -77,12 +82,15 @@ pub struct DashboardStats {
pub new_text_chunks_week: i64,
}
/// Helper for aggregating `SurrealDB` count responses.
#[derive(Deserialize)]
struct CountResult {
/// Row count returned by the query.
count: i64,
}
impl User {
/// Counts all objects of a given type belonging to the user.
async fn count_total<T: crate::storage::types::StoredObject>(
db: &SurrealDbClient,
user_id: &str,
@@ -94,9 +102,10 @@ impl User {
.bind(("user_id", user_id.to_string()))
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
Ok(result.map_or(0, |r| r.count))
}
/// Counts objects of a given type created after a specific timestamp.
async fn count_since<T: crate::storage::types::StoredObject>(
db: &SurrealDbClient,
user_id: &str,
@@ -112,14 +121,16 @@ impl User {
.bind(("since", surrealdb::Datetime::from(since)))
.await?
.take(0)?;
Ok(result.map(|r| r.count).unwrap_or(0))
Ok(result.map_or(0, |r| r.count))
}
pub async fn get_dashboard_stats(
user_id: &str,
db: &SurrealDbClient,
) -> Result<DashboardStats, AppError> {
let since = chrono::Utc::now() - Duration::days(7);
let since = chrono::Utc::now()
.checked_sub_signed(Duration::days(7))
.unwrap_or_else(chrono::Utc::now);
let (
total_documents,
@@ -261,7 +272,7 @@ impl User {
pub async fn set_api_key(id: &str, db: &SurrealDbClient) -> Result<String, AppError> {
// Generate a secure random API key
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace("-", ""));
let api_key = format!("sk_{}", Uuid::new_v4().to_string().replace('-', ""));
// Update the user record with the new API key
let user: Option<Self> = db
@@ -341,6 +352,7 @@ impl User {
) -> Result<Vec<String>, AppError> {
#[derive(Deserialize)]
struct EntityTypeResponse {
/// Raw entity type value from the database.
entity_type: String,
}
@@ -358,7 +370,7 @@ impl User {
.into_iter()
.map(|item| {
let normalized = KnowledgeEntityType::from(item.entity_type);
format!("{:?}", normalized)
format!("{normalized:?}")
})
.collect();

View File

@@ -2,12 +2,27 @@ use config::{Config, ConfigError, Environment, File};
use serde::Deserialize;
use std::env;
#[derive(Clone, Deserialize, Debug)]
/// Selects the embedding backend for vector generation.
#[derive(Clone, Deserialize, Debug, Default, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingBackend {
/// Use OpenAI-compatible API for embeddings.
OpenAI,
/// Use FastEmbed local embeddings (default).
#[default]
FastEmbed,
/// Use deterministic hashed embeddings (for testing).
Hashed,
}
#[derive(Clone, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum StorageKind {
Local,
Memory,
}
/// Default storage backend when none is configured.
fn default_storage_kind() -> StorageKind {
StorageKind::Local
}
@@ -22,10 +37,13 @@ pub enum PdfIngestMode {
LlmFirst,
}
/// Default PDF ingestion mode when unset.
fn default_pdf_ingest_mode() -> PdfIngestMode {
PdfIngestMode::LlmFirst
}
/// Application configuration loaded from files and environment variables.
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Deserialize, Debug)]
pub struct AppConfig {
pub openai_api_key: String,
@@ -53,16 +71,23 @@ pub struct AppConfig {
pub fastembed_show_download_progress: Option<bool>,
#[serde(default)]
pub fastembed_max_length: Option<usize>,
#[serde(default)]
pub retrieval_strategy: Option<String>,
#[serde(default)]
pub embedding_backend: EmbeddingBackend,
}
/// Default data directory for persisted assets.
fn default_data_dir() -> String {
"./data".to_string()
}
/// Default base URL used for OpenAI-compatible APIs.
fn default_base_url() -> String {
"https://api.openai.com/v1".to_string()
}
/// Whether reranking is enabled by default.
fn default_reranking_enabled() -> bool {
false
}
@@ -116,10 +141,14 @@ impl Default for AppConfig {
fastembed_cache_dir: None,
fastembed_show_download_progress: None,
fastembed_max_length: None,
retrieval_strategy: None,
embedding_backend: EmbeddingBackend::default(),
}
}
}
/// Loads the application configuration from the environment and optional config file.
#[allow(clippy::module_name_repetitions)]
pub fn get_config() -> Result<AppConfig, ConfigError> {
ensure_ort_path();

View File

@@ -1,19 +1,328 @@
use async_openai::types::CreateEmbeddingRequestArgs;
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
str::FromStr,
sync::Arc,
};
use anyhow::{anyhow, Context, Result};
use async_openai::{types::CreateEmbeddingRequestArgs, Client};
use fastembed::{EmbeddingModel, ModelTrait, TextEmbedding, TextInitOptions};
use tokio::sync::Mutex;
use tracing::debug;
use crate::{
error::AppError,
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
};
/// Generates an embedding vector for the given input text using OpenAI's embedding model.
/// Supported embedding backends.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EmbeddingBackend {
#[default]
OpenAI,
FastEmbed,
Hashed,
}
impl std::str::FromStr for EmbeddingBackend {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"openai" => Ok(Self::OpenAI),
"hashed" => Ok(Self::Hashed),
"fastembed" | "fast-embed" | "fast" => Ok(Self::FastEmbed),
other => Err(anyhow!(
"unknown embedding backend '{other}'. Expected 'openai', 'hashed', or 'fastembed'."
)),
}
}
}
/// Wrapper around the chosen embedding backend.
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)]
pub struct EmbeddingProvider {
/// Concrete backend implementation.
inner: EmbeddingInner,
}
/// Concrete embedding implementations.
#[derive(Clone)]
enum EmbeddingInner {
/// Uses an `OpenAI`-compatible API.
OpenAI {
/// Client used to issue embedding requests.
client: Arc<Client<async_openai::config::OpenAIConfig>>,
/// Model identifier for the API.
model: String,
/// Expected output dimensions.
dimensions: u32,
},
/// Generates deterministic hashed embeddings without external calls.
Hashed {
/// Output vector length.
dimension: usize,
},
/// Uses `FastEmbed` running locally.
FastEmbed {
/// Shared `FastEmbed` model.
model: Arc<Mutex<TextEmbedding>>,
/// Model metadata used for info logging.
model_name: EmbeddingModel,
/// Output vector length.
dimension: usize,
},
}
impl EmbeddingProvider {
pub fn backend_label(&self) -> &'static str {
match self.inner {
EmbeddingInner::Hashed { .. } => "hashed",
EmbeddingInner::FastEmbed { .. } => "fastembed",
EmbeddingInner::OpenAI { .. } => "openai",
}
}
pub fn dimension(&self) -> usize {
match &self.inner {
EmbeddingInner::Hashed { dimension } | EmbeddingInner::FastEmbed { dimension, .. } => {
*dimension
}
EmbeddingInner::OpenAI { dimensions, .. } => *dimensions as usize,
}
}
pub fn model_code(&self) -> Option<String> {
match &self.inner {
EmbeddingInner::FastEmbed { model_name, .. } => Some(model_name.to_string()),
EmbeddingInner::OpenAI { model, .. } => Some(model.clone()),
EmbeddingInner::Hashed { .. } => None,
}
}
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(hashed_embedding(text, *dimension)),
EmbeddingInner::FastEmbed { model, .. } => {
let mut guard = model.lock().await;
let embeddings = guard
.embed(vec![text.to_owned()], None)
.context("generating fastembed vector")?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("fastembed returned no embedding for input"))
}
EmbeddingInner::OpenAI {
client,
model,
dimensions,
} => {
let request = CreateEmbeddingRequestArgs::default()
.model(model.clone())
.input([text])
.dimensions(*dimensions)
.build()?;
let response = client.embeddings().create(request).await?;
let embedding = response
.data
.first()
.ok_or_else(|| anyhow!("No embedding data received from OpenAI API"))?
.embedding
.clone();
Ok(embedding)
}
}
}
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
match &self.inner {
EmbeddingInner::Hashed { dimension } => Ok(texts
.into_iter()
.map(|text| hashed_embedding(&text, *dimension))
.collect()),
EmbeddingInner::FastEmbed { model, .. } => {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut guard = model.lock().await;
guard
.embed(texts, None)
.context("generating fastembed batch embeddings")
}
EmbeddingInner::OpenAI {
client,
model,
dimensions,
} => {
if texts.is_empty() {
return Ok(Vec::new());
}
let request = CreateEmbeddingRequestArgs::default()
.model(model.clone())
.input(texts)
.dimensions(*dimensions)
.build()?;
let response = client.embeddings().create(request).await?;
let embeddings: Vec<Vec<f32>> = response
.data
.into_iter()
.map(|item| item.embedding)
.collect();
Ok(embeddings)
}
}
}
pub fn new_openai(
client: Arc<Client<async_openai::config::OpenAIConfig>>,
model: String,
dimensions: u32,
) -> Result<Self> {
Ok(Self {
inner: EmbeddingInner::OpenAI {
client,
model,
dimensions,
},
})
}
pub async fn new_fastembed(model_override: Option<String>) -> Result<Self> {
let model_name = if let Some(code) = model_override {
EmbeddingModel::from_str(&code).map_err(|err| anyhow!(err))?
} else {
EmbeddingModel::default()
};
let options = TextInitOptions::new(model_name.clone()).with_show_download_progress(true);
let model_name_for_task = model_name.clone();
let model_name_code = model_name.to_string();
let (model, dimension) = tokio::task::spawn_blocking(move || -> Result<_> {
let model =
TextEmbedding::try_new(options).context("initialising FastEmbed text model")?;
let info = EmbeddingModel::get_model_info(&model_name_for_task)
.ok_or_else(|| anyhow!("FastEmbed model metadata missing for {model_name_code}"))?;
Ok((model, info.dim))
})
.await
.context("joining FastEmbed initialisation task")??;
Ok(EmbeddingProvider {
inner: EmbeddingInner::FastEmbed {
model: Arc::new(Mutex::new(model)),
model_name,
dimension,
},
})
}
pub fn new_hashed(dimension: usize) -> Result<Self> {
Ok(EmbeddingProvider {
inner: EmbeddingInner::Hashed {
dimension: dimension.max(1),
},
})
}
/// Creates an embedding provider based on application configuration.
///
/// Dispatches to the appropriate constructor based on `config.embedding_backend`:
/// - `OpenAI`: Requires a valid OpenAI client
/// - `FastEmbed`: Uses local embedding model
/// - `Hashed`: Uses deterministic hashed embeddings (for testing)
pub async fn from_config(
config: &crate::utils::config::AppConfig,
openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
) -> Result<Self> {
use crate::utils::config::EmbeddingBackend;
match config.embedding_backend {
EmbeddingBackend::OpenAI => {
let client = openai_client
.ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
// Use defaults that match SystemSettings initial values
Self::new_openai(client, "text-embedding-3-small".to_string(), 1536)
}
EmbeddingBackend::FastEmbed => {
// Use nomic-embed-text-v1.5 as the default FastEmbed model
Self::new_fastembed(Some("nomic-ai/nomic-embed-text-v1.5".to_string())).await
}
EmbeddingBackend::Hashed => Self::new_hashed(384),
}
}
}
// Helper functions for hashed embeddings
/// Generates a hashed embedding vector without external dependencies.
fn hashed_embedding(text: &str, dimension: usize) -> Vec<f32> {
let dim = dimension.max(1);
let mut vector = vec![0.0f32; dim];
if text.is_empty() {
return vector;
}
for token in tokens(text) {
let idx = bucket(&token, dim);
if let Some(slot) = vector.get_mut(idx) {
*slot += 1.0;
}
}
let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
for value in &mut vector {
*value /= norm;
}
}
vector
}
/// Tokenizes the text into alphanumeric lowercase tokens.
fn tokens(text: &str) -> impl Iterator<Item = String> + '_ {
text.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|token| !token.is_empty())
.map(str::to_ascii_lowercase)
}
/// Buckets a token into the hashed embedding vector.
#[allow(clippy::arithmetic_side_effects)]
fn bucket(token: &str, dimension: usize) -> usize {
let safe_dimension = dimension.max(1);
let mut hasher = DefaultHasher::new();
token.hash(&mut hasher);
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
}
// Backward compatibility function
pub async fn generate_embedding_with_provider(
provider: &EmbeddingProvider,
input: &str,
) -> Result<Vec<f32>, AppError> {
provider.embed(input).await.map_err(AppError::from)
}
/// Generates an embedding vector for the given input text using `OpenAI`'s embedding model.
///
/// This function takes a text input and converts it into a numerical vector representation (embedding)
/// using OpenAI's text-embedding-3-small model. These embeddings can be used for semantic similarity
/// using `OpenAI`'s text-embedding-3-small model. These embeddings can be used for semantic similarity
/// comparisons, vector search, and other natural language processing tasks.
///
/// # Arguments
///
/// * `client`: The OpenAI client instance used to make API requests.
/// * `client`: The `OpenAI` client instance used to make API requests.
/// * `input`: The text string to generate embeddings for.
///
/// # Returns
@@ -25,9 +334,10 @@ use crate::{
/// # Errors
///
/// This function can return a `AppError` in the following cases:
/// * If the OpenAI API request fails
/// * If the `OpenAI` API request fails
/// * If the request building fails
/// * If no embedding data is received in the response
#[allow(clippy::module_name_repetitions)]
pub async fn generate_embedding(
client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input: &str,

View File

@@ -4,6 +4,7 @@ pub use minijinja_contrib;
pub use minijinja_embed;
use std::sync::Arc;
#[allow(clippy::module_name_repetitions)]
pub trait ProvidesTemplateEngine {
fn template_engine(&self) -> &Arc<TemplateEngine>;
}

View File

@@ -1,265 +0,0 @@
use std::collections::HashMap;
use serde::Deserialize;
use tracing::debug;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
};
use crate::scoring::Scored;
use common::storage::types::file_info::deserialize_flexible_id;
use surrealdb::sql::Thing;
#[derive(Debug, Deserialize)]
struct FtsScoreRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
fts_score: Option<f32>,
}
/// Executes a full-text search query against SurrealDB and returns scored results.
///
/// The function expects FTS indexes to exist for the provided table. Currently supports
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
pub async fn find_items_by_fts<T>(
take: usize,
query: &str,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let (filter_clause, score_clause) = match table {
"knowledge_entity" => (
"(name @0@ $terms OR description @1@ $terms)",
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
),
"text_chunk" => (
"(chunk @0@ $terms)",
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
),
_ => {
return Err(AppError::Validation(format!(
"FTS not configured for table '{table}'"
)))
}
};
let sql = format!(
"SELECT id, {score_clause} AS fts_score \
FROM {table} \
WHERE {filter_clause} \
AND user_id = $user_id \
ORDER BY fts_score DESC \
LIMIT $limit",
table = table,
filter_clause = filter_clause,
score_clause = score_clause
);
debug!(
table = table,
limit = take,
"Executing FTS query with filter clause: {}",
filter_clause
);
let mut response = db_client
.query(sql)
.bind(("terms", query.to_owned()))
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
if score_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut results = Vec::with_capacity(score_rows.len());
for row in score_rows {
if let Some(item) = item_map.remove(&row.id) {
let score = row.fts_score.unwrap_or_default();
results.push(Scored::new(item).with_fts_score(score));
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
text_chunk::TextChunk,
StoredObject,
};
use uuid::Uuid;
fn dummy_embedding() -> Vec<f32> {
vec![0.0; 1536]
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_name() {
let namespace = "fts_test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts";
let entity = KnowledgeEntity::new(
"source_a".into(),
"Rustacean handbook".into(),
"completely unrelated description".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"rustacean",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the name matched"
);
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_description() {
let namespace = "fts_test_ns_desc";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts_desc";
let entity = KnowledgeEntity::new(
"source_b".into(),
"neutral name".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("failed to insert entity");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results = find_items_by_fts::<KnowledgeEntity>(
5,
"async",
&db,
KnowledgeEntity::table_name(),
user_id,
)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when only the description matched"
);
}
#[tokio::test]
async fn fts_preserves_scores_for_text_chunks() {
let namespace = "fts_test_ns_chunks";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("failed to create in-memory surreal");
db.apply_migrations()
.await
.expect("failed to apply migrations");
let user_id = "user_fts_chunk";
let chunk = TextChunk::new(
"source_chunk".into(),
"GraphQL documentation reference".into(),
dummy_embedding(),
user_id.into(),
);
db.store_item(chunk.clone())
.await
.expect("failed to insert chunk");
db.rebuild_indexes()
.await
.expect("failed to rebuild indexes");
let results =
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
.await
.expect("fts query failed");
assert!(!results.is_empty(), "expected at least one FTS result");
assert!(
results[0].scores.fts.is_some(),
"expected an FTS score when chunk field matched"
);
}
}

View File

@@ -1,267 +0,0 @@
pub mod answer_retrieval;
pub mod answer_retrieval_helper;
pub mod fts;
pub mod graph;
pub mod pipeline;
pub mod reranking;
pub mod scoring;
pub mod vector;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
},
};
use reranking::RerankerLease;
use tracing::instrument;
pub use pipeline::{retrieved_entities_to_json, RetrievalConfig, RetrievalTuning};
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
#[derive(Debug, Clone)]
pub struct RetrievedChunk {
pub chunk: TextChunk,
pub score: f32,
}
// Final entity representation returned to callers, enriched with ranked chunks.
#[derive(Debug, Clone)]
pub struct RetrievedEntity {
pub entity: KnowledgeEntity,
pub score: f32,
pub chunks: Vec<RetrievedChunk>,
}
// Primary orchestrator for the process of retrieving KnowledgeEntitities related to a input_text
#[instrument(skip_all, fields(user_id))]
pub async fn retrieve_entities(
db_client: &SurrealDbClient,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
pipeline::run_pipeline(
db_client,
openai_client,
input_text,
user_id,
RetrievalConfig::default(),
reranker,
)
.await
}
#[cfg(test)]
mod tests {
use super::*;
use async_openai::Client;
use common::storage::types::{
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship,
text_chunk::TextChunk,
};
use pipeline::RetrievalConfig;
use uuid::Uuid;
fn test_embedding() -> Vec<f32> {
vec![0.9, 0.1, 0.0]
}
fn entity_embedding_high() -> Vec<f32> {
vec![0.8, 0.2, 0.0]
}
fn entity_embedding_low() -> Vec<f32> {
vec![0.1, 0.9, 0.0]
}
fn chunk_embedding_primary() -> Vec<f32> {
vec![0.85, 0.15, 0.0]
}
fn chunk_embedding_secondary() -> Vec<f32> {
vec![0.2, 0.8, 0.0]
}
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db.query(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
COMMIT TRANSACTION;",
)
.await
.expect("Failed to configure indices");
db
}
#[tokio::test]
async fn test_retrieve_entities_with_embedding_basic_flow() {
let db = setup_test_db().await;
let user_id = "test_user";
let entity = KnowledgeEntity::new(
"source_1".into(),
"Rust async guide".into(),
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let chunk = TextChunk::new(
entity.source_id.clone(),
"Tokio uses cooperative scheduling for fairness.".into(),
chunk_embedding_primary(),
user_id.into(),
);
db.store_item(entity.clone())
.await
.expect("Failed to store entity");
db.store_item(chunk.clone())
.await
.expect("Failed to store chunk");
let openai_client = Client::new();
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
test_embedding(),
"Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
None,
)
.await
.expect("Hybrid retrieval failed");
assert!(
!results.is_empty(),
"Expected at least one retrieval result"
);
let top = &results[0];
assert!(
top.entity.name.contains("Rust"),
"Expected Rust entity to be ranked first"
);
assert!(
!top.chunks.is_empty(),
"Expected Rust entity to include supporting chunks"
);
}
#[tokio::test]
async fn test_graph_relationship_enriches_results() {
let db = setup_test_db().await;
let user_id = "graph_user";
let primary = KnowledgeEntity::new(
"primary_source".into(),
"Async Rust patterns".into(),
"Explores async runtimes and scheduling strategies.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let neighbor = KnowledgeEntity::new(
"neighbor_source".into(),
"Tokio Scheduler Deep Dive".into(),
"Details on Tokio's cooperative scheduler.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_low(),
user_id.into(),
);
db.store_item(primary.clone())
.await
.expect("Failed to store primary entity");
db.store_item(neighbor.clone())
.await
.expect("Failed to store neighbor entity");
let primary_chunk = TextChunk::new(
primary.source_id.clone(),
"Rust async tasks use Tokio's cooperative scheduler.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let neighbor_chunk = TextChunk::new(
neighbor.source_id.clone(),
"Tokio's scheduler manages task fairness across executors.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(primary_chunk)
.await
.expect("Failed to store primary chunk");
db.store_item(neighbor_chunk)
.await
.expect("Failed to store neighbor chunk");
let openai_client = Client::new();
let relationship = KnowledgeRelationship::new(
primary.id.clone(),
neighbor.id.clone(),
user_id.into(),
"relationship_source".into(),
"references".into(),
);
relationship
.store_relationship(&db)
.await
.expect("Failed to store relationship");
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
test_embedding(),
"Rust concurrency async tasks",
user_id,
RetrievalConfig::default(),
None,
)
.await
.expect("Hybrid retrieval failed");
let mut neighbor_entry = None;
for entity in &results {
if entity.entity.id == neighbor.id {
neighbor_entry = Some(entity.clone());
}
}
let neighbor_entry =
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
assert!(
neighbor_entry.score > 0.2,
"Graph-enriched entity should have a meaningful fused score"
);
assert!(
neighbor_entry
.chunks
.iter()
.all(|chunk| chunk.chunk.source_id == neighbor.source_id),
"Neighbor entity should surface its own supporting chunks"
);
}
}

View File

@@ -1,67 +0,0 @@
use serde::{Deserialize, Serialize};
/// Tunable parameters that govern each retrieval stage.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalTuning {
pub entity_vector_take: usize,
pub chunk_vector_take: usize,
pub entity_fts_take: usize,
pub chunk_fts_take: usize,
pub score_threshold: f32,
pub fallback_min_results: usize,
pub token_budget_estimate: usize,
pub avg_chars_per_token: usize,
pub max_chunks_per_entity: usize,
pub graph_traversal_seed_limit: usize,
pub graph_neighbor_limit: usize,
pub graph_score_decay: f32,
pub graph_seed_min_score: f32,
pub graph_vector_inheritance: f32,
pub rerank_blend_weight: f32,
pub rerank_scores_only: bool,
pub rerank_keep_top: usize,
}
impl Default for RetrievalTuning {
fn default() -> Self {
Self {
entity_vector_take: 15,
chunk_vector_take: 20,
entity_fts_take: 10,
chunk_fts_take: 20,
score_threshold: 0.35,
fallback_min_results: 10,
token_budget_estimate: 2800,
avg_chars_per_token: 4,
max_chunks_per_entity: 4,
graph_traversal_seed_limit: 5,
graph_neighbor_limit: 6,
graph_score_decay: 0.75,
graph_seed_min_score: 0.4,
graph_vector_inheritance: 0.6,
rerank_blend_weight: 0.65,
rerank_scores_only: false,
rerank_keep_top: 8,
}
}
}
/// Wrapper containing tuning plus future flags for per-request overrides.
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
pub tuning: RetrievalTuning,
}
impl RetrievalConfig {
pub fn new(tuning: RetrievalTuning) -> Self {
Self { tuning }
}
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
tuning: RetrievalTuning::default(),
}
}
}

View File

@@ -1,106 +0,0 @@
mod config;
mod stages;
mod state;
pub use config::{RetrievalConfig, RetrievalTuning};
use crate::{reranking::RerankerLease, RetrievedEntity};
use async_openai::Client;
use common::{error::AppError, storage::db::SurrealDbClient};
use tracing::info;
/// Drives the retrieval pipeline from embedding through final assembly.
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let input_chars = input_text.chars().count();
let input_preview: String = input_text.chars().take(120).collect();
let input_preview_clean = input_preview.replace('\n', " ");
let preview_len = input_preview_clean.chars().count();
info!(
%user_id,
input_chars,
preview_truncated = input_chars > preview_len,
preview = %input_preview_clean,
"Starting ingestion retrieval pipeline"
);
let mut ctx = stages::PipelineContext::new(
db_client,
openai_client,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let machine = stages::rerank(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
Ok(results)
}
#[cfg(test)]
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Result<Vec<RetrievedEntity>, AppError> {
let machine = state::ready();
let mut ctx = stages::PipelineContext::with_embedding(
db_client,
openai_client,
query_embedding,
input_text.to_owned(),
user_id.to_owned(),
config,
reranker,
);
let machine = stages::embed(machine, &mut ctx).await?;
let machine = stages::collect_candidates(machine, &mut ctx).await?;
let machine = stages::expand_graph(machine, &mut ctx).await?;
let machine = stages::attach_chunks(machine, &mut ctx).await?;
let machine = stages::rerank(machine, &mut ctx).await?;
let results = stages::assemble(machine, &mut ctx)?;
Ok(results)
}
/// Helper exposed for tests to convert retrieved entities into downstream prompt JSON.
pub fn retrieved_entities_to_json(entities: &[RetrievedEntity]) -> serde_json::Value {
serde_json::json!(entities
.iter()
.map(|entry| {
serde_json::json!({
"KnowledgeEntity": {
"id": entry.entity.id,
"name": entry.entity.name,
"description": entry.entity.description,
"score": round_score(entry.score),
"chunks": entry.chunks.iter().map(|chunk| {
serde_json::json!({
"score": round_score(chunk.score),
"content": chunk.chunk.chunk
})
}).collect::<Vec<_>>()
}
})
})
.collect::<Vec<_>>())
}
fn round_score(value: f32) -> f64 {
(f64::from(value) * 1000.0).round() / 1000.0
}

View File

@@ -1,769 +0,0 @@
use async_openai::Client;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
utils::embedding::generate_embedding,
};
use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt};
use state_machines::core::GuardError;
use std::collections::{HashMap, HashSet};
use tracing::{debug, instrument, warn};
use crate::{
fts::find_items_by_fts,
graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids},
reranking::RerankerLease,
scoring::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
},
vector::find_items_by_vector_similarity_with_embedding,
RetrievedChunk, RetrievedEntity,
};
use super::{
config::RetrievalConfig,
state::{
CandidatesLoaded, ChunksAttached, Embedded, GraphExpanded, HybridRetrievalMachine, Ready,
Reranked,
},
};
pub struct PipelineContext<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub input_text: String,
pub user_id: String,
pub config: RetrievalConfig,
pub query_embedding: Option<Vec<f32>>,
pub entity_candidates: HashMap<String, Scored<KnowledgeEntity>>,
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
}
impl<'a> PipelineContext<'a> {
pub fn new(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
input_text: String,
user_id: String,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Self {
Self {
db_client,
openai_client,
input_text,
user_id,
config,
query_embedding: None,
entity_candidates: HashMap::new(),
chunk_candidates: HashMap::new(),
filtered_entities: Vec::new(),
chunk_values: Vec::new(),
reranker,
}
}
#[cfg(test)]
pub fn with_embedding(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
query_embedding: Vec<f32>,
input_text: String,
user_id: String,
config: RetrievalConfig,
reranker: Option<RerankerLease>,
) -> Self {
let mut ctx = Self::new(
db_client,
openai_client,
input_text,
user_id,
config,
reranker,
);
ctx.query_embedding = Some(query_embedding);
ctx
}
fn ensure_embedding(&self) -> Result<&Vec<f32>, AppError> {
self.query_embedding.as_ref().ok_or_else(|| {
AppError::InternalError(
"query embedding missing before candidate collection".to_string(),
)
})
}
}
#[instrument(level = "trace", skip_all)]
pub async fn embed(
machine: HybridRetrievalMachine<(), Ready>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), Embedded>, AppError> {
let embedding_cached = ctx.query_embedding.is_some();
if embedding_cached {
debug!("Reusing cached query embedding for hybrid retrieval");
} else {
debug!("Generating query embedding for hybrid retrieval");
let embedding =
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?;
ctx.query_embedding = Some(embedding);
}
machine
.embed()
.map_err(|(_, guard)| map_guard_error("embed", guard))
}
#[instrument(level = "trace", skip_all)]
pub async fn collect_candidates(
machine: HybridRetrievalMachine<(), Embedded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), CandidatesLoaded>, AppError> {
debug!("Collecting initial candidates via vector and FTS search");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
find_items_by_vector_similarity_with_embedding(
tuning.entity_vector_take,
embedding.clone(),
ctx.db_client,
"knowledge_entity",
&ctx.user_id,
),
find_items_by_vector_similarity_with_embedding(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
"text_chunk",
&ctx.user_id,
),
find_items_by_fts(
tuning.entity_fts_take,
&ctx.input_text,
ctx.db_client,
"knowledge_entity",
&ctx.user_id,
),
find_items_by_fts(
tuning.chunk_fts_take,
&ctx.input_text,
ctx.db_client,
"text_chunk",
&ctx.user_id
),
)?;
debug!(
vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
fts_entities = fts_entities.len(),
fts_chunks = fts_chunks.len(),
"Hybrid retrieval initial candidate counts"
);
normalize_fts_scores(&mut fts_entities);
normalize_fts_scores(&mut fts_chunks);
merge_scored_by_id(&mut ctx.entity_candidates, vector_entities);
merge_scored_by_id(&mut ctx.entity_candidates, fts_entities);
merge_scored_by_id(&mut ctx.chunk_candidates, vector_chunks);
merge_scored_by_id(&mut ctx.chunk_candidates, fts_chunks);
apply_fusion(&mut ctx.entity_candidates, weights);
apply_fusion(&mut ctx.chunk_candidates, weights);
machine
.collect_candidates()
.map_err(|(_, guard)| map_guard_error("collect_candidates", guard))
}
#[instrument(level = "trace", skip_all)]
pub async fn expand_graph(
machine: HybridRetrievalMachine<(), CandidatesLoaded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), GraphExpanded>, AppError> {
debug!("Expanding candidates using graph relationships");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
if ctx.entity_candidates.is_empty() {
return machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
}
let graph_seeds = seeds_from_candidates(
&ctx.entity_candidates,
tuning.graph_seed_min_score,
tuning.graph_traversal_seed_limit,
);
if graph_seeds.is_empty() {
return machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard));
}
let mut futures = FuturesUnordered::new();
for seed in graph_seeds {
let db = ctx.db_client;
let user = ctx.user_id.clone();
futures.push(async move {
let neighbors = find_entities_by_relationship_by_id(
db,
&seed.id,
&user,
tuning.graph_neighbor_limit,
)
.await;
(seed, neighbors)
});
}
while let Some((seed, neighbors_result)) = futures.next().await {
let neighbors = neighbors_result.map_err(AppError::from)?;
if neighbors.is_empty() {
continue;
}
for neighbor in neighbors {
if neighbor.id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
let inherited_vector = clamp_unit(graph_score * tuning.graph_vector_inheritance);
let vector_existing = entry.scores.vector.unwrap_or(0.0);
if inherited_vector > vector_existing {
entry.scores.vector = Some(inherited_vector);
}
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
if graph_score > existing_graph || entry.scores.graph.is_none() {
entry.scores.graph = Some(graph_score);
}
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
}
}
machine
.expand_graph()
.map_err(|(_, guard)| map_guard_error("expand_graph", guard))
}
#[instrument(level = "trace", skip_all)]
pub async fn attach_chunks(
machine: HybridRetrievalMachine<(), GraphExpanded>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), ChunksAttached>, AppError> {
debug!("Attaching chunks to surviving entities");
let tuning = &ctx.config.tuning;
let weights = FusionWeights::default();
let chunk_by_source = group_chunks_by_source(&ctx.chunk_candidates);
backfill_entities_from_chunks(
&mut ctx.entity_candidates,
&chunk_by_source,
ctx.db_client,
&ctx.user_id,
weights,
)
.await?;
boost_entities_with_chunks(&mut ctx.entity_candidates, &chunk_by_source, weights);
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
ctx.entity_candidates.values().cloned().collect();
sort_by_fused_desc(&mut entity_results);
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
.iter()
.filter(|candidate| candidate.fused >= tuning.score_threshold)
.cloned()
.collect();
if filtered_entities.len() < tuning.fallback_min_results {
filtered_entities = entity_results
.into_iter()
.take(tuning.fallback_min_results)
.collect();
}
ctx.filtered_entities = filtered_entities;
let mut chunk_results: Vec<Scored<TextChunk>> =
ctx.chunk_candidates.values().cloned().collect();
sort_by_fused_desc(&mut chunk_results);
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
for chunk in chunk_results {
chunk_by_id.insert(chunk.item.id.clone(), chunk);
}
enrich_chunks_from_entities(
&mut chunk_by_id,
&ctx.filtered_entities,
ctx.db_client,
&ctx.user_id,
weights,
)
.await?;
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
sort_by_fused_desc(&mut chunk_values);
ctx.chunk_values = chunk_values;
machine
.attach_chunks()
.map_err(|(_, guard)| map_guard_error("attach_chunks", guard))
}
#[instrument(level = "trace", skip_all)]
pub async fn rerank(
machine: HybridRetrievalMachine<(), ChunksAttached>,
ctx: &mut PipelineContext<'_>,
) -> Result<HybridRetrievalMachine<(), Reranked>, AppError> {
let mut applied = false;
if let Some(reranker) = ctx.reranker.as_ref() {
if ctx.filtered_entities.len() > 1 {
let documents = build_rerank_documents(ctx, ctx.config.tuning.max_chunks_per_entity);
if documents.len() > 1 {
match reranker.rerank(&ctx.input_text, documents).await {
Ok(results) if !results.is_empty() => {
apply_rerank_results(ctx, results);
applied = true;
}
Ok(_) => {
debug!("Reranker returned no results; retaining original ordering");
}
Err(err) => {
warn!(
error = %err,
"Reranking failed; continuing with original ordering"
);
}
}
} else {
debug!(
document_count = documents.len(),
"Skipping reranking stage; insufficient document context"
);
}
} else {
debug!("Skipping reranking stage; less than two entities available");
}
} else {
debug!("No reranker lease provided; skipping reranking stage");
}
if applied {
debug!("Applied reranking adjustments to candidate ordering");
}
machine
.rerank()
.map_err(|(_, guard)| map_guard_error("rerank", guard))
}
#[instrument(level = "trace", skip_all)]
pub fn assemble(
machine: HybridRetrievalMachine<(), Reranked>,
ctx: &mut PipelineContext<'_>,
) -> Result<Vec<RetrievedEntity>, AppError> {
debug!("Assembling final retrieved entities");
let tuning = &ctx.config.tuning;
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in ctx.chunk_values.drain(..) {
chunk_by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk);
}
for chunk_list in chunk_by_source.values_mut() {
sort_by_fused_desc(chunk_list);
}
let mut token_budget_remaining = tuning.token_budget_estimate;
let mut results = Vec::new();
for entity in &ctx.filtered_entities {
let mut selected_chunks = Vec::new();
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
let mut per_entity_count = 0;
candidates.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
for candidate in candidates.iter() {
if per_entity_count >= tuning.max_chunks_per_entity {
break;
}
let estimated_tokens =
estimate_tokens(&candidate.item.chunk, tuning.avg_chars_per_token);
if estimated_tokens > token_budget_remaining {
continue;
}
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
per_entity_count += 1;
selected_chunks.push(RetrievedChunk {
chunk: candidate.item.clone(),
score: candidate.fused,
});
}
}
results.push(RetrievedEntity {
entity: entity.item.clone(),
score: entity.fused,
chunks: selected_chunks,
});
if token_budget_remaining == 0 {
break;
}
}
machine
.assemble()
.map_err(|(_, guard)| map_guard_error("assemble", guard))?;
Ok(results)
}
fn map_guard_error(stage: &'static str, err: GuardError) -> AppError {
AppError::InternalError(format!(
"state machine guard '{stage}' failed: guard={}, event={}, kind={:?}",
err.guard, err.event, err.kind
))
}
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
let raw_scores: Vec<f32> = results
.iter()
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
.collect();
let normalized = min_max_normalize(&raw_scores);
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
candidate.scores.fts = Some(normalized_score);
candidate.update_fused(0.0);
}
}
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
where
T: StoredObject,
{
for candidate in candidates.values_mut() {
let fused = fuse_scores(&candidate.scores, weights);
candidate.update_fused(fused);
}
}
fn group_chunks_by_source(
chunks: &HashMap<String, Scored<TextChunk>>,
) -> HashMap<String, Vec<Scored<TextChunk>>> {
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
for chunk in chunks.values() {
by_source
.entry(chunk.item.source_id.clone())
.or_default()
.push(chunk.clone());
}
by_source
}
async fn backfill_entities_from_chunks(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
let mut missing_sources = Vec::new();
for source_id in chunk_by_source.keys() {
if !entity_candidates
.values()
.any(|entity| entity.item.source_id == *source_id)
{
missing_sources.push(source_id.clone());
}
}
if missing_sources.is_empty() {
return Ok(());
}
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
missing_sources.clone(),
"knowledge_entity",
user_id,
db_client,
)
.await
.unwrap_or_default();
if related_entities.is_empty() {
warn!("expected related entities for missing chunk sources, but none were found");
}
for entity in related_entities {
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
entity_candidates.insert(entity.id.clone(), scored);
}
}
Ok(())
}
fn boost_entities_with_chunks(
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
chunk_by_source: &HashMap<String, Vec<Scored<TextChunk>>>,
weights: FusionWeights,
) {
for entity in entity_candidates.values_mut() {
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
let best_chunk_score = chunks
.iter()
.map(|chunk| chunk.fused)
.fold(0.0f32, f32::max);
if best_chunk_score > 0.0 {
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
entity.scores.vector = Some(boosted);
let fused = fuse_scores(&entity.scores, weights);
entity.update_fused(fused);
}
}
}
}
async fn enrich_chunks_from_entities(
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
entities: &[Scored<KnowledgeEntity>],
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
) -> Result<(), AppError> {
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
source_ids.insert(entity.item.source_id.clone());
}
if source_ids.is_empty() {
return Ok(());
}
let chunks = find_entities_by_source_ids::<TextChunk>(
source_ids.into_iter().collect(),
"text_chunk",
user_id,
db_client,
)
.await?;
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
for entity in entities {
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
}
for chunk in chunks {
let entry = chunk_candidates
.entry(chunk.id.clone())
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
let entity_score = entity_score_lookup
.get(&chunk.source_id)
.copied()
.unwrap_or(0.0);
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
}
Ok(())
}
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
if ctx.filtered_entities.is_empty() {
return Vec::new();
}
let mut chunk_by_source: HashMap<&str, Vec<&Scored<TextChunk>>> = HashMap::new();
for chunk in &ctx.chunk_values {
chunk_by_source
.entry(chunk.item.source_id.as_str())
.or_default()
.push(chunk);
}
ctx.filtered_entities
.iter()
.map(|entity| {
let mut doc = format!(
"Name: {}\nType: {:?}\nDescription: {}\n",
entity.item.name, entity.item.entity_type, entity.item.description
);
if let Some(chunks) = chunk_by_source.get(entity.item.source_id.as_str()) {
let mut chunk_refs = chunks.clone();
chunk_refs.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut header_added = false;
for chunk in chunk_refs.into_iter().take(max_chunks_per_entity.max(1)) {
let snippet = chunk.item.chunk.trim();
if snippet.is_empty() {
continue;
}
if !header_added {
doc.push_str("Chunks:\n");
header_added = true;
}
doc.push_str("- ");
doc.push_str(snippet);
doc.push('\n');
}
}
doc
})
.collect()
}
fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult>) {
if results.is_empty() || ctx.filtered_entities.is_empty() {
return;
}
let mut remaining: Vec<Option<Scored<KnowledgeEntity>>> =
std::mem::take(&mut ctx.filtered_entities)
.into_iter()
.map(Some)
.collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
let normalized_scores = min_max_normalize(&raw_scores);
let use_only = ctx.config.tuning.rerank_scores_only;
let blend = if use_only {
1.0
} else {
clamp_unit(ctx.config.tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<KnowledgeEntity>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
let original = candidate.fused;
let blended = if use_only {
clamp_unit(normalized)
} else {
clamp_unit(original * (1.0 - blend) + normalized * blend)
};
candidate.update_fused(blended);
reranked.push(candidate);
}
} else {
warn!(
result_index = result.index,
"Reranker returned out-of-range index; skipping"
);
}
if reranked.len() == remaining.len() {
break;
}
}
for slot in remaining.into_iter() {
if let Some(candidate) = slot {
reranked.push(candidate);
}
}
ctx.filtered_entities = reranked;
let keep_top = ctx.config.tuning.rerank_keep_top;
if keep_top > 0 && ctx.filtered_entities.len() > keep_top {
ctx.filtered_entities.truncate(keep_top);
}
}
fn estimate_tokens(text: &str, avg_chars_per_token: usize) -> usize {
let chars = text.chars().count().max(1);
(chars / avg_chars_per_token).max(1)
}
#[derive(Clone)]
struct GraphSeed {
id: String,
fused: f32,
}
fn seeds_from_candidates(
entity_candidates: &HashMap<String, Scored<KnowledgeEntity>>,
min_score: f32,
limit: usize,
) -> Vec<GraphSeed> {
let mut seeds: Vec<GraphSeed> = entity_candidates
.values()
.filter(|entity| entity.fused >= min_score)
.map(|entity| GraphSeed {
id: entity.item.id.clone(),
fused: entity.fused,
})
.collect();
seeds.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(std::cmp::Ordering::Equal)
});
if seeds.len() > limit {
seeds.truncate(limit);
}
seeds
}

View File

@@ -1,27 +0,0 @@
use state_machines::state_machine;
state_machine! {
name: HybridRetrievalMachine,
state: HybridRetrievalState,
initial: Ready,
states: [Ready, Embedded, CandidatesLoaded, GraphExpanded, ChunksAttached, Reranked, Completed, Failed],
events {
embed { transition: { from: Ready, to: Embedded } }
collect_candidates { transition: { from: Embedded, to: CandidatesLoaded } }
expand_graph { transition: { from: CandidatesLoaded, to: GraphExpanded } }
attach_chunks { transition: { from: GraphExpanded, to: ChunksAttached } }
rerank { transition: { from: ChunksAttached, to: Reranked } }
assemble { transition: { from: Reranked, to: Completed } }
abort {
transition: { from: Ready, to: Failed }
transition: { from: CandidatesLoaded, to: Failed }
transition: { from: GraphExpanded, to: Failed }
transition: { from: ChunksAttached, to: Failed }
transition: { from: Reranked, to: Failed }
}
}
}
pub fn ready() -> HybridRetrievalMachine<(), Ready> {
HybridRetrievalMachine::new(())
}

View File

@@ -1,183 +0,0 @@
use std::cmp::Ordering;
use common::storage::types::StoredObject;
/// Holds optional subscores gathered from different retrieval signals.
#[derive(Debug, Clone, Copy, Default)]
pub struct Scores {
pub fts: Option<f32>,
pub vector: Option<f32>,
pub graph: Option<f32>,
}
/// Generic wrapper combining an item with its accumulated retrieval scores.
#[derive(Debug, Clone)]
pub struct Scored<T> {
pub item: T,
pub scores: Scores,
pub fused: f32,
}
impl<T> Scored<T> {
pub fn new(item: T) -> Self {
Self {
item,
scores: Scores::default(),
fused: 0.0,
}
}
pub const fn with_vector_score(mut self, score: f32) -> Self {
self.scores.vector = Some(score);
self
}
pub const fn with_fts_score(mut self, score: f32) -> Self {
self.scores.fts = Some(score);
self
}
pub const fn with_graph_score(mut self, score: f32) -> Self {
self.scores.graph = Some(score);
self
}
pub const fn update_fused(&mut self, fused: f32) {
self.fused = fused;
}
}
/// Weights used for linear score fusion.
#[derive(Debug, Clone, Copy)]
pub struct FusionWeights {
pub vector: f32,
pub fts: f32,
pub graph: f32,
pub multi_bonus: f32,
}
impl Default for FusionWeights {
fn default() -> Self {
Self {
vector: 0.5,
fts: 0.3,
graph: 0.2,
multi_bonus: 0.02,
}
}
}
pub const fn clamp_unit(value: f32) -> f32 {
value.clamp(0.0, 1.0)
}
pub fn distance_to_similarity(distance: f32) -> f32 {
if !distance.is_finite() {
return 0.0;
}
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
}
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return Vec::new();
}
let mut min = f32::MAX;
let mut max = f32::MIN;
for s in scores {
if !s.is_finite() {
continue;
}
if *s < min {
min = *s;
}
if *s > max {
max = *s;
}
}
if !min.is_finite() || !max.is_finite() {
return scores.iter().map(|_| 0.0).collect();
}
if (max - min).abs() < f32::EPSILON {
return vec![1.0; scores.len()];
}
scores
.iter()
.map(|score| {
if score.is_finite() {
clamp_unit((score - min) / (max - min))
} else {
0.0
}
})
.collect()
}
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
let vector = scores.vector.unwrap_or(0.0);
let fts = scores.fts.unwrap_or(0.0);
let graph = scores.graph.unwrap_or(0.0);
let mut fused = graph.mul_add(
weights.graph,
vector.mul_add(weights.vector, fts * weights.fts),
);
let signals_present = scores
.vector
.iter()
.chain(scores.fts.iter())
.chain(scores.graph.iter())
.count();
if signals_present >= 2 {
fused += weights.multi_bonus;
}
clamp_unit(fused)
}
pub fn merge_scored_by_id<T>(
target: &mut std::collections::HashMap<String, Scored<T>>,
incoming: Vec<Scored<T>>,
) where
T: StoredObject + Clone,
{
for scored in incoming {
let id = scored.item.get_id().to_owned();
target
.entry(id)
.and_modify(|existing| {
if let Some(score) = scored.scores.vector {
existing.scores.vector = Some(score);
}
if let Some(score) = scored.scores.fts {
existing.scores.fts = Some(score);
}
if let Some(score) = scored.scores.graph {
existing.scores.graph = Some(score);
}
})
.or_insert_with(|| Scored {
item: scored.item.clone(),
scores: scored.scores,
fused: scored.fused,
});
}
}
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
where
T: StoredObject,
{
items.sort_by(|a, b| {
b.fused
.partial_cmp(&a.fused)
.unwrap_or(Ordering::Equal)
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
});
}

View File

@@ -1,157 +0,0 @@
use std::collections::HashMap;
use common::storage::types::file_info::deserialize_flexible_id;
use common::{
error::AppError,
storage::{db::SurrealDbClient, types::StoredObject},
utils::embedding::generate_embedding,
};
use serde::Deserialize;
use surrealdb::sql::Thing;
use crate::scoring::{clamp_unit, distance_to_similarity, Scored};
/// Compares vectors and retrieves a number of items from the specified table.
///
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
/// and then deserializes the results into the specified type `T`.
///
/// # Arguments
///
/// * `take` - The number of items to retrieve from the database.
/// * `input_text` - The text to generate embeddings for.
/// * `db_client` - The SurrealDB client to use for querying the database.
/// * `table` - The table to query in the database.
/// * `openai_client` - The OpenAI client to use for generating embeddings.
/// * 'user_id`- The user id of the current user.
///
/// # Returns
///
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: usize,
input_text: &str,
db_client: &SurrealDbClient,
table: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
.await
}
#[derive(Debug, Deserialize)]
struct DistanceRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
distance: Option<f32>,
}
pub async fn find_items_by_vector_similarity_with_embedding<T>(
take: usize,
query_embedding: Vec<f32>,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, vector::distance::knn() AS distance \
FROM {table} \
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
LIMIT $limit",
table = table,
take = take,
embedding = embedding_literal
);
let mut response = db_client
.query(closest_query)
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let distance_rows: Vec<DistanceRow> = response.take(0)?;
if distance_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut min_distance = f32::MAX;
let mut max_distance = f32::MIN;
for row in &distance_rows {
if let Some(distance) = row.distance {
if distance.is_finite() {
if distance < min_distance {
min_distance = distance;
}
if distance > max_distance {
max_distance = distance;
}
}
}
}
let normalize = min_distance.is_finite()
&& max_distance.is_finite()
&& (max_distance - min_distance).abs() > f32::EPSILON;
let mut scored = Vec::with_capacity(distance_rows.len());
for row in distance_rows {
if let Some(item) = item_map.remove(&row.id) {
let similarity = row
.distance
.map(|distance| {
if normalize {
let span = max_distance - min_distance;
if span.abs() < f32::EPSILON {
1.0
} else {
let normalized = 1.0 - ((distance - min_distance) / span);
clamp_unit(normalized)
}
} else {
distance_to_similarity(distance)
}
})
.unwrap_or_default();
scored.push(Scored::new(item).with_vector_score(similarity));
}
}
Ok(scored)
}

74
docs/architecture.md Normal file
View File

@@ -0,0 +1,74 @@
# Architecture
## Tech Stack
| Layer | Technology |
|-------|------------|
| Backend | Rust with Axum (SSR) |
| Frontend | HTML + HTMX + minimal JS |
| Database | SurrealDB (graph, document, vector) |
| AI | OpenAI-compatible API |
| Web Processing | Headless Chromium |
## Crate Structure
```
minne/
├── main/ # Combined server + worker binary
├── api-router/ # REST API routes
├── html-router/ # SSR web interface
├── ingestion-pipeline/ # Content processing pipeline
├── retrieval-pipeline/ # Search and retrieval logic
├── common/ # Shared types, storage, utilities
├── evaluations/ # Benchmarking framework
└── json-stream-parser/ # Streaming JSON utilities
```
## Process Modes
| Binary | Purpose |
|--------|---------|
| `main` | All-in-one: serves UI and processes content |
| `server` | UI and API only (no background processing) |
| `worker` | Background processing only (no UI) |
Split deployment is useful for scaling or resource isolation.
## Data Flow
```
Content In → Ingestion Pipeline → SurrealDB
Entity Extraction
Embedding Generation
Graph Relationships
Query → Retrieval Pipeline → Results
Vector Search + FTS
RRF Fusion → (Optional Rerank) → Response
```
## Database Schema
SurrealDB stores:
- **TextContent** — Raw ingested content
- **TextChunk** — Chunked content with embeddings
- **KnowledgeEntity** — Extracted entities (people, concepts, etc.)
- **KnowledgeRelationship** — Connections between entities
- **User** — Authentication and preferences
- **SystemSettings** — Model configuration
Embeddings are stored in dedicated tables with HNSW indexes for fast vector search.
## Retrieval Strategy
1. **Collect candidates** — Vector similarity + full-text search
2. **Merge ranks** — Reciprocal Rank Fusion (RRF)
3. **Attach context** — Link chunks to parent entities
4. **Rerank** (optional) — Cross-encoder reranking
5. **Return** — Top-k results with metadata

89
docs/configuration.md Normal file
View File

@@ -0,0 +1,89 @@
# Configuration
Minne can be configured via environment variables or a `config.yaml` file. Environment variables take precedence.
## Required Settings
| Variable | Description | Example |
|----------|-------------|---------|
| `OPENAI_API_KEY` | API key for OpenAI-compatible endpoint | `sk-...` |
| `SURREALDB_ADDRESS` | WebSocket address of SurrealDB | `ws://127.0.0.1:8000` |
| `SURREALDB_USERNAME` | SurrealDB username | `root_user` |
| `SURREALDB_PASSWORD` | SurrealDB password | `root_password` |
| `SURREALDB_DATABASE` | Database name | `minne_db` |
| `SURREALDB_NAMESPACE` | Namespace | `minne_ns` |
## Optional Settings
| Variable | Description | Default |
|----------|-------------|---------|
| `HTTP_PORT` | Server port | `3000` |
| `DATA_DIR` | Local data directory | `./data` |
| `OPENAI_BASE_URL` | Custom AI provider URL | OpenAI default |
| `RUST_LOG` | Logging level | `info` |
| `STORAGE` | Storage backend (`local`, `memory`) | `local` |
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - |
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
| `FASTEMBED_MAX_LENGTH` | Max sequence length for FastEmbed models | - |
### Reranking (Optional)
| Variable | Description | Default |
|----------|-------------|---------|
| `RERANKING_ENABLED` | Enable FastEmbed reranking | `false` |
| `RERANKING_POOL_SIZE` | Concurrent reranker workers | - |
> [!NOTE]
> Enabling reranking downloads ~1.1 GB of model data on first startup.
## Example config.yaml
```yaml
surrealdb_address: "ws://127.0.0.1:8000"
surrealdb_username: "root_user"
surrealdb_password: "root_password"
surrealdb_database: "minne_db"
surrealdb_namespace: "minne_ns"
openai_api_key: "sk-your-key-here"
data_dir: "./minne_data"
http_port: 3000
# New settings
storage: "local"
pdf_ingest_mode: "llm-first"
embedding_backend: "fastembed"
# Optional reranking
reranking_enabled: true
reranking_pool_size: 2
```
## AI Provider Setup
Minne works with any OpenAI-compatible API that supports structured outputs.
### OpenAI (Default)
Set `OPENAI_API_KEY` only. The default base URL points to OpenAI.
### Ollama
```bash
OPENAI_API_KEY="ollama"
OPENAI_BASE_URL="http://localhost:11434/v1"
```
### Other Providers
Any provider exposing an OpenAI-compatible endpoint works. Set `OPENAI_BASE_URL` accordingly.
## Model Selection
1. Access `/admin` in your Minne instance
2. Select models for content processing and chat
3. **Content Processing**: Must support structured outputs
4. **Embedding Dimensions**: Update when changing embedding models (e.g., 1536 for `text-embedding-3-small`)

64
docs/features.md Normal file
View File

@@ -0,0 +1,64 @@
# Features
## Search vs Chat
**Search** — Use when you know what you're looking for. Full-text search matches query terms across your content.
**Chat** — Use when exploring concepts or reasoning about your knowledge. The AI analyzes your query and retrieves relevant context from your entire knowledge base.
## Content Processing
Minne automatically processes saved content:
1. **Web scraping** extracts readable text from URLs (via headless Chrome)
2. **Text analysis** identifies key concepts and relationships
3. **Graph creation** builds connections between related content
4. **Embedding generation** enables semantic search
## Knowledge Graph
Explore your knowledge as an interactive network:
- **Manual curation** — Create entities and relationships yourself
- **AI automation** — Let AI extract entities and discover relationships
- **Hybrid approach** — AI suggests connections for your approval
The D3-based graph visualization shows entities as nodes and relationships as edges.
## Hybrid Retrieval
Minne combines multiple retrieval strategies:
- **Vector similarity** — Semantic matching via embeddings
- **Full-text search** — Keyword matching with BM25
- **Graph traversal** — Following relationships between entities
Results are merged using Reciprocal Rank Fusion (RRF) for optimal relevance.
## Reranking (Optional)
When enabled, retrieval results are rescored with a cross-encoder model for improved relevance. Powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs).
**Trade-offs:**
- Downloads ~1.1 GB of model data
- Adds latency per query
- Potentially improves answer quality, see [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/)
Enable via `RERANKING_ENABLED=true`. See [Configuration](./configuration.md).
## Multi-Format Ingestion
Supported content types:
- Plain text and notes
- URLs (web pages)
- PDF documents
- Audio files
- Images
## Scratchpad
Quickly capture content without committing to permanent storage. Convert to full content when ready.
## iOS Shortcut
Use the [Minne iOS Shortcut](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) for quick content capture from your phone.

67
docs/installation.md Normal file
View File

@@ -0,0 +1,67 @@
# Installation
Minne can be installed through several methods. Choose the one that best fits your setup.
## Docker Compose (Recommended)
The fastest way to get Minne running with all dependencies:
```bash
git clone https://github.com/perstarkse/minne.git
cd minne
docker compose up -d
```
The included `docker-compose.yml` handles SurrealDB and Chromium automatically.
**Required:** Set your `OPENAI_API_KEY` in `docker-compose.yml` before starting.
## Nix
Run Minne directly with Nix (includes Chromium):
```bash
nix run 'github:perstarkse/minne#main'
```
Configure via environment variables or a `config.yaml` file. See [Configuration](./configuration.md).
## Pre-built Binaries
Download binaries for Windows, macOS, and Linux from [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
**Requirements:**
- SurrealDB instance (local or remote)
- Chromium (for web scraping)
## Build from Source
```bash
git clone https://github.com/perstarkse/minne.git
cd minne
cargo build --release --bin main
```
The binary will be at `target/release/main`.
**Requirements:**
- Rust toolchain
- SurrealDB accessible at configured address
- Chromium in PATH
## Process Modes
Minne offers flexible deployment:
| Binary | Description |
|--------|-------------|
| `main` | Combined server + worker (recommended) |
| `server` | Web interface and API only |
| `worker` | Background processing only |
For most users, `main` is the right choice. Split deployments are useful for resource optimization or scaling.
## Next Steps
- [Configuration](./configuration.md) — Environment variables and config.yaml
- [Features](./features.md) — What Minne can do

48
docs/vision.md Normal file
View File

@@ -0,0 +1,48 @@
# Vision
## The "Why" Behind Minne
Personal knowledge management has always fascinated me. I wanted something that made it incredibly easy to capture content—snippets of text, URLs, media—while automatically discovering connections between ideas. But I also wanted control over my knowledge structure.
Traditional tools like Logseq and Obsidian are excellent, but manual linking often becomes a hindrance. Fully automated systems sometimes miss important context or create relationships I wouldn't have chosen.
Minne offers the best of both worlds: effortless capture with AI-assisted relationship discovery, but with flexibility to manually curate, edit, or override connections. Let AI handle the heavy lifting, take full control yourself, or use a hybrid approach where AI suggests and you approve.
## Design Principles
- **Capture should be instant** — No friction between thought and storage
- **Connections should emerge** — AI finds relationships you might miss
- **Control should be optional** — Automate by default, curate when it matters
- **Privacy should be default** — Self-hosted, your data stays yours
## Roadmap
### Near-term
- [ ] TUI frontend with system editor integration
- [ ] Enhanced retrieval recall via improved reranking
- [ ] Additional content type support (e-books, research papers)
### Medium-term
- [ ] Embedded SurrealDB option (zero-config `nix run` with just `OPENAI_API_KEY`)
- [ ] Browser extension for seamless capture
- [ ] Mobile-native apps
### Long-term
- [ ] Federated knowledge sharing (opt-in)
- [ ] Local LLM integration (fully offline operation)
- [ ] Plugin system for custom entity extractors
## Related Projects
If Minne isn't quite right for you, check out:
- [Karakeep](https://github.com/karakeep-app/karakeep) (formerly Hoarder) — Excellent bookmark/read-later with AI tagging
- [Logseq](https://logseq.com/) — Outliner-based PKM with manual linking
- [Obsidian](https://obsidian.md/) — Markdown-based PKM with plugin ecosystem
## Contributing
Feature requests and contributions are welcome. Minne was built for personal use first, but the self-hosted community benefits when we share.

35
evaluations/Cargo.toml Normal file
View File

@@ -0,0 +1,35 @@
[package]
name = "evaluations"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
async-openai = { workspace = true }
chrono = { workspace = true }
common = { path = "../common" }
retrieval-pipeline = { path = "../retrieval-pipeline" }
ingestion-pipeline = { path = "../ingestion-pipeline" }
futures = { workspace = true }
fastembed = { workspace = true }
serde = { workspace = true, features = ["derive"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { workspace = true }
text-splitter = { workspace = true }
unicode-normalization = { workspace = true }
rand = "0.8"
sha2 = { workspace = true }
object_store = { workspace = true }
surrealdb = { workspace = true }
serde_json = { workspace = true }
async-trait = { workspace = true }
once_cell = "1.19"
serde_yaml = "0.9"
criterion = "0.5"
state-machines = { workspace = true }
clap = { version = "4.4", features = ["derive", "env"] }
[dev-dependencies]
tempfile = { workspace = true }

212
evaluations/README.md Normal file
View File

@@ -0,0 +1,212 @@
# Evaluations
The `evaluations` crate provides a retrieval evaluation framework for benchmarking Minne's information retrieval pipeline against standard datasets.
## Quick Start
```bash
# Run SQuAD v2.0 evaluation (vector-only, recommended)
cargo run --package evaluations -- --ingest-chunks-only
# Run a specific dataset
cargo run --package evaluations -- --dataset fiqa --ingest-chunks-only
# Convert dataset only (no evaluation)
cargo run --package evaluations -- --convert-only
```
## Prerequisites
### 1. SurrealDB
Start a SurrealDB instance before running evaluations:
```bash
docker-compose up -d surrealdb
```
Or using the default endpoint configuration:
```bash
surreal start --user root_user --pass root_password
```
### 2. Download Raw Datasets
Raw datasets must be downloaded manually and placed in `evaluations/data/raw/`. See [Dataset Sources](#dataset-sources) below for links and formats.
## Directory Structure
```
evaluations/
├── data/
│ ├── raw/ # Downloaded raw datasets (manual)
│ │ ├── squad/ # SQuAD v2.0
│ │ ├── nq-dev/ # Natural Questions
│ │ ├── fiqa/ # BEIR: FiQA-2018
│ │ ├── fever/ # BEIR: FEVER
│ │ ├── hotpotqa/ # BEIR: HotpotQA
│ │ └── ... # Other BEIR subsets
│ └── converted/ # Auto-generated (Minne JSON format)
├── cache/ # Ingestion and embedding caches
├── reports/ # Evaluation output (JSON + Markdown)
├── manifest.yaml # Dataset and slice definitions
└── src/ # Evaluation source code
```
## Dataset Sources
### SQuAD v2.0
Download and place at `data/raw/squad/dev-v2.0.json`:
```bash
mkdir -p evaluations/data/raw/squad
curl -L https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json \
-o evaluations/data/raw/squad/dev-v2.0.json
```
### Natural Questions (NQ)
Download and place at `data/raw/nq-dev/dev-all.jsonl`:
```bash
mkdir -p evaluations/data/raw/nq-dev
# Download from Google's Natural Questions page or HuggingFace
# File: dev-all.jsonl (simplified JSONL format)
```
Source: [Google Natural Questions](https://ai.google.com/research/NaturalQuestions)
### BEIR Datasets
All BEIR datasets follow the same format structure:
```
data/raw/<dataset>/
├── corpus.jsonl # Document corpus
├── queries.jsonl # Query set
└── qrels/
└── test.tsv # Relevance judgments (or dev.tsv)
```
Download datasets from the [BEIR Benchmark repository](https://github.com/beir-cellar/beir). Each dataset zip extracts to the required directory structure.
| Dataset | Directory |
|------------|---------------|
| FEVER | `fever/` |
| FiQA-2018 | `fiqa/` |
| HotpotQA | `hotpotqa/` |
| NFCorpus | `nfcorpus/` |
| Quora | `quora/` |
| TREC-COVID | `trec-covid/` |
| SciFact | `scifact/` |
| NQ (BEIR) | `nq/` |
Example download:
```bash
cd evaluations/data/raw
curl -L https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip -o fiqa.zip
unzip fiqa.zip && rm fiqa.zip
```
## Dataset Conversion
Raw datasets are automatically converted to Minne's internal JSON format on first run. To force reconversion:
```bash
cargo run --package evaluations -- --force-convert
```
Converted files are saved to `data/converted/` and cached for subsequent runs.
## CLI Reference
### Common Options
| Flag | Description | Default |
|------|-------------|---------|
| `--dataset <NAME>` | Dataset to evaluate | `squad-v2` |
| `--limit <N>` | Max questions to evaluate (0 = all) | `200` |
| `--k <N>` | Precision@k cutoff | `5` |
| `--slice <ID>` | Use a predefined slice from manifest | — |
| `--rerank` | Enable FastEmbed reranking stage | disabled |
| `--embedding-backend <BE>` | `fastembed` or `hashed` | `fastembed` |
| `--ingest-chunks-only` | Skip entity extraction, ingest only text chunks | disabled |
> [!TIP]
> Use `--ingest-chunks-only` when evaluating vector-only retrieval strategies. This skips the LLM-based entity extraction and graph generation, significantly speeding up ingestion while focusing on pure chunk-based vector search.
### Available Datasets
```
squad-v2, natural-questions, beir, fever, fiqa, hotpotqa,
nfcorpus, quora, trec-covid, scifact, nq-beir
```
### Database Configuration
| Flag | Environment | Default |
|------|-------------|---------|
| `--db-endpoint` | `EVAL_DB_ENDPOINT` | `ws://127.0.0.1:8000` |
| `--db-username` | `EVAL_DB_USERNAME` | `root_user` |
| `--db-password` | `EVAL_DB_PASSWORD` | `root_password` |
| `--db-namespace` | `EVAL_DB_NAMESPACE` | auto-generated |
| `--db-database` | `EVAL_DB_DATABASE` | auto-generated |
### Example Runs
```bash
# Vector-only evaluation (recommended for benchmarking)
cargo run --package evaluations -- \
--dataset fiqa \
--ingest-chunks-only \
--limit 200
# Full FiQA evaluation with reranking
cargo run --package evaluations -- \
--dataset fiqa \
--ingest-chunks-only \
--limit 500 \
--rerank \
--k 10
# Use a predefined slice for reproducibility
cargo run --package evaluations -- --slice fiqa-test-200 --ingest-chunks-only
# Run the mixed BEIR benchmark
cargo run --package evaluations -- --dataset beir --slice beir-mix-600 --ingest-chunks-only
```
## Slices
Slices are predefined, reproducible subsets defined in `manifest.yaml`. Each slice specifies:
- **limit**: Number of questions
- **corpus_limit**: Maximum corpus size
- **seed**: Fixed RNG seed for reproducibility
View available slices in [manifest.yaml](./manifest.yaml).
## Reports
Evaluations generate reports in `reports/`:
- **JSON**: Full structured results (`*-report.json`)
- **Markdown**: Human-readable summary with sample mismatches (`*-report.md`)
- **History**: Timestamped run history (`history/`)
## Performance Tuning
```bash
# Log per-stage performance timings
cargo run --package evaluations -- --perf-log-console
# Save telemetry to file
cargo run --package evaluations -- --perf-log-json ./perf.json
```
## License
See [../LICENSE](../LICENSE).

168
evaluations/manifest.yaml Normal file
View File

@@ -0,0 +1,168 @@
default_dataset: squad-v2
datasets:
- id: squad-v2
label: "SQuAD v2.0"
category: "SQuAD v2.0"
entity_suffix: "SQuAD"
source_prefix: "squad"
raw: "data/raw/squad/dev-v2.0.json"
converted: "data/converted/squad-minne.json"
include_unanswerable: false
slices:
- id: squad-dev-200
label: "SQuAD dev (200)"
description: "Deterministic 200-case slice for local eval"
limit: 200
corpus_limit: 2000
seed: 0x5eed2025
- id: natural-questions-dev
label: "Natural Questions (dev)"
category: "Natural Questions"
entity_suffix: "Natural Questions"
source_prefix: "nq"
raw: "data/raw/nq-dev/dev-all.jsonl"
converted: "data/converted/nq-dev-minne.json"
include_unanswerable: true
slices:
- id: nq-dev-200
label: "NQ dev (200)"
description: "200-case slice of the dev set"
limit: 200
corpus_limit: 2000
include_unanswerable: false
seed: 0x5eed2025
- id: beir
label: "BEIR mix"
category: "BEIR"
entity_suffix: "BEIR"
source_prefix: "beir"
raw: "data/raw/beir"
converted: "data/converted/beir-minne.json"
include_unanswerable: false
slices:
- id: beir-mix-600
label: "BEIR mix (600)"
description: "Balanced slice across FEVER, FiQA, HotpotQA, NFCorpus, Quora, TREC-COVID, SciFact, NQ-BEIR"
limit: 600
corpus_limit: 6000
seed: 0x5eed2025
- id: fever
label: "FEVER (BEIR)"
category: "FEVER"
entity_suffix: "FEVER"
source_prefix: "fever"
raw: "data/raw/fever"
converted: "data/converted/fever-minne.json"
include_unanswerable: false
slices:
- id: fever-test-200
label: "FEVER test (200)"
description: "200-case slice from BEIR test qrels"
limit: 200
corpus_limit: 5000
seed: 0x5eed2025
- id: fiqa
label: "FiQA-2018 (BEIR)"
category: "FiQA-2018"
entity_suffix: "FiQA"
source_prefix: "fiqa"
raw: "data/raw/fiqa"
converted: "data/converted/fiqa-minne.json"
include_unanswerable: false
slices:
- id: fiqa-test-200
label: "FiQA test (200)"
description: "200-case slice from BEIR test qrels"
limit: 200
corpus_limit: 5000
seed: 0x5eed2025
- id: hotpotqa
label: "HotpotQA (BEIR)"
category: "HotpotQA"
entity_suffix: "HotpotQA"
source_prefix: "hotpotqa"
raw: "data/raw/hotpotqa"
converted: "data/converted/hotpotqa-minne.json"
include_unanswerable: false
slices:
- id: hotpotqa-test-200
label: "HotpotQA test (200)"
description: "200-case slice from BEIR test qrels"
limit: 200
corpus_limit: 5000
seed: 0x5eed2025
- id: nfcorpus
label: "NFCorpus (BEIR)"
category: "NFCorpus"
entity_suffix: "NFCorpus"
source_prefix: "nfcorpus"
raw: "data/raw/nfcorpus"
converted: "data/converted/nfcorpus-minne.json"
include_unanswerable: false
slices:
- id: nfcorpus-test-200
label: "NFCorpus test (200)"
description: "200-case slice from BEIR test qrels"
limit: 200
corpus_limit: 5000
seed: 0x5eed2025
- id: quora
label: "Quora (IR)"
category: "Quora"
entity_suffix: "Quora"
source_prefix: "quora"
raw: "data/raw/quora"
converted: "data/converted/quora-minne.json"
include_unanswerable: false
slices:
- id: quora-test-200
label: "Quora test (200)"
description: "200-case slice from BEIR test qrels"
limit: 200
corpus_limit: 5000
seed: 0x5eed2025
- id: trec-covid
label: "TREC-COVID (BEIR)"
category: "TREC-COVID"
entity_suffix: "TREC-COVID"
source_prefix: "trec-covid"
raw: "data/raw/trec-covid"
converted: "data/converted/trec-covid-minne.json"
include_unanswerable: false
slices:
- id: trec-covid-test-200
label: "TREC-COVID test (200)"
description: "200-case slice from BEIR test qrels"
limit: 200
corpus_limit: 5000
seed: 0x5eed2025
- id: scifact
label: "SciFact (BEIR)"
category: "SciFact"
entity_suffix: "SciFact"
source_prefix: "scifact"
raw: "data/raw/scifact"
converted: "data/converted/scifact-minne.json"
include_unanswerable: false
slices:
- id: scifact-test-200
label: "SciFact test (200)"
description: "200-case slice from BEIR test qrels"
limit: 200
corpus_limit: 3000
seed: 0x5eed2025
- id: nq-beir
label: "Natural Questions (BEIR)"
category: "Natural Questions"
entity_suffix: "Natural Questions"
source_prefix: "nq-beir"
raw: "data/raw/nq"
converted: "data/converted/nq-beir-minne.json"
include_unanswerable: false
slices:
- id: nq-beir-test-200
label: "NQ (BEIR) test (200)"
description: "200-case slice from BEIR test qrels"
limit: 200
corpus_limit: 5000
seed: 0x5eed2025

506
evaluations/src/args.rs Normal file
View File

@@ -0,0 +1,506 @@
use std::{
env,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use clap::{Args, Parser, ValueEnum};
use retrieval_pipeline::RetrievalStrategy;
use crate::datasets::DatasetKind;
fn workspace_root() -> PathBuf {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
manifest_dir.parent().unwrap_or(&manifest_dir).to_path_buf()
}
fn default_report_dir() -> PathBuf {
workspace_root().join("evaluations/reports")
}
fn default_cache_dir() -> PathBuf {
workspace_root().join("evaluations/cache")
}
fn default_ingestion_cache_dir() -> PathBuf {
workspace_root().join("evaluations/cache/ingested")
}
pub const DEFAULT_SLICE_SEED: u64 = 0x5eed_2025;
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
#[value(rename_all = "lowercase")]
pub enum EmbeddingBackend {
Hashed,
#[default]
FastEmbed,
}
impl std::fmt::Display for EmbeddingBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Hashed => write!(f, "hashed"),
Self::FastEmbed => write!(f, "fastembed"),
}
}
}
#[derive(Debug, Clone, Args)]
pub struct RetrievalSettings {
/// Override chunk vector candidate cap
#[arg(long)]
pub chunk_vector_take: Option<usize>,
/// Override chunk FTS candidate cap
#[arg(long)]
pub chunk_fts_take: Option<usize>,
/// Override average characters per token used for budgeting
#[arg(long)]
pub chunk_avg_chars_per_token: Option<usize>,
/// Override maximum chunks attached per entity
#[arg(long)]
pub max_chunks_per_entity: Option<usize>,
/// Enable the FastEmbed reranking stage
#[arg(long = "rerank", action = clap::ArgAction::SetTrue, default_value_t = false)]
pub rerank: bool,
/// Reranking engine pool size / parallelism
#[arg(long, default_value_t = 4)]
pub rerank_pool_size: usize,
/// Keep top-N entities after reranking
#[arg(long, default_value_t = 10)]
pub rerank_keep_top: usize,
/// Cap the number of chunks returned by retrieval (revised strategy)
#[arg(long, default_value_t = 5)]
pub chunk_result_cap: usize,
/// Reciprocal rank fusion k value for revised chunk merging
#[arg(long)]
pub chunk_rrf_k: Option<f32>,
/// Weight for vector ranks in revised RRF
#[arg(long)]
pub chunk_rrf_vector_weight: Option<f32>,
/// Weight for chunk FTS ranks in revised RRF
#[arg(long)]
pub chunk_rrf_fts_weight: Option<f32>,
/// Include vector ranks in revised RRF (default: true)
#[arg(long)]
pub chunk_rrf_use_vector: Option<bool>,
/// Include chunk FTS ranks in revised RRF (default: true)
#[arg(long)]
pub chunk_rrf_use_fts: Option<bool>,
/// Require verified chunks (disable with --llm-mode)
#[arg(skip = true)]
pub require_verified_chunks: bool,
/// Select the retrieval pipeline strategy
#[arg(long, default_value_t = RetrievalStrategy::Default)]
pub strategy: RetrievalStrategy,
}
impl Default for RetrievalSettings {
fn default() -> Self {
Self {
chunk_vector_take: None,
chunk_fts_take: None,
chunk_avg_chars_per_token: None,
max_chunks_per_entity: None,
rerank: false,
rerank_pool_size: 4,
rerank_keep_top: 10,
chunk_result_cap: 5,
chunk_rrf_k: None,
chunk_rrf_vector_weight: None,
chunk_rrf_fts_weight: None,
chunk_rrf_use_vector: None,
chunk_rrf_use_fts: None,
require_verified_chunks: true,
strategy: RetrievalStrategy::Default,
}
}
}
#[derive(Debug, Clone, Args)]
pub struct IngestConfig {
/// Directory for ingestion corpora caches
#[arg(long, default_value_os_t = default_ingestion_cache_dir())]
pub ingestion_cache_dir: PathBuf,
/// Minimum tokens per chunk for ingestion
#[arg(long, default_value_t = 256)]
pub ingest_chunk_min_tokens: usize,
/// Maximum tokens per chunk for ingestion
#[arg(long, default_value_t = 512)]
pub ingest_chunk_max_tokens: usize,
/// Overlap between chunks during ingestion (tokens)
#[arg(long, default_value_t = 50)]
pub ingest_chunk_overlap_tokens: usize,
/// Run ingestion in chunk-only mode (skip analyzer/graph generation)
#[arg(long)]
pub ingest_chunks_only: bool,
/// Number of paragraphs to ingest concurrently
#[arg(long, default_value_t = 10)]
pub ingestion_batch_size: usize,
/// Maximum retries for ingestion failures per paragraph
#[arg(long, default_value_t = 3)]
pub ingestion_max_retries: usize,
/// Recompute embeddings for cached corpora without re-running ingestion
#[arg(long, alias = "refresh-embeddings")]
pub refresh_embeddings_only: bool,
/// Delete cached paragraph shards before rebuilding the ingestion corpus
#[arg(long)]
pub slice_reset_ingestion: bool,
}
#[derive(Debug, Clone, Args)]
pub struct DatabaseArgs {
/// SurrealDB server endpoint
#[arg(long, default_value = "ws://127.0.0.1:8000", env = "EVAL_DB_ENDPOINT")]
pub db_endpoint: String,
/// SurrealDB root username
#[arg(long, default_value = "root_user", env = "EVAL_DB_USERNAME")]
pub db_username: String,
/// SurrealDB root password
#[arg(long, default_value = "root_password", env = "EVAL_DB_PASSWORD")]
pub db_password: String,
/// Override the namespace used on the SurrealDB server
#[arg(long, env = "EVAL_DB_NAMESPACE")]
pub db_namespace: Option<String>,
/// Override the database used on the SurrealDB server
#[arg(long, env = "EVAL_DB_DATABASE")]
pub db_database: Option<String>,
/// Path to inspect DB state
#[arg(long)]
pub inspect_db_state: Option<PathBuf>,
}
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
pub struct Config {
/// Convert the selected dataset and exit
#[arg(long)]
pub convert_only: bool,
/// Regenerate the converted dataset even if it already exists
#[arg(long, alias = "refresh")]
pub force_convert: bool,
/// Dataset to evaluate
#[arg(long, default_value_t = DatasetKind::default())]
pub dataset: DatasetKind,
/// Enable LLM-assisted evaluation features (includes unanswerable cases)
#[arg(long)]
pub llm_mode: bool,
/// Cap the slice corpus size (positives + negatives)
#[arg(long)]
pub corpus_limit: Option<usize>,
/// Path to the raw dataset (defaults per dataset)
#[arg(long)]
pub raw: Option<PathBuf>,
/// Path to write/read the converted dataset (defaults per dataset)
#[arg(long)]
pub converted: Option<PathBuf>,
/// Directory to write evaluation reports
#[arg(long, default_value_os_t = default_report_dir())]
pub report_dir: PathBuf,
/// Precision@k cutoff
#[arg(long, default_value_t = 5)]
pub k: usize,
/// Limit the number of questions evaluated (0 = all)
#[arg(long = "limit", default_value_t = 200)]
pub limit_arg: usize,
/// Number of mismatches to surface in the Markdown summary
#[arg(long, default_value_t = 5)]
pub sample: usize,
/// Disable context cropping when converting datasets (ingest entire documents)
#[arg(long)]
pub full_context: bool,
#[command(flatten)]
pub retrieval: RetrievalSettings,
/// Concurrency level
#[arg(long, default_value_t = 1)]
pub concurrency: usize,
/// Embedding backend
#[arg(long, default_value_t = EmbeddingBackend::FastEmbed)]
pub embedding_backend: EmbeddingBackend,
/// FastEmbed model code
#[arg(long)]
pub embedding_model: Option<String>,
/// Directory for embedding caches
#[arg(long, default_value_os_t = default_cache_dir())]
pub cache_dir: PathBuf,
#[command(flatten)]
pub ingest: IngestConfig,
/// Include entity descriptions and categories in JSON reports
#[arg(long)]
pub detailed_report: bool,
/// Use a cached dataset slice by id or path
#[arg(long)]
pub slice: Option<String>,
/// Ignore cached corpus state and rebuild the slice's SurrealDB corpus
#[arg(long)]
pub reseed_slice: bool,
/// Slice seed
#[arg(skip = DEFAULT_SLICE_SEED)]
pub slice_seed: u64,
/// Grow the slice ledger to contain at least this many answerable cases, then exit
#[arg(long)]
pub slice_grow: Option<usize>,
/// Evaluate questions starting at this offset within the slice
#[arg(long, default_value_t = 0)]
pub slice_offset: usize,
/// Target negative-to-positive paragraph ratio for slice growth
#[arg(long, default_value_t = crate::slice::DEFAULT_NEGATIVE_MULTIPLIER)]
pub negative_multiplier: f32,
/// Annotate the run; label is stored in JSON/Markdown reports
#[arg(long)]
pub label: Option<String>,
/// Write per-query chunk diagnostics JSONL to the provided path
#[arg(long, alias = "chunk-diagnostics")]
pub chunk_diagnostics_path: Option<PathBuf>,
/// Inspect an ingestion cache question and exit
#[arg(long)]
pub inspect_question: Option<String>,
/// Path to an ingestion cache manifest JSON for inspection mode
#[arg(long)]
pub inspect_manifest: Option<PathBuf>,
/// Override the SurrealDB system settings query model
#[arg(long)]
pub query_model: Option<String>,
/// Write structured performance telemetry JSON to the provided path
#[arg(long)]
pub perf_log_json: Option<PathBuf>,
/// Directory that receives timestamped perf JSON copies
#[arg(long)]
pub perf_log_dir: Option<PathBuf>,
/// Print per-stage performance timings to stdout after the run
#[arg(long, alias = "perf-log")]
pub perf_log_console: bool,
#[command(flatten)]
pub database: DatabaseArgs,
// Computed fields (not arguments)
#[arg(skip)]
pub raw_dataset_path: PathBuf,
#[arg(skip)]
pub converted_dataset_path: PathBuf,
#[arg(skip)]
pub limit: Option<usize>,
#[arg(skip)]
pub summary_sample: usize,
}
impl Config {
pub fn context_token_limit(&self) -> Option<usize> {
None
}
pub fn finalize(&mut self) -> Result<()> {
// Handle dataset paths
if let Some(raw) = &self.raw {
self.raw_dataset_path = raw.clone();
} else {
self.raw_dataset_path = self.dataset.default_raw_path();
}
if let Some(converted) = &self.converted {
self.converted_dataset_path = converted.clone();
} else {
self.converted_dataset_path = self.dataset.default_converted_path();
}
// Handle limit
if self.limit_arg == 0 {
self.limit = None;
} else {
self.limit = Some(self.limit_arg);
}
// Handle sample
self.summary_sample = self.sample.max(1);
// Handle retrieval settings
self.retrieval.require_verified_chunks = !self.llm_mode;
if self.dataset == DatasetKind::Beir {
self.negative_multiplier = 9.0;
}
// Validations
if self.ingest.ingest_chunk_min_tokens == 0
|| self.ingest.ingest_chunk_min_tokens >= self.ingest.ingest_chunk_max_tokens
{
return Err(anyhow!(
"--ingest-chunk-min-tokens must be greater than zero and less than --ingest-chunk-max-tokens (got {} >= {})",
self.ingest.ingest_chunk_min_tokens,
self.ingest.ingest_chunk_max_tokens
));
}
if self.ingest.ingest_chunk_overlap_tokens >= self.ingest.ingest_chunk_min_tokens {
return Err(anyhow!(
"--ingest-chunk-overlap-tokens ({}) must be less than --ingest-chunk-min-tokens ({})",
self.ingest.ingest_chunk_overlap_tokens,
self.ingest.ingest_chunk_min_tokens
));
}
if self.retrieval.rerank && self.retrieval.rerank_pool_size == 0 {
return Err(anyhow!(
"--rerank-pool must be greater than zero when reranking is enabled"
));
}
if let Some(k) = self.retrieval.chunk_rrf_k {
if k <= 0.0 || !k.is_finite() {
return Err(anyhow!(
"--chunk-rrf-k must be a positive, finite number (got {k})"
));
}
}
if let Some(weight) = self.retrieval.chunk_rrf_vector_weight {
if weight < 0.0 || !weight.is_finite() {
return Err(anyhow!(
"--chunk-rrf-vector-weight must be a non-negative, finite number (got {weight})"
));
}
}
if let Some(weight) = self.retrieval.chunk_rrf_fts_weight {
if weight < 0.0 || !weight.is_finite() {
return Err(anyhow!(
"--chunk-rrf-fts-weight must be a non-negative, finite number (got {weight})"
));
}
}
if self.concurrency == 0 {
return Err(anyhow!("--concurrency must be greater than zero"));
}
if self.embedding_backend == EmbeddingBackend::Hashed && self.embedding_model.is_some() {
return Err(anyhow!(
"--embedding-model cannot be used with the 'hashed' embedding backend"
));
}
if let Some(query_model) = &self.query_model {
if query_model.trim().is_empty() {
return Err(anyhow!("--query-model requires a non-empty model name"));
}
}
if let Some(grow) = self.slice_grow {
if grow == 0 {
return Err(anyhow!("--slice-grow must be greater than zero"));
}
}
if self.negative_multiplier <= 0.0 || !self.negative_multiplier.is_finite() {
return Err(anyhow!(
"--negative-multiplier must be a positive finite number"
));
}
// Handle corpus limit logic
if let Some(limit) = self.limit {
if let Some(corpus_limit) = self.corpus_limit {
if corpus_limit < limit {
self.corpus_limit = Some(limit);
}
} else {
let default_multiplier = 10usize;
let mut computed = limit.saturating_mul(default_multiplier);
if computed < limit {
computed = limit;
}
let max_cap = 1_000usize;
if computed > max_cap {
computed = max_cap;
}
self.corpus_limit = Some(computed);
}
}
// Handle perf log dir env var fallback
if self.perf_log_dir.is_none() {
if let Ok(dir) = env::var("EVAL_PERF_LOG_DIR") {
if !dir.trim().is_empty() {
self.perf_log_dir = Some(PathBuf::from(dir));
}
}
}
Ok(())
}
}
pub struct ParsedArgs {
pub config: Config,
}
pub fn parse() -> Result<ParsedArgs> {
let mut config = Config::parse();
config.finalize()?;
Ok(ParsedArgs { config })
}
pub fn ensure_parent(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating parent directory for {}", path.display()))?;
}
Ok(())
}

88
evaluations/src/cache.rs Normal file
View File

@@ -0,0 +1,88 @@
use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
#[derive(Debug, Default, Serialize, Deserialize)]
struct EmbeddingCacheData {
entities: HashMap<String, Vec<f32>>,
chunks: HashMap<String, Vec<f32>>,
}
#[derive(Clone)]
pub struct EmbeddingCache {
path: Arc<PathBuf>,
data: Arc<Mutex<EmbeddingCacheData>>,
dirty: Arc<AtomicBool>,
}
#[allow(dead_code)]
impl EmbeddingCache {
pub async fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let data = if path.exists() {
let raw = tokio::fs::read(&path)
.await
.with_context(|| format!("reading embedding cache {}", path.display()))?;
serde_json::from_slice(&raw)
.with_context(|| format!("parsing embedding cache {}", path.display()))?
} else {
EmbeddingCacheData::default()
};
Ok(Self {
path: Arc::new(path),
data: Arc::new(Mutex::new(data)),
dirty: Arc::new(AtomicBool::new(false)),
})
}
pub async fn get_entity(&self, id: &str) -> Option<Vec<f32>> {
let guard = self.data.lock().await;
guard.entities.get(id).cloned()
}
pub async fn insert_entity(&self, id: String, embedding: Vec<f32>) {
let mut guard = self.data.lock().await;
guard.entities.insert(id, embedding);
self.dirty.store(true, Ordering::Relaxed);
}
pub async fn get_chunk(&self, id: &str) -> Option<Vec<f32>> {
let guard = self.data.lock().await;
guard.chunks.get(id).cloned()
}
pub async fn insert_chunk(&self, id: String, embedding: Vec<f32>) {
let mut guard = self.data.lock().await;
guard.chunks.insert(id, embedding);
self.dirty.store(true, Ordering::Relaxed);
}
pub async fn persist(&self) -> Result<()> {
if !self.dirty.load(Ordering::Relaxed) {
return Ok(());
}
let guard = self.data.lock().await;
let body = serde_json::to_vec_pretty(&*guard).context("serialising embedding cache")?;
if let Some(parent) = self.path.parent() {
tokio::fs::create_dir_all(parent)
.await
.with_context(|| format!("creating cache directory {}", parent.display()))?;
}
tokio::fs::write(&*self.path, body)
.await
.with_context(|| format!("writing embedding cache {}", self.path.display()))?;
self.dirty.store(false, Ordering::Relaxed);
Ok(())
}
}

187
evaluations/src/cases.rs Normal file
View File

@@ -0,0 +1,187 @@
//! Case generation from corpus manifests.
use std::collections::HashMap;
use crate::corpus;
/// A test case for retrieval evaluation derived from a manifest question.
pub(crate) struct SeededCase {
pub question_id: String,
pub question: String,
pub expected_source: String,
pub answers: Vec<String>,
pub paragraph_id: String,
pub paragraph_title: String,
pub expected_chunk_ids: Vec<String>,
pub is_impossible: bool,
pub has_verified_chunks: bool,
}
/// Convert a corpus manifest into seeded evaluation cases.
pub(crate) fn cases_from_manifest(manifest: &corpus::CorpusManifest) -> Vec<SeededCase> {
let mut title_map = HashMap::new();
for paragraph in &manifest.paragraphs {
title_map.insert(paragraph.paragraph_id.as_str(), paragraph.title.clone());
}
let include_impossible = manifest.metadata.include_unanswerable;
let require_verified_chunks = manifest.metadata.require_verified_chunks;
manifest
.questions
.iter()
.filter(|question| {
should_include_question(question, include_impossible, require_verified_chunks)
})
.map(|question| {
let title = title_map
.get(question.paragraph_id.as_str())
.cloned()
.unwrap_or_else(|| "Untitled".to_string());
SeededCase {
question_id: question.question_id.clone(),
question: question.question_text.clone(),
expected_source: question.text_content_id.clone(),
answers: question.answers.clone(),
paragraph_id: question.paragraph_id.clone(),
paragraph_title: title,
expected_chunk_ids: question.matching_chunk_ids.clone(),
is_impossible: question.is_impossible,
has_verified_chunks: !question.matching_chunk_ids.is_empty(),
}
})
.collect()
}
fn should_include_question(
question: &corpus::CorpusQuestion,
include_impossible: bool,
require_verified_chunks: bool,
) -> bool {
if !include_impossible && question.is_impossible {
return false;
}
if require_verified_chunks && question.matching_chunk_ids.is_empty() {
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::corpus::store::{CorpusParagraph, EmbeddedKnowledgeEntity, EmbeddedTextChunk};
use crate::corpus::{CorpusManifest, CorpusMetadata, CorpusQuestion, MANIFEST_VERSION};
use chrono::Utc;
use common::storage::types::text_content::TextContent;
fn sample_manifest() -> CorpusManifest {
let paragraphs = vec![
CorpusParagraph {
paragraph_id: "p1".to_string(),
title: "Alpha".to_string(),
text_content: TextContent::new(
"alpha context".to_string(),
None,
"test".to_string(),
None,
None,
"user".to_string(),
),
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
relationships: Vec::new(),
chunks: Vec::<EmbeddedTextChunk>::new(),
},
CorpusParagraph {
paragraph_id: "p2".to_string(),
title: "Beta".to_string(),
text_content: TextContent::new(
"beta context".to_string(),
None,
"test".to_string(),
None,
None,
"user".to_string(),
),
entities: Vec::<EmbeddedKnowledgeEntity>::new(),
relationships: Vec::new(),
chunks: Vec::<EmbeddedTextChunk>::new(),
},
];
let questions = vec![
CorpusQuestion {
question_id: "q1".to_string(),
paragraph_id: "p1".to_string(),
text_content_id: "tc-alpha".to_string(),
question_text: "What is Alpha?".to_string(),
answers: vec!["Alpha".to_string()],
is_impossible: false,
matching_chunk_ids: vec!["chunk-alpha".to_string()],
},
CorpusQuestion {
question_id: "q2".to_string(),
paragraph_id: "p1".to_string(),
text_content_id: "tc-alpha".to_string(),
question_text: "Unanswerable?".to_string(),
answers: Vec::new(),
is_impossible: true,
matching_chunk_ids: Vec::new(),
},
CorpusQuestion {
question_id: "q3".to_string(),
paragraph_id: "p2".to_string(),
text_content_id: "tc-beta".to_string(),
question_text: "Where is Beta?".to_string(),
answers: vec!["Beta".to_string()],
is_impossible: false,
matching_chunk_ids: Vec::new(),
},
];
CorpusManifest {
version: MANIFEST_VERSION,
metadata: CorpusMetadata {
dataset_id: "ds".to_string(),
dataset_label: "Dataset".to_string(),
slice_id: "slice".to_string(),
include_unanswerable: true,
require_verified_chunks: true,
ingestion_fingerprint: "fp".to_string(),
embedding_backend: "test".to_string(),
embedding_model: None,
embedding_dimension: 3,
converted_checksum: "chk".to_string(),
generated_at: Utc::now(),
paragraph_count: paragraphs.len(),
question_count: questions.len(),
chunk_min_tokens: 1,
chunk_max_tokens: 10,
chunk_only: false,
},
paragraphs,
questions,
}
}
#[test]
fn cases_respect_mode_filters() {
let mut manifest = sample_manifest();
manifest.metadata.include_unanswerable = false;
manifest.metadata.require_verified_chunks = true;
let strict_cases = cases_from_manifest(&manifest);
assert_eq!(strict_cases.len(), 1);
assert_eq!(strict_cases[0].question_id, "q1");
assert_eq!(strict_cases[0].paragraph_title, "Alpha");
let mut llm_manifest = manifest.clone();
llm_manifest.metadata.include_unanswerable = true;
llm_manifest.metadata.require_verified_chunks = false;
let llm_cases = cases_from_manifest(&llm_manifest);
let ids: Vec<_> = llm_cases
.iter()
.map(|case| case.question_id.as_str())
.collect();
assert_eq!(ids, vec!["q1", "q2", "q3"]);
}
}

View File

@@ -0,0 +1,42 @@
use std::path::PathBuf;
use crate::args::Config;
#[derive(Debug, Clone)]
pub struct CorpusCacheConfig {
pub ingestion_cache_dir: PathBuf,
pub force_refresh: bool,
pub refresh_embeddings_only: bool,
pub ingestion_batch_size: usize,
pub ingestion_max_retries: usize,
}
impl CorpusCacheConfig {
pub fn new(
ingestion_cache_dir: impl Into<PathBuf>,
force_refresh: bool,
refresh_embeddings_only: bool,
ingestion_batch_size: usize,
ingestion_max_retries: usize,
) -> Self {
Self {
ingestion_cache_dir: ingestion_cache_dir.into(),
force_refresh,
refresh_embeddings_only,
ingestion_batch_size,
ingestion_max_retries,
}
}
}
impl From<&Config> for CorpusCacheConfig {
fn from(config: &Config) -> Self {
CorpusCacheConfig::new(
config.ingest.ingestion_cache_dir.clone(),
config.force_convert || config.ingest.slice_reset_ingestion,
config.ingest.refresh_embeddings_only,
config.ingest.ingestion_batch_size,
config.ingest.ingestion_max_retries,
)
}
}

View File

@@ -0,0 +1,26 @@
mod config;
mod orchestrator;
pub(crate) mod store;
pub use config::CorpusCacheConfig;
pub use orchestrator::{
cached_corpus_dir, compute_ingestion_fingerprint, corpus_handle_from_manifest, ensure_corpus,
load_cached_manifest,
};
pub use store::{
seed_manifest_into_db, window_manifest, CorpusHandle, CorpusManifest, CorpusMetadata,
CorpusQuestion, EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard,
ParagraphShardStore, MANIFEST_VERSION,
};
pub fn make_ingestion_config(config: &crate::args::Config) -> ingestion_pipeline::IngestionConfig {
ingestion_pipeline::IngestionConfig {
tuning: ingestion_pipeline::IngestionTuning {
chunk_min_tokens: config.ingest.ingest_chunk_min_tokens,
chunk_max_tokens: config.ingest.ingest_chunk_max_tokens,
chunk_overlap_tokens: config.ingest.ingest_chunk_overlap_tokens,
..Default::default()
},
chunk_only: config.ingest.ingest_chunks_only,
}
}

View File

@@ -0,0 +1,785 @@
use std::{
collections::{HashMap, HashSet},
fs,
io::Read,
path::{Path, PathBuf},
sync::Arc,
};
use anyhow::{anyhow, Context, Result};
use async_openai::Client;
use chrono::Utc;
use common::{
storage::{
db::SurrealDbClient,
store::{DynStore, StorageManager},
types::{ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, StoredObject},
},
utils::config::{AppConfig, StorageKind},
};
use futures::future::try_join_all;
use ingestion_pipeline::{IngestionConfig, IngestionPipeline};
use object_store::memory::InMemory;
use sha2::{Digest, Sha256};
use tracing::{info, warn};
use uuid::Uuid;
use crate::{
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion},
slice::{self, ResolvedSlice, SliceParagraphKind},
};
use crate::corpus::{
CorpusCacheConfig, CorpusHandle, CorpusManifest, CorpusMetadata, CorpusQuestion,
EmbeddedKnowledgeEntity, EmbeddedTextChunk, ParagraphShard, ParagraphShardStore,
MANIFEST_VERSION,
};
const INGESTION_SPEC_VERSION: u32 = 2;
type OpenAIClient = Client<async_openai::config::OpenAIConfig>;
#[derive(Clone)]
struct ParagraphShardRecord {
shard: ParagraphShard,
dirty: bool,
needs_reembed: bool,
}
#[derive(Clone)]
struct IngestRequest<'a> {
slot: usize,
paragraph: &'a ConvertedParagraph,
shard_path: String,
question_refs: Vec<&'a ConvertedQuestion>,
}
impl<'a> IngestRequest<'a> {
fn from_entry(
slot: usize,
paragraph: &'a ConvertedParagraph,
entry: &'a slice::SliceParagraphEntry,
) -> Result<Self> {
let shard_path = entry
.shard_path
.clone()
.unwrap_or_else(|| slice::default_shard_path(&entry.id));
let question_refs = match &entry.kind {
SliceParagraphKind::Positive { question_ids } => question_ids
.iter()
.map(|id| {
paragraph
.questions
.iter()
.find(|question| question.id == *id)
.ok_or_else(|| {
anyhow!(
"paragraph '{}' missing question '{}' referenced by slice",
paragraph.id,
id
)
})
})
.collect::<Result<Vec<_>>>()?,
SliceParagraphKind::Negative => Vec::new(),
};
Ok(Self {
slot,
paragraph,
shard_path,
question_refs,
})
}
}
struct ParagraphPlan<'a> {
slot: usize,
entry: &'a slice::SliceParagraphEntry,
paragraph: &'a ConvertedParagraph,
}
#[derive(Default)]
struct IngestionStats {
positive_reused: usize,
positive_ingested: usize,
negative_reused: usize,
negative_ingested: usize,
}
#[allow(clippy::too_many_arguments)]
pub async fn ensure_corpus(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
window: &slice::SliceWindow<'_>,
cache: &CorpusCacheConfig,
embedding: Arc<common::utils::embedding::EmbeddingProvider>,
openai: Arc<OpenAIClient>,
user_id: &str,
converted_path: &Path,
ingestion_config: IngestionConfig,
) -> Result<CorpusHandle> {
let checksum = compute_file_checksum(converted_path)
.with_context(|| format!("computing checksum for {}", converted_path.display()))?;
let ingestion_fingerprint =
build_ingestion_fingerprint(dataset, slice, &checksum, &ingestion_config);
let base_dir = cached_corpus_dir(
cache,
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
);
if cache.force_refresh && !cache.refresh_embeddings_only {
let _ = fs::remove_dir_all(&base_dir);
}
let store = ParagraphShardStore::new(base_dir.clone());
store.ensure_base_dir()?;
let positive_set: HashSet<&str> = window.positive_ids().collect();
let require_verified_chunks = slice.manifest.require_verified_chunks;
let embedding_backend_label = embedding.backend_label().to_string();
let embedding_model_code = embedding.model_code();
let embedding_dimension = embedding.dimension();
if positive_set.is_empty() {
return Err(anyhow!(
"window selection contains zero positive paragraphs for slice '{}'",
slice.manifest.slice_id
));
}
let desired_negatives =
((positive_set.len() as f32) * slice.manifest.negative_multiplier).ceil() as usize;
let mut plan = Vec::new();
let mut negatives_added = 0usize;
for (idx, entry) in slice.manifest.paragraphs.iter().enumerate() {
let include = match &entry.kind {
SliceParagraphKind::Positive { .. } => positive_set.contains(entry.id.as_str()),
SliceParagraphKind::Negative => {
negatives_added < desired_negatives && {
negatives_added += 1;
true
}
}
};
if include {
let paragraph = slice
.paragraphs
.get(idx)
.copied()
.ok_or_else(|| anyhow!("slice missing paragraph index {}", idx))?;
plan.push(ParagraphPlan {
slot: plan.len(),
entry,
paragraph,
});
}
}
if plan.is_empty() {
return Err(anyhow!(
"no paragraphs selected for ingestion (slice '{}')",
slice.manifest.slice_id
));
}
let mut records: Vec<Option<ParagraphShardRecord>> = vec![None; plan.len()];
let mut ingest_requests = Vec::new();
let mut stats = IngestionStats::default();
for plan_entry in &plan {
let shard_path = plan_entry
.entry
.shard_path
.clone()
.unwrap_or_else(|| slice::default_shard_path(&plan_entry.entry.id));
let shard = if cache.force_refresh {
None
} else {
store.load(&shard_path, &ingestion_fingerprint)?
};
if let Some(shard) = shard {
let model_matches = shard.embedding_model.as_deref() == embedding_model_code.as_deref();
let needs_reembed = shard.embedding_backend != embedding_backend_label
|| shard.embedding_dimension != embedding_dimension
|| !model_matches;
match plan_entry.entry.kind {
SliceParagraphKind::Positive { .. } => stats.positive_reused += 1,
SliceParagraphKind::Negative => stats.negative_reused += 1,
}
records[plan_entry.slot] = Some(ParagraphShardRecord {
shard,
dirty: false,
needs_reembed,
});
} else {
match plan_entry.entry.kind {
SliceParagraphKind::Positive { .. } => stats.positive_ingested += 1,
SliceParagraphKind::Negative => stats.negative_ingested += 1,
}
let request =
IngestRequest::from_entry(plan_entry.slot, plan_entry.paragraph, plan_entry.entry)?;
ingest_requests.push(request);
}
}
if cache.refresh_embeddings_only && !ingest_requests.is_empty() {
return Err(anyhow!(
"--refresh-embeddings requested but {} shard(s) missing for dataset '{}' slice '{}'",
ingest_requests.len(),
dataset.metadata.id,
slice.manifest.slice_id
));
}
if !ingest_requests.is_empty() {
let new_shards = ingest_paragraph_batch(
dataset,
&ingest_requests,
embedding.clone(),
openai.clone(),
user_id,
&ingestion_fingerprint,
&embedding_backend_label,
embedding_model_code.clone(),
embedding_dimension,
cache.ingestion_batch_size,
cache.ingestion_max_retries,
ingestion_config.clone(),
)
.await
.context("ingesting missing slice paragraphs")?;
for (request, shard) in ingest_requests.into_iter().zip(new_shards.into_iter()) {
store.persist(&shard)?;
records[request.slot] = Some(ParagraphShardRecord {
shard,
dirty: false,
needs_reembed: false,
});
}
}
for record in &mut records {
let shard_record = record
.as_mut()
.context("shard record missing after ingestion run")?;
if cache.refresh_embeddings_only || shard_record.needs_reembed {
// Embeddings are now generated by the pipeline using FastEmbed - no need to re-embed
shard_record.shard.ingestion_fingerprint = ingestion_fingerprint.clone();
shard_record.shard.ingested_at = Utc::now();
shard_record.shard.embedding_backend = embedding_backend_label.clone();
shard_record.shard.embedding_model = embedding_model_code.clone();
shard_record.shard.embedding_dimension = embedding_dimension;
shard_record.dirty = true;
shard_record.needs_reembed = false;
}
}
let mut record_index = HashMap::new();
for (idx, plan_entry) in plan.iter().enumerate() {
record_index.insert(plan_entry.entry.id.as_str(), idx);
}
let mut corpus_paragraphs = Vec::with_capacity(plan.len());
for record in &records {
let shard = &record.as_ref().expect("record missing").shard;
corpus_paragraphs.push(shard.to_corpus_paragraph());
}
let mut corpus_questions = Vec::with_capacity(window.cases.len());
for case in &window.cases {
let slot = record_index
.get(case.paragraph.id.as_str())
.copied()
.ok_or_else(|| {
anyhow!(
"slice case references paragraph '{}' that is not part of the window",
case.paragraph.id
)
})?;
let record_slot = records
.get_mut(slot)
.context("shard record slot missing for question binding")?;
let record = record_slot
.as_mut()
.context("shard record missing for question binding")?;
let (chunk_ids, updated) = match record.shard.ensure_question_binding(case.question) {
Ok(result) => result,
Err(err) => {
if require_verified_chunks {
return Err(err).context(format!(
"locating answer text for question '{}' in paragraph '{}'",
case.question.id, case.paragraph.id
));
}
warn!(
question_id = %case.question.id,
paragraph_id = %case.paragraph.id,
error = %err,
"Failed to locate answer text in ingested content; recording empty chunk bindings"
);
record
.shard
.question_bindings
.insert(case.question.id.clone(), Vec::new());
record.dirty = true;
(Vec::new(), true)
}
};
if updated {
record.dirty = true;
}
corpus_questions.push(CorpusQuestion {
question_id: case.question.id.clone(),
paragraph_id: case.paragraph.id.clone(),
text_content_id: record.shard.text_content.get_id().to_string(),
question_text: case.question.question.clone(),
answers: case.question.answers.clone(),
is_impossible: case.question.is_impossible,
matching_chunk_ids: chunk_ids,
});
}
for entry in records.iter_mut().flatten() {
if entry.dirty {
store.persist(&entry.shard)?;
}
}
let manifest = CorpusManifest {
version: MANIFEST_VERSION,
metadata: CorpusMetadata {
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
slice_id: slice.manifest.slice_id.clone(),
include_unanswerable: slice.manifest.includes_unanswerable,
require_verified_chunks: slice.manifest.require_verified_chunks,
ingestion_fingerprint: ingestion_fingerprint.clone(),
embedding_backend: embedding.backend_label().to_string(),
embedding_model: embedding.model_code(),
embedding_dimension: embedding.dimension(),
converted_checksum: checksum,
generated_at: Utc::now(),
paragraph_count: corpus_paragraphs.len(),
question_count: corpus_questions.len(),
chunk_min_tokens: ingestion_config.tuning.chunk_min_tokens,
chunk_max_tokens: ingestion_config.tuning.chunk_max_tokens,
chunk_only: ingestion_config.chunk_only,
},
paragraphs: corpus_paragraphs,
questions: corpus_questions,
};
let ingested_count = stats.positive_ingested + stats.negative_ingested;
let reused_ingestion = ingested_count == 0 && !cache.force_refresh;
let reused_embeddings = reused_ingestion && !cache.refresh_embeddings_only;
info!(
dataset = %dataset.metadata.id,
slice = %slice.manifest.slice_id,
fingerprint = %ingestion_fingerprint,
reused_ingestion,
reused_embeddings,
positive_reused = stats.positive_reused,
positive_ingested = stats.positive_ingested,
negative_reused = stats.negative_reused,
negative_ingested = stats.negative_ingested,
shard_dir = %base_dir.display(),
"Corpus cache outcome"
);
let handle = CorpusHandle {
manifest,
path: base_dir,
reused_ingestion,
reused_embeddings,
positive_reused: stats.positive_reused,
positive_ingested: stats.positive_ingested,
negative_reused: stats.negative_reused,
negative_ingested: stats.negative_ingested,
};
persist_manifest(&handle).context("persisting corpus manifest")?;
Ok(handle)
}
#[allow(clippy::too_many_arguments)]
async fn ingest_paragraph_batch(
dataset: &ConvertedDataset,
targets: &[IngestRequest<'_>],
embedding: Arc<common::utils::embedding::EmbeddingProvider>,
openai: Arc<OpenAIClient>,
user_id: &str,
ingestion_fingerprint: &str,
embedding_backend: &str,
embedding_model: Option<String>,
embedding_dimension: usize,
batch_size: usize,
max_retries: usize,
ingestion_config: IngestionConfig,
) -> Result<Vec<ParagraphShard>> {
if targets.is_empty() {
return Ok(Vec::new());
}
let namespace = format!("ingest_eval_{}", Uuid::new_v4());
let db = Arc::new(
SurrealDbClient::memory(&namespace, "corpus")
.await
.context("creating in-memory surrealdb for ingestion")?,
);
db.apply_migrations()
.await
.context("applying migrations for ingestion")?;
let app_config = AppConfig {
storage: StorageKind::Memory,
..Default::default()
};
let backend: DynStore = Arc::new(InMemory::new());
let storage = StorageManager::with_backend(backend, StorageKind::Memory);
let pipeline_config = ingestion_config.clone();
let pipeline = IngestionPipeline::new_with_config(
db,
openai.clone(),
app_config,
None::<Arc<retrieval_pipeline::reranking::RerankerPool>>,
storage,
embedding.clone(),
pipeline_config,
)?;
let pipeline = Arc::new(pipeline);
let mut shards = Vec::with_capacity(targets.len());
let category = dataset.metadata.category.clone();
for (batch_index, batch) in targets.chunks(batch_size).enumerate() {
info!(
batch = batch_index,
batch_size = batch.len(),
total_batches = targets.len().div_ceil(batch_size),
"Ingesting paragraph batch"
);
let model_clone = embedding_model.clone();
let backend_clone = embedding_backend.to_string();
let pipeline_clone = pipeline.clone();
let category_clone = category.clone();
let tasks = batch.iter().cloned().map(move |request| {
ingest_single_paragraph(
pipeline_clone.clone(),
request,
category_clone.clone(),
user_id,
ingestion_fingerprint,
backend_clone.clone(),
model_clone.clone(),
embedding_dimension,
max_retries,
ingestion_config.tuning.chunk_min_tokens,
ingestion_config.tuning.chunk_max_tokens,
ingestion_config.chunk_only,
)
});
let batch_results: Vec<ParagraphShard> = try_join_all(tasks)
.await
.context("ingesting batch of paragraphs")?;
shards.extend(batch_results);
}
Ok(shards)
}
#[allow(clippy::too_many_arguments)]
async fn ingest_single_paragraph(
pipeline: Arc<IngestionPipeline>,
request: IngestRequest<'_>,
category: String,
user_id: &str,
ingestion_fingerprint: &str,
embedding_backend: String,
embedding_model: Option<String>,
embedding_dimension: usize,
max_retries: usize,
chunk_min_tokens: usize,
chunk_max_tokens: usize,
chunk_only: bool,
) -> Result<ParagraphShard> {
let paragraph = request.paragraph;
let mut last_err: Option<anyhow::Error> = None;
for attempt in 1..=max_retries {
let payload = IngestionPayload::Text {
text: paragraph.context.clone(),
context: paragraph.title.clone(),
category: category.clone(),
user_id: user_id.to_string(),
};
let task = IngestionTask::new(payload, user_id.to_string());
match pipeline.produce_artifacts(&task).await {
Ok(artifacts) => {
let entities: Vec<EmbeddedKnowledgeEntity> = artifacts
.entities
.into_iter()
.map(|e| EmbeddedKnowledgeEntity {
entity: e.entity,
embedding: e.embedding,
})
.collect();
let chunks: Vec<EmbeddedTextChunk> = artifacts
.chunks
.into_iter()
.map(|c| EmbeddedTextChunk {
chunk: c.chunk,
embedding: c.embedding,
})
.collect();
// No need to reembed - pipeline now uses FastEmbed internally
let mut shard = ParagraphShard::new(
paragraph,
request.shard_path,
ingestion_fingerprint,
artifacts.text_content,
entities,
artifacts.relationships,
chunks,
&embedding_backend,
embedding_model.clone(),
embedding_dimension,
chunk_min_tokens,
chunk_max_tokens,
chunk_only,
);
for question in &request.question_refs {
if let Err(err) = shard.ensure_question_binding(question) {
warn!(
question_id = %question.id,
paragraph_id = %paragraph.id,
error = %err,
"Failed to locate answer text in ingested content; recording empty chunk bindings"
);
shard
.question_bindings
.insert(question.id.clone(), Vec::new());
}
}
return Ok(shard);
}
Err(err) => {
warn!(
paragraph_id = %paragraph.id,
attempt,
max_attempts = max_retries,
error = ?err,
"ingestion attempt failed for paragraph; retrying"
);
last_err = Some(err.into());
}
}
}
Err(last_err
.unwrap_or_else(|| anyhow!("ingestion failed"))
.context(format!("running ingestion for paragraph {}", paragraph.id)))
}
pub fn cached_corpus_dir(cache: &CorpusCacheConfig, dataset_id: &str, slice_id: &str) -> PathBuf {
cache.ingestion_cache_dir.join(dataset_id).join(slice_id)
}
pub fn build_ingestion_fingerprint(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
checksum: &str,
ingestion_config: &IngestionConfig,
) -> String {
let config_repr = format!("{:?}", ingestion_config);
let mut hasher = Sha256::new();
hasher.update(config_repr.as_bytes());
let config_hash = format!("{:x}", hasher.finalize());
format!(
"v{INGESTION_SPEC_VERSION}:{}:{}:{}:{}:{}",
dataset.metadata.id,
slice.manifest.slice_id,
slice.manifest.includes_unanswerable,
checksum,
config_hash
)
}
pub fn compute_ingestion_fingerprint(
dataset: &ConvertedDataset,
slice: &ResolvedSlice<'_>,
converted_path: &Path,
ingestion_config: &IngestionConfig,
) -> Result<String> {
let checksum = compute_file_checksum(converted_path)?;
Ok(build_ingestion_fingerprint(
dataset,
slice,
&checksum,
ingestion_config,
))
}
pub fn load_cached_manifest(base_dir: &Path) -> Result<Option<CorpusManifest>> {
let path = base_dir.join("manifest.json");
if !path.exists() {
return Ok(None);
}
let mut file = fs::File::open(&path)
.with_context(|| format!("opening cached manifest {}", path.display()))?;
let mut buf = Vec::new();
file.read_to_end(&mut buf)
.with_context(|| format!("reading cached manifest {}", path.display()))?;
let manifest: CorpusManifest = serde_json::from_slice(&buf)
.with_context(|| format!("deserialising cached manifest {}", path.display()))?;
Ok(Some(manifest))
}
fn persist_manifest(handle: &CorpusHandle) -> Result<()> {
let path = handle.path.join("manifest.json");
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating manifest directory {}", parent.display()))?;
}
let tmp_path = path.with_extension("json.tmp");
let blob =
serde_json::to_vec_pretty(&handle.manifest).context("serialising corpus manifest")?;
fs::write(&tmp_path, &blob)
.with_context(|| format!("writing temporary manifest {}", tmp_path.display()))?;
fs::rename(&tmp_path, &path)
.with_context(|| format!("replacing manifest {}", path.display()))?;
Ok(())
}
pub fn corpus_handle_from_manifest(manifest: CorpusManifest, base_dir: PathBuf) -> CorpusHandle {
CorpusHandle {
manifest,
path: base_dir,
reused_ingestion: true,
reused_embeddings: true,
positive_reused: 0,
positive_ingested: 0,
negative_reused: 0,
negative_ingested: 0,
}
}
fn compute_file_checksum(path: &Path) -> Result<String> {
let mut file = fs::File::open(path)
.with_context(|| format!("opening file {} for checksum", path.display()))?;
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192];
loop {
let read = file
.read(&mut buffer)
.with_context(|| format!("reading {} for checksum", path.display()))?;
if read == 0 {
break;
}
hasher.update(&buffer[..read]);
}
Ok(format!("{:x}", hasher.finalize()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
datasets::{ConvertedDataset, ConvertedParagraph, ConvertedQuestion, DatasetKind},
slice::{CaseRef, SliceCaseEntry, SliceManifest, SliceParagraphEntry, SliceParagraphKind},
};
use chrono::Utc;
fn dummy_dataset() -> ConvertedDataset {
let question = ConvertedQuestion {
id: "q1".to_string(),
question: "What?".to_string(),
answers: vec!["A".to_string()],
is_impossible: false,
};
let paragraph = ConvertedParagraph {
id: "p1".to_string(),
title: "title".to_string(),
context: "context".to_string(),
questions: vec![question],
};
ConvertedDataset {
generated_at: Utc::now(),
metadata: crate::datasets::DatasetMetadata::for_kind(
DatasetKind::default(),
false,
None,
),
source: "src".to_string(),
paragraphs: vec![paragraph],
}
}
fn dummy_slice<'a>(dataset: &'a ConvertedDataset) -> ResolvedSlice<'a> {
let paragraph = &dataset.paragraphs[0];
let question = &paragraph.questions[0];
let manifest = SliceManifest {
version: 1,
slice_id: "slice-1".to_string(),
dataset_id: dataset.metadata.id.clone(),
dataset_label: dataset.metadata.label.clone(),
dataset_source: dataset.source.clone(),
includes_unanswerable: false,
require_verified_chunks: false,
seed: 1,
requested_limit: Some(1),
requested_corpus: 1,
generated_at: Utc::now(),
case_count: 1,
positive_paragraphs: 1,
negative_paragraphs: 0,
total_paragraphs: 1,
negative_multiplier: 1.0,
cases: vec![SliceCaseEntry {
question_id: question.id.clone(),
paragraph_id: paragraph.id.clone(),
}],
paragraphs: vec![SliceParagraphEntry {
id: paragraph.id.clone(),
kind: SliceParagraphKind::Positive {
question_ids: vec![question.id.clone()],
},
shard_path: None,
}],
};
ResolvedSlice {
manifest,
path: PathBuf::from("cache"),
paragraphs: dataset.paragraphs.iter().collect(),
cases: vec![CaseRef {
paragraph,
question,
}],
}
}
#[test]
fn fingerprint_changes_with_chunk_settings() {
let dataset = dummy_dataset();
let slice = dummy_slice(&dataset);
let checksum = "deadbeef";
let base_config = IngestionConfig::default();
let fp_base = build_ingestion_fingerprint(&dataset, &slice, checksum, &base_config);
let mut token_config = base_config.clone();
token_config.tuning.chunk_min_tokens += 1;
let fp_token = build_ingestion_fingerprint(&dataset, &slice, checksum, &token_config);
assert_ne!(fp_base, fp_token, "token bounds should affect fingerprint");
let mut chunk_only_config = base_config;
chunk_only_config.chunk_only = true;
let fp_chunk_only =
build_ingestion_fingerprint(&dataset, &slice, checksum, &chunk_only_config);
assert_ne!(
fp_base, fp_chunk_only,
"chunk-only mode should affect fingerprint"
);
}
}

View File

@@ -0,0 +1,934 @@
use std::{
collections::{HashMap, HashSet},
fs,
io::BufReader,
path::PathBuf,
};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use common::storage::types::StoredObject;
use common::storage::{
db::SurrealDbClient,
types::{
knowledge_entity::KnowledgeEntity,
knowledge_entity_embedding::KnowledgeEntityEmbedding,
knowledge_relationship::{KnowledgeRelationship, RelationshipMetadata},
text_chunk::TextChunk,
text_chunk_embedding::TextChunkEmbedding,
text_content::TextContent,
},
};
use serde::Deserialize;
use serde::Serialize;
use surrealdb::sql::Thing;
use tracing::{debug, warn};
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
pub const MANIFEST_VERSION: u32 = 3;
pub const PARAGRAPH_SHARD_VERSION: u32 = 3;
const MANIFEST_BATCH_SIZE: usize = 100;
const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches
const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively
const MAX_BATCHES_PER_REQUEST: usize = 24;
const REQUEST_MAX_BYTES: usize = 800_000; // total payload cap per Surreal query request
fn current_manifest_version() -> u32 {
MANIFEST_VERSION
}
fn current_paragraph_shard_version() -> u32 {
PARAGRAPH_SHARD_VERSION
}
fn default_chunk_min_tokens() -> usize {
500
}
fn default_chunk_max_tokens() -> usize {
2_000
}
fn default_chunk_only() -> bool {
false
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EmbeddedKnowledgeEntity {
pub entity: KnowledgeEntity,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EmbeddedTextChunk {
pub chunk: TextChunk,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct LegacyKnowledgeEntity {
#[serde(flatten)]
pub entity: KnowledgeEntity,
#[serde(default)]
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct LegacyTextChunk {
#[serde(flatten)]
pub chunk: TextChunk,
#[serde(default)]
pub embedding: Vec<f32>,
}
fn deserialize_embedded_entities<'de, D>(
deserializer: D,
) -> Result<Vec<EmbeddedKnowledgeEntity>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum EntityInput {
Embedded(Vec<EmbeddedKnowledgeEntity>),
Legacy(Vec<LegacyKnowledgeEntity>),
}
match EntityInput::deserialize(deserializer)? {
EntityInput::Embedded(items) => Ok(items),
EntityInput::Legacy(items) => Ok(items
.into_iter()
.map(|legacy| EmbeddedKnowledgeEntity {
entity: legacy.entity,
embedding: legacy.embedding,
})
.collect()),
}
}
fn deserialize_embedded_chunks<'de, D>(deserializer: D) -> Result<Vec<EmbeddedTextChunk>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum ChunkInput {
Embedded(Vec<EmbeddedTextChunk>),
Legacy(Vec<LegacyTextChunk>),
}
match ChunkInput::deserialize(deserializer)? {
ChunkInput::Embedded(items) => Ok(items),
ChunkInput::Legacy(items) => Ok(items
.into_iter()
.map(|legacy| EmbeddedTextChunk {
chunk: legacy.chunk,
embedding: legacy.embedding,
})
.collect()),
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusManifest {
#[serde(default = "current_manifest_version")]
pub version: u32,
pub metadata: CorpusMetadata,
pub paragraphs: Vec<CorpusParagraph>,
pub questions: Vec<CorpusQuestion>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusMetadata {
pub dataset_id: String,
pub dataset_label: String,
pub slice_id: String,
pub include_unanswerable: bool,
#[serde(default)]
pub require_verified_chunks: bool,
pub ingestion_fingerprint: String,
pub embedding_backend: String,
pub embedding_model: Option<String>,
pub embedding_dimension: usize,
pub converted_checksum: String,
pub generated_at: DateTime<Utc>,
pub paragraph_count: usize,
pub question_count: usize,
#[serde(default = "default_chunk_min_tokens")]
pub chunk_min_tokens: usize,
#[serde(default = "default_chunk_max_tokens")]
pub chunk_max_tokens: usize,
#[serde(default = "default_chunk_only")]
pub chunk_only: bool,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusParagraph {
pub paragraph_id: String,
pub title: String,
pub text_content: TextContent,
#[serde(deserialize_with = "deserialize_embedded_entities")]
pub entities: Vec<EmbeddedKnowledgeEntity>,
pub relationships: Vec<KnowledgeRelationship>,
#[serde(deserialize_with = "deserialize_embedded_chunks")]
pub chunks: Vec<EmbeddedTextChunk>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CorpusQuestion {
pub question_id: String,
pub paragraph_id: String,
pub text_content_id: String,
pub question_text: String,
pub answers: Vec<String>,
pub is_impossible: bool,
pub matching_chunk_ids: Vec<String>,
}
pub struct CorpusHandle {
pub manifest: CorpusManifest,
pub path: PathBuf,
pub reused_ingestion: bool,
pub reused_embeddings: bool,
pub positive_reused: usize,
pub positive_ingested: usize,
pub negative_reused: usize,
pub negative_ingested: usize,
}
pub fn window_manifest(
manifest: &CorpusManifest,
offset: usize,
length: usize,
negative_multiplier: f32,
) -> Result<CorpusManifest> {
let total = manifest.questions.len();
if total == 0 {
return Err(anyhow!(
"manifest contains no questions; cannot select a window"
));
}
if offset >= total {
return Err(anyhow!(
"window offset {} exceeds manifest questions ({})",
offset,
total
));
}
let end = (offset + length).min(total);
let questions = manifest.questions[offset..end].to_vec();
let selected_positive_ids: HashSet<_> =
questions.iter().map(|q| q.paragraph_id.clone()).collect();
let positives_all: HashSet<_> = manifest
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect();
let available_negatives = manifest
.paragraphs
.len()
.saturating_sub(positives_all.len());
let desired_negatives =
((selected_positive_ids.len() as f32) * negative_multiplier).ceil() as usize;
let desired_negatives = desired_negatives.min(available_negatives);
let mut paragraphs = Vec::new();
let mut negative_count = 0usize;
for paragraph in &manifest.paragraphs {
if selected_positive_ids.contains(&paragraph.paragraph_id) {
paragraphs.push(paragraph.clone());
} else if negative_count < desired_negatives {
paragraphs.push(paragraph.clone());
negative_count += 1;
}
}
let mut narrowed = manifest.clone();
narrowed.questions = questions;
narrowed.paragraphs = paragraphs;
narrowed.metadata.paragraph_count = narrowed.paragraphs.len();
narrowed.metadata.question_count = narrowed.questions.len();
Ok(narrowed)
}
#[derive(Debug, Clone, Serialize)]
struct RelationInsert {
#[serde(rename = "in")]
pub in_: Thing,
#[serde(rename = "out")]
pub out: Thing,
pub id: String,
pub metadata: RelationshipMetadata,
}
#[derive(Debug)]
struct SizedBatch<T> {
approx_bytes: usize,
items: Vec<T>,
}
struct ManifestBatches {
text_contents: Vec<SizedBatch<TextContent>>,
entities: Vec<SizedBatch<KnowledgeEntity>>,
entity_embeddings: Vec<SizedBatch<KnowledgeEntityEmbedding>>,
relationships: Vec<SizedBatch<RelationInsert>>,
chunks: Vec<SizedBatch<TextChunk>>,
chunk_embeddings: Vec<SizedBatch<TextChunkEmbedding>>,
}
fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches> {
let mut text_contents = Vec::new();
let mut entities = Vec::new();
let mut entity_embeddings = Vec::new();
let mut relationships = Vec::new();
let mut chunks = Vec::new();
let mut chunk_embeddings = Vec::new();
let mut seen_text_content = HashSet::new();
let mut seen_entities = HashSet::new();
let mut seen_relationships = HashSet::new();
let mut seen_chunks = HashSet::new();
for paragraph in &manifest.paragraphs {
if seen_text_content.insert(paragraph.text_content.id.clone()) {
text_contents.push(paragraph.text_content.clone());
}
for embedded_entity in &paragraph.entities {
if seen_entities.insert(embedded_entity.entity.id.clone()) {
let entity = embedded_entity.entity.clone();
entities.push(entity.clone());
entity_embeddings.push(KnowledgeEntityEmbedding::new(
&entity.id,
embedded_entity.embedding.clone(),
entity.user_id.clone(),
));
}
}
for relationship in &paragraph.relationships {
if seen_relationships.insert(relationship.id.clone()) {
let table = KnowledgeEntity::table_name();
let in_id = relationship
.in_
.strip_prefix(&format!("{table}:"))
.unwrap_or(&relationship.in_);
let out_id = relationship
.out
.strip_prefix(&format!("{table}:"))
.unwrap_or(&relationship.out);
let in_thing = Thing::from((table, in_id));
let out_thing = Thing::from((table, out_id));
relationships.push(RelationInsert {
in_: in_thing,
out: out_thing,
id: relationship.id.clone(),
metadata: relationship.metadata.clone(),
});
}
}
for embedded_chunk in &paragraph.chunks {
if seen_chunks.insert(embedded_chunk.chunk.id.clone()) {
let chunk = embedded_chunk.chunk.clone();
chunks.push(chunk.clone());
chunk_embeddings.push(TextChunkEmbedding::new(
&chunk.id,
chunk.source_id.clone(),
embedded_chunk.embedding.clone(),
chunk.user_id.clone(),
));
}
}
}
Ok(ManifestBatches {
text_contents: chunk_items(
&text_contents,
MANIFEST_BATCH_SIZE,
TEXT_CONTENT_MAX_BYTES_PER_BATCH,
)
.context("chunking text_content payloads")?,
entities: chunk_items(&entities, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
.context("chunking knowledge_entity payloads")?,
entity_embeddings: chunk_items(
&entity_embeddings,
MANIFEST_BATCH_SIZE,
MANIFEST_MAX_BYTES_PER_BATCH,
)
.context("chunking knowledge_entity_embedding payloads")?,
relationships: chunk_items(
&relationships,
MANIFEST_BATCH_SIZE,
MANIFEST_MAX_BYTES_PER_BATCH,
)
.context("chunking relationship payloads")?,
chunks: chunk_items(&chunks, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
.context("chunking text_chunk payloads")?,
chunk_embeddings: chunk_items(
&chunk_embeddings,
MANIFEST_BATCH_SIZE,
MANIFEST_MAX_BYTES_PER_BATCH,
)
.context("chunking text_chunk_embedding payloads")?,
})
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ParagraphShard {
#[serde(default = "current_paragraph_shard_version")]
pub version: u32,
pub paragraph_id: String,
pub shard_path: String,
pub ingestion_fingerprint: String,
pub ingested_at: DateTime<Utc>,
pub title: String,
pub text_content: TextContent,
#[serde(deserialize_with = "deserialize_embedded_entities")]
pub entities: Vec<EmbeddedKnowledgeEntity>,
pub relationships: Vec<KnowledgeRelationship>,
#[serde(deserialize_with = "deserialize_embedded_chunks")]
pub chunks: Vec<EmbeddedTextChunk>,
#[serde(default)]
pub question_bindings: HashMap<String, Vec<String>>,
#[serde(default)]
pub embedding_backend: String,
#[serde(default)]
pub embedding_model: Option<String>,
#[serde(default)]
pub embedding_dimension: usize,
#[serde(default = "default_chunk_min_tokens")]
pub chunk_min_tokens: usize,
#[serde(default = "default_chunk_max_tokens")]
pub chunk_max_tokens: usize,
#[serde(default = "default_chunk_only")]
pub chunk_only: bool,
}
pub struct ParagraphShardStore {
base_dir: PathBuf,
}
impl ParagraphShardStore {
pub fn new(base_dir: PathBuf) -> Self {
Self { base_dir }
}
pub fn ensure_base_dir(&self) -> Result<()> {
fs::create_dir_all(&self.base_dir)
.with_context(|| format!("creating shard base dir {}", self.base_dir.display()))
}
fn resolve(&self, relative: &str) -> PathBuf {
self.base_dir.join(relative)
}
pub fn load(&self, relative: &str, fingerprint: &str) -> Result<Option<ParagraphShard>> {
let path = self.resolve(relative);
let file = match fs::File::open(&path) {
Ok(file) => file,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(err) => {
return Err(err).with_context(|| format!("opening shard {}", path.display()))
}
};
let reader = BufReader::new(file);
let mut shard: ParagraphShard = serde_json::from_reader(reader)
.with_context(|| format!("parsing shard {}", path.display()))?;
if shard.ingestion_fingerprint != fingerprint {
debug!(
path = %path.display(),
expected = fingerprint,
found = shard.ingestion_fingerprint,
"Shard fingerprint mismatch; will rebuild"
);
return Ok(None);
}
if shard.version != PARAGRAPH_SHARD_VERSION {
warn!(
path = %path.display(),
version = shard.version,
expected = PARAGRAPH_SHARD_VERSION,
"Upgrading shard to current version"
);
shard.version = PARAGRAPH_SHARD_VERSION;
}
shard.shard_path = relative.to_string();
Ok(Some(shard))
}
pub fn persist(&self, shard: &ParagraphShard) -> Result<()> {
let mut shard = shard.clone();
shard.version = PARAGRAPH_SHARD_VERSION;
let path = self.resolve(&shard.shard_path);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating shard dir {}", parent.display()))?;
}
let tmp_path = path.with_extension("json.tmp");
let body = serde_json::to_vec_pretty(&shard).context("serialising paragraph shard")?;
fs::write(&tmp_path, &body)
.with_context(|| format!("writing shard tmp {}", tmp_path.display()))?;
fs::rename(&tmp_path, &path)
.with_context(|| format!("renaming shard tmp {}", path.display()))?;
Ok(())
}
}
impl ParagraphShard {
#[allow(clippy::too_many_arguments)]
pub fn new(
paragraph: &ConvertedParagraph,
shard_path: String,
ingestion_fingerprint: &str,
text_content: TextContent,
entities: Vec<EmbeddedKnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
chunks: Vec<EmbeddedTextChunk>,
embedding_backend: &str,
embedding_model: Option<String>,
embedding_dimension: usize,
chunk_min_tokens: usize,
chunk_max_tokens: usize,
chunk_only: bool,
) -> Self {
Self {
version: PARAGRAPH_SHARD_VERSION,
paragraph_id: paragraph.id.clone(),
shard_path,
ingestion_fingerprint: ingestion_fingerprint.to_string(),
ingested_at: Utc::now(),
title: paragraph.title.clone(),
text_content,
entities,
relationships,
chunks,
question_bindings: HashMap::new(),
embedding_backend: embedding_backend.to_string(),
embedding_model,
embedding_dimension,
chunk_min_tokens,
chunk_max_tokens,
chunk_only,
}
}
pub fn to_corpus_paragraph(&self) -> CorpusParagraph {
CorpusParagraph {
paragraph_id: self.paragraph_id.clone(),
title: self.title.clone(),
text_content: self.text_content.clone(),
entities: self.entities.clone(),
relationships: self.relationships.clone(),
chunks: self.chunks.clone(),
}
}
pub fn ensure_question_binding(
&mut self,
question: &ConvertedQuestion,
) -> Result<(Vec<String>, bool)> {
if let Some(existing) = self.question_bindings.get(&question.id) {
return Ok((existing.clone(), false));
}
let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?;
self.question_bindings
.insert(question.id.clone(), chunk_ids.clone());
Ok((chunk_ids, true))
}
}
fn validate_answers(
content: &TextContent,
chunks: &[EmbeddedTextChunk],
question: &ConvertedQuestion,
) -> Result<Vec<String>> {
if question.is_impossible || question.answers.is_empty() {
return Ok(Vec::new());
}
let mut matches = std::collections::BTreeSet::new();
let mut found_any = false;
let haystack = content.text.to_ascii_lowercase();
let haystack_norm = normalize_answer_text(&haystack);
for answer in &question.answers {
let needle: String = answer.to_ascii_lowercase();
let needle_norm = normalize_answer_text(&needle);
let text_match = haystack.contains(&needle)
|| (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm));
if text_match {
found_any = true;
}
for chunk in chunks {
let chunk_text = chunk.chunk.chunk.to_ascii_lowercase();
let chunk_norm = normalize_answer_text(&chunk_text);
if chunk_text.contains(&needle)
|| (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm))
{
matches.insert(chunk.chunk.get_id().to_string());
found_any = true;
}
}
}
if !found_any {
Err(anyhow!(
"expected answer for question '{}' was not found in ingested content",
question.id
))
} else {
Ok(matches.into_iter().collect())
}
}
fn normalize_answer_text(text: &str) -> String {
text.chars()
.map(|ch| {
if ch.is_alphanumeric() || ch.is_whitespace() {
ch.to_ascii_lowercase()
} else {
' '
}
})
.collect::<String>()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
fn chunk_items<T: Clone + Serialize>(
items: &[T],
max_items: usize,
max_bytes: usize,
) -> Result<Vec<SizedBatch<T>>> {
if items.is_empty() {
return Ok(Vec::new());
}
let mut batches = Vec::new();
let mut current = Vec::new();
let mut current_bytes = 0usize;
for item in items {
let size = serde_json::to_vec(item)
.map(|buf| buf.len())
.context("serialising batch item for sizing")?;
let would_overflow_items = !current.is_empty() && current.len() >= max_items;
let would_overflow_bytes = !current.is_empty() && current_bytes + size > max_bytes;
if would_overflow_items || would_overflow_bytes {
batches.push(SizedBatch {
approx_bytes: current_bytes.max(1),
items: std::mem::take(&mut current),
});
current_bytes = 0;
}
current_bytes += size;
current.push(item.clone());
}
if !current.is_empty() {
batches.push(SizedBatch {
approx_bytes: current_bytes.max(1),
items: current,
});
}
Ok(batches)
}
async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
db: &SurrealDbClient,
statement: impl AsRef<str>,
prefix: &str,
batches: &[SizedBatch<T>],
) -> Result<()> {
if batches.is_empty() {
return Ok(());
}
let mut start = 0;
while start < batches.len() {
let mut group_bytes = 0usize;
let mut group_end = start;
let mut group_count = 0usize;
while group_end < batches.len() {
let batch_bytes = batches[group_end].approx_bytes.max(1);
if group_count > 0
&& (group_bytes + batch_bytes > REQUEST_MAX_BYTES
|| group_count >= MAX_BATCHES_PER_REQUEST)
{
break;
}
group_bytes += batch_bytes;
group_end += 1;
group_count += 1;
}
let slice = &batches[start..group_end];
let mut query = db.client.query("BEGIN TRANSACTION;");
for (bind_index, batch) in slice.iter().enumerate() {
let name = format!("{prefix}{bind_index}");
query = query
.query(format!("{} ${};", statement.as_ref(), name))
.bind((name, batch.items.clone()));
}
let response = query
.query("COMMIT TRANSACTION;")
.await
.context("executing batched insert transaction")?;
if let Err(err) = response.check() {
return Err(anyhow!(
"batched insert failed for statement '{}': {err:?}",
statement.as_ref()
));
}
start = group_end;
}
Ok(())
}
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
let result = async {
execute_batched_inserts(
db,
format!("INSERT INTO {}", TextContent::table_name()),
"tc",
&batches.text_contents,
)
.await?;
execute_batched_inserts(
db,
format!("INSERT INTO {}", KnowledgeEntity::table_name()),
"ke",
&batches.entities,
)
.await?;
execute_batched_inserts(
db,
format!("INSERT INTO {}", TextChunk::table_name()),
"ch",
&batches.chunks,
)
.await?;
execute_batched_inserts(
db,
"INSERT RELATION INTO relates_to",
"rel",
&batches.relationships,
)
.await?;
execute_batched_inserts(
db,
format!("INSERT INTO {}", KnowledgeEntityEmbedding::table_name()),
"kee",
&batches.entity_embeddings,
)
.await?;
execute_batched_inserts(
db,
format!("INSERT INTO {}", TextChunkEmbedding::table_name()),
"tce",
&batches.chunk_embeddings,
)
.await?;
Ok(())
}
.await;
if result.is_err() {
// Best-effort cleanup to avoid leaving partial manifest data behind.
let _ = db
.client
.query(
"BEGIN TRANSACTION;
DELETE text_chunk_embedding;
DELETE knowledge_entity_embedding;
DELETE relates_to;
DELETE text_chunk;
DELETE knowledge_entity;
DELETE text_content;
COMMIT TRANSACTION;",
)
.await;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use common::storage::types::knowledge_entity::KnowledgeEntityType;
use uuid::Uuid;
fn build_manifest() -> CorpusManifest {
let user_id = "user-1".to_string();
let source_id = "source-1".to_string();
let now = Utc::now();
let text_content_id = Uuid::new_v4().to_string();
let text_content = TextContent {
id: text_content_id.clone(),
created_at: now,
updated_at: now,
text: "Hello world".to_string(),
file_info: None,
url_info: None,
context: None,
category: "test".to_string(),
user_id: user_id.clone(),
};
let entity = KnowledgeEntity {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id: source_id.clone(),
name: "Entity".to_string(),
description: "A test entity".to_string(),
entity_type: KnowledgeEntityType::Document,
metadata: None,
user_id: user_id.clone(),
};
let relationship = KnowledgeRelationship::new(
format!("knowledge_entity:{}", entity.id),
format!("knowledge_entity:{}", entity.id),
user_id.clone(),
source_id.clone(),
"related".to_string(),
);
let chunk = TextChunk {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
source_id: source_id.clone(),
chunk: "chunk text".to_string(),
user_id: user_id.clone(),
};
let paragraph_one = CorpusParagraph {
paragraph_id: "p1".to_string(),
title: "Paragraph 1".to_string(),
text_content: text_content.clone(),
entities: vec![EmbeddedKnowledgeEntity {
entity: entity.clone(),
embedding: vec![0.1, 0.2, 0.3],
}],
relationships: vec![relationship],
chunks: vec![EmbeddedTextChunk {
chunk: chunk.clone(),
embedding: vec![0.3, 0.2, 0.1],
}],
};
// Duplicate content/entities should be de-duplicated by the loader.
let paragraph_two = CorpusParagraph {
paragraph_id: "p2".to_string(),
title: "Paragraph 2".to_string(),
text_content: text_content.clone(),
entities: vec![EmbeddedKnowledgeEntity {
entity: entity.clone(),
embedding: vec![0.1, 0.2, 0.3],
}],
relationships: Vec::new(),
chunks: vec![EmbeddedTextChunk {
chunk: chunk.clone(),
embedding: vec![0.3, 0.2, 0.1],
}],
};
let question = CorpusQuestion {
question_id: "q1".to_string(),
paragraph_id: paragraph_one.paragraph_id.clone(),
text_content_id: text_content_id,
question_text: "What is this?".to_string(),
answers: vec!["Hello".to_string()],
is_impossible: false,
matching_chunk_ids: vec![chunk.id.clone()],
};
CorpusManifest {
version: current_manifest_version(),
metadata: CorpusMetadata {
dataset_id: "dataset".to_string(),
dataset_label: "Dataset".to_string(),
slice_id: "slice".to_string(),
include_unanswerable: false,
require_verified_chunks: false,
ingestion_fingerprint: "fp".to_string(),
embedding_backend: "test".to_string(),
embedding_model: Some("model".to_string()),
embedding_dimension: 3,
converted_checksum: "checksum".to_string(),
generated_at: now,
paragraph_count: 2,
question_count: 1,
chunk_min_tokens: 1,
chunk_max_tokens: 10,
chunk_only: false,
},
paragraphs: vec![paragraph_one, paragraph_two],
questions: vec![question],
}
}
#[test]
fn window_manifest_trims_questions_and_negatives() {
let manifest = build_manifest();
// Add extra negatives to simulate multiplier ~4x
let mut manifest = manifest;
let mut extra_paragraphs = Vec::new();
for _ in 0..8 {
let mut p = manifest.paragraphs[0].clone();
p.paragraph_id = Uuid::new_v4().to_string();
p.entities.clear();
p.relationships.clear();
p.chunks.clear();
extra_paragraphs.push(p);
}
manifest.paragraphs.extend(extra_paragraphs);
manifest.metadata.paragraph_count = manifest.paragraphs.len();
let windowed = window_manifest(&manifest, 0, 1, 4.0).expect("window manifest");
assert_eq!(windowed.questions.len(), 1);
// Expect roughly 4x negatives (bounded by available paragraphs)
assert!(
windowed.paragraphs.len() <= manifest.paragraphs.len(),
"windowed paragraphs should never exceed original"
);
let positive_set: std::collections::HashSet<_> = windowed
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect();
let positives = windowed
.paragraphs
.iter()
.filter(|p| positive_set.contains(p.paragraph_id.as_str()))
.count();
let negatives = windowed.paragraphs.len().saturating_sub(positives);
assert_eq!(positives, 1);
assert!(negatives >= 1, "should include some negatives");
}
}

View File

@@ -0,0 +1,341 @@
use std::{
collections::{BTreeMap, HashMap},
fs::File,
io::{BufRead, BufReader},
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use serde::Deserialize;
use tracing::warn;
use super::{ConvertedParagraph, ConvertedQuestion, DatasetKind};
const ANSWER_SNIPPET_CHARS: usize = 240;
#[derive(Debug, Deserialize)]
struct BeirCorpusRow {
#[serde(rename = "_id")]
id: String,
#[serde(default)]
title: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct BeirQueryRow {
#[serde(rename = "_id")]
id: String,
text: String,
}
#[derive(Debug, Clone)]
struct BeirParagraph {
title: String,
context: String,
}
#[derive(Debug, Clone)]
struct BeirQuery {
text: String,
}
#[derive(Debug, Clone)]
struct QrelEntry {
doc_id: String,
score: i32,
}
pub fn convert_beir(raw_dir: &Path, dataset: DatasetKind) -> Result<Vec<ConvertedParagraph>> {
let corpus_path = raw_dir.join("corpus.jsonl");
let queries_path = raw_dir.join("queries.jsonl");
let qrels_path = resolve_qrels_path(raw_dir)?;
let corpus = load_corpus(&corpus_path)?;
let queries = load_queries(&queries_path)?;
let qrels = load_qrels(&qrels_path)?;
let mut paragraphs = Vec::with_capacity(corpus.len());
let mut paragraph_index = HashMap::new();
for (doc_id, entry) in corpus.iter() {
let paragraph_id = format!("{}-{doc_id}", dataset.source_prefix());
let paragraph = ConvertedParagraph {
id: paragraph_id.clone(),
title: entry.title.clone(),
context: entry.context.clone(),
questions: Vec::new(),
};
paragraph_index.insert(doc_id.clone(), paragraphs.len());
paragraphs.push(paragraph);
}
let mut missing_queries = 0usize;
let mut missing_docs = 0usize;
let mut skipped_answers = 0usize;
for (query_id, entries) in qrels {
let query = match queries.get(&query_id) {
Some(query) => query,
None => {
missing_queries += 1;
warn!(query_id = %query_id, "Skipping qrels entry for missing query");
continue;
}
};
let best = match select_best_doc(&entries) {
Some(entry) => entry,
None => continue,
};
let paragraph_slot = match paragraph_index.get(&best.doc_id) {
Some(slot) => *slot,
None => {
missing_docs += 1;
warn!(
query_id = %query_id,
doc_id = %best.doc_id,
"Skipping qrels entry referencing missing corpus document"
);
continue;
}
};
let answer = answer_snippet(&paragraphs[paragraph_slot].context);
let answers = match answer {
Some(snippet) => vec![snippet],
None => {
skipped_answers += 1;
warn!(
query_id = %query_id,
doc_id = %best.doc_id,
"Skipping query because no non-empty answer snippet could be derived"
);
continue;
}
};
let question_id = format!("{}-{query_id}", dataset.source_prefix());
paragraphs[paragraph_slot]
.questions
.push(ConvertedQuestion {
id: question_id,
question: query.text.clone(),
answers,
is_impossible: false,
});
}
if missing_queries + missing_docs + skipped_answers > 0 {
warn!(
missing_queries,
missing_docs, skipped_answers, "Skipped some BEIR qrels entries during conversion"
);
}
Ok(paragraphs)
}
fn resolve_qrels_path(raw_dir: &Path) -> Result<PathBuf> {
let qrels_dir = raw_dir.join("qrels");
let candidates = ["test.tsv", "dev.tsv", "train.tsv"];
for name in candidates {
let candidate = qrels_dir.join(name);
if candidate.exists() {
return Ok(candidate);
}
}
Err(anyhow!(
"No qrels file found under {}; expected one of {:?}",
qrels_dir.display(),
candidates
))
}
fn load_corpus(path: &Path) -> Result<BTreeMap<String, BeirParagraph>> {
let file =
File::open(path).with_context(|| format!("opening BEIR corpus at {}", path.display()))?;
let reader = BufReader::new(file);
let mut corpus = BTreeMap::new();
for (idx, line) in reader.lines().enumerate() {
let raw = line
.with_context(|| format!("reading corpus line {} from {}", idx + 1, path.display()))?;
if raw.trim().is_empty() {
continue;
}
let row: BeirCorpusRow = serde_json::from_str(&raw).with_context(|| {
format!(
"parsing corpus JSON on line {} from {}",
idx + 1,
path.display()
)
})?;
let title = row.title.unwrap_or_else(|| row.id.clone());
let text = row.text.unwrap_or_default();
let context = build_context(&title, &text);
if context.is_empty() {
warn!(doc_id = %row.id, "Skipping empty corpus document");
continue;
}
corpus.insert(row.id, BeirParagraph { title, context });
}
Ok(corpus)
}
fn load_queries(path: &Path) -> Result<BTreeMap<String, BeirQuery>> {
let file = File::open(path)
.with_context(|| format!("opening BEIR queries file at {}", path.display()))?;
let reader = BufReader::new(file);
let mut queries = BTreeMap::new();
for (idx, line) in reader.lines().enumerate() {
let raw = line
.with_context(|| format!("reading query line {} from {}", idx + 1, path.display()))?;
if raw.trim().is_empty() {
continue;
}
let row: BeirQueryRow = serde_json::from_str(&raw).with_context(|| {
format!(
"parsing query JSON on line {} from {}",
idx + 1,
path.display()
)
})?;
queries.insert(
row.id,
BeirQuery {
text: row.text.trim().to_string(),
},
);
}
Ok(queries)
}
fn load_qrels(path: &Path) -> Result<BTreeMap<String, Vec<QrelEntry>>> {
let file =
File::open(path).with_context(|| format!("opening BEIR qrels at {}", path.display()))?;
let reader = BufReader::new(file);
let mut qrels: BTreeMap<String, Vec<QrelEntry>> = BTreeMap::new();
for (idx, line) in reader.lines().enumerate() {
let raw = line
.with_context(|| format!("reading qrels line {} from {}", idx + 1, path.display()))?;
let trimmed = raw.trim();
if trimmed.is_empty() || trimmed.starts_with("query-id") {
continue;
}
let mut parts = trimmed.split_whitespace();
let query_id = parts
.next()
.ok_or_else(|| anyhow!("missing query id on line {}", idx + 1))?;
let doc_id = parts
.next()
.ok_or_else(|| anyhow!("missing document id on line {}", idx + 1))?;
let score_raw = parts
.next()
.ok_or_else(|| anyhow!("missing score on line {}", idx + 1))?;
let score: i32 = score_raw.parse().with_context(|| {
format!(
"parsing qrels score '{}' on line {} from {}",
score_raw,
idx + 1,
path.display()
)
})?;
qrels
.entry(query_id.to_string())
.or_default()
.push(QrelEntry {
doc_id: doc_id.to_string(),
score,
});
}
Ok(qrels)
}
fn select_best_doc(entries: &[QrelEntry]) -> Option<&QrelEntry> {
entries
.iter()
.max_by(|a, b| a.score.cmp(&b.score).then_with(|| b.doc_id.cmp(&a.doc_id)))
}
fn answer_snippet(text: &str) -> Option<String> {
let trimmed = text.trim();
if trimmed.is_empty() {
return None;
}
let snippet: String = trimmed.chars().take(ANSWER_SNIPPET_CHARS).collect();
let snippet = snippet.trim();
if snippet.is_empty() {
None
} else {
Some(snippet.to_string())
}
}
fn build_context(title: &str, text: &str) -> String {
let title = title.trim();
let text = text.trim();
match (title.is_empty(), text.is_empty()) {
(true, true) => String::new(),
(true, false) => text.to_string(),
(false, true) => title.to_string(),
(false, false) => format!("{title}\n\n{text}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn converts_basic_beir_layout() {
let dir = tempdir().unwrap();
let corpus = r#"
{"_id":"d1","title":"Doc 1","text":"Doc one has some text for testing."}
{"_id":"d2","title":"Doc 2","text":"Second document content."}
"#;
let queries = r#"
{"_id":"q1","text":"What is in doc one?"}
"#;
let qrels = "query-id\tcorpus-id\tscore\nq1\td1\t2\n";
fs::write(dir.path().join("corpus.jsonl"), corpus.trim()).unwrap();
fs::write(dir.path().join("queries.jsonl"), queries.trim()).unwrap();
fs::create_dir_all(dir.path().join("qrels")).unwrap();
fs::write(dir.path().join("qrels/test.tsv"), qrels).unwrap();
let paragraphs = convert_beir(dir.path(), DatasetKind::Fever).unwrap();
assert_eq!(paragraphs.len(), 2);
let doc_one = paragraphs
.iter()
.find(|p| p.id == "fever-d1")
.expect("missing paragraph for d1");
assert_eq!(doc_one.questions.len(), 1);
let question = &doc_one.questions[0];
assert_eq!(question.id, "fever-q1");
assert!(!question.answers.is_empty());
assert!(doc_one.context.contains(&question.answers[0]));
let doc_two = paragraphs
.iter()
.find(|p| p.id == "fever-d2")
.expect("missing paragraph for d2");
assert!(doc_two.questions.is_empty());
}
}

View File

@@ -0,0 +1,623 @@
mod beir;
mod nq;
mod squad;
use std::{
collections::{BTreeMap, HashMap},
fs::{self},
path::{Path, PathBuf},
str::FromStr,
};
use anyhow::{anyhow, bail, Context, Result};
use chrono::{DateTime, TimeZone, Utc};
use clap::ValueEnum;
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use tracing::warn;
const MANIFEST_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/manifest.yaml");
static DATASET_CATALOG: OnceCell<DatasetCatalog> = OnceCell::new();
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DatasetCatalog {
datasets: BTreeMap<String, DatasetEntry>,
slices: HashMap<String, SliceLocation>,
default_dataset: String,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DatasetEntry {
pub metadata: DatasetMetadata,
pub raw_path: PathBuf,
pub converted_path: PathBuf,
pub include_unanswerable: bool,
pub slices: Vec<SliceEntry>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SliceEntry {
pub id: String,
pub dataset_id: String,
pub label: String,
pub description: Option<String>,
pub limit: Option<usize>,
pub corpus_limit: Option<usize>,
pub include_unanswerable: Option<bool>,
pub seed: Option<u64>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct SliceLocation {
dataset_id: String,
slice_index: usize,
}
#[derive(Debug, Deserialize)]
struct ManifestFile {
default_dataset: Option<String>,
datasets: Vec<ManifestDataset>,
}
#[derive(Debug, Deserialize)]
struct ManifestDataset {
id: String,
label: String,
category: String,
#[serde(default)]
entity_suffix: Option<String>,
#[serde(default)]
source_prefix: Option<String>,
raw: String,
converted: String,
#[serde(default)]
include_unanswerable: bool,
#[serde(default)]
slices: Vec<ManifestSlice>,
}
#[derive(Debug, Deserialize)]
struct ManifestSlice {
id: String,
label: String,
#[serde(default)]
description: Option<String>,
#[serde(default)]
limit: Option<usize>,
#[serde(default)]
corpus_limit: Option<usize>,
#[serde(default)]
include_unanswerable: Option<bool>,
#[serde(default)]
seed: Option<u64>,
}
impl DatasetCatalog {
pub fn load() -> Result<Self> {
let manifest_raw = fs::read_to_string(MANIFEST_PATH)
.with_context(|| format!("reading dataset manifest at {}", MANIFEST_PATH))?;
let manifest: ManifestFile = serde_yaml::from_str(&manifest_raw)
.with_context(|| format!("parsing dataset manifest at {}", MANIFEST_PATH))?;
let root = Path::new(env!("CARGO_MANIFEST_DIR"));
let mut datasets = BTreeMap::new();
let mut slices = HashMap::new();
for dataset in manifest.datasets {
let raw_path = resolve_path(root, &dataset.raw);
let converted_path = resolve_path(root, &dataset.converted);
if !raw_path.exists() {
bail!(
"dataset '{}' raw file missing at {}",
dataset.id,
raw_path.display()
);
}
if !converted_path.exists() {
warn!(
"dataset '{}' converted file missing at {}; the next conversion run will regenerate it",
dataset.id,
converted_path.display()
);
}
let metadata = DatasetMetadata {
id: dataset.id.clone(),
label: dataset.label.clone(),
category: dataset.category.clone(),
entity_suffix: dataset
.entity_suffix
.clone()
.unwrap_or_else(|| dataset.label.clone()),
source_prefix: dataset
.source_prefix
.clone()
.unwrap_or_else(|| dataset.id.clone()),
include_unanswerable: dataset.include_unanswerable,
context_token_limit: None,
};
let mut entry_slices = Vec::with_capacity(dataset.slices.len());
for (index, manifest_slice) in dataset.slices.into_iter().enumerate() {
if slices.contains_key(&manifest_slice.id) {
bail!(
"slice '{}' defined multiple times in manifest",
manifest_slice.id
);
}
entry_slices.push(SliceEntry {
id: manifest_slice.id.clone(),
dataset_id: dataset.id.clone(),
label: manifest_slice.label,
description: manifest_slice.description,
limit: manifest_slice.limit,
corpus_limit: manifest_slice.corpus_limit,
include_unanswerable: manifest_slice.include_unanswerable,
seed: manifest_slice.seed,
});
slices.insert(
manifest_slice.id,
SliceLocation {
dataset_id: dataset.id.clone(),
slice_index: index,
},
);
}
datasets.insert(
metadata.id.clone(),
DatasetEntry {
metadata,
raw_path,
converted_path,
include_unanswerable: dataset.include_unanswerable,
slices: entry_slices,
},
);
}
let default_dataset = manifest
.default_dataset
.or_else(|| datasets.keys().next().cloned())
.ok_or_else(|| anyhow!("dataset manifest does not include any datasets"))?;
Ok(Self {
datasets,
slices,
default_dataset,
})
}
pub fn global() -> Result<&'static Self> {
DATASET_CATALOG.get_or_try_init(Self::load)
}
pub fn dataset(&self, id: &str) -> Result<&DatasetEntry> {
self.datasets
.get(id)
.ok_or_else(|| anyhow!("unknown dataset '{id}' in manifest"))
}
#[allow(dead_code)]
pub fn default_dataset(&self) -> Result<&DatasetEntry> {
self.dataset(&self.default_dataset)
}
#[allow(dead_code)]
pub fn slice(&self, slice_id: &str) -> Result<(&DatasetEntry, &SliceEntry)> {
let location = self
.slices
.get(slice_id)
.ok_or_else(|| anyhow!("unknown slice '{slice_id}' in manifest"))?;
let dataset = self
.datasets
.get(&location.dataset_id)
.ok_or_else(|| anyhow!("slice '{slice_id}' references missing dataset"))?;
let slice = dataset
.slices
.get(location.slice_index)
.ok_or_else(|| anyhow!("slice index out of bounds for '{slice_id}'"))?;
Ok((dataset, slice))
}
}
fn resolve_path(root: &Path, value: &str) -> PathBuf {
let path = Path::new(value);
if path.is_absolute() {
path.to_path_buf()
} else {
root.join(path)
}
}
pub fn catalog() -> Result<&'static DatasetCatalog> {
DatasetCatalog::global()
}
fn dataset_entry_for_kind(kind: DatasetKind) -> Result<&'static DatasetEntry> {
let catalog = catalog()?;
catalog.dataset(kind.id())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum DatasetKind {
#[default]
SquadV2,
NaturalQuestions,
Beir,
#[value(name = "fever")]
Fever,
#[value(name = "fiqa")]
Fiqa,
#[value(name = "hotpotqa", alias = "hotpot-qa")]
HotpotQa,
#[value(name = "nfcorpus", alias = "nf-corpus")]
Nfcorpus,
#[value(name = "quora")]
Quora,
#[value(name = "trec-covid", alias = "treccovid", alias = "trec_covid")]
TrecCovid,
#[value(name = "scifact")]
Scifact,
#[value(name = "nq-beir", alias = "natural-questions-beir")]
NqBeir,
}
impl DatasetKind {
pub fn id(self) -> &'static str {
match self {
Self::SquadV2 => "squad-v2",
Self::NaturalQuestions => "natural-questions-dev",
Self::Beir => "beir",
Self::Fever => "fever",
Self::Fiqa => "fiqa",
Self::HotpotQa => "hotpotqa",
Self::Nfcorpus => "nfcorpus",
Self::Quora => "quora",
Self::TrecCovid => "trec-covid",
Self::Scifact => "scifact",
Self::NqBeir => "nq-beir",
}
}
pub fn label(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD v2.0",
Self::NaturalQuestions => "Natural Questions (dev)",
Self::Beir => "BEIR mix",
Self::Fever => "FEVER (BEIR)",
Self::Fiqa => "FiQA-2018 (BEIR)",
Self::HotpotQa => "HotpotQA (BEIR)",
Self::Nfcorpus => "NFCorpus (BEIR)",
Self::Quora => "Quora (IR)",
Self::TrecCovid => "TREC-COVID (BEIR)",
Self::Scifact => "SciFact (BEIR)",
Self::NqBeir => "Natural Questions (BEIR)",
}
}
pub fn category(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD v2.0",
Self::NaturalQuestions => "Natural Questions",
Self::Beir => "BEIR",
Self::Fever => "FEVER",
Self::Fiqa => "FiQA-2018",
Self::HotpotQa => "HotpotQA",
Self::Nfcorpus => "NFCorpus",
Self::Quora => "Quora",
Self::TrecCovid => "TREC-COVID",
Self::Scifact => "SciFact",
Self::NqBeir => "Natural Questions",
}
}
pub fn entity_suffix(self) -> &'static str {
match self {
Self::SquadV2 => "SQuAD",
Self::NaturalQuestions => "Natural Questions",
Self::Beir => "BEIR",
Self::Fever => "FEVER",
Self::Fiqa => "FiQA",
Self::HotpotQa => "HotpotQA",
Self::Nfcorpus => "NFCorpus",
Self::Quora => "Quora",
Self::TrecCovid => "TREC-COVID",
Self::Scifact => "SciFact",
Self::NqBeir => "Natural Questions",
}
}
pub fn source_prefix(self) -> &'static str {
match self {
Self::SquadV2 => "squad",
Self::NaturalQuestions => "nq",
Self::Beir => "beir",
Self::Fever => "fever",
Self::Fiqa => "fiqa",
Self::HotpotQa => "hotpotqa",
Self::Nfcorpus => "nfcorpus",
Self::Quora => "quora",
Self::TrecCovid => "trec-covid",
Self::Scifact => "scifact",
Self::NqBeir => "nq-beir",
}
}
pub fn default_raw_path(self) -> PathBuf {
dataset_entry_for_kind(self)
.map(|entry| entry.raw_path.clone())
.unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self))
}
pub fn default_converted_path(self) -> PathBuf {
dataset_entry_for_kind(self)
.map(|entry| entry.converted_path.clone())
.unwrap_or_else(|err| panic!("dataset manifest missing entry for {:?}: {err}", self))
}
}
impl std::fmt::Display for DatasetKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.id())
}
}
impl FromStr for DatasetKind {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"squad" | "squad-v2" | "squad_v2" => Ok(Self::SquadV2),
"nq" | "natural-questions" | "natural_questions" | "natural-questions-dev" => {
Ok(Self::NaturalQuestions)
}
"beir" => Ok(Self::Beir),
"fever" => Ok(Self::Fever),
"fiqa" | "fiqa-2018" => Ok(Self::Fiqa),
"hotpotqa" | "hotpot-qa" => Ok(Self::HotpotQa),
"nfcorpus" | "nf-corpus" => Ok(Self::Nfcorpus),
"quora" => Ok(Self::Quora),
"trec-covid" | "treccovid" | "trec_covid" => Ok(Self::TrecCovid),
"scifact" => Ok(Self::Scifact),
"nq-beir" | "natural-questions-beir" => Ok(Self::NqBeir),
other => {
anyhow::bail!("unknown dataset '{other}'. Expected one of: squad, natural-questions, beir, fever, fiqa, hotpotqa, nfcorpus, quora, trec-covid, scifact, nq-beir.")
}
}
}
}
pub const BEIR_DATASETS: [DatasetKind; 8] = [
DatasetKind::Fever,
DatasetKind::Fiqa,
DatasetKind::HotpotQa,
DatasetKind::Nfcorpus,
DatasetKind::Quora,
DatasetKind::TrecCovid,
DatasetKind::Scifact,
DatasetKind::NqBeir,
];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetMetadata {
pub id: String,
pub label: String,
pub category: String,
pub entity_suffix: String,
pub source_prefix: String,
#[serde(default)]
pub include_unanswerable: bool,
#[serde(default)]
pub context_token_limit: Option<usize>,
}
impl DatasetMetadata {
pub fn for_kind(
kind: DatasetKind,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Self {
if let Ok(entry) = dataset_entry_for_kind(kind) {
return Self {
id: entry.metadata.id.clone(),
label: entry.metadata.label.clone(),
category: entry.metadata.category.clone(),
entity_suffix: entry.metadata.entity_suffix.clone(),
source_prefix: entry.metadata.source_prefix.clone(),
include_unanswerable,
context_token_limit,
};
}
Self {
id: kind.id().to_string(),
label: kind.label().to_string(),
category: kind.category().to_string(),
entity_suffix: kind.entity_suffix().to_string(),
source_prefix: kind.source_prefix().to_string(),
include_unanswerable,
context_token_limit,
}
}
}
fn default_metadata() -> DatasetMetadata {
DatasetMetadata::for_kind(DatasetKind::default(), false, None)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedDataset {
pub generated_at: DateTime<Utc>,
#[serde(default = "default_metadata")]
pub metadata: DatasetMetadata,
pub source: String,
pub paragraphs: Vec<ConvertedParagraph>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedParagraph {
pub id: String,
pub title: String,
pub context: String,
pub questions: Vec<ConvertedQuestion>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvertedQuestion {
pub id: String,
pub question: String,
pub answers: Vec<String>,
pub is_impossible: bool,
}
pub fn convert(
raw_path: &Path,
dataset: DatasetKind,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Result<ConvertedDataset> {
let paragraphs = match dataset {
DatasetKind::SquadV2 => squad::convert_squad(raw_path)?,
DatasetKind::NaturalQuestions => {
nq::convert_nq(raw_path, include_unanswerable, context_token_limit)?
}
DatasetKind::Beir => convert_beir_mix(include_unanswerable, context_token_limit)?,
DatasetKind::Fever
| DatasetKind::Fiqa
| DatasetKind::HotpotQa
| DatasetKind::Nfcorpus
| DatasetKind::Quora
| DatasetKind::TrecCovid
| DatasetKind::Scifact
| DatasetKind::NqBeir => beir::convert_beir(raw_path, dataset)?,
};
let metadata_limit = match dataset {
DatasetKind::NaturalQuestions => None,
_ => context_token_limit,
};
let generated_at = match dataset {
DatasetKind::Beir
| DatasetKind::Fever
| DatasetKind::Fiqa
| DatasetKind::HotpotQa
| DatasetKind::Nfcorpus
| DatasetKind::Quora
| DatasetKind::TrecCovid
| DatasetKind::Scifact
| DatasetKind::NqBeir => base_timestamp(),
_ => Utc::now(),
};
let source_label = match dataset {
DatasetKind::Beir => "beir-mix".to_string(),
_ => raw_path.display().to_string(),
};
Ok(ConvertedDataset {
generated_at,
metadata: DatasetMetadata::for_kind(dataset, include_unanswerable, metadata_limit),
source: source_label,
paragraphs,
})
}
fn convert_beir_mix(
include_unanswerable: bool,
_context_token_limit: Option<usize>,
) -> Result<Vec<ConvertedParagraph>> {
if include_unanswerable {
warn!("BEIR mix ignores include_unanswerable flag; all questions are answerable");
}
let mut paragraphs = Vec::new();
for subset in BEIR_DATASETS {
let entry = dataset_entry_for_kind(subset)?;
let subset_paragraphs = beir::convert_beir(&entry.raw_path, subset)?;
paragraphs.extend(subset_paragraphs);
}
Ok(paragraphs)
}
fn ensure_parent(path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("creating parent directory for {}", path.display()))?;
}
Ok(())
}
pub fn write_converted(dataset: &ConvertedDataset, converted_path: &Path) -> Result<()> {
ensure_parent(converted_path)?;
let json =
serde_json::to_string_pretty(dataset).context("serialising converted dataset to JSON")?;
fs::write(converted_path, json)
.with_context(|| format!("writing converted dataset to {}", converted_path.display()))
}
pub fn read_converted(converted_path: &Path) -> Result<ConvertedDataset> {
let raw = fs::read_to_string(converted_path)
.with_context(|| format!("reading converted dataset at {}", converted_path.display()))?;
let mut dataset: ConvertedDataset = serde_json::from_str(&raw)
.with_context(|| format!("parsing converted dataset at {}", converted_path.display()))?;
if dataset.metadata.id.trim().is_empty() {
dataset.metadata = default_metadata();
}
if dataset.source.is_empty() {
dataset.source = converted_path.display().to_string();
}
Ok(dataset)
}
pub fn ensure_converted(
dataset_kind: DatasetKind,
raw_path: &Path,
converted_path: &Path,
force: bool,
include_unanswerable: bool,
context_token_limit: Option<usize>,
) -> Result<ConvertedDataset> {
if force || !converted_path.exists() {
let dataset = convert(
raw_path,
dataset_kind,
include_unanswerable,
context_token_limit,
)?;
write_converted(&dataset, converted_path)?;
return Ok(dataset);
}
match read_converted(converted_path) {
Ok(dataset)
if dataset.metadata.id == dataset_kind.id()
&& dataset.metadata.include_unanswerable == include_unanswerable
&& dataset.metadata.context_token_limit == context_token_limit =>
{
Ok(dataset)
}
_ => {
let dataset = convert(
raw_path,
dataset_kind,
include_unanswerable,
context_token_limit,
)?;
write_converted(&dataset, converted_path)?;
Ok(dataset)
}
}
}
pub fn base_timestamp() -> DateTime<Utc> {
Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
}

View File

@@ -0,0 +1,234 @@
use std::{
collections::BTreeSet,
fs::File,
io::{BufRead, BufReader},
path::Path,
};
use anyhow::{Context, Result};
use serde::Deserialize;
use tracing::warn;
use super::{ConvertedParagraph, ConvertedQuestion};
pub fn convert_nq(
raw_path: &Path,
include_unanswerable: bool,
_context_token_limit: Option<usize>,
) -> Result<Vec<ConvertedParagraph>> {
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqExample {
question_text: String,
document_title: String,
example_id: i64,
document_tokens: Vec<NqToken>,
long_answer_candidates: Vec<NqLongAnswerCandidate>,
annotations: Vec<NqAnnotation>,
}
#[derive(Debug, Deserialize)]
struct NqToken {
token: String,
#[serde(default)]
html_token: bool,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqLongAnswerCandidate {
start_token: i32,
end_token: i32,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqAnnotation {
short_answers: Vec<NqShortAnswer>,
#[serde(default)]
yes_no_answer: String,
long_answer: NqLongAnswer,
}
#[derive(Debug, Deserialize)]
struct NqShortAnswer {
start_token: i32,
end_token: i32,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct NqLongAnswer {
candidate_index: i32,
}
fn join_tokens(tokens: &[NqToken], start: usize, end: usize) -> String {
let mut buffer = String::new();
let end = end.min(tokens.len());
for token in tokens.iter().skip(start).take(end.saturating_sub(start)) {
if token.html_token {
continue;
}
let text = token.token.trim();
if text.is_empty() {
continue;
}
let attach = matches!(
text,
"," | "." | "!" | "?" | ";" | ":" | ")" | "]" | "}" | "%" | "" | "..."
) || text.starts_with('\'')
|| text == "n't"
|| text == "'s"
|| text == "'re"
|| text == "'ve"
|| text == "'d"
|| text == "'ll";
if buffer.is_empty() || attach {
buffer.push_str(text);
} else {
buffer.push(' ');
buffer.push_str(text);
}
}
buffer.trim().to_string()
}
let file = File::open(raw_path).with_context(|| {
format!(
"opening Natural Questions dataset at {}",
raw_path.display()
)
})?;
let reader = BufReader::new(file);
let mut paragraphs = Vec::new();
for (line_idx, line) in reader.lines().enumerate() {
let line = line.with_context(|| {
format!(
"reading Natural Questions line {} from {}",
line_idx + 1,
raw_path.display()
)
})?;
if line.trim().is_empty() {
continue;
}
let example: NqExample = serde_json::from_str(&line).with_context(|| {
format!(
"parsing Natural Questions JSON (line {}) at {}",
line_idx + 1,
raw_path.display()
)
})?;
let mut answer_texts: Vec<String> = Vec::new();
let mut short_answer_texts: Vec<String> = Vec::new();
let mut has_short_or_yesno = false;
let mut has_short_answer = false;
for annotation in &example.annotations {
for short in &annotation.short_answers {
if short.start_token < 0 || short.end_token <= short.start_token {
continue;
}
let start = short.start_token as usize;
let end = short.end_token as usize;
if start >= example.document_tokens.len() || end > example.document_tokens.len() {
continue;
}
let text = join_tokens(&example.document_tokens, start, end);
if !text.is_empty() {
answer_texts.push(text.clone());
short_answer_texts.push(text);
has_short_or_yesno = true;
has_short_answer = true;
}
}
match annotation
.yes_no_answer
.trim()
.to_ascii_lowercase()
.as_str()
{
"yes" => {
answer_texts.push("yes".to_string());
has_short_or_yesno = true;
}
"no" => {
answer_texts.push("no".to_string());
has_short_or_yesno = true;
}
_ => {}
}
}
let mut answers = dedupe_strings(answer_texts);
let is_unanswerable = !has_short_or_yesno || answers.is_empty();
if is_unanswerable {
if !include_unanswerable {
continue;
}
answers.clear();
}
let paragraph_id = format!("nq-{}", example.example_id);
let question_id = format!("nq-{}", example.example_id);
let context = join_tokens(&example.document_tokens, 0, example.document_tokens.len());
if context.is_empty() {
continue;
}
if has_short_answer && !short_answer_texts.is_empty() {
let normalized_context = context.to_ascii_lowercase();
let missing_answer = short_answer_texts.iter().any(|answer| {
let needle = answer.trim().to_ascii_lowercase();
!needle.is_empty() && !normalized_context.contains(&needle)
});
if missing_answer {
warn!(
question_id = %question_id,
"Skipping Natural Questions example because answers were not found in the assembled context"
);
continue;
}
}
if !include_unanswerable && (!has_short_answer || short_answer_texts.is_empty()) {
// yes/no-only questions are excluded by default unless --llm-mode is used
continue;
}
let question = ConvertedQuestion {
id: question_id,
question: example.question_text.trim().to_string(),
answers,
is_impossible: is_unanswerable,
};
paragraphs.push(ConvertedParagraph {
id: paragraph_id,
title: example.document_title.trim().to_string(),
context,
questions: vec![question],
});
}
Ok(paragraphs)
}
fn dedupe_strings<I>(values: I) -> Vec<String>
where
I: IntoIterator<Item = String>,
{
let mut set = BTreeSet::new();
for value in values {
let trimmed = value.trim();
if !trimmed.is_empty() {
set.insert(trimmed.to_string());
}
}
set.into_iter().collect()
}

View File

@@ -0,0 +1,107 @@
use std::{collections::BTreeSet, fs, path::Path};
use anyhow::{Context, Result};
use serde::Deserialize;
use super::{ConvertedParagraph, ConvertedQuestion};
pub fn convert_squad(raw_path: &Path) -> Result<Vec<ConvertedParagraph>> {
#[derive(Debug, Deserialize)]
struct SquadDataset {
data: Vec<SquadArticle>,
}
#[derive(Debug, Deserialize)]
struct SquadArticle {
title: String,
paragraphs: Vec<SquadParagraph>,
}
#[derive(Debug, Deserialize)]
struct SquadParagraph {
context: String,
qas: Vec<SquadQuestion>,
}
#[derive(Debug, Deserialize)]
struct SquadQuestion {
id: String,
question: String,
answers: Vec<SquadAnswer>,
#[serde(default)]
is_impossible: bool,
}
#[derive(Debug, Deserialize)]
struct SquadAnswer {
text: String,
}
let raw = fs::read_to_string(raw_path)
.with_context(|| format!("reading raw SQuAD dataset at {}", raw_path.display()))?;
let parsed: SquadDataset = serde_json::from_str(&raw)
.with_context(|| format!("parsing SQuAD dataset at {}", raw_path.display()))?;
let mut paragraphs = Vec::new();
for (article_idx, article) in parsed.data.into_iter().enumerate() {
for (paragraph_idx, paragraph) in article.paragraphs.into_iter().enumerate() {
let mut questions = Vec::new();
for qa in paragraph.qas {
let answers = dedupe_strings(qa.answers.into_iter().map(|answer| answer.text));
questions.push(ConvertedQuestion {
id: qa.id,
question: qa.question.trim().to_string(),
answers,
is_impossible: qa.is_impossible,
});
}
let paragraph_id =
format!("{}-{}", slugify(&article.title, article_idx), paragraph_idx);
paragraphs.push(ConvertedParagraph {
id: paragraph_id,
title: article.title.trim().to_string(),
context: paragraph.context.trim().to_string(),
questions,
});
}
}
Ok(paragraphs)
}
fn dedupe_strings<I>(values: I) -> Vec<String>
where
I: IntoIterator<Item = String>,
{
let mut set = BTreeSet::new();
for value in values {
let trimmed = value.trim();
if !trimmed.is_empty() {
set.insert(trimmed.to_string());
}
}
set.into_iter().collect()
}
fn slugify(input: &str, fallback_idx: usize) -> String {
let mut slug = String::new();
let mut last_dash = false;
for ch in input.chars() {
let c = ch.to_ascii_lowercase();
if c.is_ascii_alphanumeric() {
slug.push(c);
last_dash = false;
} else if !last_dash {
slug.push('-');
last_dash = true;
}
}
slug = slug.trim_matches('-').to_string();
if slug.is_empty() {
slug = format!("article-{fallback_idx}");
}
slug
}

View File

@@ -0,0 +1,110 @@
use anyhow::{Context, Result};
use common::storage::{db::SurrealDbClient, indexes::ensure_runtime_indexes};
use tracing::info;
// Helper functions for index management during namespace reseed
pub async fn remove_all_indexes(db: &SurrealDbClient) -> Result<()> {
let _ = db;
info!("Removing ALL indexes before namespace reseed (no-op placeholder)");
Ok(())
}
pub async fn recreate_indexes(db: &SurrealDbClient, dimension: usize) -> Result<()> {
info!("Recreating ALL indexes after namespace reseed via shared runtime helper");
ensure_runtime_indexes(db, dimension)
.await
.context("creating runtime indexes")
}
pub async fn reset_namespace(db: &SurrealDbClient, namespace: &str, database: &str) -> Result<()> {
let query = format!(
"REMOVE NAMESPACE {ns};
DEFINE NAMESPACE {ns};
DEFINE DATABASE {db};",
ns = namespace,
db = database
);
db.client
.query(query)
.await
.context("resetting SurrealDB namespace")?;
db.client
.use_ns(namespace)
.use_db(database)
.await
.context("selecting namespace/database after reset")?;
Ok(())
}
// // Test helper to force index dimension change
// #[allow(dead_code)]
// pub async fn change_embedding_length_in_hnsw_indexes(
// db: &SurrealDbClient,
// dimension: usize,
// ) -> Result<()> {
// recreate_indexes(db, dimension).await
// }
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
struct FooRow {
label: String,
}
#[tokio::test]
async fn reset_namespace_drops_existing_rows() {
let namespace = format!("reset_ns_{}", Uuid::new_v4().simple());
let database = format!("reset_db_{}", Uuid::new_v4().simple());
let db = SurrealDbClient::memory(&namespace, &database)
.await
.expect("in-memory db");
db.client
.query(
"DEFINE TABLE foo SCHEMALESS;
CREATE foo:foo SET label = 'before';",
)
.await
.expect("seed namespace")
.check()
.expect("seed response");
let mut before = db
.client
.query("SELECT * FROM foo")
.await
.expect("select before reset");
let existing: Vec<FooRow> = before.take(0).expect("rows before reset");
assert_eq!(existing.len(), 1);
assert_eq!(existing[0].label, "before");
reset_namespace(&db, &namespace, &database)
.await
.expect("namespace reset");
match db.client.query("SELECT * FROM foo").await {
Ok(mut response) => {
let rows: Vec<FooRow> = response.take(0).unwrap_or_default();
assert!(
rows.is_empty(),
"reset namespace should drop rows, found {:?}",
rows
);
}
Err(error) => {
let message = error.to_string();
assert!(
message.to_ascii_lowercase().contains("table")
|| message.to_ascii_lowercase().contains("namespace")
|| message.to_ascii_lowercase().contains("foo"),
"unexpected error after namespace reset: {message}"
);
}
}
}
}

128
evaluations/src/eval.rs Normal file
View File

@@ -0,0 +1,128 @@
//! Evaluation utilities module - re-exports from focused submodules.
// Re-export types from the root types module
pub use crate::types::*;
// Re-export from focused modules at crate root (crate-internal only)
pub(crate) use crate::cases::{cases_from_manifest, SeededCase};
pub(crate) use crate::namespace::{
can_reuse_namespace, connect_eval_db, default_database, default_namespace, ensure_eval_user,
record_namespace_state,
};
pub(crate) use crate::settings::{enforce_system_settings, load_or_init_system_settings};
use std::path::Path;
use anyhow::{Context, Result};
use common::storage::db::SurrealDbClient;
use tokio::io::AsyncWriteExt;
use tracing::info;
use crate::{
args::{self, Config},
datasets::ConvertedDataset,
slice::{self},
};
/// Grow the slice ledger to contain the target number of cases.
pub async fn grow_slice(dataset: &ConvertedDataset, config: &Config) -> Result<()> {
let ledger_limit = ledger_target(config);
let slice_settings = slice::slice_config_with_limit(config, ledger_limit);
let slice =
slice::resolve_slice(dataset, &slice_settings).context("resolving dataset slice")?;
info!(
slice = slice.manifest.slice_id.as_str(),
cases = slice.manifest.case_count,
positives = slice.manifest.positive_paragraphs,
negatives = slice.manifest.negative_paragraphs,
total_paragraphs = slice.manifest.total_paragraphs,
"Slice ledger ready"
);
println!(
"Slice `{}` now contains {} questions ({} positives, {} negatives)",
slice.manifest.slice_id,
slice.manifest.case_count,
slice.manifest.positive_paragraphs,
slice.manifest.negative_paragraphs
);
Ok(())
}
pub(crate) fn ledger_target(config: &Config) -> Option<usize> {
match (config.slice_grow, config.limit) {
(Some(grow), Some(limit)) => Some(limit.max(grow)),
(Some(grow), None) => Some(grow),
(None, limit) => limit,
}
}
pub(crate) async fn write_chunk_diagnostics(path: &Path, cases: &[CaseDiagnostics]) -> Result<()> {
args::ensure_parent(path)?;
let mut file = tokio::fs::File::create(path)
.await
.with_context(|| format!("creating diagnostics file {}", path.display()))?;
for case in cases {
let line = serde_json::to_vec(case).context("serialising chunk diagnostics entry")?;
file.write_all(&line).await?;
file.write_all(b"\n").await?;
}
file.flush().await?;
Ok(())
}
pub(crate) async fn warm_hnsw_cache(db: &SurrealDbClient, dimension: usize) -> Result<()> {
// Create a dummy embedding for cache warming
let dummy_embedding: Vec<f32> = (0..dimension).map(|i| (i as f32).sin()).collect();
info!("Warming HNSW caches with sample queries");
// Warm up chunk embedding index - just query the embedding table to load HNSW index
let _ = db
.client
.query(
r#"SELECT chunk_id
FROM text_chunk_embedding
WHERE embedding <|1,1|> $embedding
LIMIT 5"#,
)
.bind(("embedding", dummy_embedding.clone()))
.await
.context("warming text chunk HNSW cache")?;
// Warm up entity embedding index
let _ = db
.client
.query(
r#"SELECT entity_id
FROM knowledge_entity_embedding
WHERE embedding <|1,1|> $embedding
LIMIT 5"#,
)
.bind(("embedding", dummy_embedding))
.await
.context("warming knowledge entity HNSW cache")?;
info!("HNSW cache warming completed");
Ok(())
}
use chrono::{DateTime, SecondsFormat, Utc};
pub fn format_timestamp(timestamp: &DateTime<Utc>) -> String {
timestamp.to_rfc3339_opts(SecondsFormat::Secs, true)
}
pub(crate) fn sanitize_model_code(code: &str) -> String {
code.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() {
ch.to_ascii_lowercase()
} else {
'_'
}
})
.collect()
}
// Re-export run_evaluation from the pipeline module at crate root
pub use crate::pipeline::run_evaluation;

View File

@@ -0,0 +1,184 @@
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Context, Result};
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
use crate::{args::Config, corpus, eval::connect_eval_db, snapshot::DbSnapshotState};
pub async fn inspect_question(config: &Config) -> Result<()> {
let question_id = config
.inspect_question
.as_ref()
.ok_or_else(|| anyhow!("--inspect-question is required for inspection mode"))?;
let manifest_path = config
.inspect_manifest
.as_ref()
.ok_or_else(|| anyhow!("--inspect-manifest must be provided for inspection mode"))?;
let manifest = load_manifest(manifest_path)?;
let chunk_lookup = build_chunk_lookup(&manifest);
let question = manifest
.questions
.iter()
.find(|q| q.question_id == *question_id)
.ok_or_else(|| {
anyhow!(
"question '{}' not found in manifest {}",
question_id,
manifest_path.display()
)
})?;
println!("Question: {}", question.question_text);
println!("Answers: {:?}", question.answers);
println!(
"matching_chunk_ids ({}):",
question.matching_chunk_ids.len()
);
let mut missing_in_manifest = Vec::new();
for chunk_id in &question.matching_chunk_ids {
if let Some(entry) = chunk_lookup.get(chunk_id) {
println!(
" - {} (paragraph: {})\n snippet: {}",
chunk_id, entry.paragraph_title, entry.snippet
);
} else {
println!(" - {} (missing from manifest)", chunk_id);
missing_in_manifest.push(chunk_id.clone());
}
}
if missing_in_manifest.is_empty() {
println!("All matching_chunk_ids are present in the ingestion manifest");
} else {
println!(
"Missing chunk IDs in manifest {}: {:?}",
manifest_path.display(),
missing_in_manifest
);
}
let db_state_path = config
.database
.inspect_db_state
.clone()
.unwrap_or_else(|| default_state_path(config, &manifest));
if let Some(state) = load_db_state(&db_state_path)? {
if let (Some(ns), Some(db_name)) = (state.namespace.as_deref(), state.database.as_deref()) {
match connect_eval_db(config, ns, db_name).await {
Ok(db) => match verify_chunks_in_db(&db, &question.matching_chunk_ids).await? {
MissingChunks::None => println!(
"All matching_chunk_ids exist in namespace '{}', database '{}'",
ns, db_name
),
MissingChunks::Missing(list) => println!(
"Missing chunks in namespace '{}', database '{}': {:?}",
ns, db_name, list
),
},
Err(err) => {
println!(
"Failed to connect to SurrealDB namespace '{}' / database '{}': {err}",
ns, db_name
);
}
}
} else {
println!(
"State file {} is missing namespace/database fields; skipping live DB validation",
db_state_path.display()
);
}
} else {
println!(
"State file {} not found; skipping live DB validation",
db_state_path.display()
);
}
Ok(())
}
struct ChunkEntry {
paragraph_title: String,
snippet: String,
}
fn load_manifest(path: &Path) -> Result<corpus::CorpusManifest> {
let bytes =
fs::read(path).with_context(|| format!("reading ingestion manifest {}", path.display()))?;
serde_json::from_slice(&bytes)
.with_context(|| format!("parsing ingestion manifest {}", path.display()))
}
fn build_chunk_lookup(manifest: &corpus::CorpusManifest) -> HashMap<String, ChunkEntry> {
let mut lookup = HashMap::new();
for paragraph in &manifest.paragraphs {
for chunk in &paragraph.chunks {
let snippet = chunk
.chunk
.chunk
.chars()
.take(160)
.collect::<String>()
.replace('\n', " ");
lookup.insert(
chunk.chunk.id.clone(),
ChunkEntry {
paragraph_title: paragraph.title.clone(),
snippet,
},
);
}
}
lookup
}
fn default_state_path(config: &Config, manifest: &corpus::CorpusManifest) -> PathBuf {
config
.cache_dir
.join("snapshots")
.join(&manifest.metadata.dataset_id)
.join(&manifest.metadata.slice_id)
.join("db/state.json")
}
fn load_db_state(path: &Path) -> Result<Option<DbSnapshotState>> {
if !path.exists() {
return Ok(None);
}
let bytes = fs::read(path).with_context(|| format!("reading db state {}", path.display()))?;
let state = serde_json::from_slice(&bytes)
.with_context(|| format!("parsing db state {}", path.display()))?;
Ok(Some(state))
}
enum MissingChunks {
None,
Missing(Vec<String>),
}
async fn verify_chunks_in_db(db: &SurrealDbClient, chunk_ids: &[String]) -> Result<MissingChunks> {
let mut missing = Vec::new();
for chunk_id in chunk_ids {
let exists = db
.get_item::<TextChunk>(chunk_id)
.await
.with_context(|| format!("fetching text_chunk {}", chunk_id))?
.is_some();
if !exists {
missing.push(chunk_id.clone());
}
}
if missing.is_empty() {
Ok(MissingChunks::None)
} else {
Ok(MissingChunks::Missing(missing))
}
}

247
evaluations/src/main.rs Normal file
View File

@@ -0,0 +1,247 @@
mod args;
mod cache;
mod cases;
mod corpus;
mod datasets;
mod db_helpers;
mod eval;
mod inspection;
mod namespace;
mod openai;
mod perf;
mod pipeline;
mod report;
mod settings;
mod slice;
mod snapshot;
mod types;
use anyhow::Context;
use tokio::runtime::Builder;
use tracing::info;
use tracing_subscriber::{fmt, EnvFilter};
/// Configure SurrealDB environment variables for optimal performance
fn configure_surrealdb_performance(cpu_count: usize) {
// Set environment variables only if they're not already set
let indexing_batch_size = std::env::var("SURREAL_INDEXING_BATCH_SIZE")
.unwrap_or_else(|_| (cpu_count * 2).to_string());
std::env::set_var("SURREAL_INDEXING_BATCH_SIZE", indexing_batch_size);
let max_order_queue = std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE")
.unwrap_or_else(|_| (cpu_count * 4).to_string());
std::env::set_var(
"SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE",
max_order_queue,
);
let websocket_concurrent = std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS")
.unwrap_or_else(|_| cpu_count.to_string());
std::env::set_var(
"SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS",
websocket_concurrent,
);
let websocket_buffer = std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE")
.unwrap_or_else(|_| (cpu_count * 8).to_string());
std::env::set_var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE", websocket_buffer);
let transaction_cache = std::env::var("SURREAL_TRANSACTION_CACHE_SIZE")
.unwrap_or_else(|_| (cpu_count * 16).to_string());
std::env::set_var("SURREAL_TRANSACTION_CACHE_SIZE", transaction_cache);
info!(
indexing_batch_size = %std::env::var("SURREAL_INDEXING_BATCH_SIZE").unwrap(),
max_order_queue = %std::env::var("SURREAL_MAX_ORDER_LIMIT_PRIORITY_QUEUE_SIZE").unwrap(),
websocket_concurrent = %std::env::var("SURREAL_WEBSOCKET_MAX_CONCURRENT_REQUESTS").unwrap(),
websocket_buffer = %std::env::var("SURREAL_WEBSOCKET_RESPONSE_BUFFER_SIZE").unwrap(),
transaction_cache = %std::env::var("SURREAL_TRANSACTION_CACHE_SIZE").unwrap(),
"Configured SurrealDB performance variables"
);
}
fn main() -> anyhow::Result<()> {
// Create an explicit multi-threaded runtime with optimized configuration
let runtime = Builder::new_multi_thread()
.enable_all()
.worker_threads(std::thread::available_parallelism()?.get())
.max_blocking_threads(std::thread::available_parallelism()?.get())
.thread_stack_size(10 * 1024 * 1024) // 10MiB stack size
.thread_name("eval-retrieval-worker")
.build()
.context("failed to create tokio runtime")?;
runtime.block_on(async_main())
}
async fn async_main() -> anyhow::Result<()> {
// Log runtime configuration
let cpu_count = std::thread::available_parallelism()?.get();
info!(
cpu_cores = cpu_count,
worker_threads = cpu_count,
blocking_threads = cpu_count,
thread_stack_size = "10MiB",
"Started multi-threaded tokio runtime"
);
// Configure SurrealDB environment variables for better performance
configure_surrealdb_performance(cpu_count);
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
let _ = fmt()
.with_env_filter(EnvFilter::try_new(&filter).unwrap_or_else(|_| EnvFilter::new("info")))
.try_init();
let parsed = args::parse()?;
// Clap handles help automatically, so we don't need to check for it manually
if parsed.config.inspect_question.is_some() {
inspection::inspect_question(&parsed.config).await?;
return Ok(());
}
let dataset_kind = parsed.config.dataset;
if parsed.config.convert_only {
info!(
dataset = dataset_kind.id(),
"Starting dataset conversion only run"
);
let dataset = crate::datasets::convert(
parsed.config.raw_dataset_path.as_path(),
dataset_kind,
parsed.config.llm_mode,
parsed.config.context_token_limit(),
)
.with_context(|| {
format!(
"converting {} dataset at {}",
dataset_kind.label(),
parsed.config.raw_dataset_path.display()
)
})?;
crate::datasets::write_converted(&dataset, parsed.config.converted_dataset_path.as_path())
.with_context(|| {
format!(
"writing converted dataset to {}",
parsed.config.converted_dataset_path.display()
)
})?;
println!(
"Converted dataset written to {}",
parsed.config.converted_dataset_path.display()
);
return Ok(());
}
info!(dataset = dataset_kind.id(), "Preparing converted dataset");
let dataset = crate::datasets::ensure_converted(
dataset_kind,
parsed.config.raw_dataset_path.as_path(),
parsed.config.converted_dataset_path.as_path(),
parsed.config.force_convert,
parsed.config.llm_mode,
parsed.config.context_token_limit(),
)
.with_context(|| {
format!(
"preparing converted dataset at {}",
parsed.config.converted_dataset_path.display()
)
})?;
info!(
questions = dataset
.paragraphs
.iter()
.map(|p| p.questions.len())
.sum::<usize>(),
paragraphs = dataset.paragraphs.len(),
dataset = dataset.metadata.id.as_str(),
"Dataset ready"
);
if parsed.config.slice_grow.is_some() {
eval::grow_slice(&dataset, &parsed.config)
.await
.context("growing slice ledger")?;
return Ok(());
}
info!("Running retrieval evaluation");
let summary = eval::run_evaluation(&dataset, &parsed.config)
.await
.context("running retrieval evaluation")?;
let report = report::write_reports(
&summary,
parsed.config.report_dir.as_path(),
parsed.config.summary_sample,
)
.with_context(|| format!("writing reports to {}", parsed.config.report_dir.display()))?;
let perf_mirrors = perf::mirror_perf_outputs(
&report.record,
&summary,
parsed.config.report_dir.as_path(),
parsed.config.perf_log_json.as_deref(),
parsed.config.perf_log_dir.as_deref(),
)
.with_context(|| {
format!(
"writing perf mirrors under {}",
parsed.config.report_dir.display()
)
})?;
let perf_note = if perf_mirrors.is_empty() {
String::new()
} else {
format!(
" | Perf mirrors: {}",
perf_mirrors
.iter()
.map(|path| path.display().to_string())
.collect::<Vec<_>>()
.join(", ")
)
};
if summary.llm_cases > 0 {
println!(
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) + LLM: {llm_answered}/{llm_total} ({llm_precision:.3}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
summary.dataset_label,
k = summary.k,
precision = summary.precision,
correct = summary.correct,
retrieval_total = summary.retrieval_cases,
llm_answered = summary.llm_answered,
llm_total = summary.llm_cases,
llm_precision = summary.llm_precision,
json = report.paths.json.display(),
md = report.paths.markdown.display(),
history = report.history_path.display(),
perf_note = perf_note,
);
} else {
println!(
"[{}] Retrieval Precision@{k}: {precision:.3} ({correct}/{retrieval_total}) → JSON: {json} | Markdown: {md} | History: {history}{perf_note}",
summary.dataset_label,
k = summary.k,
precision = summary.precision,
correct = summary.correct,
retrieval_total = summary.retrieval_cases,
json = report.paths.json.display(),
md = report.paths.markdown.display(),
history = report.history_path.display(),
perf_note = perf_note,
);
}
if parsed.config.perf_log_console {
perf::print_console_summary(&report.record);
}
Ok(())
}

View File

@@ -0,0 +1,225 @@
//! Database namespace management utilities.
use anyhow::{anyhow, Context, Result};
use chrono::Utc;
use common::storage::{db::SurrealDbClient, types::user::User, types::StoredObject};
use serde::Deserialize;
use tracing::{info, warn};
use crate::{
args::Config,
datasets,
snapshot::{self, DbSnapshotState},
};
/// Connect to the evaluation database with fallback auth strategies.
pub(crate) async fn connect_eval_db(
config: &Config,
namespace: &str,
database: &str,
) -> Result<SurrealDbClient> {
match SurrealDbClient::new(
&config.database.db_endpoint,
&config.database.db_username,
&config.database.db_password,
namespace,
database,
)
.await
{
Ok(client) => {
info!(
endpoint = %config.database.db_endpoint,
namespace,
database,
auth = "root",
"Connected to SurrealDB"
);
Ok(client)
}
Err(root_err) => {
info!(
endpoint = %config.database.db_endpoint,
namespace,
database,
"Root authentication failed; trying namespace-level auth"
);
let namespace_client = SurrealDbClient::new_with_namespace_user(
&config.database.db_endpoint,
namespace,
&config.database.db_username,
&config.database.db_password,
database,
)
.await
.map_err(|ns_err| {
anyhow!(
"failed to connect to SurrealDB via root ({root_err}) or namespace ({ns_err}) credentials"
)
})?;
info!(
endpoint = %config.database.db_endpoint,
namespace,
database,
auth = "namespace",
"Connected to SurrealDB"
);
Ok(namespace_client)
}
}
}
/// Check if the namespace contains any corpus data.
pub(crate) async fn namespace_has_corpus(db: &SurrealDbClient) -> Result<bool> {
#[derive(Deserialize)]
struct CountRow {
count: i64,
}
let mut response = db
.client
.query("SELECT count() AS count FROM text_chunk")
.await
.context("checking namespace corpus state")?;
let rows: Vec<CountRow> = response.take(0).unwrap_or_default();
Ok(rows.first().map(|row| row.count).unwrap_or(0) > 0)
}
/// Determine if we can reuse an existing namespace based on cached state.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn can_reuse_namespace(
db: &SurrealDbClient,
descriptor: &snapshot::Descriptor,
namespace: &str,
database: &str,
dataset_id: &str,
slice_id: &str,
ingestion_fingerprint: &str,
slice_case_count: usize,
) -> Result<bool> {
let state = match descriptor.load_db_state().await? {
Some(state) => state,
None => {
info!("No namespace state recorded; reseeding corpus from cached shards");
return Ok(false);
}
};
if state.slice_case_count != slice_case_count {
info!(
requested_cases = slice_case_count,
stored_cases = state.slice_case_count,
"Skipping live namespace reuse; cached state does not match requested window"
);
return Ok(false);
}
if state.dataset_id != dataset_id
|| state.slice_id != slice_id
|| state.ingestion_fingerprint != ingestion_fingerprint
|| state.namespace.as_deref() != Some(namespace)
|| state.database.as_deref() != Some(database)
{
info!(
namespace,
database, "Cached namespace metadata mismatch; rebuilding corpus from ingestion cache"
);
return Ok(false);
}
if namespace_has_corpus(db).await? {
Ok(true)
} else {
info!(
namespace,
database,
"Namespace metadata matches but tables are empty; reseeding from ingestion cache"
);
Ok(false)
}
}
/// Record the current namespace state to allow future reuse checks.
pub(crate) async fn record_namespace_state(
descriptor: &snapshot::Descriptor,
dataset_id: &str,
slice_id: &str,
ingestion_fingerprint: &str,
namespace: &str,
database: &str,
slice_case_count: usize,
) {
let state = DbSnapshotState {
dataset_id: dataset_id.to_string(),
slice_id: slice_id.to_string(),
ingestion_fingerprint: ingestion_fingerprint.to_string(),
snapshot_hash: descriptor.metadata_hash().to_string(),
updated_at: Utc::now(),
namespace: Some(namespace.to_string()),
database: Some(database.to_string()),
slice_case_count,
};
if let Err(err) = descriptor.store_db_state(&state).await {
warn!(error = %err, "Failed to record namespace state");
}
}
fn sanitize_identifier(input: &str) -> String {
let mut cleaned: String = input
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() {
ch.to_ascii_lowercase()
} else {
'_'
}
})
.collect();
if cleaned.is_empty() {
cleaned.push('x');
}
if cleaned.len() > 64 {
cleaned.truncate(64);
}
cleaned
}
/// Generate a default namespace name based on dataset and limit.
pub(crate) fn default_namespace(dataset_id: &str, limit: Option<usize>) -> String {
let dataset_component = sanitize_identifier(dataset_id);
let limit_component = match limit {
Some(value) if value > 0 => format!("limit{}", value),
_ => "all".to_string(),
};
format!("eval_{}_{}", dataset_component, limit_component)
}
/// Generate the default database name for evaluations.
pub(crate) fn default_database() -> String {
"retrieval_eval".to_string()
}
/// Ensure the evaluation user exists in the database.
pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
let timestamp = datasets::base_timestamp();
let user = User {
id: "eval-user".to_string(),
created_at: timestamp,
updated_at: timestamp,
email: "eval-retrieval@minne.dev".to_string(),
password: "not-used".to_string(),
anonymous: false,
api_key: None,
admin: false,
timezone: "UTC".to_string(),
};
if let Some(existing) = db.get_item::<User>(user.get_id()).await? {
return Ok(existing);
}
db.store_item(user.clone())
.await
.context("storing evaluation user")?;
Ok(user)
}

16
evaluations/src/openai.rs Normal file
View File

@@ -0,0 +1,16 @@
use anyhow::{Context, Result};
use async_openai::{config::OpenAIConfig, Client};
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub fn build_client_from_env() -> Result<(Client<OpenAIConfig>, String)> {
let api_key = std::env::var("OPENAI_API_KEY")
.context("OPENAI_API_KEY must be set to run retrieval evaluations")?;
let base_url =
std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
let config = OpenAIConfig::new()
.with_api_key(api_key)
.with_api_base(&base_url);
Ok((Client::with_config(config), base_url))
}

248
evaluations/src/perf.rs Normal file
View File

@@ -0,0 +1,248 @@
use std::{
fs,
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use crate::{
args,
eval::EvaluationSummary,
report::{self, EvaluationReport},
};
pub fn mirror_perf_outputs(
record: &EvaluationReport,
summary: &EvaluationSummary,
report_root: &Path,
extra_json: Option<&Path>,
extra_dir: Option<&Path>,
) -> Result<Vec<PathBuf>> {
let mut written = Vec::new();
if let Some(path) = extra_json {
args::ensure_parent(path)?;
let blob = serde_json::to_vec_pretty(record).context("serialising perf log JSON")?;
fs::write(path, blob)
.with_context(|| format!("writing perf log copy to {}", path.display()))?;
written.push(path.to_path_buf());
}
if let Some(dir) = extra_dir {
fs::create_dir_all(dir)
.with_context(|| format!("creating perf log directory {}", dir.display()))?;
let dataset_dir = report::dataset_report_dir(report_root, &summary.dataset_id);
let dataset_slug = dataset_dir
.file_name()
.and_then(|os| os.to_str())
.unwrap_or("dataset");
let timestamp = summary.generated_at.format("%Y%m%dT%H%M%S").to_string();
let filename = format!("perf-{}-{}.json", dataset_slug, timestamp);
let path = dir.join(filename);
let blob = serde_json::to_vec_pretty(record).context("serialising perf log JSON")?;
fs::write(&path, blob)
.with_context(|| format!("writing perf log mirror {}", path.display()))?;
written.push(path);
}
Ok(written)
}
pub fn print_console_summary(record: &EvaluationReport) {
let perf = &record.performance;
println!(
"[perf] retrieval strategy={} | concurrency={} | rerank={} (pool {:?}, keep {})",
record.retrieval.strategy,
record.retrieval.concurrency,
record.retrieval.rerank_enabled,
record.retrieval.rerank_pool_size,
record.retrieval.rerank_keep_top
);
println!(
"[perf] ingestion={}ms | namespace_seed={}",
perf.ingestion_ms,
format_duration(perf.namespace_seed_ms),
);
let stage = &perf.stage_latency;
println!(
"[perf] stage avg ms → embed {:.1} | collect {:.1} | graph {:.1} | chunk {:.1} | rerank {:.1} | assemble {:.1}",
stage.embed.avg,
stage.collect_candidates.avg,
stage.graph_expansion.avg,
stage.chunk_attach.avg,
stage.rerank.avg,
stage.assemble.avg,
);
let eval = &perf.evaluation_stages_ms;
println!(
"[perf] eval stage ms → slice {} | db {} | corpus {} | namespace {} | queries {} | summarize {} | finalize {}",
eval.prepare_slice_ms,
eval.prepare_db_ms,
eval.prepare_corpus_ms,
eval.prepare_namespace_ms,
eval.run_queries_ms,
eval.summarize_ms,
eval.finalize_ms,
);
}
fn format_duration(value: Option<u128>) -> String {
value
.map(|ms| format!("{ms}ms"))
.unwrap_or_else(|| "-".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::{EvaluationStageTimings, PerformanceTimings};
use chrono::Utc;
use tempfile::tempdir;
fn sample_latency() -> crate::eval::LatencyStats {
crate::eval::LatencyStats {
avg: 10.0,
p50: 8,
p95: 15,
}
}
fn sample_stage_latency() -> crate::eval::StageLatencyBreakdown {
crate::eval::StageLatencyBreakdown {
embed: sample_latency(),
collect_candidates: sample_latency(),
graph_expansion: sample_latency(),
chunk_attach: sample_latency(),
rerank: sample_latency(),
assemble: sample_latency(),
}
}
fn sample_eval_stage() -> EvaluationStageTimings {
EvaluationStageTimings {
prepare_slice_ms: 10,
prepare_db_ms: 20,
prepare_corpus_ms: 30,
prepare_namespace_ms: 40,
run_queries_ms: 50,
summarize_ms: 60,
finalize_ms: 70,
}
}
fn sample_summary() -> EvaluationSummary {
EvaluationSummary {
generated_at: Utc::now(),
k: 5,
limit: Some(10),
run_label: Some("test".into()),
total_cases: 2,
correct: 1,
precision: 0.5,
correct_at_1: 1,
correct_at_2: 1,
correct_at_3: 1,
precision_at_1: 0.5,
precision_at_2: 0.5,
precision_at_3: 0.5,
mrr: 0.0,
average_ndcg: 0.0,
duration_ms: 1234,
dataset_id: "squad-v2".into(),
dataset_label: "SQuAD v2".into(),
dataset_includes_unanswerable: false,
dataset_source: "dev".into(),
includes_impossible_cases: false,
require_verified_chunks: true,
filtered_questions: 0,
retrieval_cases: 2,
retrieval_correct: 1,
retrieval_precision: 0.5,
llm_cases: 0,
llm_answered: 0,
llm_precision: 0.0,
slice_id: "slice123".into(),
slice_seed: 42,
slice_total_cases: 400,
slice_window_offset: 0,
slice_window_length: 10,
slice_cases: 10,
slice_positive_paragraphs: 10,
slice_negative_paragraphs: 40,
slice_total_paragraphs: 50,
slice_negative_multiplier: 4.0,
namespace_reused: true,
corpus_paragraphs: 50,
ingestion_cache_path: "/tmp/cache".into(),
ingestion_reused: true,
ingestion_embeddings_reused: true,
ingestion_fingerprint: "fingerprint".into(),
positive_paragraphs_reused: 10,
negative_paragraphs_reused: 40,
latency_ms: sample_latency(),
perf: PerformanceTimings {
openai_base_url: "https://example.com".into(),
ingestion_ms: 1000,
namespace_seed_ms: Some(150),
evaluation_stage_ms: sample_eval_stage(),
stage_latency: sample_stage_latency(),
},
embedding_backend: "fastembed".into(),
embedding_model: Some("test-model".into()),
embedding_dimension: 32,
rerank_enabled: true,
rerank_pool_size: Some(4),
rerank_keep_top: 10,
concurrency: 2,
detailed_report: false,
retrieval_strategy: "initial".into(),
chunk_result_cap: 5,
chunk_rrf_k: 60.0,
chunk_rrf_vector_weight: 1.0,
chunk_rrf_fts_weight: 1.0,
chunk_rrf_use_vector: true,
chunk_rrf_use_fts: true,
ingest_chunk_min_tokens: 256,
ingest_chunk_max_tokens: 512,
ingest_chunks_only: false,
ingest_chunk_overlap_tokens: 50,
chunk_vector_take: 20,
chunk_fts_take: 20,
chunk_avg_chars_per_token: 4,
max_chunks_per_entity: 4,
cases: Vec::new(),
}
}
#[test]
fn writes_perf_mirrors_from_record() {
let tmp = tempdir().unwrap();
let report_root = tmp.path().join("reports");
let summary = sample_summary();
let record = report::EvaluationReport::from_summary(&summary, 5);
let json_path = tmp.path().join("extra.json");
let dir_path = tmp.path().join("copies");
let outputs = mirror_perf_outputs(
&record,
&summary,
&report_root,
Some(json_path.as_path()),
Some(dir_path.as_path()),
)
.expect("perf mirrors");
assert!(json_path.exists());
let content = std::fs::read_to_string(&json_path).expect("reading mirror json");
assert!(
content.contains("\"evaluation_stages_ms\""),
"perf mirror should include evaluation stage timings"
);
assert_eq!(outputs.len(), 2);
let mirrored = outputs
.into_iter()
.filter(|path| path.starts_with(&dir_path))
.collect::<Vec<_>>();
assert_eq!(mirrored.len(), 1, "expected timestamped mirror in dir");
}
}

View File

@@ -0,0 +1,198 @@
use std::{
path::PathBuf,
sync::Arc,
time::{Duration, Instant},
};
use async_openai::Client;
use common::{
storage::{
db::SurrealDbClient,
types::{system_settings::SystemSettings, user::User},
},
utils::embedding::EmbeddingProvider,
};
use retrieval_pipeline::{
pipeline::{PipelineStageTimings, RetrievalConfig},
reranking::RerankerPool,
};
use crate::{
args::Config,
cache::EmbeddingCache,
corpus,
datasets::ConvertedDataset,
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
slice, snapshot,
};
pub(super) struct EvaluationContext<'a> {
dataset: &'a ConvertedDataset,
config: &'a Config,
pub stage_timings: EvaluationStageTimings,
pub ledger_limit: Option<usize>,
pub slice_settings: Option<slice::SliceConfig<'a>>,
pub slice: Option<slice::ResolvedSlice<'a>>,
pub window_offset: usize,
pub window_length: usize,
pub window_total_cases: usize,
pub namespace: String,
pub database: String,
pub db: Option<SurrealDbClient>,
pub descriptor: Option<snapshot::Descriptor>,
pub settings: Option<SystemSettings>,
pub settings_missing: bool,
pub must_reapply_settings: bool,
pub embedding_provider: Option<EmbeddingProvider>,
pub embedding_cache: Option<EmbeddingCache>,
pub openai_client: Option<Arc<Client<async_openai::config::OpenAIConfig>>>,
pub openai_base_url: Option<String>,
pub expected_fingerprint: Option<String>,
pub ingestion_duration_ms: u128,
pub namespace_seed_ms: Option<u128>,
pub namespace_reused: bool,
pub evaluation_start: Option<Instant>,
pub eval_user: Option<User>,
pub corpus_handle: Option<corpus::CorpusHandle>,
pub cases: Vec<SeededCase>,
pub filtered_questions: usize,
pub stage_latency_samples: Vec<PipelineStageTimings>,
pub latencies: Vec<u128>,
pub diagnostics_output: Vec<CaseDiagnostics>,
pub query_summaries: Vec<CaseSummary>,
pub rerank_pool: Option<Arc<RerankerPool>>,
pub retrieval_config: Option<Arc<RetrievalConfig>>,
pub summary: Option<EvaluationSummary>,
pub diagnostics_path: Option<PathBuf>,
pub diagnostics_enabled: bool,
}
impl<'a> EvaluationContext<'a> {
pub fn new(dataset: &'a ConvertedDataset, config: &'a Config) -> Self {
Self {
dataset,
config,
stage_timings: EvaluationStageTimings::default(),
ledger_limit: None,
slice_settings: None,
slice: None,
window_offset: 0,
window_length: 0,
window_total_cases: 0,
namespace: String::new(),
database: String::new(),
db: None,
descriptor: None,
settings: None,
settings_missing: false,
must_reapply_settings: false,
embedding_provider: None,
embedding_cache: None,
openai_client: None,
openai_base_url: None,
expected_fingerprint: None,
ingestion_duration_ms: 0,
namespace_seed_ms: None,
namespace_reused: false,
evaluation_start: None,
eval_user: None,
corpus_handle: None,
cases: Vec::new(),
filtered_questions: 0,
stage_latency_samples: Vec::new(),
latencies: Vec::new(),
diagnostics_output: Vec::new(),
query_summaries: Vec::new(),
rerank_pool: None,
retrieval_config: None,
summary: None,
diagnostics_path: config.chunk_diagnostics_path.clone(),
diagnostics_enabled: config.chunk_diagnostics_path.is_some(),
}
}
pub fn dataset(&self) -> &'a ConvertedDataset {
self.dataset
}
pub fn config(&self) -> &'a Config {
self.config
}
pub fn slice(&self) -> &slice::ResolvedSlice<'a> {
self.slice.as_ref().expect("slice has not been prepared")
}
pub fn db(&self) -> &SurrealDbClient {
self.db.as_ref().expect("database connection missing")
}
pub fn descriptor(&self) -> &snapshot::Descriptor {
self.descriptor
.as_ref()
.expect("snapshot descriptor unavailable")
}
pub fn embedding_provider(&self) -> &EmbeddingProvider {
self.embedding_provider
.as_ref()
.expect("embedding provider not initialised")
}
pub fn openai_client(&self) -> Arc<Client<async_openai::config::OpenAIConfig>> {
self.openai_client
.as_ref()
.expect("openai client missing")
.clone()
}
pub fn corpus_handle(&self) -> &corpus::CorpusHandle {
self.corpus_handle.as_ref().expect("corpus handle missing")
}
pub fn evaluation_user(&self) -> &User {
self.eval_user.as_ref().expect("evaluation user missing")
}
pub fn record_stage_duration(&mut self, stage: EvalStage, duration: Duration) {
let elapsed = duration.as_millis();
match stage {
EvalStage::PrepareSlice => self.stage_timings.prepare_slice_ms += elapsed,
EvalStage::PrepareDb => self.stage_timings.prepare_db_ms += elapsed,
EvalStage::PrepareCorpus => self.stage_timings.prepare_corpus_ms += elapsed,
EvalStage::PrepareNamespace => self.stage_timings.prepare_namespace_ms += elapsed,
EvalStage::RunQueries => self.stage_timings.run_queries_ms += elapsed,
EvalStage::Summarize => self.stage_timings.summarize_ms += elapsed,
EvalStage::Finalize => self.stage_timings.finalize_ms += elapsed,
}
}
pub fn into_summary(self) -> EvaluationSummary {
self.summary.expect("evaluation summary missing")
}
}
#[derive(Copy, Clone)]
pub(super) enum EvalStage {
PrepareSlice,
PrepareDb,
PrepareCorpus,
PrepareNamespace,
RunQueries,
Summarize,
Finalize,
}
impl EvalStage {
pub fn label(&self) -> &'static str {
match self {
EvalStage::PrepareSlice => "prepare-slice",
EvalStage::PrepareDb => "prepare-db",
EvalStage::PrepareCorpus => "prepare-corpus",
EvalStage::PrepareNamespace => "prepare-namespace",
EvalStage::RunQueries => "run-queries",
EvalStage::Summarize => "summarize",
EvalStage::Finalize => "finalize",
}
}
}

View File

@@ -0,0 +1,27 @@
mod context;
mod stages;
mod state;
use anyhow::Result;
use crate::{args::Config, datasets::ConvertedDataset, types::EvaluationSummary};
use context::EvaluationContext;
pub async fn run_evaluation(
dataset: &ConvertedDataset,
config: &Config,
) -> Result<EvaluationSummary> {
let mut ctx = EvaluationContext::new(dataset, config);
let machine = state::ready();
let machine = stages::prepare_slice(machine, &mut ctx).await?;
let machine = stages::prepare_db(machine, &mut ctx).await?;
let machine = stages::prepare_corpus(machine, &mut ctx).await?;
let machine = stages::prepare_namespace(machine, &mut ctx).await?;
let machine = stages::run_queries(machine, &mut ctx).await?;
let machine = stages::summarize(machine, &mut ctx).await?;
let _ = stages::finalize(machine, &mut ctx).await?;
Ok(ctx.into_summary())
}

View File

@@ -0,0 +1,59 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::eval::write_chunk_diagnostics;
use super::super::{
context::{EvalStage, EvaluationContext},
state::{Completed, EvaluationMachine, Summarized},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn finalize(
machine: EvaluationMachine<(), Summarized>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<Completed> {
let stage = EvalStage::Finalize;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
if let Some(cache) = ctx.embedding_cache.as_ref() {
cache
.persist()
.await
.context("persisting embedding cache")?;
}
if let Some(path) = ctx.diagnostics_path.as_ref() {
if ctx.diagnostics_enabled {
write_chunk_diagnostics(path.as_path(), &ctx.diagnostics_output)
.await
.with_context(|| format!("writing chunk diagnostics to {}", path.display()))?;
}
}
info!(
total_cases = ctx.summary.as_ref().map(|s| s.total_cases).unwrap_or(0),
correct = ctx.summary.as_ref().map(|s| s.correct).unwrap_or(0),
precision = ctx.summary.as_ref().map(|s| s.precision).unwrap_or(0.0),
dataset = ctx.dataset().metadata.id.as_str(),
"Evaluation complete"
);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.finalize()
.map_err(|(_, guard)| map_guard_error("finalize", guard))
}

View File

@@ -0,0 +1,26 @@
mod finalize;
mod prepare_corpus;
mod prepare_db;
mod prepare_namespace;
mod prepare_slice;
mod run_queries;
mod summarize;
pub(crate) use finalize::finalize;
pub(crate) use prepare_corpus::prepare_corpus;
pub(crate) use prepare_db::prepare_db;
pub(crate) use prepare_namespace::prepare_namespace;
pub(crate) use prepare_slice::prepare_slice;
pub(crate) use run_queries::run_queries;
pub(crate) use summarize::summarize;
use anyhow::Result;
use state_machines::core::GuardError;
use super::state::EvaluationMachine;
fn map_guard_error(event: &str, guard: GuardError) -> anyhow::Error {
anyhow::anyhow!("invalid evaluation pipeline transition during {event}: {guard:?}")
}
type StageResult<S> = Result<EvaluationMachine<(), S>>;

View File

@@ -0,0 +1,142 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::{corpus, eval::can_reuse_namespace, slice, snapshot};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{CorpusReady, DbReady, EvaluationMachine},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_corpus(
machine: EvaluationMachine<(), DbReady>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<CorpusReady> {
let stage = EvalStage::PrepareCorpus;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let config = ctx.config();
let cache_settings = corpus::CorpusCacheConfig::from(config);
let embedding_provider = ctx.embedding_provider().clone();
let openai_client = ctx.openai_client();
let slice = ctx.slice();
let window = slice::select_window(slice, ctx.config().slice_offset, ctx.config().limit)
.context("selecting slice window for corpus preparation")?;
let descriptor = snapshot::Descriptor::new(config, slice, ctx.embedding_provider());
let ingestion_config = corpus::make_ingestion_config(config);
let expected_fingerprint = corpus::compute_ingestion_fingerprint(
ctx.dataset(),
slice,
config.converted_dataset_path.as_path(),
&ingestion_config,
)?;
let base_dir = corpus::cached_corpus_dir(
&cache_settings,
ctx.dataset().metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
);
if !config.reseed_slice {
let requested_cases = window.cases.len();
if can_reuse_namespace(
ctx.db(),
&descriptor,
&ctx.namespace,
&ctx.database,
ctx.dataset().metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
requested_cases,
)
.await?
{
if let Some(manifest) = corpus::load_cached_manifest(&base_dir)? {
info!(
cache = %base_dir.display(),
namespace = ctx.namespace.as_str(),
database = ctx.database.as_str(),
"Namespace already seeded; reusing cached corpus manifest"
);
let corpus_handle = corpus::corpus_handle_from_manifest(manifest, base_dir);
ctx.corpus_handle = Some(corpus_handle);
ctx.expected_fingerprint = Some(expected_fingerprint);
ctx.ingestion_duration_ms = 0;
ctx.descriptor = Some(descriptor);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
return machine
.prepare_corpus()
.map_err(|(_, guard)| map_guard_error("prepare_corpus", guard));
} else {
info!(
cache = %base_dir.display(),
"Namespace reusable but cached manifest missing; regenerating corpus"
);
}
}
}
let eval_user_id = "eval-user".to_string();
let ingestion_timer = Instant::now();
let corpus_handle = {
corpus::ensure_corpus(
ctx.dataset(),
slice,
&window,
&cache_settings,
embedding_provider.clone().into(),
openai_client,
&eval_user_id,
config.converted_dataset_path.as_path(),
ingestion_config.clone(),
)
.await
.context("ensuring ingestion-backed corpus")?
};
let expected_fingerprint = corpus_handle
.manifest
.metadata
.ingestion_fingerprint
.clone();
let ingestion_duration_ms = ingestion_timer.elapsed().as_millis();
info!(
cache = %corpus_handle.path.display(),
reused_ingestion = corpus_handle.reused_ingestion,
reused_embeddings = corpus_handle.reused_embeddings,
positive_ingested = corpus_handle.positive_ingested,
negative_ingested = corpus_handle.negative_ingested,
"Ingestion corpus ready"
);
ctx.corpus_handle = Some(corpus_handle);
ctx.expected_fingerprint = Some(expected_fingerprint);
ctx.ingestion_duration_ms = ingestion_duration_ms;
ctx.descriptor = Some(descriptor);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_corpus()
.map_err(|(_, guard)| map_guard_error("prepare_corpus", guard))
}

View File

@@ -0,0 +1,121 @@
use std::{sync::Arc, time::Instant};
use anyhow::{anyhow, Context};
use tracing::info;
use crate::{
args::EmbeddingBackend,
cache::EmbeddingCache,
eval::{
connect_eval_db, enforce_system_settings, load_or_init_system_settings, sanitize_model_code,
},
openai,
};
use common::utils::embedding::EmbeddingProvider;
use super::super::{
context::{EvalStage, EvaluationContext},
state::{DbReady, EvaluationMachine, SlicePrepared},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_db(
machine: EvaluationMachine<(), SlicePrepared>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<DbReady> {
let stage = EvalStage::PrepareDb;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let namespace = ctx.namespace.clone();
let database = ctx.database.clone();
let config = ctx.config();
let db = connect_eval_db(config, &namespace, &database).await?;
let (raw_openai_client, openai_base_url) =
openai::build_client_from_env().context("building OpenAI client")?;
let openai_client = Arc::new(raw_openai_client);
// Create embedding provider directly from config (eval only supports FastEmbed and Hashed)
let embedding_provider = match config.embedding_backend {
crate::args::EmbeddingBackend::FastEmbed => {
EmbeddingProvider::new_fastembed(config.embedding_model.clone())
.await
.context("creating FastEmbed provider")?
}
crate::args::EmbeddingBackend::Hashed => {
EmbeddingProvider::new_hashed(1536).context("creating Hashed provider")?
}
};
let provider_dimension = embedding_provider.dimension();
if provider_dimension == 0 {
return Err(anyhow!(
"embedding provider reported zero dimensions; cannot continue"
));
}
info!(
backend = embedding_provider.backend_label(),
model = embedding_provider
.model_code()
.as_deref()
.unwrap_or("<none>"),
dimension = provider_dimension,
"Embedding provider initialised"
);
info!(openai_base_url = %openai_base_url, "OpenAI client configured");
let (mut settings, settings_missing) =
load_or_init_system_settings(&db, provider_dimension).await?;
let embedding_cache = if config.embedding_backend == EmbeddingBackend::FastEmbed {
if let Some(model_code) = embedding_provider.model_code() {
let sanitized = sanitize_model_code(&model_code);
let path = config.cache_dir.join(format!("{sanitized}.json"));
if config.force_convert && path.exists() {
tokio::fs::remove_file(&path)
.await
.with_context(|| format!("removing stale cache {}", path.display()))
.ok();
}
let cache = EmbeddingCache::load(&path).await?;
info!(path = %path.display(), "Embedding cache ready");
Some(cache)
} else {
None
}
} else {
None
};
let must_reapply_settings = settings_missing;
let defer_initial_enforce = settings_missing && !config.reseed_slice;
if !defer_initial_enforce {
settings = enforce_system_settings(&db, settings, provider_dimension, config).await?;
}
ctx.db = Some(db);
ctx.settings_missing = settings_missing;
ctx.must_reapply_settings = must_reapply_settings;
ctx.settings = Some(settings);
ctx.embedding_provider = Some(embedding_provider);
ctx.embedding_cache = embedding_cache;
ctx.openai_client = Some(openai_client);
ctx.openai_base_url = Some(openai_base_url);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_db()
.map_err(|(_, guard)| map_guard_error("prepare_db", guard))
}

View File

@@ -0,0 +1,203 @@
use std::time::Instant;
use anyhow::{anyhow, Context};
use common::storage::types::system_settings::SystemSettings;
use tracing::{info, warn};
use crate::{
corpus,
db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace},
eval::{
can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user,
record_namespace_state, warm_hnsw_cache,
},
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{CorpusReady, EvaluationMachine, NamespaceReady},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_namespace(
machine: EvaluationMachine<(), CorpusReady>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<NamespaceReady> {
let stage = EvalStage::PrepareNamespace;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let config = ctx.config();
let dataset = ctx.dataset();
let expected_fingerprint = ctx
.expected_fingerprint
.as_deref()
.unwrap_or_default()
.to_string();
let namespace = ctx.namespace.clone();
let database = ctx.database.clone();
let embedding_provider = ctx.embedding_provider().clone();
let corpus_handle = ctx.corpus_handle();
let base_manifest = &corpus_handle.manifest;
let manifest_for_seed =
if ctx.window_offset == 0 && ctx.window_length >= base_manifest.questions.len() {
base_manifest.clone()
} else {
corpus::window_manifest(
base_manifest,
ctx.window_offset,
ctx.window_length,
ctx.config().negative_multiplier,
)
.context("selecting manifest window for seeding")?
};
let requested_cases = manifest_for_seed.questions.len();
let mut namespace_reused = false;
if !config.reseed_slice {
namespace_reused = {
let slice = ctx.slice();
can_reuse_namespace(
ctx.db(),
ctx.descriptor(),
&namespace,
&database,
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
requested_cases,
)
.await?
};
}
let mut namespace_seed_ms = None;
if !namespace_reused {
ctx.must_reapply_settings = true;
if let Err(err) = reset_namespace(ctx.db(), &namespace, &database).await {
warn!(
error = %err,
namespace,
database = %database,
"Failed to reset namespace before reseeding; continuing with existing data"
);
} else if let Err(err) = ctx.db().apply_migrations().await {
warn!(error = %err, "Failed to reapply migrations after namespace reset");
}
{
let slice = ctx.slice();
info!(
slice = slice.manifest.slice_id.as_str(),
window_offset = ctx.window_offset,
window_length = ctx.window_length,
positives = manifest_for_seed
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect::<std::collections::HashSet<_>>()
.len(),
negatives = manifest_for_seed.paragraphs.len().saturating_sub(
manifest_for_seed
.questions
.iter()
.map(|q| q.paragraph_id.as_str())
.collect::<std::collections::HashSet<_>>()
.len(),
),
total = manifest_for_seed.paragraphs.len(),
"Seeding ingestion corpus into SurrealDB"
);
}
let indexes_disabled = remove_all_indexes(ctx.db()).await.is_ok();
let seed_start = Instant::now();
corpus::seed_manifest_into_db(ctx.db(), &manifest_for_seed)
.await
.context("seeding ingestion corpus from manifest")?;
namespace_seed_ms = Some(seed_start.elapsed().as_millis());
// Recreate indexes AFTER data is loaded (correct bulk loading pattern)
if indexes_disabled {
info!("Recreating indexes after seeding data");
recreate_indexes(ctx.db(), embedding_provider.dimension())
.await
.context("recreating indexes with correct dimension")?;
warm_hnsw_cache(ctx.db(), embedding_provider.dimension()).await?;
}
{
let slice = ctx.slice();
record_namespace_state(
ctx.descriptor(),
dataset.metadata.id.as_str(),
slice.manifest.slice_id.as_str(),
expected_fingerprint.as_str(),
&namespace,
&database,
requested_cases,
)
.await;
}
}
if ctx.must_reapply_settings {
let mut settings = SystemSettings::get_current(ctx.db())
.await
.context("reloading system settings after namespace reset")?;
settings =
enforce_system_settings(ctx.db(), settings, embedding_provider.dimension(), config)
.await?;
ctx.settings = Some(settings);
ctx.must_reapply_settings = false;
}
let user = ensure_eval_user(ctx.db()).await?;
ctx.eval_user = Some(user);
let total_manifest_questions = manifest_for_seed.questions.len();
let cases = cases_from_manifest(&manifest_for_seed);
let include_impossible = manifest_for_seed.metadata.include_unanswerable;
let require_verified_chunks = manifest_for_seed.metadata.require_verified_chunks;
let filtered = total_manifest_questions.saturating_sub(cases.len());
if filtered > 0 {
info!(
filtered_questions = filtered,
total_questions = total_manifest_questions,
includes_impossible = include_impossible,
require_verified_chunks = require_verified_chunks,
"Filtered questions not eligible for this evaluation mode (impossible or unverifiable)"
);
}
if cases.is_empty() {
return Err(anyhow!(
"no eligible questions found in converted dataset for evaluation (consider --llm-mode or refreshing ingestion data)"
));
}
ctx.cases = cases;
ctx.filtered_questions = filtered;
ctx.namespace_reused = namespace_reused;
ctx.namespace_seed_ms = namespace_seed_ms;
info!(
cases = ctx.cases.len(),
window_offset = ctx.window_offset,
namespace_reused = namespace_reused,
"Dataset ready"
);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_namespace()
.map_err(|(_, guard)| map_guard_error("prepare_namespace", guard))
}

View File

@@ -0,0 +1,72 @@
use std::time::Instant;
use anyhow::Context;
use tracing::info;
use crate::{
eval::{default_database, default_namespace, ledger_target},
slice,
};
use super::super::{
context::{EvalStage, EvaluationContext},
state::{EvaluationMachine, Ready, SlicePrepared},
};
use super::{map_guard_error, StageResult};
pub(crate) async fn prepare_slice(
machine: EvaluationMachine<(), Ready>,
ctx: &mut EvaluationContext<'_>,
) -> StageResult<SlicePrepared> {
let stage = EvalStage::PrepareSlice;
info!(
evaluation_stage = stage.label(),
"starting evaluation stage"
);
let started = Instant::now();
let ledger_limit = ledger_target(ctx.config());
let slice_settings = slice::slice_config_with_limit(ctx.config(), ledger_limit);
let resolved_slice =
slice::resolve_slice(ctx.dataset(), &slice_settings).context("resolving dataset slice")?;
let window = slice::select_window(
&resolved_slice,
ctx.config().slice_offset,
ctx.config().limit,
)
.context("selecting slice window (use --slice-grow to extend the ledger first)")?;
ctx.ledger_limit = ledger_limit;
ctx.slice_settings = Some(slice_settings);
ctx.slice = Some(resolved_slice.clone());
ctx.window_offset = window.offset;
ctx.window_length = window.length;
ctx.window_total_cases = window.total_cases;
ctx.namespace = ctx
.config()
.database
.db_namespace
.clone()
.unwrap_or_else(|| {
default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit)
});
ctx.database = ctx
.config()
.database
.db_database
.clone()
.unwrap_or_else(default_database);
let elapsed = started.elapsed();
ctx.record_stage_duration(stage, elapsed);
info!(
evaluation_stage = stage.label(),
duration_ms = elapsed.as_millis(),
"completed evaluation stage"
);
machine
.prepare_slice()
.map_err(|(_, guard)| map_guard_error("prepare_slice", guard))
}

Some files were not shown because too many files have changed in this diff Show More