37 Commits

Author SHA1 Message Date
Per Stark
f22a1e5ba4 chore: devenv inconsistency, spawn server manually in dev 2026-02-15 18:31:43 +01:00
Per Stark
4d237ff6d9 release: 1.0.2 2026-02-15 11:57:04 +01:00
Per Stark
eb928cdb0e test: minio to devenv, improved testing s3 and relationships 2026-02-15 08:52:56 +01:00
Per Stark
1490852a09 chore: dep updates & kv-mem separation to test feature
docker builder update
2026-02-15 08:51:48 +01:00
Per Stark
b0b01182d7 test: add admin auth integration coverage 2026-02-14 23:11:35 +01:00
Per Stark
679308aa1d feat: caching chat history & dto 2026-02-14 19:43:34 +01:00
Per Stark
f93c06b347 fix: harden html responses and cache chat sidebar data
Use strict template response handling and sanitized template user context, then add an in-process conversation archive cache with mutation-driven invalidation for chat sidebar renders.
2026-02-14 17:47:14 +01:00
Per Stark
a3f207beb1 fix: simplified admin checking 2026-02-13 23:04:01 +01:00
Per Stark
e07199adfc fix: name harmonization of endpoints & ingestion security hardening 2026-02-13 22:36:00 +01:00
Per Stark
f22cac891c fix: redact ingestion payload logs and update changelog 2026-02-13 12:06:18 +01:00
Per Stark
b89171d934 fix: parameterize storage-layer queries and add injection tests 2026-02-12 21:42:46 +01:00
Per Stark
0133eead63 fix: border in navigation 2026-02-12 20:39:36 +01:00
Per Stark
e5d2b6605f fix: browser back navigation from chat windows
addenum
2026-02-12 20:32:06 +01:00
Per Stark
bbad91d55b fix: references bug
fix
2026-02-11 22:02:40 +01:00
Per Stark
96846ad664 release: 1.0.1 2026-02-11 15:39:28 +01:00
Per Stark
269bcec659 docs: updated domain name 2026-02-11 15:17:03 +01:00
Per Stark
7c738c4b30 fix: gracefully handle old users 2026-02-11 07:50:19 +01:00
Per Stark
cb88127fcb docs: updated readme 2026-01-18 18:48:53 +01:00
Per Stark
49e1fbd985 dev: devenv processes 2026-01-18 18:45:30 +01:00
Per Stark
f2fa5bbbcc fix: edge case when deleting content
nit
2026-01-18 18:45:21 +01:00
Per Stark
a3bc6fba98 design: better dark mode 2026-01-17 23:31:05 +01:00
Per Stark
ece744d5a0 refactor: additional responsibilities to middleware, simplified handlers
fix
2026-01-17 21:07:25 +01:00
Per Stark
a9fda67209 theme: obsidian-prism 2026-01-17 08:45:47 +01:00
Per Stark
fa7f407306 feat: s3 storage backend 2026-01-16 23:38:47 +01:00
Per Stark
b25cfb4633 feat: add user theme preference
- Add theme field to User model (common)
- Create migration for theme field
- Add theme selection to Account Settings (html-router)
- Implement server-side theme rendering in base template
- Update JS for system/preference theme handling
- Remove header theme toggle for authenticated users
2026-01-16 13:54:07 +01:00
Per Stark
0df2b9810c docs: addenum 2026-01-14 22:24:23 +01:00
Per Stark
354dc727c1 refactor: extendable templates
refactor: simplification

refactor: simplification
2026-01-13 22:18:00 +01:00
Per Stark
037057d108 fix: allow for multiple templates directories 2026-01-12 21:25:12 +01:00
Per Stark
9f17c6c2b0 fix: updating models in admin view 2026-01-12 21:01:53 +01:00
Per Stark
17f252e630 release: 1.0.0
fix: cargo dist
2026-01-11 20:35:01 +01:00
Per Stark
db43be1606 fix: schemafull and textcontent 2026-01-02 15:41:22 +01:00
Per Stark
8e8370b080 docs: more complete and correct 2025-12-24 23:36:58 +01:00
Per Stark
84695fa0cc chore: wording 2025-12-22 23:03:33 +01:00
Per Stark
654add98bc fix: never block fts, rely on rrf 2025-12-22 22:56:57 +01:00
Per Stark
244ec0ea25 fix: migrating embeddings to new dimensions
changing order
2025-12-22 22:39:14 +01:00
Per Stark
d8416ac711 fix: ordering of index creation 2025-12-22 21:59:35 +01:00
Per Stark
f9f48d1046 docs: evaluations instructions and readme refactoring 2025-12-22 18:55:47 +01:00
139 changed files with 12107 additions and 3535 deletions

View File

@@ -1,8 +1,24 @@
# Changelog # Changelog
## Unreleased ## Unreleased
## 1.0.2 (2026-02-15)
- Fix: edge case where navigation back to a chat page could trigger a new response generation
- Fix: chat references now validate and render more reliably
- Fix: improved admin access checks for restricted routes
- Performance: faster chat sidebar loads from cached conversation archive data
- API: harmonized ingest endpoint naming and added configurable ingest safety limits
- Security: hardened query handling and ingestion logging to reduce injection and data exposure risk
## 1.0.1 (2026-02-11)
- Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments.
- Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling.
- Fixed edge cases, including content deletion behavior and compatibility for older user records.
## 1.0.0 (2026-01-02)
- **Locally generated embeddings are now default**. If you want to continue using API embeddings, set EMBEDDING_BACKEND to openai. This will download a ONNX model and recreate all embeddings. But in most instances it's very worth it. Removing the network bound call to create embeddings. Creating embeddings on my N100 device is extremely fast. Typically a search response is provided in less than 50ms.
- Added a benchmarks create for evaluating the retrieval process - Added a benchmarks create for evaluating the retrieval process
- Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms. - Added fastembed embedding support, enables the use of local CPU generated embeddings, greatly improved latency if machine can handle it. Quick search has vastly better accuracy and is much faster, 50ms latency when testing compared to minimum 300ms.
- Embeddings stored on own table - Embeddings stored on own table.
- Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details. - Refactored retrieval pipeline to use the new, faster and more accurate strategy. Read [blog post](https://blog.stark.pub/posts/eval-retrieval-refactor/) for more details.
## Version 0.2.7 (2025-12-04) ## Version 0.2.7 (2025-12-04)

3210
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -40,7 +40,7 @@ serde_json = "1.0.128"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
sha2 = "0.10.8" sha2 = "0.10.8"
surrealdb-migrations = "2.2.2" surrealdb-migrations = "2.2.2"
surrealdb = { version = "2", features = ["kv-mem"] } surrealdb = { version = "2" }
tempfile = "3.12.0" tempfile = "3.12.0"
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] } text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
tokenizers = { version = "0.20.4", features = ["http"] } tokenizers = { version = "0.20.4", features = ["http"] }
@@ -56,7 +56,7 @@ url = { version = "2.5.2", features = ["serde"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] } uuid = { version = "1.10.0", features = ["v4", "serde"] }
tokio-retry = "0.3.0" tokio-retry = "0.3.0"
base64 = "0.22.1" base64 = "0.22.1"
object_store = { version = "0.11.2" } object_store = { version = "0.11.2", features = ["aws"] }
bytes = "1.7.1" bytes = "1.7.1"
state-machines = "0.2.0" state-machines = "0.2.0"
fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] } fastembed = { version = "5.2.0", default-features = false, features = ["hf-hub-native-tls", "ort-load-dynamic"] }

View File

@@ -1,5 +1,5 @@
# === Builder === # === Builder ===
FROM rust:1.86-bookworm AS builder FROM rust:1.89-bookworm AS builder
WORKDIR /usr/src/minne WORKDIR /usr/src/minne
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/* pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
@@ -30,8 +30,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libgomp1 libstdc++6 curl \ libgomp1 libstdc++6 curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# ONNX Runtime (CPU). Change if you bump ort. # ONNX Runtime (CPU). Keep in sync with ort crate requirements.
ARG ORT_VERSION=1.22.0 ARG ORT_VERSION=1.23.2
RUN mkdir -p /opt/onnxruntime && \ RUN mkdir -p /opt/onnxruntime && \
curl -fsSL -o /tmp/ort.tgz \ curl -fsSL -o /tmp/ort.tgz \
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \ "https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \

260
README.md
View File

@@ -1,61 +1,265 @@
# Minne # Minne - A Graph-Powered Personal Knowledge Base
**A graph-powered personal knowledge base that remembers for you.** **Minne (Swedish for "memory")** is a personal knowledge management system and save-for-later application for capturing, organizing, and accessing your information. Inspired by the Zettelkasten method, it uses a graph database to automatically create connections between your notes without manual linking overhead.
Capture content effortlessly, let AI discover connections, and explore your knowledge visually. Self-hosted and privacy-focused.
[![Release Status](https://github.com/perstarkse/minne/actions/workflows/release.yml/badge.svg)](https://github.com/perstarkse/minne/actions/workflows/release.yml) [![Release Status](https://github.com/perstarkse/minne/actions/workflows/release.yml/badge.svg)](https://github.com/perstarkse/minne/actions/workflows/release.yml)
[![License: AGPL v3](https://img.shields.io/badge/License-AGPL_v3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0) [![License: AGPL v3](https://img.shields.io/badge/License-AGPL_v3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![Latest Release](https://img.shields.io/github/v/release/perstarkse/minne?sort=semver)](https://github.com/perstarkse/minne/releases/latest) [![Latest Release](https://img.shields.io/github/v/release/perstarkse/minne?sort=semver)](https://github.com/perstarkse/minne/releases/latest)
## Try It ![Screenshot](screenshot-graph.webp)
**[Live Demo](https://minne-demo.stark.pub)** — Read-only demo deployment ## Demo deployment
To test _Minne_ out, enter [this](https://minne.stark.pub) and sign in to a read-only demo deployment to view and test functionality out.
## Noteworthy Features
- **Search & Chat Interface** - Find content or knowledge instantly with full-text search, or use the chat mode and conversational AI to find and reason about content
- **Manual and AI-assisted connections** - Build entities and relationships manually with full control, let AI create entities and relationships automatically, or blend both approaches with AI suggestions for manual approval
- **Hybrid Retrieval System** - Search combining vector similarity & full-text search
- **Scratchpad Feature** - Quickly capture thoughts and convert them to permanent content when ready
- **Visual Graph Explorer** - Interactive D3-based navigation of your knowledge entities and connections
- **Multi-Format Support** - Ingest text, URLs, PDFs, audio files, and images into your knowledge base
- **Performance Focus** - Built with Rust and server-side rendering for speed and efficiency
- **Self-Hosted & Privacy-Focused** - Full control over your data, and compatible with any OpenAI-compatible API that supports structured outputs
## The "Why" Behind Minne
For a while I've been fascinated by personal knowledge management systems. I wanted something that made it incredibly easy to capture content - snippets of text, URLs, and other media - while automatically discovering connections between ideas. But I also wanted to maintain control over my knowledge structure.
Traditional tools like Logseq and Obsidian are excellent, but the manual linking process often became a hindrance. Meanwhile, fully automated systems sometimes miss important context or create relationships I wouldn't have chosen myself.
So I built Minne to offer the best of both worlds: effortless content capture with AI-assisted relationship discovery, but with the flexibility to manually curate, edit, or override any connections. You can let AI handle the heavy lifting of extracting entities and finding relationships, take full control yourself, or use a hybrid approach where AI suggests connections that you can approve or modify.
While developing Minne, I discovered [KaraKeep](https://github.com/karakeep-app/karakeep) (formerly Hoarder), which is an excellent application in a similar space you probably want to check it out! However, if you're interested in a PKM that offers both intelligent automation and manual curation, with the ability to chat with your knowledge base, then Minne might be worth testing.
## Table of Contents
- [Quick Start](#quick-start)
- [Features in Detail](#features-in-detail)
- [Configuration](#configuration)
- [Tech Stack](#tech-stack)
- [Application Architecture](#application-architecture)
- [AI Configuration](#ai-configuration--model-selection)
- [Roadmap](#roadmap)
- [Development](#development)
- [Contributing](#contributing)
- [License](#license)
## Quick Start ## Quick Start
The fastest way to get Minne running is with Docker Compose:
```bash ```bash
# Clone the repository
git clone https://github.com/perstarkse/minne.git git clone https://github.com/perstarkse/minne.git
cd minne cd minne
# Set your OpenAI API key in docker-compose.yml, then: # Start Minne and its database
docker compose up -d docker compose up -d
# Open http://localhost:3000 # Access at http://localhost:3000
``` ```
Or with Nix: **Required Setup:**
- Replace `your_openai_api_key_here` in `docker-compose.yml` with your actual API key
- Configure `OPENAI_BASE_URL` if using a custom AI provider (like Ollama)
For detailed installation options, see [Configuration](#configuration).
## Features in Detail
### Search vs. Chat mode
**Search** - Use when you know roughly what you're looking for. Full-text search finds items quickly by matching your query terms.
**Chat Mode** - Use when you want to explore concepts, find connections, or reason about your knowledge. The AI analyzes your query and finds relevant context across your entire knowledge base.
### Content Processing
Minne automatically processes content you save:
1. **Web scraping** extracts readable text from URLs
2. **Text analysis** identifies key concepts and relationships
3. **Graph creation** builds connections between related content
4. **Embedding generation** enables semantic search capabilities
### Visual Knowledge Graph
Explore your knowledge as an interactive network with flexible curation options:
**Manual Curation** - Create knowledge entities and relationships yourself with full control over your graph structure
**AI Automation** - Let AI automatically extract entities and discover relationships from your content
**Hybrid Approach** - Get AI-suggested relationships and entities that you can manually review, edit, or approve
The graph visualization shows:
- Knowledge entities as nodes (manually created or AI-extracted)
- Relationships as connections (manually defined, AI-discovered, or suggested)
- Interactive navigation for discovery and editing
### Optional FastEmbed Reranking
Minne ships with an opt-in reranking stage powered by [fastembed-rs](https://github.com/Anush008/fastembed-rs). When enabled, the hybrid retrieval results are rescored with a lightweight cross-encoder before being returned to chat or ingestion flows. In practice this often means more relevant results, boosting answer quality and downstream enrichment.
⚠️ **Resource notes**
- Enabling reranking downloads and caches ~1.1GB of model data on first startup (cached under `<data_dir>/fastembed/reranker` by default).
- Initialization takes longer while warming the cache, and each query consumes extra CPU. The default pool size (2) is tuned for a singe user setup, but could work with a pool size on 1 as well.
- The feature is disabled by default. Set `reranking_enabled: true` (or `RERANKING_ENABLED=true`) if youre comfortable with the additional footprint.
Example configuration:
```yaml
reranking_enabled: true
reranking_pool_size: 2
fastembed_cache_dir: "/var/lib/minne/fastembed" # optional override, defaults to .fastembed_cache
```
## Tech Stack
- **Backend:** Rust with Axum framework and Server-Side Rendering (SSR)
- **Frontend:** HTML with HTMX and minimal JavaScript for interactivity
- **Database:** SurrealDB (graph, document, and vector search)
- **AI Integration:** OpenAI-compatible API with structured outputs
- **Web Processing:** Headless Chrome for robust webpage content extraction
## Configuration
Minne can be configured using environment variables or a `config.yaml` file. Environment variables take precedence over `config.yaml`.
### Required Configuration
- `SURREALDB_ADDRESS`: WebSocket address of your SurrealDB instance (e.g., `ws://127.0.0.1:8000`)
- `SURREALDB_USERNAME`: Username for SurrealDB (e.g., `root_user`)
- `SURREALDB_PASSWORD`: Password for SurrealDB (e.g., `root_password`)
- `SURREALDB_DATABASE`: Database name in SurrealDB (e.g., `minne_db`)
- `SURREALDB_NAMESPACE`: Namespace in SurrealDB (e.g., `minne_ns`)
- `OPENAI_API_KEY`: Your API key for OpenAI compatible endpoint
- `HTTP_PORT`: Port for the Minne server (Default: `3000`)
### Optional Configuration
- `RUST_LOG`: Controls logging level (e.g., `minne=info,tower_http=debug`)
- `DATA_DIR`: Directory to store local data (e.g., `./data`)
- `OPENAI_BASE_URL`: Base URL for custom AI providers (like Ollama)
- `RERANKING_ENABLED` / `reranking_enabled`: Set to `true` to enable the FastEmbed reranking stage (default `false`)
- `RERANKING_POOL_SIZE` / `reranking_pool_size`: Maximum concurrent reranker workers (defaults to `2`)
- `FASTEMBED_CACHE_DIR` / `fastembed_cache_dir`: Directory for cached FastEmbed models (defaults to `<data_dir>/fastembed/reranker`)
- `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` / `fastembed_show_download_progress`: Show model download progress when warming the cache (default `true`)
### Example config.yaml
```yaml
surrealdb_address: "ws://127.0.0.1:8000"
surrealdb_username: "root_user"
surrealdb_password: "root_password"
surrealdb_database: "minne_db"
surrealdb_namespace: "minne_ns"
openai_api_key: "sk-YourActualOpenAIKeyGoesHere"
data_dir: "./minne_app_data"
http_port: 3000
# rust_log: "info"
```
## Installation Options
### 1. Docker Compose (Recommended)
```bash
# Clone and run
git clone https://github.com/perstarkse/minne.git
cd minne
docker compose up -d
```
The included `docker-compose.yml` handles SurrealDB and Chromium dependencies automatically.
### 2. Nix
```bash ```bash
nix run 'github:perstarkse/minne#main' nix run 'github:perstarkse/minne#main'
``` ```
## Features This fetches Minne and all dependencies, including Chromium.
- **Search & Chat** — Full-text search or conversational AI to find and reason about content ### 3. Pre-built Binaries
- **Knowledge Graph** — Visual exploration with automatic or manual relationship curation
- **Hybrid Retrieval** — Vector similarity + full-text + graph traversal for relevant results
- **Multi-Format** — Ingest text, URLs, PDFs, audio, and images
- **Self-Hosted** — Your data, your server, any OpenAI-compatible API
## Documentation Download binaries for Windows, macOS, and Linux from the [GitHub Releases](https://github.com/perstarkse/minne/releases/latest).
| Guide | Description | **Requirements:** You'll need to provide SurrealDB and Chromium separately.
|-------|-------------|
| [Installation](docs/installation.md) | Docker, Nix, binaries, source builds |
| [Configuration](docs/configuration.md) | Environment variables, config.yaml, AI setup |
| [Features](docs/features.md) | Search, Chat, Graph, Reranking, Ingestion |
| [Architecture](docs/architecture.md) | Tech stack, crate structure, data flow |
| [Vision](docs/vision.md) | Philosophy, roadmap, related projects |
## Tech Stack ### 4. Build from Source
Rust • Axum • HTMX • SurrealDB • FastEmbed ```bash
git clone https://github.com/perstarkse/minne.git
cd minne
cargo run --release --bin main
```
**Requirements:** SurrealDB and Chromium must be installed and accessible in your PATH.
## Application Architecture
Minne offers flexible deployment options:
- **`main`**: Combined server and worker in one process (recommended for most users)
- **`server`**: Web interface and API only
- **`worker`**: Background processing only (for resource optimization)
## Usage
Once Minne is running at `http://localhost:3000`:
1. **Web Interface**: Full-featured experience for desktop and mobile
2. **iOS Shortcut**: Use the [Minne iOS Shortcut](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) for quick content capture
3. **Content Types**: Save notes, URLs, audio files, and more
4. **Knowledge Graph**: Explore automatic connections between your content
5. **Chat Interface**: Query your knowledge base conversationally
## AI Configuration & Model Selection
### Setting Up AI Providers
Minne uses OpenAI-compatible APIs. Configure via environment variables or `config.yaml`:
- `OPENAI_API_KEY` (required): Your API key
- `OPENAI_BASE_URL` (optional): Custom provider URL (e.g., Ollama: `http://localhost:11434/v1`)
### Model Selection
1. Access the `/admin` page in your Minne instance
2. Select models for content processing and chat from your configured provider
3. **Content Processing Requirements**: The model must support structured outputs
4. **Embedding Dimensions**: Update this setting when changing embedding models (e.g., 1536 for `text-embedding-3-small`, 768 for `nomic-embed-text`)
## Roadmap
Current development focus:
- TUI frontend with system editor integration
- Enhanced reranking for improved retrieval recall
- Additional content type support
Feature requests and contributions are welcome!
## Development
```bash
# Run tests
cargo test
# Development build
cargo build
# Comprehensive linting
cargo clippy --workspace --all-targets --all-features
```
The codebase includes extensive unit tests. Integration tests and additional contributions are welcome.
## Contributing ## Contributing
I've developed Minne primarily for my own use, but having been in the selfhosted space for a long time, and using the efforts by others, I thought I'd share with the community. Feature requests are welcome.
Feature requests and contributions welcome. See [Vision](docs/vision.md) for roadmap.
## License ## License
[AGPL-3.0](LICENSE) Minne is licensed under the **GNU Affero General Public License v3.0 (AGPL-3.0)**. See the [LICENSE](LICENSE) file for details.

View File

@@ -20,6 +20,9 @@ pub enum ApiError {
#[error("Unauthorized: {0}")] #[error("Unauthorized: {0}")]
Unauthorized(String), Unauthorized(String),
#[error("Payload too large: {0}")]
PayloadTooLarge(String),
} }
impl From<AppError> for ApiError { impl From<AppError> for ApiError {
@@ -67,6 +70,13 @@ impl IntoResponse for ApiError {
status: "error".to_string(), status: "error".to_string(),
}, },
), ),
Self::PayloadTooLarge(message) => (
StatusCode::PAYLOAD_TOO_LARGE,
ErrorResponse {
error: message,
status: "error".to_string(),
},
),
}; };
(status, Json(error_response)).into_response() (status, Json(error_response)).into_response()
@@ -132,6 +142,10 @@ mod tests {
// Test unauthorized status // Test unauthorized status
let error = ApiError::Unauthorized("not allowed".to_string()); let error = ApiError::Unauthorized("not allowed".to_string());
assert_status_code(error, StatusCode::UNAUTHORIZED); assert_status_code(error, StatusCode::UNAUTHORIZED);
// Test payload too large status
let error = ApiError::PayloadTooLarge("too big".to_string());
assert_status_code(error, StatusCode::PAYLOAD_TOO_LARGE);
} }
// Alternative approach that doesn't try to parse the response body // Alternative approach that doesn't try to parse the response body

View File

@@ -6,7 +6,7 @@ use axum::{
Router, Router,
}; };
use middleware_api_auth::api_auth; use middleware_api_auth::api_auth;
use routes::{categories::get_categories, ingress::ingest_data, liveness::live, readiness::ready}; use routes::{categories::get_categories, ingest::ingest_data, liveness::live, readiness::ready};
pub mod api_state; pub mod api_state;
pub mod error; pub mod error;
@@ -26,9 +26,13 @@ where
// Protected API endpoints (require auth) // Protected API endpoints (require auth)
let protected = Router::new() let protected = Router::new()
.route("/ingress", post(ingest_data)) .route(
"/ingest",
post(ingest_data).layer(DefaultBodyLimit::max(
app_state.config.ingest_max_body_bytes,
)),
)
.route("/categories", get(get_categories)) .route("/categories", get(get_categories))
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
.route_layer(from_fn_with_state(app_state.clone(), api_auth)); .route_layer(from_fn_with_state(app_state.clone(), api_auth));
public.merge(protected) public.merge(protected)

View File

@@ -6,6 +6,7 @@ use common::{
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
user::User, user::User,
}, },
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
}; };
use futures::{future::try_join_all, TryFutureExt}; use futures::{future::try_join_all, TryFutureExt};
use serde_json::json; use serde_json::json;
@@ -19,7 +20,7 @@ pub struct IngestParams {
pub content: Option<String>, pub content: Option<String>,
pub context: String, pub context: String,
pub category: String, pub category: String,
#[form_data(limit = "10000000")] // Adjust limit as needed #[form_data(limit = "20000000")]
#[form_data(default)] #[form_data(default)]
pub files: Vec<FieldData<NamedTempFile>>, pub files: Vec<FieldData<NamedTempFile>>,
} }
@@ -29,8 +30,38 @@ pub async fn ingest_data(
Extension(user): Extension<User>, Extension(user): Extension<User>,
TypedMultipart(input): TypedMultipart<IngestParams>, TypedMultipart(input): TypedMultipart<IngestParams>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", input);
let user_id = user.id; let user_id = user.id;
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
match validate_ingest_input(
&state.config,
input.content.as_deref(),
&input.context,
&input.category,
file_count,
) {
Ok(()) => {}
Err(IngestValidationError::PayloadTooLarge(message)) => {
return Err(ApiError::PayloadTooLarge(message));
}
Err(IngestValidationError::BadRequest(message)) => {
return Err(ApiError::ValidationError(message));
}
}
info!(
user_id = %user_id,
has_content,
content_bytes,
context_bytes,
category_bytes,
file_count,
"Received ingest request"
);
let file_infos = try_join_all(input.files.into_iter().map(|file| { let file_infos = try_join_all(input.files.into_iter().map(|file| {
FileInfo::new_with_storage(file, &state.db, &user_id, &state.storage) FileInfo::new_with_storage(file, &state.db, &user_id, &state.storage)

View File

@@ -1,4 +1,4 @@
pub mod categories; pub mod categories;
pub mod ingress; pub mod ingest;
pub mod liveness; pub mod liveness;
pub mod readiness; pub mod readiness;

View File

@@ -16,7 +16,7 @@ tracing = { workspace = true }
anyhow = { workspace = true } anyhow = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
surrealdb = { workspace = true, features = ["kv-mem"] } surrealdb = { workspace = true }
async-openai = { workspace = true } async-openai = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
tempfile = { workspace = true } tempfile = { workspace = true }
@@ -49,4 +49,7 @@ fastembed = { workspace = true }
[features] [features]
test-utils = [] test-utils = ["surrealdb/kv-mem"]
[dev-dependencies]
surrealdb = { workspace = true, features = ["kv-mem"] }

View File

@@ -62,9 +62,36 @@ DEFINE TABLE OVERWRITE conversation SCHEMAFULL;
DEFINE TABLE OVERWRITE file SCHEMAFULL; DEFINE TABLE OVERWRITE file SCHEMAFULL;
DEFINE TABLE OVERWRITE knowledge_entity SCHEMAFULL; DEFINE TABLE OVERWRITE knowledge_entity SCHEMAFULL;
DEFINE TABLE OVERWRITE message SCHEMAFULL; DEFINE TABLE OVERWRITE message SCHEMAFULL;
DEFINE TABLE OVERWRITE relates_to SCHEMAFULL; DEFINE TABLE OVERWRITE relates_to SCHEMAFULL TYPE RELATION;
DEFINE FIELD IF NOT EXISTS in ON relates_to TYPE record<knowledge_entity>;
DEFINE FIELD IF NOT EXISTS out ON relates_to TYPE record<knowledge_entity>;
DEFINE FIELD IF NOT EXISTS metadata ON relates_to TYPE object;
DEFINE FIELD IF NOT EXISTS metadata.user_id ON relates_to TYPE string;
DEFINE FIELD IF NOT EXISTS metadata.source_id ON relates_to TYPE string;
DEFINE FIELD IF NOT EXISTS metadata.relationship_type ON relates_to TYPE string;
DEFINE TABLE OVERWRITE scratchpad SCHEMAFULL; DEFINE TABLE OVERWRITE scratchpad SCHEMAFULL;
DEFINE TABLE OVERWRITE system_settings SCHEMAFULL; DEFINE TABLE OVERWRITE system_settings SCHEMAFULL;
DEFINE TABLE OVERWRITE text_chunk SCHEMAFULL; DEFINE TABLE OVERWRITE text_chunk SCHEMAFULL;
-- text_content must have fields defined before enforcing SCHEMAFULL
DEFINE TABLE OVERWRITE text_content SCHEMAFULL; DEFINE TABLE OVERWRITE text_content SCHEMAFULL;
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS updated_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;
DEFINE TABLE OVERWRITE user SCHEMAFULL; DEFINE TABLE OVERWRITE user SCHEMAFULL;

View File

@@ -0,0 +1 @@
DEFINE FIELD IF NOT EXISTS theme ON user TYPE string DEFAULT "system";

View File

@@ -0,0 +1 @@
{"schemas":"--- original\n+++ modified\n@@ -242,7 +242,7 @@\n\n # Defines the schema for the 'text_content' table.\n\n-DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS;\n+DEFINE TABLE IF NOT EXISTS text_content SCHEMAFULL;\n\n # Standard fields\n DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;\n@@ -254,10 +254,24 @@\n DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;\n # UrlInfo is a struct, store as object\n DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;\n+DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;\n+\n DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;\n DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;\n\n+# FileInfo fields\n+DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;\n+DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;\n+DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;\n+\n # Indexes based on query patterns\n DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;\n","events":null}

View File

@@ -0,0 +1 @@
{"schemas":"--- original\n+++ modified\n@@ -28,6 +28,7 @@\n # Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)\n DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY\n+DEFINE INDEX IF NOT EXISTS conversation_user_updated_at_idx ON conversation FIELDS user_id, updated_at; # For sidebar conversation projection ORDER BY\n\n # Defines the schema for the 'file' table (used by FileInfo).\n\n","events":null}

View File

@@ -13,3 +13,4 @@ DEFINE FIELD IF NOT EXISTS title ON conversation TYPE string;
# Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations) # Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)
DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id; DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;
DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY
DEFINE INDEX IF NOT EXISTS conversation_user_updated_at_idx ON conversation FIELDS user_id, updated_at; # For sidebar conversation projection ORDER BY

View File

@@ -1,6 +1,6 @@
# Defines the schema for the 'text_content' table. # Defines the schema for the 'text_content' table.
DEFINE TABLE IF NOT EXISTS text_content SCHEMALESS; DEFINE TABLE IF NOT EXISTS text_content SCHEMAFULL;
# Standard fields # Standard fields
DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime; DEFINE FIELD IF NOT EXISTS created_at ON text_content TYPE datetime;
@@ -12,10 +12,24 @@ DEFINE FIELD IF NOT EXISTS text ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>; DEFINE FIELD IF NOT EXISTS file_info ON text_content TYPE option<object>;
# UrlInfo is a struct, store as object # UrlInfo is a struct, store as object
DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>; DEFINE FIELD IF NOT EXISTS url_info ON text_content TYPE option<object>;
DEFINE FIELD IF NOT EXISTS url_info.url ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS url_info.title ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS url_info.image_id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>; DEFINE FIELD IF NOT EXISTS context ON text_content TYPE option<string>;
DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string; DEFINE FIELD IF NOT EXISTS category ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string; DEFINE FIELD IF NOT EXISTS user_id ON text_content TYPE string;
# FileInfo fields
DEFINE FIELD IF NOT EXISTS file_info.id ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.created_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS file_info.updated_at ON text_content TYPE datetime;
DEFINE FIELD IF NOT EXISTS file_info.sha256 ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.path ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.file_name ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.mime_type ON text_content TYPE string;
DEFINE FIELD IF NOT EXISTS file_info.user_id ON text_content TYPE string;
# Indexes based on query patterns # Indexes based on query patterns
DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id; DEFINE INDEX IF NOT EXISTS text_content_user_id_idx ON text_content FIELDS user_id;
DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at; DEFINE INDEX IF NOT EXISTS text_content_created_at_idx ON text_content FIELDS created_at;

View File

@@ -208,7 +208,26 @@ async fn ensure_runtime_indexes_inner(
) )
.await .await
} }
HnswIndexState::Matches => Ok(()), HnswIndexState::Matches => {
let status = get_index_status(db, spec.index_name, spec.table).await?;
if status.eq_ignore_ascii_case("error") {
warn!(
index = spec.index_name,
table = spec.table,
"HNSW index found in error state; triggering rebuild"
);
create_index_with_polling(
db,
spec.definition_overwrite(embedding_dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
} else {
Ok(())
}
}
HnswIndexState::Different(existing) => { HnswIndexState::Different(existing) => {
info!( info!(
index = spec.index_name, index = spec.index_name,
@@ -234,6 +253,30 @@ async fn ensure_runtime_indexes_inner(
Ok(()) Ok(())
} }
async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -> Result<String> {
let info_query = format!("INFO FOR INDEX {index_name} ON TABLE {table};");
let mut info_res = db
.client
.query(info_query)
.await
.context("checking index status")?;
let info: Option<Value> = info_res.take(0).context("failed to take info result")?;
let info = match info {
Some(i) => i,
None => return Ok("unknown".to_string()),
};
let building = info.get("building");
let status = building
.and_then(|b| b.get("status"))
.and_then(|s| s.as_str())
.unwrap_or("ready")
.to_string();
Ok(status)
}
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> { async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
debug!("Rebuilding indexes with concurrent definitions"); debug!("Rebuilding indexes with concurrent definitions");
create_fts_analyzer(db).await?; create_fts_analyzer(db).await?;

View File

@@ -6,6 +6,7 @@ use anyhow::{anyhow, Result as AnyResult};
use bytes::Bytes; use bytes::Bytes;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
use object_store::aws::AmazonS3Builder;
use object_store::local::LocalFileSystem; use object_store::local::LocalFileSystem;
use object_store::memory::InMemory; use object_store::memory::InMemory;
use object_store::{path::Path as ObjPath, ObjectStore}; use object_store::{path::Path as ObjPath, ObjectStore};
@@ -234,6 +235,39 @@ async fn create_storage_backend(
let store = InMemory::new(); let store = InMemory::new();
Ok((Arc::new(store), None)) Ok((Arc::new(store), None))
} }
StorageKind::S3 => {
let bucket = cfg
.s3_bucket
.as_ref()
.ok_or_else(|| object_store::Error::Generic {
store: "S3",
source: anyhow!("s3_bucket is required for S3 storage").into(),
})?;
let mut builder = AmazonS3Builder::new()
.with_bucket_name(bucket)
.with_allow_http(true);
if let (Ok(key), Ok(secret)) = (
std::env::var("AWS_ACCESS_KEY_ID"),
std::env::var("AWS_SECRET_ACCESS_KEY"),
) {
builder = builder
.with_access_key_id(key)
.with_secret_access_key(secret);
}
if let Some(endpoint) = &cfg.s3_endpoint {
builder = builder.with_endpoint(endpoint);
}
if let Some(region) = &cfg.s3_region {
builder = builder.with_region(region);
}
let store = builder.build()?;
Ok((Arc::new(store), None))
}
} }
} }
@@ -247,6 +281,33 @@ pub mod testing {
use crate::utils::config::{AppConfig, PdfIngestMode}; use crate::utils::config::{AppConfig, PdfIngestMode};
use uuid; use uuid;
const DEFAULT_TEST_S3_BUCKET: &str = "minne-tests";
const DEFAULT_TEST_S3_ENDPOINT: &str = "http://127.0.0.1:19000";
fn configured_test_s3_bucket() -> String {
std::env::var("MINNE_TEST_S3_BUCKET")
.ok()
.filter(|value| !value.trim().is_empty())
.or_else(|| {
std::env::var("S3_BUCKET")
.ok()
.filter(|value| !value.trim().is_empty())
})
.unwrap_or_else(|| DEFAULT_TEST_S3_BUCKET.to_string())
}
fn configured_test_s3_endpoint() -> String {
std::env::var("MINNE_TEST_S3_ENDPOINT")
.ok()
.filter(|value| !value.trim().is_empty())
.or_else(|| {
std::env::var("S3_ENDPOINT")
.ok()
.filter(|value| !value.trim().is_empty())
})
.unwrap_or_else(|| DEFAULT_TEST_S3_ENDPOINT.to_string())
}
/// Create a test configuration with memory storage. /// Create a test configuration with memory storage.
/// ///
/// This provides a ready-to-use configuration for testing scenarios /// This provides a ready-to-use configuration for testing scenarios
@@ -290,6 +351,30 @@ pub mod testing {
} }
} }
/// Create a test configuration with S3 storage (MinIO).
///
/// Uses `MINNE_TEST_S3_ENDPOINT` / `S3_ENDPOINT` and
/// `MINNE_TEST_S3_BUCKET` / `S3_BUCKET` when provided.
pub fn test_config_s3() -> AppConfig {
AppConfig {
openai_api_key: "test".into(),
surrealdb_address: "test".into(),
surrealdb_username: "test".into(),
surrealdb_password: "test".into(),
surrealdb_namespace: "test".into(),
surrealdb_database: "test".into(),
data_dir: "/tmp/unused".into(),
http_port: 0,
openai_base_url: "..".into(),
storage: StorageKind::S3,
s3_bucket: Some(configured_test_s3_bucket()),
s3_endpoint: Some(configured_test_s3_endpoint()),
s3_region: Some("us-east-1".into()),
pdf_ingest_mode: PdfIngestMode::LlmFirst,
..Default::default()
}
}
/// A specialized StorageManager for testing scenarios. /// A specialized StorageManager for testing scenarios.
/// ///
/// This provides automatic setup for memory storage with proper isolation /// This provides automatic setup for memory storage with proper isolation
@@ -332,6 +417,30 @@ pub mod testing {
}) })
} }
/// Create a new TestStorageManager with S3 backend (MinIO).
///
/// This requires a reachable MinIO endpoint and an existing test bucket.
pub async fn new_s3() -> object_store::Result<Self> {
// Ensure credentials are set for MinIO
// We set these env vars for the process, which AmazonS3Builder will pick up
std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
std::env::set_var("AWS_REGION", "us-east-1");
let cfg = test_config_s3();
let storage = StorageManager::new(&cfg).await?;
// Probe the bucket so tests can cleanly skip when the endpoint is unreachable
// or the test bucket is not provisioned.
let probe_prefix = format!("__minne_s3_probe__/{}", uuid::Uuid::new_v4());
storage.list(Some(&probe_prefix)).await?;
Ok(Self {
storage,
_temp_dir: None,
})
}
/// Create a TestStorageManager with custom configuration. /// Create a TestStorageManager with custom configuration.
pub async fn with_config(cfg: &AppConfig) -> object_store::Result<Self> { pub async fn with_config(cfg: &AppConfig) -> object_store::Result<Self> {
let storage = StorageManager::new(cfg).await?; let storage = StorageManager::new(cfg).await?;
@@ -369,6 +478,14 @@ pub mod testing {
self.storage.get(location).await self.storage.get(location).await
} }
/// Get a streaming handle for test data.
pub async fn get_stream(
&self,
location: &str,
) -> object_store::Result<BoxStream<'static, object_store::Result<Bytes>>> {
self.storage.get_stream(location).await
}
/// Delete test data below the specified prefix. /// Delete test data below the specified prefix.
pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> { pub async fn delete_prefix(&self, prefix: &str) -> object_store::Result<()> {
self.storage.delete_prefix(prefix).await self.storage.delete_prefix(prefix).await
@@ -837,4 +954,117 @@ mod tests {
// Verify it's using memory backend // Verify it's using memory backend
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory); assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
} }
// S3 Tests - Require a reachable MinIO endpoint and test bucket.
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
#[tokio::test]
async fn test_storage_manager_s3_basic_operations() {
// Skip if S3 connection fails (e.g. no MinIO)
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
eprintln!("Skipping S3 test (setup failed)");
return;
};
let prefix = format!("test-basic-{}", Uuid::new_v4());
let location = format!("{prefix}/file.txt");
let data = b"test data for S3";
// Test put
if let Err(e) = storage.put(&location, data).await {
eprintln!("Skipping S3 test (put failed - bucket missing?): {e}");
return;
}
// Test get
let retrieved = storage.get(&location).await.expect("get");
assert_eq!(retrieved.as_ref(), data);
// Test exists
assert!(storage.exists(&location).await.expect("exists"));
// Test delete
storage
.delete_prefix(&format!("{prefix}/"))
.await
.expect("delete");
assert!(!storage
.exists(&location)
.await
.expect("exists after delete"));
}
#[tokio::test]
async fn test_storage_manager_s3_list_operations() {
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return;
};
let prefix = format!("test-list-{}", Uuid::new_v4());
let files = vec![
(format!("{prefix}/file1.txt"), b"content1"),
(format!("{prefix}/file2.txt"), b"content2"),
(format!("{prefix}/sub/file3.txt"), b"content3"),
];
for (loc, data) in &files {
if storage.put(loc, *data).await.is_err() {
return; // Abort if put fails
}
}
// List with prefix
let list_prefix = format!("{prefix}/");
let items = storage.list(Some(&list_prefix)).await.expect("list");
assert_eq!(items.len(), 3);
// Cleanup
storage.delete_prefix(&list_prefix).await.expect("cleanup");
}
#[tokio::test]
async fn test_storage_manager_s3_stream_operations() {
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return;
};
let prefix = format!("test-stream-{}", Uuid::new_v4());
let location = format!("{prefix}/large.bin");
let content = vec![42u8; 1024 * 10]; // 10KB
if storage.put(&location, &content).await.is_err() {
return;
}
let mut stream = storage.get_stream(&location).await.expect("get stream");
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await {
collected.extend_from_slice(&chunk.expect("chunk"));
}
assert_eq!(collected, content);
storage
.delete_prefix(&format!("{prefix}/"))
.await
.expect("cleanup");
}
#[tokio::test]
async fn test_storage_manager_s3_backend_kind() {
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return;
};
assert_eq!(*storage.storage().backend_kind(), StorageKind::S3);
}
#[tokio::test]
async fn test_storage_manager_s3_error_handling() {
let Ok(storage) = testing::TestStorageManager::new_s3().await else {
return;
};
let location = format!("nonexistent-{}/file.txt", Uuid::new_v4());
assert!(storage.get(&location).await.is_err());
assert!(!storage.exists(&location).await.expect("exists check"));
}
} }

View File

@@ -10,6 +10,54 @@ stored_object!(Conversation, "conversation", {
title: String title: String
}); });
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
pub struct SidebarConversation {
#[serde(deserialize_with = "deserialize_sidebar_id")]
pub id: String,
pub title: String,
}
struct SidebarIdVisitor;
impl<'de> serde::de::Visitor<'de> for SidebarIdVisitor {
type Value = String;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string id or a SurrealDB Thing")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(value.to_string())
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(value)
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let thing = <surrealdb::sql::Thing as serde::Deserialize>::deserialize(
serde::de::value::MapAccessDeserializer::new(map),
)?;
Ok(thing.id.to_raw())
}
}
fn deserialize_sidebar_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(SidebarIdVisitor)
}
impl Conversation { impl Conversation {
pub fn new(user_id: String, title: String) -> Self { pub fn new(user_id: String, title: String) -> Self {
let now = Utc::now(); let now = Utc::now();
@@ -75,6 +123,23 @@ impl Conversation {
Ok(()) Ok(())
} }
pub async fn get_user_sidebar_conversations(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<SidebarConversation>, AppError> {
let conversations: Vec<SidebarConversation> = db
.client
.query(
"SELECT id, title, updated_at FROM type::table($table_name) WHERE user_id = $user_id ORDER BY updated_at DESC",
)
.bind(("table_name", Self::table_name()))
.bind(("user_id", user_id.to_string()))
.await?
.take(0)?;
Ok(conversations)
}
} }
#[cfg(test)] #[cfg(test)]
@@ -249,6 +314,96 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_get_user_sidebar_conversations_filters_and_orders_by_updated_at_desc() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let user_id = "sidebar_user";
let other_user_id = "other_user";
let base = Utc::now();
let mut oldest = Conversation::new(user_id.to_string(), "Oldest".to_string());
oldest.updated_at = base - chrono::Duration::minutes(30);
let mut newest = Conversation::new(user_id.to_string(), "Newest".to_string());
newest.updated_at = base - chrono::Duration::minutes(5);
let mut middle = Conversation::new(user_id.to_string(), "Middle".to_string());
middle.updated_at = base - chrono::Duration::minutes(15);
let mut other_user = Conversation::new(other_user_id.to_string(), "Other".to_string());
other_user.updated_at = base;
db.store_item(oldest.clone())
.await
.expect("Failed to store oldest conversation");
db.store_item(newest.clone())
.await
.expect("Failed to store newest conversation");
db.store_item(middle.clone())
.await
.expect("Failed to store middle conversation");
db.store_item(other_user)
.await
.expect("Failed to store other-user conversation");
let sidebar_items = Conversation::get_user_sidebar_conversations(user_id, &db)
.await
.expect("Failed to get sidebar conversations");
assert_eq!(sidebar_items.len(), 3);
assert_eq!(sidebar_items[0].id, newest.id);
assert_eq!(sidebar_items[0].title, "Newest");
assert_eq!(sidebar_items[1].id, middle.id);
assert_eq!(sidebar_items[1].title, "Middle");
assert_eq!(sidebar_items[2].id, oldest.id);
assert_eq!(sidebar_items[2].title, "Oldest");
}
#[tokio::test]
async fn test_sidebar_projection_reflects_patch_title_and_updated_at_reorder() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let user_id = "sidebar_patch_user";
let base = Utc::now();
let mut first = Conversation::new(user_id.to_string(), "First".to_string());
first.updated_at = base - chrono::Duration::minutes(20);
let mut second = Conversation::new(user_id.to_string(), "Second".to_string());
second.updated_at = base - chrono::Duration::minutes(10);
db.store_item(first.clone())
.await
.expect("Failed to store first conversation");
db.store_item(second.clone())
.await
.expect("Failed to store second conversation");
let before_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
.await
.expect("Failed to get sidebar conversations before patch");
assert_eq!(before_patch[0].id, second.id);
Conversation::patch_title(&first.id, user_id, "First (renamed)", &db)
.await
.expect("Failed to patch conversation title");
let after_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
.await
.expect("Failed to get sidebar conversations after patch");
assert_eq!(after_patch[0].id, first.id);
assert_eq!(after_patch[0].title, "First (renamed)");
}
#[tokio::test] #[tokio::test]
async fn test_get_complete_conversation_with_messages() { async fn test_get_complete_conversation_with_messages() {
// Setup in-memory database for testing // Setup in-memory database for testing

View File

@@ -3,7 +3,10 @@ use bytes;
use mime_guess::from_path; use mime_guess::from_path;
use object_store::Error as ObjectStoreError; use object_store::Error as ObjectStoreError;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::{io::{BufReader, Read}, path::Path}; use std::{
io::{BufReader, Read},
path::Path,
};
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use thiserror::Error; use thiserror::Error;
use tokio::task; use tokio::task;
@@ -134,8 +137,12 @@ impl FileInfo {
/// # Returns /// # Returns
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found. /// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> { async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256); let mut response = db_client
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?; .client
.query("SELECT * FROM file WHERE sha256 = $sha256 LIMIT 1")
.bind(("sha256", sha256.to_owned()))
.await?;
let response: Vec<FileInfo> = response.take(0)?;
response response
.into_iter() .into_iter()
@@ -662,6 +669,36 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_get_by_sha_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let now = Utc::now();
let file_info = FileInfo {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
user_id: "user123".to_string(),
sha256: "known_sha_value".to_string(),
path: "/path/to/file.txt".to_string(),
file_name: "file.txt".to_string(),
mime_type: "text/plain".to_string(),
};
db.store_item(file_info)
.await
.expect("Failed to store test file info");
let malicious_sha = "known_sha_value' OR true --";
let result = FileInfo::get_by_sha(malicious_sha, &db).await;
assert!(matches!(result, Err(FileError::FileNotFound(_))));
}
#[tokio::test] #[tokio::test]
async fn test_manual_file_info_creation() { async fn test_manual_file_info_creation() {
let namespace = "test_ns"; let namespace = "test_ns";

View File

@@ -171,12 +171,18 @@ impl KnowledgeEntity {
source_id: &str, source_id: &str,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let query = format!( // Delete embeddings first, while we can still look them up via the entity's source_id
"DELETE {} WHERE source_id = '{}'", KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
Self::table_name(),
source_id db_client
); .client
db_client.query(query).await?; .query("DELETE FROM type::table($table) WHERE source_id = $source_id")
.bind(("table", Self::table_name()))
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(()) Ok(())
} }
@@ -224,7 +230,7 @@ impl KnowledgeEntity {
) -> Result<Vec<KnowledgeEntityVectorResult>, AppError> { ) -> Result<Vec<KnowledgeEntityVectorResult>, AppError> {
#[derive(Deserialize)] #[derive(Deserialize)]
struct Row { struct Row {
entity_id: KnowledgeEntity, entity_id: Option<KnowledgeEntity>,
score: f32, score: f32,
} }
@@ -257,9 +263,11 @@ impl KnowledgeEntity {
Ok(rows Ok(rows
.into_iter() .into_iter()
.map(|r| KnowledgeEntityVectorResult { .filter_map(|r| {
entity: r.entity_id, r.entity_id.map(|entity| KnowledgeEntityVectorResult {
score: r.score, entity,
score: r.score,
})
}) })
.collect()) .collect())
} }
@@ -460,7 +468,11 @@ impl KnowledgeEntity {
for (i, entity) in all_entities.iter().enumerate() { for (i, entity) in all_entities.iter().enumerate() {
if i > 0 && i % 100 == 0 { if i > 0 && i % 100 == 0 {
info!(progress = i, total = total_entities, "Re-embedding progress"); info!(
progress = i,
total = total_entities,
"Re-embedding progress"
);
} }
let embedding_input = format!( let embedding_input = format!(
@@ -485,6 +497,32 @@ impl KnowledgeEntity {
new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone())); new_embeddings.insert(entity.id.clone(), (embedding, entity.user_id.clone()));
} }
info!("Successfully generated all new embeddings."); info!("Successfully generated all new embeddings.");
info!("Successfully generated all new embeddings.");
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
info!("Removing old index and clearing embeddings...");
// Explicitly remove the index first. This prevents background HNSW maintenance from crashing
// when we delete/replace data, dealing with a known SurrealDB panic.
db.client
.query(format!(
"REMOVE INDEX idx_embedding_knowledge_entity_embedding ON TABLE {};",
KnowledgeEntityEmbedding::table_name()
))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
db.client
.query(format!(
"DELETE FROM {};",
KnowledgeEntityEmbedding::table_name()
))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
// Perform DB updates in a single transaction // Perform DB updates in a single transaction
info!("Applying embedding updates in a transaction..."); info!("Applying embedding updates in a transaction...");
@@ -500,11 +538,11 @@ impl KnowledgeEntity {
.join(",") .join(",")
); );
transaction_query.push_str(&format!( transaction_query.push_str(&format!(
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \ "CREATE type::thing('knowledge_entity_embedding', '{id}') SET \
entity_id = type::thing('knowledge_entity', '{id}'), \ entity_id = type::thing('knowledge_entity', '{id}'), \
embedding = {embedding}, \ embedding = {embedding}, \
user_id = '{user_id}', \ user_id = '{user_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ created_at = time::now(), \
updated_at = time::now();", updated_at = time::now();",
id = id, id = id,
embedding = embedding_str, embedding = embedding_str,
@@ -520,7 +558,12 @@ impl KnowledgeEntity {
transaction_query.push_str("COMMIT TRANSACTION;"); transaction_query.push_str("COMMIT TRANSACTION;");
// Execute the entire atomic operation // Execute the entire atomic operation
db.query(transaction_query).await?; db.client
.query(transaction_query)
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
info!("Re-embedding process for knowledge entities completed successfully."); info!("Re-embedding process for knowledge entities completed successfully.");
Ok(()) Ok(())
@@ -721,6 +764,69 @@ mod tests {
assert_eq!(different_remaining[0].id, different_entity.id); assert_eq!(different_remaining[0].id, different_entity.id);
} }
#[tokio::test]
async fn test_delete_by_source_id_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let user_id = "user123".to_string();
let entity1 = KnowledgeEntity::new(
"safe_source".to_string(),
"Entity 1".to_string(),
"Description 1".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
let entity2 = KnowledgeEntity::new(
"other_source".to_string(),
"Entity 2".to_string(),
"Description 2".to_string(),
KnowledgeEntityType::Document,
None,
user_id,
);
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db)
.await
.expect("store entity1");
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db)
.await
.expect("store entity2");
let malicious_source = "safe_source' OR 1=1 --";
KnowledgeEntity::delete_by_source_id(malicious_source, &db)
.await
.expect("delete call should succeed");
let remaining: Vec<KnowledgeEntity> = db
.client
.query("SELECT * FROM type::table($table)")
.bind(("table", KnowledgeEntity::table_name()))
.await
.expect("query failed")
.take(0)
.expect("take failed");
assert_eq!(
remaining.len(),
2,
"malicious input must not delete unrelated entities"
);
}
#[tokio::test] #[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() { async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns"; let namespace = "test_ns";
@@ -879,4 +985,50 @@ mod tests {
assert_eq!(results[0].entity.id, e2.id); assert_eq!(results[0].entity.id, e2.id);
assert_eq!(results[1].entity.id, e1.id); assert_eq!(results[1].entity.id, e1.id);
} }
#[tokio::test]
async fn test_vector_search_with_orphaned_embedding() {
let namespace = "test_ns_orphan";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let user_id = "user".to_string();
let source_id = "src".to_string();
let entity = KnowledgeEntity::new(
source_id.clone(),
"orphan".to_string(),
"orphan desc".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
KnowledgeEntity::store_with_embedding(entity.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store entity with embedding");
// Manually delete the entity to create an orphan
let query = format!("DELETE type::thing('knowledge_entity', '{}')", entity.id);
db.client.query(query).await.expect("delete entity");
// Now search
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
.await
.expect("search should succeed even with orphans");
assert!(
results.is_empty(),
"Should return empty result for orphan, got: {:?}",
results
);
}
} }

View File

@@ -40,22 +40,28 @@ impl KnowledgeRelationship {
} }
} }
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> { pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
let query = format!( db_client
r#"DELETE relates_to:`{rel_id}`; .client
RELATE knowledge_entity:`{in_id}`->relates_to:`{rel_id}`->knowledge_entity:`{out_id}` .query(
SET r#"BEGIN TRANSACTION;
metadata.user_id = '{user_id}', LET $in_entity = type::thing('knowledge_entity', $in_id);
metadata.source_id = '{source_id}', LET $out_entity = type::thing('knowledge_entity', $out_id);
metadata.relationship_type = '{relationship_type}'"#, LET $relation = type::thing('relates_to', $rel_id);
rel_id = self.id, DELETE type::thing('relates_to', $rel_id);
in_id = self.in_, RELATE $in_entity->$relation->$out_entity SET
out_id = self.out, metadata.user_id = $user_id,
user_id = self.metadata.user_id.as_str(), metadata.source_id = $source_id,
source_id = self.metadata.source_id.as_str(), metadata.relationship_type = $relationship_type;
relationship_type = self.metadata.relationship_type.as_str() COMMIT TRANSACTION;"#,
); )
.bind(("rel_id", self.id.clone()))
db_client.query(query).await?; .bind(("in_id", self.in_.clone()))
.bind(("out_id", self.out.clone()))
.bind(("user_id", self.metadata.user_id.clone()))
.bind(("source_id", self.metadata.source_id.clone()))
.bind(("relationship_type", self.metadata.relationship_type.clone()))
.await?
.check()?;
Ok(()) Ok(())
} }
@@ -64,11 +70,12 @@ impl KnowledgeRelationship {
source_id: &str, source_id: &str,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let query = format!( db_client
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'" .client
); .query("DELETE FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.to_owned()))
db_client.query(query).await?; .await?
.check()?;
Ok(()) Ok(())
} }
@@ -79,15 +86,20 @@ impl KnowledgeRelationship {
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let mut authorized_result = db_client let mut authorized_result = db_client
.query(format!( .client
"SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'" .query(
)) "SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id",
)
.bind(("id", id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?; .await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default(); let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
if authorized.is_empty() { if authorized.is_empty() {
let mut exists_result = db_client let mut exists_result = db_client
.query(format!("SELECT * FROM relates_to:`{id}`")) .client
.query("SELECT * FROM type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?; .await?;
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?; let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
@@ -100,8 +112,11 @@ impl KnowledgeRelationship {
} }
} else { } else {
db_client db_client
.query(format!("DELETE relates_to:`{id}`")) .client
.await?; .query("DELETE type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?
.check()?;
Ok(()) Ok(())
} }
} }
@@ -112,6 +127,34 @@ mod tests {
use super::*; use super::*;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}; use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
}
async fn get_relationship_by_id(
relationship_id: &str,
db_client: &SurrealDbClient,
) -> Option<KnowledgeRelationship> {
let mut result = db_client
.client
.query("SELECT * FROM type::thing('relates_to', $id)")
.bind(("id", relationship_id.to_owned()))
.await
.expect("relationship query by id failed");
result.take(0).expect("failed to take relationship by id")
}
// Helper function to create a test knowledge entity for the relationship tests // Helper function to create a test knowledge entity for the relationship tests
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String { async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
let source_id = "source123".to_string(); let source_id = "source123".to_string();
@@ -161,13 +204,9 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_store_relationship() { async fn test_store_and_verify_by_source_id() {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let db = setup_test_db().await;
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
// Create two entities to relate // Create two entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await; let entity1_id = create_test_entity("Entity 1", &db).await;
@@ -192,30 +231,69 @@ mod tests {
.await .await
.expect("Failed to store relationship"); .expect("Failed to store relationship");
let persisted = get_relationship_by_id(&relationship.id, &db)
.await
.expect("Relationship should be retrievable by id");
assert_eq!(persisted.in_, entity1_id);
assert_eq!(persisted.out, entity2_id);
assert_eq!(persisted.metadata.user_id, user_id);
assert_eq!(persisted.metadata.source_id, source_id);
// Query to verify the relationship exists by checking for relationships with our source_id // Query to verify the relationship exists by checking for relationships with our source_id
// This approach is more reliable than trying to look up by ID // This approach is more reliable than trying to look up by ID
let check_query = format!( let mut check_result = db
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'", .query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
source_id .bind(("source_id", source_id.clone()))
); .await
let mut check_result = db.query(check_query).await.expect("Check query failed"); .expect("Check query failed");
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default(); let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
// Just verify that a relationship was created assert_eq!(
assert!( check_results.len(),
!check_results.is_empty(), 1,
"Relationship should exist in the database" "Expected one relationship for source_id"
); );
} }
#[tokio::test] #[tokio::test]
async fn test_delete_relationship_by_id() { async fn test_store_relationship_resists_query_injection() {
// Setup in-memory database for testing let db = setup_test_db().await;
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string(); let entity1_id = create_test_entity("Entity 1", &db).await;
let db = SurrealDbClient::memory(namespace, database) let entity2_id = create_test_entity("Entity 2", &db).await;
let relationship = KnowledgeRelationship::new(
entity1_id,
entity2_id,
"user'123".to_string(),
"source123'; DELETE FROM relates_to; --".to_string(),
"references'; UPDATE user SET admin = true; --".to_string(),
);
relationship
.store_relationship(&db)
.await .await
.expect("Failed to start in-memory surrealdb"); .expect("store relationship should safely handle quote-containing values");
let mut res = db
.client
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
.bind(("id", relationship.id.clone()))
.await
.expect("query relationship by id failed");
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
assert_eq!(rows.len(), 1);
assert_eq!(
rows[0].metadata.source_id,
"source123'; DELETE FROM relates_to; --"
);
}
#[tokio::test]
async fn test_store_and_delete_relationship() {
// Setup in-memory database for testing
let db = setup_test_db().await;
// Create two entities to relate // Create two entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await; let entity1_id = create_test_entity("Entity 1", &db).await;
@@ -234,7 +312,7 @@ mod tests {
relationship_type, relationship_type,
); );
// Store the relationship // Store relationship
relationship relationship
.store_relationship(&db) .store_relationship(&db)
.await .await
@@ -255,12 +333,12 @@ mod tests {
"Relationship should exist before deletion" "Relationship should exist before deletion"
); );
// Delete the relationship by ID // Delete relationship by ID
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db) KnowledgeRelationship::delete_relationship_by_id(&relationship.id, &user_id, &db)
.await .await
.expect("Failed to delete relationship by ID"); .expect("Failed to delete relationship by ID");
// Query to verify the relationship was deleted // Query to verify relationship was deleted
let mut result = db let mut result = db
.query(format!( .query(format!(
"SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'", "SELECT * FROM relates_to WHERE metadata.user_id = '{}' AND metadata.source_id = '{}'",
@@ -270,17 +348,13 @@ mod tests {
.expect("Query failed"); .expect("Query failed");
let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default(); let results: Vec<KnowledgeRelationship> = result.take(0).unwrap_or_default();
// Verify the relationship no longer exists // Verify relationship no longer exists
assert!(results.is_empty(), "Relationship should be deleted"); assert!(results.is_empty(), "Relationship should be deleted");
} }
#[tokio::test] #[tokio::test]
async fn test_delete_relationship_by_id_unauthorized() { async fn test_delete_relationship_by_id_unauthorized() {
let namespace = "test_ns"; let db = setup_test_db().await;
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let entity1_id = create_test_entity("Entity 1", &db).await; let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await; let entity2_id = create_test_entity("Entity 2", &db).await;
@@ -342,13 +416,9 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn test_delete_relationships_by_source_id() { async fn test_store_relationship_exists() {
// Setup in-memory database for testing // Setup in-memory database for testing
let namespace = "test_ns"; let db = setup_test_db().await;
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
// Create entities to relate // Create entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await; let entity1_id = create_test_entity("Entity 1", &db).await;
@@ -400,49 +470,87 @@ mod tests {
.await .await
.expect("Failed to store different relationship"); .expect("Failed to store different relationship");
// Sanity-check setup: exactly two relationships use source_id and one uses different_source_id.
let mut before_delete = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.clone()))
.await
.expect("before delete query failed");
let before_delete_rows: Vec<KnowledgeRelationship> =
before_delete.take(0).unwrap_or_default();
assert_eq!(before_delete_rows.len(), 2);
let mut before_delete_different = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", different_source_id.clone()))
.await
.expect("before delete different query failed");
let before_delete_different_rows: Vec<KnowledgeRelationship> =
before_delete_different.take(0).unwrap_or_default();
assert_eq!(before_delete_different_rows.len(), 1);
// Delete relationships by source_id // Delete relationships by source_id
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db) KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
.await .await
.expect("Failed to delete relationships by source_id"); .expect("Failed to delete relationships by source_id");
// Query to verify the relationships with source_id were deleted // Query to verify the specific relationships with source_id were deleted.
let query1 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship1.id); let result1 = get_relationship_by_id(&relationship1.id, &db).await;
let query2 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship2.id); let result2 = get_relationship_by_id(&relationship2.id, &db).await;
let different_query = format!( let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
"SELECT * FROM relates_to WHERE id = '{}'",
different_relationship.id
);
let mut result1 = db.query(query1).await.expect("Query 1 failed");
let results1: Vec<KnowledgeRelationship> = result1.take(0).unwrap_or_default();
let mut result2 = db.query(query2).await.expect("Query 2 failed");
let results2: Vec<KnowledgeRelationship> = result2.take(0).unwrap_or_default();
let mut different_result = db
.query(different_query)
.await
.expect("Different query failed");
let _different_results: Vec<KnowledgeRelationship> =
different_result.take(0).unwrap_or_default();
// Verify relationships with the source_id are deleted // Verify relationships with the source_id are deleted
assert!(results1.is_empty(), "Relationship 1 should be deleted"); assert!(result1.is_none(), "Relationship 1 should be deleted");
assert!(results2.is_empty(), "Relationship 2 should be deleted"); assert!(result2.is_none(), "Relationship 2 should be deleted");
let remaining =
different_result.expect("Relationship with different source_id should remain");
assert_eq!(remaining.metadata.source_id, different_source_id);
}
// For the relationship with different source ID, we need to check differently #[tokio::test]
// Let's just verify we have a relationship where the source_id matches different_source_id async fn test_delete_relationships_by_source_id_resists_query_injection() {
let check_query = format!( let db = setup_test_db().await;
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
different_source_id let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let entity3_id = create_test_entity("Entity 3", &db).await;
let safe_relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
"user123".to_string(),
"safe_source".to_string(),
"references".to_string(),
); );
let mut check_result = db.query(check_query).await.expect("Check query failed");
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
// Verify the relationship with a different source_id still exists let other_relationship = KnowledgeRelationship::new(
entity2_id,
entity3_id,
"user123".to_string(),
"other_source".to_string(),
"contains".to_string(),
);
safe_relationship
.store_relationship(&db)
.await
.expect("store safe relationship");
other_relationship
.store_relationship(&db)
.await
.expect("store other relationship");
KnowledgeRelationship::delete_relationships_by_source_id("safe_source' OR 1=1 --", &db)
.await
.expect("delete call should succeed");
let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await;
let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await;
assert!(remaining_safe.is_some(), "Safe relationship should remain");
assert!( assert!(
!check_results.is_empty(), remaining_other.is_some(),
"Relationship with different source_id should still exist" "Other relationship should remain"
); );
} }
} }

View File

@@ -116,7 +116,7 @@ macro_rules! stored_object {
} }
$(#[$struct_attr])* $(#[$struct_attr])*
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct $name { pub struct $name {
#[serde(deserialize_with = "deserialize_flexible_id")] #[serde(deserialize_with = "deserialize_flexible_id")]

View File

@@ -44,12 +44,18 @@ impl TextChunk {
source_id: &str, source_id: &str,
db_client: &SurrealDbClient, db_client: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let query = format!( // Delete embeddings first
"DELETE {} WHERE source_id = '{}'", TextChunkEmbedding::delete_by_source_id(source_id, db_client).await?;
Self::table_name(),
source_id db_client
); .client
db_client.query(query).await?; .query("DELETE FROM type::table($table) WHERE source_id = $source_id")
.bind(("table", Self::table_name()))
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(()) Ok(())
} }
@@ -102,7 +108,7 @@ impl TextChunk {
#[allow(clippy::missing_docs_in_private_items)] #[allow(clippy::missing_docs_in_private_items)]
#[derive(Deserialize)] #[derive(Deserialize)]
struct Row { struct Row {
chunk_id: TextChunk, chunk_id: Option<TextChunk>,
score: f32, score: f32,
} }
@@ -134,9 +140,11 @@ impl TextChunk {
Ok(rows Ok(rows
.into_iter() .into_iter()
.map(|r| TextChunkSearchResult { .filter_map(|r| {
chunk: r.chunk_id, r.chunk_id.map(|chunk| TextChunkSearchResult {
score: r.score, chunk,
score: r.score,
})
}) })
.collect()) .collect())
} }
@@ -352,12 +360,12 @@ impl TextChunk {
// Generate all new embeddings in memory // Generate all new embeddings in memory
let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new(); let mut new_embeddings: HashMap<String, (Vec<f32>, String, String)> = HashMap::new();
info!("Generating new embeddings for all chunks..."); info!("Generating new embeddings for all chunks...");
for (i, chunk) in all_chunks.iter().enumerate() { for (i, chunk) in all_chunks.iter().enumerate() {
if i > 0 && i % 100 == 0 { if i > 0 && i % 100 == 0 {
info!(progress = i, total = total_chunks, "Re-embedding progress"); info!(progress = i, total = total_chunks, "Re-embedding progress");
} }
let embedding = provider let embedding = provider
.embed(&chunk.chunk) .embed(&chunk.chunk)
.await .await
@@ -379,6 +387,28 @@ impl TextChunk {
} }
info!("Successfully generated all new embeddings."); info!("Successfully generated all new embeddings.");
// Clear existing embeddings and index first to prevent SurrealDB panics and dimension conflicts.
info!("Removing old index and clearing embeddings...");
// Explicitly remove the index first. This prevents background HNSW maintenance from crashing
// when we delete/replace data, dealing with a known SurrealDB panic.
db.client
.query(format!(
"REMOVE INDEX idx_embedding_text_chunk_embedding ON TABLE {};",
TextChunkEmbedding::table_name()
))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
db.client
.query(format!("DELETE FROM {};", TextChunkEmbedding::table_name()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
// Perform DB updates in a single transaction against the embedding table // Perform DB updates in a single transaction against the embedding table
info!("Applying embedding updates in a transaction..."); info!("Applying embedding updates in a transaction...");
let mut transaction_query = String::from("BEGIN TRANSACTION;"); let mut transaction_query = String::from("BEGIN TRANSACTION;");
@@ -394,12 +424,12 @@ impl TextChunk {
); );
write!( write!(
&mut transaction_query, &mut transaction_query,
"UPSERT type::thing('text_chunk_embedding', '{id}') SET \ "CREATE type::thing('text_chunk_embedding', '{id}') SET \
chunk_id = type::thing('text_chunk', '{id}'), \ chunk_id = type::thing('text_chunk', '{id}'), \
source_id = '{source_id}', \ source_id = '{source_id}', \
embedding = {embedding}, \ embedding = {embedding}, \
user_id = '{user_id}', \ user_id = '{user_id}', \
created_at = IF created_at != NONE THEN created_at ELSE time::now() END, \ created_at = time::now(), \
updated_at = time::now();", updated_at = time::now();",
id = id, id = id,
embedding = embedding_str, embedding = embedding_str,
@@ -418,7 +448,12 @@ impl TextChunk {
transaction_query.push_str("COMMIT TRANSACTION;"); transaction_query.push_str("COMMIT TRANSACTION;");
db.query(transaction_query).await?; db.client
.query(transaction_query)
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
info!("Re-embedding process for text chunks completed successfully."); info!("Re-embedding process for text chunks completed successfully.");
Ok(()) Ok(())
@@ -585,6 +620,57 @@ mod tests {
assert_eq!(remaining.len(), 1); assert_eq!(remaining.len(), 1);
} }
#[tokio::test]
async fn test_delete_by_source_id_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
let chunk1 = TextChunk::new(
"safe_source".to_string(),
"Safe chunk".to_string(),
"user123".to_string(),
);
let chunk2 = TextChunk::new(
"other_source".to_string(),
"Other chunk".to_string(),
"user123".to_string(),
);
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("store chunk1");
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db)
.await
.expect("store chunk2");
let malicious_source = "safe_source' OR 1=1 --";
TextChunk::delete_by_source_id(malicious_source, &db)
.await
.expect("delete call should succeed");
let remaining: Vec<TextChunk> = db
.client
.query("SELECT * FROM type::table($table)")
.bind(("table", TextChunk::table_name()))
.await
.expect("query failed")
.take(0)
.expect("take failed");
assert_eq!(
remaining.len(),
2,
"malicious input must not delete unrelated rows"
);
}
#[tokio::test] #[tokio::test]
async fn test_store_with_embedding_creates_both_records() { async fn test_store_with_embedding_creates_both_records() {
let namespace = "test_ns"; let namespace = "test_ns";

View File

@@ -102,44 +102,19 @@ impl TextChunkEmbedding {
/// Delete all embeddings that belong to chunks with a given `source_id` /// Delete all embeddings that belong to chunks with a given `source_id`
/// ///
/// This uses a subquery to the `text_chunk` table: /// This uses the denormalized `source_id` on the embedding table.
///
/// DELETE FROM text_chunk_embedding
/// WHERE chunk_id IN (SELECT id FROM text_chunk WHERE source_id = $source_id)
pub async fn delete_by_source_id( pub async fn delete_by_source_id(
source_id: &str, source_id: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
#[allow(clippy::missing_docs_in_private_items)] let query = format!(
#[derive(Deserialize)] "DELETE FROM {} WHERE source_id = $source_id",
struct IdRow {
id: RecordId,
}
let ids_query = format!(
"SELECT id FROM {} WHERE source_id = $source_id",
TextChunk::table_name()
);
let mut res = db
.client
.query(ids_query)
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?;
let ids: Vec<IdRow> = res.take(0).map_err(AppError::Database)?;
if ids.is_empty() {
return Ok(());
}
let delete_query = format!(
"DELETE FROM {} WHERE chunk_id IN $chunk_ids",
Self::table_name() Self::table_name()
); );
db.client db.client
.query(delete_query) .query(query)
.bind(( .bind(("source_id", source_id.to_owned()))
"chunk_ids",
ids.into_iter().map(|row| row.id).collect::<Vec<_>>(),
))
.await .await
.map_err(AppError::Database)? .map_err(AppError::Database)?
.check() .check()

View File

@@ -25,6 +25,56 @@ pub struct CategoryResponse {
category: String, category: String,
} }
use std::str::FromStr;
/// Supported UI themes.
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum Theme {
Light,
Dark,
WarmPaper,
ObsidianPrism,
#[default]
System,
}
impl FromStr for Theme {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"light" => Ok(Self::Light),
"dark" => Ok(Self::Dark),
"warm-paper" => Ok(Self::WarmPaper),
"obsidian-prism" => Ok(Self::ObsidianPrism),
"system" => Ok(Self::System),
_ => Err(()),
}
}
}
impl Theme {
pub fn as_str(&self) -> &'static str {
match self {
Self::Light => "light",
Self::Dark => "dark",
Self::WarmPaper => "warm-paper",
Self::ObsidianPrism => "obsidian-prism",
Self::System => "system",
}
}
/// Returns the theme that should be initially applied.
/// For "system", defaults to "light".
pub fn initial_theme(&self) -> &'static str {
match self {
Self::System => "light",
other => other.as_str(),
}
}
}
stored_object!( stored_object!(
#[allow(clippy::unsafe_derive_deserialize)] #[allow(clippy::unsafe_derive_deserialize)]
User, "user", { User, "user", {
@@ -34,9 +84,21 @@ stored_object!(
api_key: Option<String>, api_key: Option<String>,
admin: bool, admin: bool,
#[serde(default)] #[serde(default)]
timezone: String timezone: String,
#[serde(default, deserialize_with = "deserialize_theme_or_default")]
theme: Theme
}); });
fn deserialize_theme_or_default<'de, D>(deserializer: D) -> Result<Theme, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = Option::<String>::deserialize(deserializer)?;
Ok(raw
.and_then(|value| Theme::from_str(value.as_str()).ok())
.unwrap_or_default())
}
#[async_trait] #[async_trait]
impl Authentication<User, String, Surreal<Any>> for User { impl Authentication<User, String, Surreal<Any>> for User {
async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> { async fn load_user(userid: String, db: Option<&Surreal<Any>>) -> Result<User, anyhow::Error> {
@@ -70,6 +132,11 @@ fn validate_timezone(input: &str) -> String {
"UTC".to_owned() "UTC".to_owned()
} }
/// Ensures a theme string is valid, defaulting to "system" when invalid.
fn validate_theme(input: &str) -> Theme {
Theme::from_str(input).unwrap_or_default()
}
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct DashboardStats { pub struct DashboardStats {
pub total_documents: i64, pub total_documents: i64,
@@ -168,6 +235,7 @@ impl User {
password: String, password: String,
db: &SurrealDbClient, db: &SurrealDbClient,
timezone: String, timezone: String,
theme: String,
) -> Result<Self, AppError> { ) -> Result<Self, AppError> {
// verify that the application allows new creations // verify that the application allows new creations
let systemsettings = SystemSettings::get_current(db).await?; let systemsettings = SystemSettings::get_current(db).await?;
@@ -176,6 +244,7 @@ impl User {
} }
let validated_tz = validate_timezone(&timezone); let validated_tz = validate_timezone(&timezone);
let validated_theme = validate_theme(&theme);
let now = Utc::now(); let now = Utc::now();
let id = Uuid::new_v4().to_string(); let id = Uuid::new_v4().to_string();
@@ -190,7 +259,8 @@ impl User {
anonymous = false, anonymous = false,
created_at = $created_at, created_at = $created_at,
updated_at = $updated_at, updated_at = $updated_at,
timezone = $timezone", timezone = $timezone,
theme = $theme",
) )
.bind(("table", "user")) .bind(("table", "user"))
.bind(("id", id)) .bind(("id", id))
@@ -199,6 +269,7 @@ impl User {
.bind(("created_at", surrealdb::Datetime::from(now))) .bind(("created_at", surrealdb::Datetime::from(now)))
.bind(("updated_at", surrealdb::Datetime::from(now))) .bind(("updated_at", surrealdb::Datetime::from(now)))
.bind(("timezone", validated_tz)) .bind(("timezone", validated_tz))
.bind(("theme", validated_theme.as_str()))
.await? .await?
.take(1)?; .take(1)?;
@@ -468,6 +539,19 @@ impl User {
Ok(()) Ok(())
} }
pub async fn update_theme(
user_id: &str,
theme: &str,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let validated_theme = validate_theme(theme);
db.query("UPDATE type::thing('user', $user_id) SET theme = $theme")
.bind(("user_id", user_id.to_string()))
.bind(("theme", validated_theme.as_str()))
.await?;
Ok(())
}
pub async fn get_user_categories( pub async fn get_user_categories(
user_id: &str, user_id: &str,
db: &SurrealDbClient, db: &SurrealDbClient,
@@ -674,6 +758,7 @@ mod tests {
password.to_string(), password.to_string(),
&db, &db,
timezone.to_string(), timezone.to_string(),
"system".to_string(),
) )
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
@@ -711,6 +796,7 @@ mod tests {
password.to_string(), password.to_string(),
&db, &db,
"UTC".to_string(), "UTC".to_string(),
"system".to_string(),
) )
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
@@ -858,6 +944,7 @@ mod tests {
password.to_string(), password.to_string(),
&db, &db,
"UTC".to_string(), "UTC".to_string(),
"system".to_string(),
) )
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
@@ -892,6 +979,7 @@ mod tests {
password.to_string(), password.to_string(),
&db, &db,
"UTC".to_string(), "UTC".to_string(),
"system".to_string(),
) )
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
@@ -944,6 +1032,42 @@ mod tests {
assert!(not_found.is_none()); assert!(not_found.is_none());
} }
#[tokio::test]
async fn test_set_api_key_with_none_theme() {
let db = setup_test_db().await;
let user = User::create_new(
"legacy_theme@example.com".to_string(),
"apikey_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("Failed to create user");
db.client
.query("UPDATE type::thing('user', $id) SET theme = NONE")
.bind(("id", user.id.clone()))
.await
.expect("Failed to set user theme to NONE");
let api_key = User::set_api_key(&user.id, &db)
.await
.expect("set_api_key should tolerate NONE theme");
assert!(api_key.starts_with("sk_"));
let updated_user = db
.get_item::<User>(&user.id)
.await
.expect("Failed to retrieve user")
.expect("User should still exist");
assert_eq!(updated_user.theme, Theme::System);
assert_eq!(updated_user.api_key, Some(api_key));
}
#[tokio::test] #[tokio::test]
async fn test_password_update() { async fn test_password_update() {
// Setup test database // Setup test database
@@ -959,6 +1083,7 @@ mod tests {
old_password.to_string(), old_password.to_string(),
&db, &db,
"UTC".to_string(), "UTC".to_string(),
"system".to_string(),
) )
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
@@ -1006,6 +1131,7 @@ mod tests {
"password".to_string(), "password".to_string(),
&db, &db,
"UTC".to_string(), "UTC".to_string(),
"system".to_string(),
) )
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
@@ -1116,4 +1242,51 @@ mod tests {
); );
} }
} }
#[tokio::test]
async fn test_validate_theme() {
assert_eq!(validate_theme("light"), Theme::Light);
assert_eq!(validate_theme("dark"), Theme::Dark);
assert_eq!(validate_theme("system"), Theme::System);
assert_eq!(validate_theme("invalid"), Theme::System);
}
#[tokio::test]
async fn test_theme_update() {
let db = setup_test_db().await;
let email = "theme_test@example.com";
let user = User::create_new(
email.to_string(),
"password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("Failed to create user");
assert_eq!(user.theme, Theme::System);
User::update_theme(&user.id, "dark", &db)
.await
.expect("update theme");
let updated = db
.get_item::<User>(&user.id)
.await
.expect("get user")
.unwrap();
assert_eq!(updated.theme, Theme::Dark);
// Invalid theme should default to system (but update_theme calls validate_theme)
User::update_theme(&user.id, "invalid", &db)
.await
.expect("update theme invalid");
let updated2 = db
.get_item::<User>(&user.id)
.await
.expect("get user")
.unwrap();
assert_eq!(updated2.theme, Theme::System);
}
} }

View File

@@ -20,6 +20,7 @@ pub enum EmbeddingBackend {
pub enum StorageKind { pub enum StorageKind {
Local, Local,
Memory, Memory,
S3,
} }
/// Default storage backend when none is configured. /// Default storage backend when none is configured.
@@ -27,6 +28,10 @@ fn default_storage_kind() -> StorageKind {
StorageKind::Local StorageKind::Local
} }
fn default_s3_region() -> Option<String> {
Some("us-east-1".to_string())
}
/// Selects the strategy used for PDF ingestion. /// Selects the strategy used for PDF ingestion.
#[derive(Clone, Deserialize, Debug)] #[derive(Clone, Deserialize, Debug)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
@@ -59,6 +64,12 @@ pub struct AppConfig {
pub openai_base_url: String, pub openai_base_url: String,
#[serde(default = "default_storage_kind")] #[serde(default = "default_storage_kind")]
pub storage: StorageKind, pub storage: StorageKind,
#[serde(default)]
pub s3_bucket: Option<String>,
#[serde(default)]
pub s3_endpoint: Option<String>,
#[serde(default = "default_s3_region")]
pub s3_region: Option<String>,
#[serde(default = "default_pdf_ingest_mode")] #[serde(default = "default_pdf_ingest_mode")]
pub pdf_ingest_mode: PdfIngestMode, pub pdf_ingest_mode: PdfIngestMode,
#[serde(default = "default_reranking_enabled")] #[serde(default = "default_reranking_enabled")]
@@ -75,6 +86,16 @@ pub struct AppConfig {
pub retrieval_strategy: Option<String>, pub retrieval_strategy: Option<String>,
#[serde(default)] #[serde(default)]
pub embedding_backend: EmbeddingBackend, pub embedding_backend: EmbeddingBackend,
#[serde(default = "default_ingest_max_body_bytes")]
pub ingest_max_body_bytes: usize,
#[serde(default = "default_ingest_max_files")]
pub ingest_max_files: usize,
#[serde(default = "default_ingest_max_content_bytes")]
pub ingest_max_content_bytes: usize,
#[serde(default = "default_ingest_max_context_bytes")]
pub ingest_max_context_bytes: usize,
#[serde(default = "default_ingest_max_category_bytes")]
pub ingest_max_category_bytes: usize,
} }
/// Default data directory for persisted assets. /// Default data directory for persisted assets.
@@ -92,6 +113,26 @@ fn default_reranking_enabled() -> bool {
false false
} }
fn default_ingest_max_body_bytes() -> usize {
20_000_000
}
fn default_ingest_max_files() -> usize {
5
}
fn default_ingest_max_content_bytes() -> usize {
262_144
}
fn default_ingest_max_context_bytes() -> usize {
16_384
}
fn default_ingest_max_category_bytes() -> usize {
128
}
pub fn ensure_ort_path() { pub fn ensure_ort_path() {
if env::var_os("ORT_DYLIB_PATH").is_some() { if env::var_os("ORT_DYLIB_PATH").is_some() {
return; return;
@@ -135,6 +176,9 @@ impl Default for AppConfig {
http_port: 0, http_port: 0,
openai_base_url: default_base_url(), openai_base_url: default_base_url(),
storage: default_storage_kind(), storage: default_storage_kind(),
s3_bucket: None,
s3_endpoint: None,
s3_region: default_s3_region(),
pdf_ingest_mode: default_pdf_ingest_mode(), pdf_ingest_mode: default_pdf_ingest_mode(),
reranking_enabled: default_reranking_enabled(), reranking_enabled: default_reranking_enabled(),
reranking_pool_size: None, reranking_pool_size: None,
@@ -143,6 +187,11 @@ impl Default for AppConfig {
fastembed_max_length: None, fastembed_max_length: None,
retrieval_strategy: None, retrieval_strategy: None,
embedding_backend: EmbeddingBackend::default(), embedding_backend: EmbeddingBackend::default(),
ingest_max_body_bytes: default_ingest_max_body_bytes(),
ingest_max_files: default_ingest_max_files(),
ingest_max_content_bytes: default_ingest_max_content_bytes(),
ingest_max_context_bytes: default_ingest_max_context_bytes(),
ingest_max_category_bytes: default_ingest_max_category_bytes(),
} }
} }
} }

View File

@@ -250,9 +250,8 @@ impl EmbeddingProvider {
match config.embedding_backend { match config.embedding_backend {
EmbeddingBackend::OpenAI => { EmbeddingBackend::OpenAI => {
let client = openai_client.ok_or_else(|| { let client = openai_client
anyhow!("OpenAI embedding backend requires an OpenAI client") .ok_or_else(|| anyhow!("OpenAI embedding backend requires an OpenAI client"))?;
})?;
// Use defaults that match SystemSettings initial values // Use defaults that match SystemSettings initial values
Self::new_openai(client, "text-embedding-3-small".to_string(), 1536) Self::new_openai(client, "text-embedding-3-small".to_string(), 1536)
} }

View File

@@ -0,0 +1,113 @@
use super::config::AppConfig;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IngestValidationError {
PayloadTooLarge(String),
BadRequest(String),
}
pub fn validate_ingest_input(
config: &AppConfig,
content: Option<&str>,
context: &str,
category: &str,
file_count: usize,
) -> Result<(), IngestValidationError> {
if file_count > config.ingest_max_files {
return Err(IngestValidationError::BadRequest(format!(
"Too many files. Maximum allowed is {}",
config.ingest_max_files
)));
}
if let Some(content) = content {
if content.len() > config.ingest_max_content_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Content is too large. Maximum allowed is {} bytes",
config.ingest_max_content_bytes
)));
}
}
if context.len() > config.ingest_max_context_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Context is too large. Maximum allowed is {} bytes",
config.ingest_max_context_bytes
)));
}
if category.len() > config.ingest_max_category_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Category is too large. Maximum allowed is {} bytes",
config.ingest_max_category_bytes
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_ingest_input_rejects_too_many_files() {
let config = AppConfig {
ingest_max_files: 1,
..Default::default()
};
let result = validate_ingest_input(&config, Some("ok"), "ctx", "cat", 2);
assert!(matches!(result, Err(IngestValidationError::BadRequest(_))));
}
#[test]
fn validate_ingest_input_rejects_oversized_content() {
let config = AppConfig {
ingest_max_content_bytes: 4,
..Default::default()
};
let result = validate_ingest_input(&config, Some("12345"), "ctx", "cat", 0);
assert!(matches!(
result,
Err(IngestValidationError::PayloadTooLarge(_))
));
}
#[test]
fn validate_ingest_input_rejects_oversized_context() {
let config = AppConfig {
ingest_max_context_bytes: 2,
..Default::default()
};
let result = validate_ingest_input(&config, None, "long", "cat", 0);
assert!(matches!(
result,
Err(IngestValidationError::PayloadTooLarge(_))
));
}
#[test]
fn validate_ingest_input_rejects_oversized_category() {
let config = AppConfig {
ingest_max_category_bytes: 2,
..Default::default()
};
let result = validate_ingest_input(&config, None, "ok", "long", 0);
assert!(matches!(
result,
Err(IngestValidationError::PayloadTooLarge(_))
));
}
#[test]
fn validate_ingest_input_accepts_valid_payload() {
let config = AppConfig::default();
let result = validate_ingest_input(&config, Some("ok"), "ctx", "cat", 1);
assert!(result.is_ok());
}
}

View File

@@ -1,3 +1,4 @@
pub mod config; pub mod config;
pub mod embedding; pub mod embedding;
pub mod ingest_limits;
pub mod template_engine; pub mod template_engine;

View File

@@ -21,19 +21,49 @@ pub enum TemplateEngine {
#[macro_export] #[macro_export]
macro_rules! create_template_engine { macro_rules! create_template_engine {
// Macro takes the relative path to the templates dir as input // Single path argument
($relative_path:expr) => {{ ($relative_path:expr) => {
$crate::create_template_engine!($relative_path, Option::<&str>::None)
};
// Path + Fallback argument
($relative_path:expr, $fallback_path:expr) => {{
// Code for debug builds (AutoReload) // Code for debug builds (AutoReload)
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
// These lines execute in the CALLING crate's context // These lines execute in the CALLING crate's context
let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let template_path = crate_dir.join($relative_path); let template_path = crate_dir.join($relative_path);
let fallback_path = $fallback_path.map(|p| crate_dir.join(p));
let reloader = $crate::utils::template_engine::AutoReloader::new(move |notifier| { let reloader = $crate::utils::template_engine::AutoReloader::new(move |notifier| {
let mut env = $crate::utils::template_engine::Environment::new(); let mut env = $crate::utils::template_engine::Environment::new();
env.set_loader($crate::utils::template_engine::path_loader(&template_path));
let loader_primary = $crate::utils::template_engine::path_loader(&template_path);
// Clone fallback_path for the closure
let fallback = fallback_path.clone();
env.set_loader(move |name| match loader_primary(name) {
Ok(Some(tmpl)) => Ok(Some(tmpl)),
Ok(None) => {
if let Some(ref fb_path) = fallback {
let loader_fallback =
$crate::utils::template_engine::path_loader(fb_path);
loader_fallback(name)
} else {
Ok(None)
}
}
Err(e) => Err(e),
});
notifier.set_fast_reload(true); notifier.set_fast_reload(true);
notifier.watch_path(&template_path, true); notifier.watch_path(&template_path, true);
if let Some(ref fb) = fallback_path {
notifier.watch_path(fb, true);
}
// Add contrib filters/functions // Add contrib filters/functions
$crate::utils::template_engine::minijinja_contrib::add_to_environment(&mut env); $crate::utils::template_engine::minijinja_contrib::add_to_environment(&mut env);
Ok(env) Ok(env)

View File

@@ -3,10 +3,10 @@
"devenv": { "devenv": {
"locked": { "locked": {
"dir": "src/modules", "dir": "src/modules",
"lastModified": 1761839147, "lastModified": 1771066302,
"owner": "cachix", "owner": "cachix",
"repo": "devenv", "repo": "devenv",
"rev": "bb7849648b68035f6b910120252c22b28195cf54", "rev": "1b355dec9bddbaddbe4966d6fc30d7aa3af8575b",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -22,10 +22,10 @@
"rust-analyzer-src": "rust-analyzer-src" "rust-analyzer-src": "rust-analyzer-src"
}, },
"locked": { "locked": {
"lastModified": 1761893049, "lastModified": 1771052630,
"owner": "nix-community", "owner": "nix-community",
"repo": "fenix", "repo": "fenix",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6", "rev": "d0555da98576b8611c25df0c208e51e9a182d95f",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -37,14 +37,14 @@
"flake-compat": { "flake-compat": {
"flake": false, "flake": false,
"locked": { "locked": {
"lastModified": 1761588595, "lastModified": 1767039857,
"owner": "edolstra", "owner": "NixOS",
"repo": "flake-compat", "repo": "flake-compat",
"rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "edolstra", "owner": "NixOS",
"repo": "flake-compat", "repo": "flake-compat",
"type": "github" "type": "github"
} }
@@ -58,10 +58,10 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1760663237, "lastModified": 1770726378,
"owner": "cachix", "owner": "cachix",
"repo": "git-hooks.nix", "repo": "git-hooks.nix",
"rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37", "rev": "5eaaedde414f6eb1aea8b8525c466dc37bba95ae",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -78,10 +78,10 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1709087332, "lastModified": 1762808025,
"owner": "hercules-ci", "owner": "hercules-ci",
"repo": "gitignore.nix", "repo": "gitignore.nix",
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394", "rev": "cb5e3fdca1de58ccbc3ef53de65bd372b48f567c",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -92,10 +92,10 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1761672384, "lastModified": 1771008912,
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c", "rev": "a82ccc39b39b621151d6732718e3e250109076fa",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -107,10 +107,10 @@
}, },
"nixpkgs_2": { "nixpkgs_2": {
"locked": { "locked": {
"lastModified": 1761880412, "lastModified": 1770843696,
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "a7fc11be66bdfb5cdde611ee5ce381c183da8386", "rev": "2343bbb58f99267223bc2aac4fc9ea301a155a16",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -135,10 +135,10 @@
"rust-analyzer-src": { "rust-analyzer-src": {
"flake": false, "flake": false,
"locked": { "locked": {
"lastModified": 1761849405, "lastModified": 1771007332,
"owner": "rust-lang", "owner": "rust-lang",
"repo": "rust-analyzer", "repo": "rust-analyzer",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550", "rev": "bbc84d335fbbd9b3099d3e40c7469ee57dbd1873",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -155,10 +155,10 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1761878277, "lastModified": 1771038269,
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "6604534e44090c917db714faa58d47861657690c", "rev": "d7a86c8a4df49002446737603a3e0d7ef91a9637",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@@ -10,11 +10,14 @@
packages = [ packages = [
pkgs.openssl pkgs.openssl
pkgs.nodejs pkgs.nodejs
pkgs.watchman
pkgs.vscode-langservers-extracted pkgs.vscode-langservers-extracted
pkgs.cargo-dist pkgs.cargo-dist
pkgs.cargo-xwin pkgs.cargo-xwin
pkgs.clang pkgs.clang
pkgs.onnxruntime pkgs.onnxruntime
pkgs.cargo-watch
pkgs.tailwindcss_4
]; ];
languages.rust = { languages.rust = {
@@ -27,9 +30,24 @@
env = { env = {
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so"; ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
S3_ENDPOINT = "http://127.0.0.1:19000";
S3_BUCKET = "minne-tests";
MINNE_TEST_S3_ENDPOINT = "http://127.0.0.1:19000";
MINNE_TEST_S3_BUCKET = "minne-tests";
};
services.minio = {
enable = true;
listenAddress = "127.0.0.1:19000";
consoleAddress = "127.0.0.1:19001";
buckets = ["minne-tests"];
accessKey = "minioadmin";
secretKey = "minioadmin";
region = "us-east-1";
}; };
processes = { processes = {
surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --net=host --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password"; surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --net=host --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password";
tailwind.exec = "tailwindcss --cwd html-router -i app.css -o assets/style.css --watch=always";
}; };
} }

View File

@@ -12,11 +12,13 @@ include = ["lib"]
# The installers to generate for each app # The installers to generate for each app
installers = [] installers = []
# Target platforms to build apps for (Rust target-triple syntax) # Target platforms to build apps for (Rust target-triple syntax)
targets = ["aarch64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"] targets = ["aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-unknown-linux-gnu", "x86_64-pc-windows-msvc"]
# Skip checking whether the specified configuration files are up to date # Skip checking whether the specified configuration files are up to date
allow-dirty = ["ci"] allow-dirty = ["ci"]
[dist.github-custom-runners] [dist.github-custom-runners]
aarch64-apple-darwin = "macos-latest"
x86_64-apple-darwin = "macos-15-intel"
x86_64-unknown-linux-gnu = "ubuntu-22.04" x86_64-unknown-linux-gnu = "ubuntu-22.04"
x86_64-unknown-linux-musl = "ubuntu-22.04" x86_64-unknown-linux-musl = "ubuntu-22.04"
x86_64-pc-windows-msvc = "windows-latest" x86_64-pc-windows-msvc = "windows-latest"

View File

@@ -47,7 +47,7 @@ Content In → Ingestion Pipeline → SurrealDB
Query → Retrieval Pipeline → Results Query → Retrieval Pipeline → Results
Vector Search + FTS + Graph Vector Search + FTS
RRF Fusion → (Optional Rerank) → Response RRF Fusion → (Optional Rerank) → Response
``` ```
@@ -70,5 +70,5 @@ Embeddings are stored in dedicated tables with HNSW indexes for fast vector sear
1. **Collect candidates** — Vector similarity + full-text search 1. **Collect candidates** — Vector similarity + full-text search
2. **Merge ranks** — Reciprocal Rank Fusion (RRF) 2. **Merge ranks** — Reciprocal Rank Fusion (RRF)
3. **Attach context** — Link chunks to parent entities 3. **Attach context** — Link chunks to parent entities
4. **Rerank** (optional) — Cross-encoder rescoring 4. **Rerank** (optional) — Cross-encoder reranking
5. **Return** — Top-k results with metadata 5. **Return** — Top-k results with metadata

View File

@@ -13,6 +13,7 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
| `SURREALDB_DATABASE` | Database name | `minne_db` | | `SURREALDB_DATABASE` | Database name | `minne_db` |
| `SURREALDB_NAMESPACE` | Namespace | `minne_ns` | | `SURREALDB_NAMESPACE` | Namespace | `minne_ns` |
## Optional Settings ## Optional Settings
| Variable | Description | Default | | Variable | Description | Default |
@@ -21,14 +22,37 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
| `DATA_DIR` | Local data directory | `./data` | | `DATA_DIR` | Local data directory | `./data` |
| `OPENAI_BASE_URL` | Custom AI provider URL | OpenAI default | | `OPENAI_BASE_URL` | Custom AI provider URL | OpenAI default |
| `RUST_LOG` | Logging level | `info` | | `RUST_LOG` | Logging level | `info` |
| `STORAGE` | Storage backend (`local`, `memory`, `s3`) | `local` |
| `PDF_INGEST_MODE` | PDF ingestion strategy (`classic`, `llm-first`) | `llm-first` |
| `RETRIEVAL_STRATEGY` | Default retrieval strategy | - |
| `EMBEDDING_BACKEND` | Embedding provider (`openai`, `fastembed`) | `fastembed` |
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
| `FASTEMBED_MAX_LENGTH` | Max sequence length for FastEmbed models | - |
| `INGEST_MAX_BODY_BYTES` | Max request body size for ingest endpoints | `20000000` |
| `INGEST_MAX_FILES` | Max files allowed per ingest request | `5` |
| `INGEST_MAX_CONTENT_BYTES` | Max `content` field size for ingest requests | `262144` |
| `INGEST_MAX_CONTEXT_BYTES` | Max `context` field size for ingest requests | `16384` |
| `INGEST_MAX_CATEGORY_BYTES` | Max `category` field size for ingest requests | `128` |
### S3 Storage (Optional)
Used when `STORAGE` is set to `s3`.
| Variable | Description | Default |
|----------|-------------|---------|
| `S3_BUCKET` | S3 bucket name | - |
| `S3_ENDPOINT` | Custom endpoint (e.g. MinIO) | AWS default |
| `S3_REGION` | AWS Region | `us-east-1` |
| `AWS_ACCESS_KEY_ID` | Access key | - |
| `AWS_SECRET_ACCESS_KEY` | Secret key | - |
### Reranking (Optional) ### Reranking (Optional)
| Variable | Description | Default | | Variable | Description | Default |
|----------|-------------|---------| |----------|-------------|---------|
| `RERANKING_ENABLED` | Enable FastEmbed reranking | `false` | | `RERANKING_ENABLED` | Enable FastEmbed reranking | `false` |
| `RERANKING_POOL_SIZE` | Concurrent reranker workers | `2` | | `RERANKING_POOL_SIZE` | Concurrent reranker workers | - |
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed/reranker` |
> [!NOTE] > [!NOTE]
> Enabling reranking downloads ~1.1 GB of model data on first startup. > Enabling reranking downloads ~1.1 GB of model data on first startup.
@@ -45,9 +69,25 @@ openai_api_key: "sk-your-key-here"
data_dir: "./minne_data" data_dir: "./minne_data"
http_port: 3000 http_port: 3000
# New settings
storage: "local"
# storage: "s3"
# s3_bucket: "my-bucket"
# s3_endpoint: "http://localhost:9000" # Optional, for MinIO etc.
# s3_region: "us-east-1"
pdf_ingest_mode: "llm-first"
embedding_backend: "fastembed"
# Optional reranking # Optional reranking
reranking_enabled: true reranking_enabled: true
reranking_pool_size: 2 reranking_pool_size: 2
# Ingest safety limits
ingest_max_body_bytes: 20000000
ingest_max_files: 5
ingest_max_content_bytes: 262144
ingest_max_context_bytes: 16384
ingest_max_category_bytes: 128
``` ```
## AI Provider Setup ## AI Provider Setup

View File

@@ -33,3 +33,4 @@ clap = { version = "4.4", features = ["derive", "env"] }
[dev-dependencies] [dev-dependencies]
tempfile = { workspace = true } tempfile = { workspace = true }
common = { path = "../common", features = ["test-utils"] }

View File

@@ -9,6 +9,8 @@ use std::{
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_openai::Client; use async_openai::Client;
use chrono::Utc; use chrono::Utc;
#[cfg(not(test))]
use common::utils::config::get_config;
use common::{ use common::{
storage::{ storage::{
db::SurrealDbClient, db::SurrealDbClient,
@@ -421,11 +423,7 @@ async fn ingest_paragraph_batch(
return Ok(Vec::new()); return Ok(Vec::new());
} }
let namespace = format!("ingest_eval_{}", Uuid::new_v4()); let namespace = format!("ingest_eval_{}", Uuid::new_v4());
let db = Arc::new( let db = create_ingest_db(&namespace).await?;
SurrealDbClient::memory(&namespace, "corpus")
.await
.context("creating in-memory surrealdb for ingestion")?,
);
db.apply_migrations() db.apply_migrations()
.await .await
.context("applying migrations for ingestion")?; .context("applying migrations for ingestion")?;
@@ -487,6 +485,29 @@ async fn ingest_paragraph_batch(
Ok(shards) Ok(shards)
} }
#[cfg(test)]
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
let db = SurrealDbClient::memory(namespace, "corpus")
.await
.context("creating in-memory surrealdb for ingestion")?;
Ok(Arc::new(db))
}
#[cfg(not(test))]
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
let config = get_config().context("loading app config for ingestion database")?;
let db = SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
namespace,
"corpus",
)
.await
.context("creating surrealdb database for ingestion")?;
Ok(Arc::new(db))
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn ingest_single_paragraph( async fn ingest_single_paragraph(
pipeline: Arc<IngestionPipeline>, pipeline: Arc<IngestionPipeline>,

View File

@@ -893,158 +893,6 @@ mod tests {
} }
} }
#[tokio::test]
async fn seeds_manifest_with_transactional_batches() {
let namespace = "test_ns";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("memory db");
db.apply_migrations()
.await
.expect("apply migrations for memory db");
let manifest = build_manifest();
seed_manifest_into_db(&db, &manifest)
.await
.expect("manifest seed should succeed");
let text_contents: Vec<TextContent> = db
.client
.query(format!("SELECT * FROM {};", TextContent::table_name()))
.await
.expect("select text_content")
.take(0)
.unwrap_or_default();
assert_eq!(text_contents.len(), 1);
let entities: Vec<KnowledgeEntity> = db
.client
.query(format!("SELECT * FROM {};", KnowledgeEntity::table_name()))
.await
.expect("select knowledge_entity")
.take(0)
.unwrap_or_default();
assert_eq!(entities.len(), 1);
let chunks: Vec<TextChunk> = db
.client
.query(format!("SELECT * FROM {};", TextChunk::table_name()))
.await
.expect("select text_chunk")
.take(0)
.unwrap_or_default();
assert_eq!(chunks.len(), 1);
let relationships: Vec<KnowledgeRelationship> = db
.client
.query("SELECT * FROM relates_to;")
.await
.expect("select relates_to")
.take(0)
.unwrap_or_default();
assert_eq!(relationships.len(), 1);
let entity_embeddings: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {};",
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("select knowledge_entity_embedding")
.take(0)
.unwrap_or_default();
assert_eq!(entity_embeddings.len(), 1);
let chunk_embeddings: Vec<TextChunkEmbedding> = db
.client
.query(format!(
"SELECT * FROM {};",
TextChunkEmbedding::table_name()
))
.await
.expect("select text_chunk_embedding")
.take(0)
.unwrap_or_default();
assert_eq!(chunk_embeddings.len(), 1);
}
#[tokio::test]
async fn rolls_back_when_embeddings_mismatch_index_dimension() {
let namespace = "test_ns_rollback";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("memory db");
db.apply_migrations()
.await
.expect("apply migrations for memory db");
let manifest = build_manifest();
let result = seed_manifest_into_db(&db, &manifest).await;
assert!(
result.is_ok(),
"seeding should succeed even if embedding dimensions differ from default index"
);
let text_contents: Vec<TextContent> = db
.client
.query(format!("SELECT * FROM {};", TextContent::table_name()))
.await
.expect("select text_content")
.take(0)
.unwrap_or_default();
let entities: Vec<KnowledgeEntity> = db
.client
.query(format!("SELECT * FROM {};", KnowledgeEntity::table_name()))
.await
.expect("select knowledge_entity")
.take(0)
.unwrap_or_default();
let chunks: Vec<TextChunk> = db
.client
.query(format!("SELECT * FROM {};", TextChunk::table_name()))
.await
.expect("select text_chunk")
.take(0)
.unwrap_or_default();
let relationships: Vec<KnowledgeRelationship> = db
.client
.query("SELECT * FROM relates_to;")
.await
.expect("select relates_to")
.take(0)
.unwrap_or_default();
let entity_embeddings: Vec<KnowledgeEntityEmbedding> = db
.client
.query(format!(
"SELECT * FROM {};",
KnowledgeEntityEmbedding::table_name()
))
.await
.expect("select knowledge_entity_embedding")
.take(0)
.unwrap_or_default();
let chunk_embeddings: Vec<TextChunkEmbedding> = db
.client
.query(format!(
"SELECT * FROM {};",
TextChunkEmbedding::table_name()
))
.await
.expect("select text_chunk_embedding")
.take(0)
.unwrap_or_default();
assert_eq!(text_contents.len(), 1);
assert_eq!(entities.len(), 1);
assert_eq!(chunks.len(), 1);
assert_eq!(relationships.len(), 1);
assert_eq!(entity_embeddings.len(), 1);
assert_eq!(chunk_embeddings.len(), 1);
}
#[test] #[test]
fn window_manifest_trims_questions_and_negatives() { fn window_manifest_trims_questions_and_negatives() {
let manifest = build_manifest(); let manifest = build_manifest();

View File

@@ -7,7 +7,7 @@ use std::{
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk}; use common::storage::{db::SurrealDbClient, types::text_chunk::TextChunk};
use crate::{args::Config, eval::connect_eval_db, corpus, snapshot::DbSnapshotState}; use crate::{args::Config, corpus, eval::connect_eval_db, snapshot::DbSnapshotState};
pub async fn inspect_question(config: &Config) -> Result<()> { pub async fn inspect_question(config: &Config) -> Result<()> {
let question_id = config let question_id = config

View File

@@ -2,7 +2,11 @@
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use chrono::Utc; use chrono::Utc;
use common::storage::{db::SurrealDbClient, types::user::User, types::StoredObject}; use common::storage::{
db::SurrealDbClient,
types::user::{Theme, User},
types::StoredObject,
};
use serde::Deserialize; use serde::Deserialize;
use tracing::{info, warn}; use tracing::{info, warn};
@@ -212,6 +216,7 @@ pub(crate) async fn ensure_eval_user(db: &SurrealDbClient) -> Result<User> {
api_key: None, api_key: None,
admin: false, admin: false,
timezone: "UTC".to_string(), timezone: "UTC".to_string(),
theme: Theme::System,
}; };
if let Some(existing) = db.get_item::<User>(user.get_id()).await? { if let Some(existing) = db.get_item::<User>(user.get_id()).await? {

View File

@@ -20,9 +20,10 @@ use retrieval_pipeline::{
use crate::{ use crate::{
args::Config, args::Config,
cache::EmbeddingCache, cache::EmbeddingCache,
corpus,
datasets::ConvertedDataset, datasets::ConvertedDataset,
eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase}, eval::{CaseDiagnostics, CaseSummary, EvaluationStageTimings, EvaluationSummary, SeededCase},
corpus, slice, snapshot, slice, snapshot,
}; };
pub(super) struct EvaluationContext<'a> { pub(super) struct EvaluationContext<'a> {

View File

@@ -3,7 +3,7 @@ use std::time::Instant;
use anyhow::Context; use anyhow::Context;
use tracing::info; use tracing::info;
use crate::{eval::can_reuse_namespace, corpus, slice, snapshot}; use crate::{corpus, eval::can_reuse_namespace, slice, snapshot};
use super::super::{ use super::super::{
context::{EvalStage, EvaluationContext}, context::{EvalStage, EvaluationContext},

View File

@@ -5,12 +5,12 @@ use common::storage::types::system_settings::SystemSettings;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{ use crate::{
corpus,
db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace}, db_helpers::{recreate_indexes, remove_all_indexes, reset_namespace},
eval::{ eval::{
can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user, can_reuse_namespace, cases_from_manifest, enforce_system_settings, ensure_eval_user,
record_namespace_state, warm_hnsw_cache, record_namespace_state, warm_hnsw_cache,
}, },
corpus,
}; };
use super::super::{ use super::super::{

View File

@@ -48,7 +48,9 @@ pub(crate) async fn prepare_slice(
.database .database
.db_namespace .db_namespace
.clone() .clone()
.unwrap_or_else(|| default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit)); .unwrap_or_else(|| {
default_namespace(ctx.dataset().metadata.id.as_str(), ctx.config().limit)
});
ctx.database = ctx ctx.database = ctx
.config() .config()
.database .database

6
flake.lock generated
View File

@@ -35,11 +35,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1761672384, "lastModified": 1771008912,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=", "narHash": "sha256-gf2AmWVTs8lEq7z/3ZAsgnZDhWIckkb+ZnAo5RzSxJg=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c", "rev": "a82ccc39b39b621151d6732718e3e250109076fa",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@@ -41,5 +41,8 @@ common = { path = "../common" }
retrieval-pipeline = { path = "../retrieval-pipeline" } retrieval-pipeline = { path = "../retrieval-pipeline" }
json-stream-parser = { path = "../json-stream-parser" } json-stream-parser = { path = "../json-stream-parser" }
[dev-dependencies]
common = { path = "../common", features = ["test-utils"] }
[build-dependencies] [build-dependencies]
minijinja-embed = { version = "2.8.0" } minijinja-embed = { version = "2.8.0" }

View File

@@ -69,26 +69,41 @@
[data-theme="dark"] { [data-theme="dark"] {
color-scheme: dark; color-scheme: dark;
--color-base-100: oklch(22% 0.015 255);
--color-base-200: oklch(18% 0.014 253); /* --- Canvas: Dark Warm Grey (Matches Light Mode's hue 90) --- */
--color-base-300: oklch(14% 0.012 251); --color-base-100: oklch(18% 0.01 90);
--color-base-content: oklch(97.2% 0.02 255); --color-base-200: oklch(15% 0.01 90);
--color-primary: oklch(58% 0.233 277.12); --color-base-300: oklch(12% 0.01 90);
--color-primary-content: oklch(96% 0.018 272.31); --color-base-content: oklch(96% 0.01 90);
--color-secondary: oklch(65% 0.241 354.31);
--color-secondary-content: oklch(94% 0.028 342.26); /* --- Primary: Vibrant Indigo (Light Mode Hue 265, boosted for dark) --- */
--color-accent: oklch(78% 0.22 80); --color-primary: oklch(65% 0.22 265);
--color-accent-content: oklch(20% 0.035 80); --color-primary-content: oklch(98% 0.01 265);
--color-neutral: oklch(26% 0.02 255);
--color-neutral-content: oklch(97% 0.03 255); /* --- Secondary: Deep Indigo (Similar to Light Mode Primary) --- */
--color-info: oklch(74% 0.16 232.66); --color-secondary: oklch(45% 0.18 265);
--color-info-content: oklch(29% 0.066 243.16); --color-secondary-content: oklch(98% 0.01 265);
--color-success: oklch(76% 0.177 163.22);
--color-success-content: oklch(37% 0.077 168.94); /* --- Accent: Vibrant Amber (Light Mode Hue 80) --- */
--color-warning: oklch(82% 0.189 84.43); --color-accent: oklch(75% 0.19 80);
--color-warning-content: oklch(41% 0.112 45.9); --color-accent-content: oklch(18% 0.04 80);
--color-error: oklch(71% 0.194 13.43);
--color-error-content: oklch(27% 0.105 12.09); /* --- Neutral: Warm Graphite --- */
--color-neutral: oklch(25% 0.02 90);
--color-neutral-content: oklch(96% 0.01 90);
/* --- Semantic Colors (Matching Light Mode Hues) --- */
--color-info: oklch(70% 0.15 220); /* Blue */
--color-success: oklch(72% 0.18 150); /* Green */
--color-warning: oklch(80% 0.18 85); /* Orange/Amber */
--color-error: oklch(68% 0.20 25); /* Red */
--color-info-content: oklch(15% 0.05 220);
--color-success-content: oklch(15% 0.05 150);
--color-warning-content: oklch(15% 0.05 85);
--color-error-content: oklch(98% 0.01 25);
/* --- Neobrutalist Structure --- */
--radius-selector: 0rem; --radius-selector: 0rem;
--radius-field: 0rem; --radius-field: 0rem;
--radius-box: 0rem; --radius-box: 0rem;
@@ -97,6 +112,103 @@
--border: 2px; --border: 2px;
} }
[data-theme="obsidian-prism"] {
color-scheme: dark;
/* --- Canvas & Surfaces --- */
--color-base-100: oklch(12% 0.015 260);
--color-base-200: oklch(9% 0.018 262);
--color-base-300: oklch(6% 0.02 265);
--color-base-content: oklch(95% 0.008 260);
/* --- Primary: Electric Violet Signal --- */
--color-primary: oklch(62% 0.28 290);
--color-primary-content: oklch(98% 0.01 290);
/* --- Secondary: Cyan Edge --- */
--color-secondary: oklch(68% 0.18 220);
--color-secondary-content: oklch(98% 0.01 220);
/* --- Accent: Ember (warm counterpoint) --- */
--color-accent: oklch(78% 0.19 55);
--color-accent-content: oklch(18% 0.04 55);
/* --- Neutral: Cold Steel --- */
--color-neutral: oklch(24% 0.02 260);
--color-neutral-content: oklch(92% 0.01 260);
/* --- Semantic Colors --- */
--color-info: oklch(72% 0.14 230);
--color-info-content: oklch(25% 0.06 230);
--color-success: oklch(74% 0.16 155);
--color-success-content: oklch(25% 0.06 155);
--color-warning: oklch(82% 0.18 75);
--color-warning-content: oklch(25% 0.08 75);
--color-error: oklch(68% 0.22 15);
--color-error-content: oklch(98% 0.02 15);
/* --- Radii (NB Law: Zero) --- */
--radius-selector: 0rem;
--radius-field: 0rem;
--radius-box: 0rem;
--size-selector: 0.25rem;
--size-field: 0.25rem;
--border: 2px;
/* --- Prismatic Shadow System --- */
--nb-shadow-hue: 290;
--nb-shadow: 4px 4px 0 0 oklch(8% 0.06 var(--nb-shadow-hue));
--nb-shadow-hover: 6px 6px 0 0 oklch(6% 0.08 calc(var(--nb-shadow-hue) + 15));
}
[data-theme="warm-paper"] {
color-scheme: light;
/* --- Canvas & Surfaces: Warm cream paper (lighter, less yellow) --- */
--color-base-100: oklch(98.5% 0.01 90);
--color-base-200: oklch(95% 0.015 90);
--color-base-300: oklch(92% 0.02 90);
--color-base-content: oklch(18% 0.015 75);
/* --- Primary: Warm Amber/Gold (the landing page CTA color) --- */
--color-primary: oklch(72% 0.16 75);
--color-primary-content: oklch(18% 0.02 75);
/* --- Secondary: Warm Terracotta --- */
--color-secondary: oklch(55% 0.14 45);
--color-secondary-content: oklch(98% 0.01 85);
/* --- Accent: Deep Charcoal (for contrast buttons like "View on GitHub") --- */
--color-accent: oklch(22% 0.01 80);
--color-accent-content: oklch(98% 0.02 85);
/* --- Neutral: Warm Charcoal --- */
--color-neutral: oklch(20% 0.015 75);
--color-neutral-content: oklch(96% 0.015 85);
/* --- Semantic Colors (warmer variants) --- */
--color-info: oklch(58% 0.12 230);
--color-info-content: oklch(98% 0.01 230);
--color-success: oklch(62% 0.15 155);
--color-success-content: oklch(98% 0.01 155);
--color-warning: oklch(78% 0.16 70);
--color-warning-content: oklch(20% 0.04 70);
--color-error: oklch(58% 0.20 25);
--color-error-content: oklch(98% 0.02 25);
/* --- Radii (NB Law: Zero) --- */
--radius-selector: 0rem;
--radius-field: 0rem;
--radius-box: 0rem;
--size-selector: 0.25rem;
--size-field: 0.25rem;
--border: 2px;
/* --- Classic Black Shadow --- */
--nb-shadow: 4px 4px 0 0 #000;
--nb-shadow-hover: 6px 6px 0 0 #000;
}
body { body {
background-color: var(--color-base-100); background-color: var(--color-base-100);
color: var(--color-base-content); color: var(--color-base-content);
@@ -608,7 +720,7 @@
line-height: inherit; line-height: inherit;
} }
.markdown-content :not(pre) > code { .markdown-content :not(pre)>code {
background-color: rgba(0, 0, 0, 0.05); background-color: rgba(0, 0, 0, 0.05);
color: var(--color-base-content); color: var(--color-base-content);
padding: 0.15em 0.4em; padding: 0.15em 0.4em;
@@ -662,7 +774,7 @@
color: var(--color-base-content); color: var(--color-base-content);
} }
[data-theme="dark"] .markdown-content :not(pre) > code { [data-theme="dark"] .markdown-content :not(pre)>code {
background-color: rgba(255, 255, 255, 0.12); background-color: rgba(255, 255, 255, 0.12);
color: var(--color-base-content); color: var(--color-base-content);
} }
@@ -677,6 +789,136 @@
z-index: 9999; z-index: 9999;
box-shadow: var(--nb-shadow); box-shadow: var(--nb-shadow);
} }
/* .nb-label: Uppercase, bold, tracking-wide, text-xs for section headers */
.nb-label {
@apply uppercase font-bold tracking-wide text-xs;
}
/* .nb-data: JetBrains Mono, tabular-nums for timestamps, IDs, badges */
.nb-data {
font-family: 'JetBrains Mono', ui-monospace, SFMono-Regular, monospace;
font-variant-numeric: tabular-nums;
}
/* The Stamp: Button :active state pushes into page */
.nb-btn:active {
transform: translate(2px, 2px) !important;
box-shadow: 2px 2px 0 0 #000 !important;
}
/* Staggered Card Dealing Animation */
@keyframes deal-in {
0% {
opacity: 0;
transform: translateY(12px);
}
100% {
opacity: 1;
transform: translateY(0);
}
}
/* Staggered deal-in animation - STRICTLY SCOPED to main content area */
main .nb-card,
main .nb-panel {
animation: deal-in 300ms var(--ease-mechanical, cubic-bezier(0.25, 1, 0.5, 1)) backwards;
}
/* Exclude elements that shouldn't animate even inside main */
main nav.nb-panel,
main .no-animation {
animation: none;
}
/* Apply staggered delays only to direct children of grids/lists or top-level containers */
main .nb-masonry>.nb-card:nth-child(1),
main .grid>.nb-panel:nth-child(1) {
animation-delay: 0ms;
}
main .nb-masonry>.nb-card:nth-child(2),
main .grid>.nb-panel:nth-child(2) {
animation-delay: 50ms;
}
main .nb-masonry>.nb-card:nth-child(3),
main .grid>.nb-panel:nth-child(3) {
animation-delay: 100ms;
}
main .nb-masonry>.nb-card:nth-child(4),
main .grid>.nb-panel:nth-child(4) {
animation-delay: 150ms;
}
main .nb-masonry>.nb-card:nth-child(5),
main .grid>.nb-panel:nth-child(5) {
animation-delay: 200ms;
}
main .nb-masonry>.nb-card:nth-child(6),
main .grid>.nb-panel:nth-child(6) {
animation-delay: 250ms;
}
main .nb-masonry>.nb-card:nth-child(7),
main .grid>.nb-panel:nth-child(7) {
animation-delay: 300ms;
}
main .nb-masonry>.nb-card:nth-child(8),
main .grid>.nb-panel:nth-child(8) {
animation-delay: 350ms;
}
main .nb-masonry>.nb-card:nth-child(n+9),
main .grid>.nb-panel:nth-child(n+9) {
animation-delay: 400ms;
}
/* HTMX Swap Fade-Up Animation */
@keyframes fade-up {
0% {
opacity: 0;
transform: translateY(8px);
}
100% {
opacity: 1;
transform: translateY(0);
}
}
.animate-fade-up {
animation: fade-up 200ms var(--ease-mechanical, cubic-bezier(0.25, 1, 0.5, 1)) forwards;
}
/* Kinetic Input: Chat Armed State */
#chat-input:not(:placeholder-shown)~button {
filter: saturate(1.3) brightness(1.1);
}
#chat-input:not(:placeholder-shown) {
border-color: var(--color-accent);
}
/* Evidence Frame for images (Tufte treatment) */
.nb-evidence-frame {
@apply border-2 border-neutral m-2 bg-base-200;
}
.nb-evidence-frame img {
display: block;
width: 100%;
height: auto;
}
.nb-evidence-frame figcaption {
@apply text-xs px-2 py-1 border-t-2 border-neutral;
font-family: 'JetBrains Mono', ui-monospace, monospace;
}
} }
/* Theme-aware placeholder contrast tweaks */ /* Theme-aware placeholder contrast tweaks */
@@ -691,6 +933,31 @@
color: rgba(255, 255, 255, 0.78) !important; color: rgba(255, 255, 255, 0.78) !important;
opacity: 0.85; opacity: 0.85;
} }
/* === DESIGN POLISHING: Receding Reality === */
/* Modal opens → background scales and blurs */
body:has(dialog[open]) #main-content-wrapper,
body.modal-open #main-content-wrapper {
transform: scale(0.98);
filter: blur(2px);
transition: transform 250ms var(--ease-mechanical, cubic-bezier(0.25, 1, 0.5, 1)),
filter 250ms var(--ease-mechanical, cubic-bezier(0.25, 1, 0.5, 1));
}
#main-content-wrapper {
transform: scale(1);
filter: blur(0);
transition: transform 250ms var(--ease-mechanical, cubic-bezier(0.25, 1, 0.5, 1)),
filter 250ms var(--ease-mechanical, cubic-bezier(0.25, 1, 0.5, 1));
}
/* === DESIGN POLISHING: Scroll-Linked Navbar Shadow === */
nav {
--scroll-depth: 0;
box-shadow: 4px calc(4px + var(--scroll-depth) * 4px) 0 0 #000;
transition: box-shadow 150ms ease;
}
} }
/* satoshi.css */ /* satoshi.css */
@@ -714,6 +981,15 @@
font-display: swap; font-display: swap;
} }
@font-face {
font-family: 'JetBrains Mono';
src: url('fonts/JetBrainsMono-Regular.woff2') format('woff2'),
url('fonts/JetBrainsMono-Variable.ttf') format('truetype');
font-weight: 400 700;
font-style: normal;
font-display: swap;
}
/* Minimal override: prevent DaisyUI .menu hover bg on our nb buttons */ /* Minimal override: prevent DaisyUI .menu hover bg on our nb buttons */
@layer utilities { @layer utilities {
@@ -737,3 +1013,111 @@
@apply text-lg font-bold; @apply text-lg font-bold;
} }
} }
/* Prismatic shadow hue shift on hover */
[data-theme="obsidian-prism"] .nb-panel:hover,
[data-theme="obsidian-prism"] .nb-card:hover,
[data-theme="obsidian-prism"] .nb-btn:hover {
--nb-shadow-hue: 305;
}
/* Focus state: breathing shadow pulse */
@keyframes shadow-breathe {
0%,
100% {
box-shadow: 6px 6px 0 0 oklch(8% 0.08 305);
}
50% {
box-shadow: 7px 7px 0 0 oklch(10% 0.10 310);
}
}
[data-theme="obsidian-prism"] .nb-btn:focus-visible,
[data-theme="obsidian-prism"] .nb-input:focus-visible,
[data-theme="obsidian-prism"] .nb-select:focus-visible {
animation: shadow-breathe 1.5s ease-in-out infinite;
outline: 2px solid oklch(62% 0.28 290);
outline-offset: 2px;
}
/* Selection color: Prismatic violet */
[data-theme="obsidian-prism"] ::selection {
background: oklch(62% 0.28 290 / 0.35);
color: oklch(98% 0.01 290);
}
/* Prose adjustments for Obsidian Prism */
[data-theme="obsidian-prism"] .prose-tufte,
[data-theme="obsidian-prism"] .prose-tufte-compact {
color: var(--color-base-content);
--tw-prose-body: oklch(92% 0.008 260);
--tw-prose-headings: oklch(98% 0.01 260);
--tw-prose-lead: oklch(88% 0.01 260);
--tw-prose-links: oklch(78% 0.19 55);
--tw-prose-bold: oklch(98% 0.01 260);
--tw-prose-counters: oklch(70% 0.01 260);
--tw-prose-bullets: oklch(50% 0.01 260);
--tw-prose-hr: oklch(24% 0.02 260);
--tw-prose-quotes: oklch(88% 0.01 260);
--tw-prose-quote-borders: oklch(40% 0.04 290);
--tw-prose-captions: oklch(70% 0.01 260);
--tw-prose-code: oklch(95% 0.008 260);
--tw-prose-pre-code: inherit;
--tw-prose-pre-bg: oklch(8% 0.02 262);
--tw-prose-th-borders: oklch(30% 0.02 260);
--tw-prose-td-borders: oklch(24% 0.02 260);
}
[data-theme="obsidian-prism"] .prose-tufte a,
[data-theme="obsidian-prism"] .prose-tufte-compact a {
color: oklch(78% 0.19 55);
}
/* Code blocks: deeper well */
[data-theme="obsidian-prism"] .markdown-content pre {
background-color: oklch(7% 0.018 262);
border-color: oklch(20% 0.03 290);
}
[data-theme="obsidian-prism"] .markdown-content :not(pre)>code {
background-color: oklch(18% 0.025 265);
}
/* Tables in Obsidian Prism */
[data-theme="obsidian-prism"] .markdown-content th,
[data-theme="obsidian-prism"] .markdown-content td {
border-color: oklch(24% 0.02 260);
}
/* Blockquotes */
[data-theme="obsidian-prism"] .markdown-content blockquote {
border-color: oklch(40% 0.04 290);
color: oklch(85% 0.01 260);
}
/* HR */
[data-theme="obsidian-prism"] .markdown-content hr {
border-top-color: oklch(24% 0.02 260);
}
/* Checkbox in Obsidian Prism (white tick) */
[data-theme="obsidian-prism"] .nb-checkbox:checked {
background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' fill='none' stroke='%23fff' stroke-width='3' stroke-linecap='round' stroke-linejoin='round'><polyline points='20 6 9 17 4 12'/></svg>");
}
/* Placeholder text */
[data-theme="obsidian-prism"] .nb-input::placeholder,
[data-theme="obsidian-prism"] .input::placeholder,
[data-theme="obsidian-prism"] .textarea::placeholder,
[data-theme="obsidian-prism"] textarea::placeholder,
[data-theme="obsidian-prism"] input::placeholder {
color: oklch(70% 0.01 260) !important;
opacity: 0.85;
}
/* Nav shadow uses prismatic color */
[data-theme="obsidian-prism"] nav {
box-shadow: 4px calc(4px + var(--scroll-depth, 0) * 4px) 0 0 oklch(8% 0.06 290);
}

View File

@@ -0,0 +1,199 @@
/**
* Design Polishing Pass - Interactive Effects
*
* Includes:
* - Scroll-Linked Navbar Shadow
* - HTMX Swap Animation
* - Typewriter AI Response
* - Rubberbanding Scroll
*/
(function() {
'use strict';
// === SCROLL-LINKED NAVBAR SHADOW ===
function initScrollShadow() {
const mainContent = document.querySelector('main');
const navbar = document.querySelector('nav');
if (!mainContent || !navbar) return;
mainContent.addEventListener('scroll', () => {
const scrollTop = mainContent.scrollTop;
const scrollHeight = mainContent.scrollHeight - mainContent.clientHeight;
const scrollDepth = scrollHeight > 0 ? Math.min(scrollTop / 200, 1) : 0;
navbar.style.setProperty('--scroll-depth', scrollDepth.toFixed(2));
}, { passive: true });
}
// === HTMX SWAP ANIMATION ===
function initHtmxSwapAnimation() {
document.body.addEventListener('htmx:afterSwap', (event) => {
let target = event.detail.target;
if (!target) return;
// If full body swap (hx-boost), animate only the main content
if (target.tagName === 'BODY') {
const main = document.querySelector('main');
if (main) target = main;
}
// Only animate if target is valid and inside/is main content or a card/panel
// Avoid animating sidebar or navbar updates
if (target && (target.tagName === 'MAIN' || target.closest('main'))) {
if (!target.classList.contains('animate-fade-up')) {
target.classList.add('animate-fade-up');
// Remove class after animation completes to allow re-animation
setTimeout(() => {
target.classList.remove('animate-fade-up');
}, 250);
}
}
});
}
// === TYPEWRITER AI RESPONSE ===
// Works with SSE streaming - buffers text and reveals character by character
window.initTypewriter = function(element, options = {}) {
const {
minDelay = 5,
maxDelay = 15,
showCursor = true
} = options;
let buffer = '';
let isTyping = false;
let cursorElement = null;
if (showCursor) {
cursorElement = document.createElement('span');
cursorElement.className = 'typewriter-cursor';
cursorElement.textContent = '▌';
cursorElement.style.animation = 'blink 1s step-end infinite';
element.appendChild(cursorElement);
}
function typeNextChar() {
if (buffer.length === 0) {
isTyping = false;
return;
}
isTyping = true;
const char = buffer.charAt(0);
buffer = buffer.slice(1);
// Insert before cursor
if (cursorElement && cursorElement.parentNode) {
const textNode = document.createTextNode(char);
element.insertBefore(textNode, cursorElement);
} else {
element.textContent += char;
}
const delay = minDelay + Math.random() * (maxDelay - minDelay);
setTimeout(typeNextChar, delay);
}
return {
append: function(text) {
buffer += text;
if (!isTyping) {
typeNextChar();
}
},
complete: function() {
// Flush remaining buffer immediately
if (cursorElement && cursorElement.parentNode) {
const textNode = document.createTextNode(buffer);
element.insertBefore(textNode, cursorElement);
cursorElement.remove();
} else {
element.textContent += buffer;
}
buffer = '';
isTyping = false;
}
};
};
// === RUBBERBANDING SCROLL ===
function initRubberbanding() {
const containers = document.querySelectorAll('#chat-scroll-container, .content-scroll-container');
containers.forEach(container => {
let startY = 0;
let pulling = false;
let pullDistance = 0;
const maxPull = 60;
const resistance = 0.4;
container.addEventListener('touchstart', (e) => {
startY = e.touches[0].clientY;
}, { passive: true });
container.addEventListener('touchmove', (e) => {
const currentY = e.touches[0].clientY;
const diff = currentY - startY;
// At top boundary, pulling down
if (container.scrollTop <= 0 && diff > 0) {
pulling = true;
pullDistance = Math.min(diff * resistance, maxPull);
container.style.transform = `translateY(${pullDistance}px)`;
}
// At bottom boundary, pulling up
else if (container.scrollTop + container.clientHeight >= container.scrollHeight && diff < 0) {
pulling = true;
pullDistance = Math.max(diff * resistance, -maxPull);
container.style.transform = `translateY(${pullDistance}px)`;
}
}, { passive: true });
container.addEventListener('touchend', () => {
if (pulling) {
container.style.transition = 'transform 300ms cubic-bezier(0.25, 1, 0.5, 1)';
container.style.transform = 'translateY(0)';
setTimeout(() => {
container.style.transition = '';
}, 300);
pulling = false;
pullDistance = 0;
}
}, { passive: true });
});
}
// === INITIALIZATION ===
function init() {
initScrollShadow();
initHtmxSwapAnimation();
initRubberbanding();
}
// Run on DOMContentLoaded
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', init);
} else {
init();
}
// Re-init rubberbanding after HTMX navigations
document.body.addEventListener('htmx:afterSettle', () => {
initRubberbanding();
});
// Add typewriter cursor blink animation
const style = document.createElement('style');
style.textContent = `
@keyframes blink {
0%, 100% { opacity: 1; }
50% { opacity: 0; }
}
.typewriter-cursor {
color: var(--color-accent);
font-weight: bold;
}
`;
document.head.appendChild(style);
})();

Binary file not shown.

Binary file not shown.

View File

@@ -1,144 +0,0 @@
//==========================================================
// head-support.js
//
// An extension to add head tag merging.
//==========================================================
(function(){
var api = null;
function log() {
//console.log(arguments);
}
function mergeHead(newContent, defaultMergeStrategy) {
if (newContent && newContent.indexOf('<head') > -1) {
const htmlDoc = document.createElement("html");
// remove svgs to avoid conflicts
var contentWithSvgsRemoved = newContent.replace(/<svg(\s[^>]*>|>)([\s\S]*?)<\/svg>/gim, '');
// extract head tag
var headTag = contentWithSvgsRemoved.match(/(<head(\s[^>]*>|>)([\s\S]*?)<\/head>)/im);
// if the head tag exists...
if (headTag) {
var added = []
var removed = []
var preserved = []
var nodesToAppend = []
htmlDoc.innerHTML = headTag;
var newHeadTag = htmlDoc.querySelector("head");
var currentHead = document.head;
if (newHeadTag == null) {
return;
} else {
// put all new head elements into a Map, by their outerHTML
var srcToNewHeadNodes = new Map();
for (const newHeadChild of newHeadTag.children) {
srcToNewHeadNodes.set(newHeadChild.outerHTML, newHeadChild);
}
}
// determine merge strategy
var mergeStrategy = api.getAttributeValue(newHeadTag, "hx-head") || defaultMergeStrategy;
// get the current head
for (const currentHeadElt of currentHead.children) {
// If the current head element is in the map
var inNewContent = srcToNewHeadNodes.has(currentHeadElt.outerHTML);
var isReAppended = currentHeadElt.getAttribute("hx-head") === "re-eval";
var isPreserved = api.getAttributeValue(currentHeadElt, "hx-preserve") === "true";
if (inNewContent || isPreserved) {
if (isReAppended) {
// remove the current version and let the new version replace it and re-execute
removed.push(currentHeadElt);
} else {
// this element already exists and should not be re-appended, so remove it from
// the new content map, preserving it in the DOM
srcToNewHeadNodes.delete(currentHeadElt.outerHTML);
preserved.push(currentHeadElt);
}
} else {
if (mergeStrategy === "append") {
// we are appending and this existing element is not new content
// so if and only if it is marked for re-append do we do anything
if (isReAppended) {
removed.push(currentHeadElt);
nodesToAppend.push(currentHeadElt);
}
} else {
// if this is a merge, we remove this content since it is not in the new head
if (api.triggerEvent(document.body, "htmx:removingHeadElement", {headElement: currentHeadElt}) !== false) {
removed.push(currentHeadElt);
}
}
}
}
// Push the tremaining new head elements in the Map into the
// nodes to append to the head tag
nodesToAppend.push(...srcToNewHeadNodes.values());
log("to append: ", nodesToAppend);
for (const newNode of nodesToAppend) {
log("adding: ", newNode);
var newElt = document.createRange().createContextualFragment(newNode.outerHTML);
log(newElt);
if (api.triggerEvent(document.body, "htmx:addingHeadElement", {headElement: newElt}) !== false) {
currentHead.appendChild(newElt);
added.push(newElt);
}
}
// remove all removed elements, after we have appended the new elements to avoid
// additional network requests for things like style sheets
for (const removedElement of removed) {
if (api.triggerEvent(document.body, "htmx:removingHeadElement", {headElement: removedElement}) !== false) {
currentHead.removeChild(removedElement);
}
}
api.triggerEvent(document.body, "htmx:afterHeadMerge", {added: added, kept: preserved, removed: removed});
}
}
}
htmx.defineExtension("head-support", {
init: function(apiRef) {
// store a reference to the internal API.
api = apiRef;
htmx.on('htmx:afterSwap', function(evt){
let xhr = evt.detail.xhr;
if (xhr) {
var serverResponse = xhr.response;
if (api.triggerEvent(document.body, "htmx:beforeHeadMerge", evt.detail)) {
mergeHead(serverResponse, evt.detail.boosted ? "merge" : "append");
}
}
})
htmx.on('htmx:historyRestore', function(evt){
if (api.triggerEvent(document.body, "htmx:beforeHeadMerge", evt.detail)) {
if (evt.detail.cacheMiss) {
mergeHead(evt.detail.serverResponse, "merge");
} else {
mergeHead(evt.detail.item.head, "merge");
}
}
})
htmx.on('htmx:historyItemCreated', function(evt){
var historyItem = evt.detail.item;
historyItem.head = document.head.outerHTML;
})
}
});
})()

File diff suppressed because one or more lines are too long

View File

@@ -1,32 +1,79 @@
// Global media query and listener state
const systemMediaQuery = window.matchMedia('(prefers-color-scheme: dark)');
let isSystemListenerAttached = false;
const handleSystemThemeChange = (e) => {
const themePreference = document.documentElement.getAttribute('data-theme-preference');
if (themePreference === 'system') {
document.documentElement.setAttribute('data-theme', e.matches ? 'dark' : 'light');
}
// For explicit themes like 'obsidian-prism', 'light', 'dark' - do nothing on system change
};
const initializeTheme = () => { const initializeTheme = () => {
const themeToggle = document.querySelector('.theme-controller'); const themeToggle = document.querySelector('.theme-controller');
if (!themeToggle) { const themePreference = document.documentElement.getAttribute('data-theme-preference');
return;
}
// Detect system preference if (themeToggle) {
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches; // Anonymous mode
if (isSystemListenerAttached) {
systemMediaQuery.removeEventListener('change', handleSystemThemeChange);
isSystemListenerAttached = false;
}
// Initialize theme from local storage or system preference // Avoid re-binding if already bound
const savedTheme = localStorage.getItem('theme'); if (themeToggle.dataset.bound) return;
const initialTheme = savedTheme ? savedTheme : (prefersDark ? 'dark' : 'light'); themeToggle.dataset.bound = "true";
document.documentElement.setAttribute('data-theme', initialTheme);
themeToggle.checked = initialTheme === 'dark';
// Update theme and local storage on toggle // Detect system preference
themeToggle.addEventListener('change', () => { const prefersDark = systemMediaQuery.matches;
const theme = themeToggle.checked ? 'dark' : 'light';
document.documentElement.setAttribute('data-theme', theme);
localStorage.setItem('theme', theme);
});
}; // Initialize theme from local storage or system preference
const savedTheme = localStorage.getItem('theme');
const initialTheme = savedTheme ? savedTheme : (prefersDark ? 'dark' : 'light');
document.documentElement.setAttribute('data-theme', initialTheme);
themeToggle.checked = initialTheme === 'dark';
// Run the initialization after the DOM is fully loaded // Update theme and local storage on toggle
document.addEventListener('DOMContentLoaded', () => { themeToggle.addEventListener('change', () => {
initializeTheme(); const theme = themeToggle.checked ? 'dark' : 'light';
}); document.documentElement.setAttribute('data-theme', theme);
localStorage.setItem('theme', theme);
});
// Reinitialize theme toggle after HTMX swaps } else {
document.addEventListener('htmx:afterSwap', initializeTheme); // Authenticated mode
document.addEventListener('htmx:afterSettle', initializeTheme); localStorage.removeItem('theme');
if (themePreference === 'system') {
// Ensure correct theme is set immediately
const currentSystemTheme = systemMediaQuery.matches ? 'dark' : 'light';
// Only update if needed
if (document.documentElement.getAttribute('data-theme') !== currentSystemTheme) {
document.documentElement.setAttribute('data-theme', currentSystemTheme);
}
if (!isSystemListenerAttached) {
systemMediaQuery.addEventListener('change', handleSystemThemeChange);
isSystemListenerAttached = true;
}
} else {
// Explicit theme: 'light', 'dark', 'obsidian-prism', etc.
if (isSystemListenerAttached) {
systemMediaQuery.removeEventListener('change', handleSystemThemeChange);
isSystemListenerAttached = false;
}
// Ensure data-theme matches preference exactly
if (themePreference && document.documentElement.getAttribute('data-theme') !== themePreference) {
document.documentElement.setAttribute('data-theme', themePreference);
}
}
}
};
// Run the initialization after the DOM is fully loaded
document.addEventListener('DOMContentLoaded', initializeTheme);
// Reinitialize theme toggle after HTMX swaps
document.addEventListener('htmx:afterSwap', initializeTheme);
document.addEventListener('htmx:afterSettle', initializeTheme);

View File

@@ -2,7 +2,7 @@
"name": "html-router", "name": "html-router",
"version": "1.0.0", "version": "1.0.0",
"scripts": { "scripts": {
"tailwind": "npx @tailwindcss/cli -i app.css -o assets/style.css -w -m" "tailwind": "tailwindcss -i app.css -o assets/style.css --watch=always"
}, },
"author": "", "author": "",
"license": "ISC", "license": "ISC",

View File

@@ -1,9 +1,16 @@
use common::storage::types::conversation::SidebarConversation;
use common::storage::{db::SurrealDbClient, store::StorageManager}; use common::storage::{db::SurrealDbClient, store::StorageManager};
use common::utils::embedding::EmbeddingProvider; use common::utils::embedding::EmbeddingProvider;
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine}; use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig}; use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy}; use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy};
use std::sync::Arc; use std::collections::HashMap;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::debug; use tracing::debug;
use crate::{OpenAIClientType, SessionStoreType}; use crate::{OpenAIClientType, SessionStoreType};
@@ -18,8 +25,20 @@ pub struct HtmlState {
pub storage: StorageManager, pub storage: StorageManager,
pub reranker_pool: Option<Arc<RerankerPool>>, pub reranker_pool: Option<Arc<RerankerPool>>,
pub embedding_provider: Arc<EmbeddingProvider>, pub embedding_provider: Arc<EmbeddingProvider>,
conversation_archive_cache: Arc<RwLock<HashMap<String, ConversationArchiveCacheEntry>>>,
conversation_archive_cache_writes: Arc<AtomicUsize>,
} }
#[derive(Clone)]
struct ConversationArchiveCacheEntry {
conversations: Vec<SidebarConversation>,
expires_at: Instant,
}
const CONVERSATION_ARCHIVE_CACHE_TTL: Duration = Duration::from_secs(30);
const CONVERSATION_ARCHIVE_CACHE_MAX_USERS: usize = 1024;
const CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL: usize = 64;
impl HtmlState { impl HtmlState {
pub async fn new_with_resources( pub async fn new_with_resources(
db: Arc<SurrealDbClient>, db: Arc<SurrealDbClient>,
@@ -29,19 +48,23 @@ impl HtmlState {
config: AppConfig, config: AppConfig,
reranker_pool: Option<Arc<RerankerPool>>, reranker_pool: Option<Arc<RerankerPool>>,
embedding_provider: Arc<EmbeddingProvider>, embedding_provider: Arc<EmbeddingProvider>,
template_engine: Option<Arc<TemplateEngine>>,
) -> Result<Self, Box<dyn std::error::Error>> { ) -> Result<Self, Box<dyn std::error::Error>> {
let template_engine = create_template_engine!("templates"); let templates =
debug!("Template engine created for html_router."); template_engine.unwrap_or_else(|| Arc::new(create_template_engine!("templates")));
debug!("Template engine configured for html_router.");
Ok(Self { Ok(Self {
db, db,
openai_client, openai_client,
session_store, session_store,
templates: Arc::new(template_engine), templates,
config, config,
storage, storage,
reranker_pool, reranker_pool,
embedding_provider, embedding_provider,
conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())),
conversation_archive_cache_writes: Arc::new(AtomicUsize::new(0)),
}) })
} }
@@ -52,6 +75,86 @@ impl HtmlState {
.and_then(|value| value.parse().ok()) .and_then(|value| value.parse().ok())
.unwrap_or(RetrievalStrategy::Default) .unwrap_or(RetrievalStrategy::Default)
} }
pub async fn get_cached_conversation_archive(
&self,
user_id: &str,
) -> Option<Vec<SidebarConversation>> {
let now = Instant::now();
let should_evict_expired = {
let cache = self.conversation_archive_cache.read().await;
if let Some(entry) = cache.get(user_id) {
if entry.expires_at > now {
return Some(entry.conversations.clone());
}
true
} else {
false
}
};
if should_evict_expired {
let mut cache = self.conversation_archive_cache.write().await;
cache.remove(user_id);
}
None
}
pub async fn set_cached_conversation_archive(
&self,
user_id: &str,
conversations: Vec<SidebarConversation>,
) {
let now = Instant::now();
let mut cache = self.conversation_archive_cache.write().await;
cache.insert(
user_id.to_string(),
ConversationArchiveCacheEntry {
conversations,
expires_at: now + CONVERSATION_ARCHIVE_CACHE_TTL,
},
);
let writes = self
.conversation_archive_cache_writes
.fetch_add(1, Ordering::Relaxed)
+ 1;
if writes % CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL == 0 {
Self::purge_expired_entries(&mut cache, now);
}
Self::enforce_cache_capacity(&mut cache);
}
pub async fn invalidate_conversation_archive_cache(&self, user_id: &str) {
let mut cache = self.conversation_archive_cache.write().await;
cache.remove(user_id);
}
fn purge_expired_entries(
cache: &mut HashMap<String, ConversationArchiveCacheEntry>,
now: Instant,
) {
cache.retain(|_, entry| entry.expires_at > now);
}
fn enforce_cache_capacity(cache: &mut HashMap<String, ConversationArchiveCacheEntry>) {
if cache.len() <= CONVERSATION_ARCHIVE_CACHE_MAX_USERS {
return;
}
let overflow = cache.len() - CONVERSATION_ARCHIVE_CACHE_MAX_USERS;
let mut by_expiry: Vec<(String, Instant)> = cache
.iter()
.map(|(user_id, entry)| (user_id.clone(), entry.expires_at))
.collect();
by_expiry.sort_by_key(|(_, expires_at)| *expires_at);
for (user_id, _) in by_expiry.into_iter().take(overflow) {
cache.remove(&user_id);
}
}
} }
impl ProvidesDb for HtmlState { impl ProvidesDb for HtmlState {
fn db(&self) -> &Arc<SurrealDbClient> { fn db(&self) -> &Arc<SurrealDbClient> {
@@ -63,3 +166,93 @@ impl ProvidesTemplateEngine for HtmlState {
&self.templates &self.templates
} }
} }
impl crate::middlewares::response_middleware::ProvidesHtmlState for HtmlState {
fn html_state(&self) -> &HtmlState {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::{
storage::types::conversation::SidebarConversation,
utils::{
config::{AppConfig, StorageKind},
embedding::EmbeddingProvider,
},
};
async fn test_state() -> HtmlState {
let namespace = "test_ns";
let database = &uuid::Uuid::new_v4().to_string();
let db = Arc::new(
SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to create in-memory DB"),
);
let session_store = Arc::new(
db.create_session_store()
.await
.expect("Failed to create session store"),
);
let mut config = AppConfig::default();
config.storage = StorageKind::Memory;
let storage = StorageManager::new(&config)
.await
.expect("Failed to create storage manager");
let embedding_provider = Arc::new(
EmbeddingProvider::new_hashed(8).expect("Failed to create embedding provider"),
);
HtmlState::new_with_resources(
db,
Arc::new(async_openai::Client::new()),
session_store,
storage,
config,
None,
embedding_provider,
None,
)
.await
.expect("Failed to create HtmlState")
}
#[tokio::test]
async fn test_expired_conversation_archive_entry_is_evicted_on_read() {
let state = test_state().await;
let user_id = "expired-user";
{
let mut cache = state.conversation_archive_cache.write().await;
cache.insert(
user_id.to_string(),
ConversationArchiveCacheEntry {
conversations: vec![SidebarConversation {
id: "conv-1".to_string(),
title: "A stale chat".to_string(),
}],
expires_at: Instant::now() - Duration::from_secs(1),
},
);
}
let cached = state.get_cached_conversation_archive(user_id).await;
assert!(
cached.is_none(),
"Expired cache entry should not be returned"
);
let cache = state.conversation_archive_cache.read().await;
assert!(
!cache.contains_key(user_id),
"Expired cache entry should be evicted after read"
);
}
}

View File

@@ -35,7 +35,9 @@ where
.add_protected_routes(routes::chat::router()) .add_protected_routes(routes::chat::router())
.add_protected_routes(routes::content::router()) .add_protected_routes(routes::content::router())
.add_protected_routes(routes::knowledge::router()) .add_protected_routes(routes::knowledge::router())
.add_protected_routes(routes::ingestion::router()) .add_protected_routes(routes::ingestion::router(
app_state.config.ingest_max_body_bytes,
))
.add_protected_routes(routes::scratchpad::router()) .add_protected_routes(routes::scratchpad::router())
.with_compression() .with_compression()
.build() .build()

View File

@@ -46,3 +46,14 @@ pub async fn require_auth(auth: AuthSessionType, mut request: Request, next: Nex
} }
} }
} }
pub async fn require_admin(auth: AuthSessionType, mut request: Request, next: Next) -> Response {
match auth.current_user {
Some(user) if user.admin => {
request.extensions_mut().insert(user);
next.run(request).await
}
Some(_) => TemplateResponse::redirect("/").into_response(),
None => TemplateResponse::redirect("/signin").into_response(),
}
}

View File

@@ -1,17 +1,33 @@
use std::collections::HashMap;
use axum::{ use axum::{
extract::State, extract::{Request, State},
http::{HeaderName, StatusCode}, http::{HeaderName, StatusCode},
middleware::Next,
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Redirect, Response},
Extension, Extension,
}; };
use axum_htmx::{HxRequest, HX_TRIGGER}; use axum_htmx::{HxRequest, HX_TRIGGER};
use common::{error::AppError, utils::template_engine::ProvidesTemplateEngine}; use common::{
use minijinja::{context, Value}; error::AppError,
utils::template_engine::{ProvidesTemplateEngine, Value},
};
use minijinja::context;
use serde::Serialize; use serde::Serialize;
use serde_json::json; use serde_json::json;
use tracing::error; use tracing::error;
#[derive(Clone)] use crate::{html_state::HtmlState, AuthSessionType};
use common::storage::types::{
conversation::{Conversation, SidebarConversation},
user::{Theme, User},
};
pub trait ProvidesHtmlState {
fn html_state(&self) -> &HtmlState;
}
#[derive(Clone, Debug)]
pub enum TemplateKind { pub enum TemplateKind {
Full(String), Full(String),
Partial(String, String), Partial(String, String),
@@ -98,20 +114,118 @@ impl IntoResponse for TemplateResponse {
} }
} }
#[derive(Serialize)]
struct TemplateUser {
id: String,
email: String,
admin: bool,
timezone: String,
theme: String,
}
impl From<&User> for TemplateUser {
fn from(user: &User) -> Self {
Self {
id: user.id.clone(),
email: user.email.clone(),
admin: user.admin,
timezone: user.timezone.clone(),
theme: user.theme.as_str().to_string(),
}
}
}
#[derive(Serialize)]
struct ContextWrapper<'a> {
user_theme: &'a str,
initial_theme: &'a str,
is_authenticated: bool,
user: Option<&'a TemplateUser>,
conversation_archive: Vec<SidebarConversation>,
#[serde(flatten)]
context: HashMap<String, Value>,
}
pub async fn with_template_response<S>( pub async fn with_template_response<S>(
State(state): State<S>, State(state): State<S>,
HxRequest(is_htmx): HxRequest, HxRequest(is_htmx): HxRequest,
response: Response<axum::body::Body>, req: Request,
) -> Response<axum::body::Body> next: Next,
) -> Response
where where
S: ProvidesTemplateEngine + Clone + Send + Sync + 'static, S: ProvidesTemplateEngine + ProvidesHtmlState + Clone + Send + Sync + 'static,
{ {
let mut user_theme = Theme::System.as_str();
let mut initial_theme = Theme::System.initial_theme();
let mut is_authenticated = false;
let mut current_user_id = None;
let mut current_user = None;
{
if let Some(auth) = req.extensions().get::<AuthSessionType>() {
if let Some(user) = &auth.current_user {
is_authenticated = true;
current_user_id = Some(user.id.clone());
user_theme = user.theme.as_str();
initial_theme = user.theme.initial_theme();
current_user = Some(TemplateUser::from(user));
}
}
}
let response = next.run(req).await;
// Headers to forward from the original response // Headers to forward from the original response
const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"]; const HTMX_HEADERS_TO_FORWARD: &[&str] = &["HX-Push", "HX-Trigger", "HX-Redirect"];
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() { if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
let template_engine = state.template_engine(); let template_engine = state.template_engine();
let mut conversation_archive = Vec::new();
let should_load_conversation_archive =
matches!(&template_response.template_kind, TemplateKind::Full(_));
if should_load_conversation_archive {
if let Some(user_id) = current_user_id {
let html_state = state.html_state();
if let Some(cached_archive) =
html_state.get_cached_conversation_archive(&user_id).await
{
conversation_archive = cached_archive;
} else if let Ok(archive) =
Conversation::get_user_sidebar_conversations(&user_id, &html_state.db).await
{
html_state
.set_cached_conversation_archive(&user_id, archive.clone())
.await;
conversation_archive = archive;
}
}
}
fn context_to_map(
value: &Value,
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
match value.kind() {
minijinja::value::ValueKind::Map => {
let mut map = HashMap::new();
if let Ok(keys) = value.try_iter() {
for key in keys {
if let Ok(val) = value.get_item(&key) {
map.insert(key.to_string(), val);
}
}
}
Ok(map)
}
minijinja::value::ValueKind::None | minijinja::value::ValueKind::Undefined => {
Ok(HashMap::new())
}
other => Err(other),
}
}
// Helper to forward relevant headers // Helper to forward relevant headers
fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) { fn forward_headers(from: &axum::http::HeaderMap, to: &mut axum::http::HeaderMap) {
for &header_name in HTMX_HEADERS_TO_FORWARD { for &header_name in HTMX_HEADERS_TO_FORWARD {
@@ -123,9 +237,29 @@ where
} }
} }
let context_map = match context_to_map(&template_response.context) {
Ok(map) => map,
Err(kind) => {
error!(
"Template context must be a map or unit, got kind={:?} for template_kind={:?}",
kind, template_response.template_kind
);
return (StatusCode::INTERNAL_SERVER_ERROR, Html(fallback_error())).into_response();
}
};
let context = ContextWrapper {
user_theme: &user_theme,
initial_theme: &initial_theme,
is_authenticated,
user: current_user.as_ref(),
conversation_archive,
context: context_map,
};
match &template_response.template_kind { match &template_response.template_kind {
TemplateKind::Full(name) => { TemplateKind::Full(name) => {
match template_engine.render(name, &template_response.context) { match template_engine.render(name, &Value::from_serialize(&context)) {
Ok(html) => { Ok(html) => {
let mut final_response = Html(html).into_response(); let mut final_response = Html(html).into_response();
forward_headers(response.headers(), final_response.headers_mut()); forward_headers(response.headers(), final_response.headers_mut());
@@ -138,7 +272,11 @@ where
} }
} }
TemplateKind::Partial(template, block) => { TemplateKind::Partial(template, block) => {
match template_engine.render_block(template, block, &template_response.context) { match template_engine.render_block(
template,
block,
&Value::from_serialize(&context),
) {
Ok(html) => { Ok(html) => {
let mut final_response = Html(html).into_response(); let mut final_response = Html(html).into_response();
forward_headers(response.headers(), final_response.headers_mut()); forward_headers(response.headers(), final_response.headers_mut());
@@ -169,12 +307,15 @@ where
let trigger_payload = json!({"toast": {"title": title, "description": description, "type": "error"}}); let trigger_payload = json!({"toast": {"title": title, "description": description, "type": "error"}});
let trigger_value = serde_json::to_string(&trigger_payload).unwrap_or_else(|e| { let trigger_value = serde_json::to_string(&trigger_payload).unwrap_or_else(|e| {
error!("Failed to serialize HX-Trigger payload: {}", e); error!("Failed to serialize HX-Trigger payload: {}", e);
r#"{"toast":{"title":"Error","description":"An unexpected error occurred.", "type":"error"}}"#.to_string() r#"{"toast":{"title":"Error","description":"An unexpected error occurred.", "type":"error"}}"#
.to_string()
}); });
(StatusCode::NO_CONTENT, [(HX_TRIGGER, trigger_value)], "").into_response() (StatusCode::NO_CONTENT, [(HX_TRIGGER, trigger_value)], "").into_response()
} else { } else {
// Non-HTMX request: Render the full errors/error.html page // Non-HTMX request: Render the full errors/error.html page
match template_engine.render("errors/error.html", &template_response.context) { match template_engine
.render("errors/error.html", &Value::from_serialize(&context))
{
Ok(html) => (*status, Html(html)).into_response(), Ok(html) => (*status, Html(html)).into_response(),
Err(e) => { Err(e) => {
error!("Critical: Failed to render 'errors/error.html': {:?}", e); error!("Critical: Failed to render 'errors/error.html': {:?}", e);

View File

@@ -1,8 +1,4 @@
use axum::{ use axum::{extract::FromRef, middleware::from_fn_with_state, Router};
extract::FromRef,
middleware::{from_fn_with_state, map_response_with_state},
Router,
};
use axum_session::SessionLayer; use axum_session::SessionLayer;
use axum_session_auth::{AuthConfig, AuthSessionLayer}; use axum_session_auth::{AuthConfig, AuthSessionLayer};
use axum_session_surreal::SessionSurrealPool; use axum_session_surreal::SessionSurrealPool;
@@ -124,79 +120,80 @@ where
} }
pub fn build(self) -> Router<S> { pub fn build(self) -> Router<S> {
// Start with an empty router // Build the "App" router (Pages, API interactions, etc.)
let mut public_router = Router::new(); let mut app_router = Router::new();
// Merge all public routers // Merge all public routers (pages)
for router in self.public_routers { for router in self.public_routers {
public_router = public_router.merge(router); app_router = app_router.merge(router);
} }
// Add nested public routes // Add nested public routes
for (path, router) in self.nested_routes { for (path, router) in self.nested_routes {
public_router = public_router.nest(&path, router); app_router = app_router.nest(&path, router);
} }
// Add public assets to public router // Build protected router logic...
if let Some(assets_config) = self.public_assets_config {
// Call the macro using the stored relative directory path
let asset_service = create_asset_service!(&assets_config.directory);
// Nest the resulting service under the stored URL path
public_router = public_router.nest_service(&assets_config.path, asset_service);
}
// Start with an empty protected router
let mut protected_router = Router::new(); let mut protected_router = Router::new();
// Check if there are any protected routers
let has_protected_routes = let has_protected_routes =
!self.protected_routers.is_empty() || !self.nested_protected_routes.is_empty(); !self.protected_routers.is_empty() || !self.nested_protected_routes.is_empty();
// Merge root-level protected routers
for router in self.protected_routers { for router in self.protected_routers {
protected_router = protected_router.merge(router); protected_router = protected_router.merge(router);
} }
// Nest protected routers
for (path, router) in self.nested_protected_routes { for (path, router) in self.nested_protected_routes {
protected_router = protected_router.nest(&path, router); protected_router = protected_router.nest(&path, router);
} }
// Apply auth middleware
if has_protected_routes { if has_protected_routes {
protected_router = protected_router protected_router = protected_router
.route_layer(from_fn_with_state(self.app_state.clone(), require_auth)); .route_layer(from_fn_with_state(self.app_state.clone(), require_auth));
} }
// Combine public and protected routes // Combine public and protected routes into the App router
let mut router = Router::new().merge(public_router).merge(protected_router); app_router = app_router.merge(protected_router);
// Apply custom middleware in order they were added // Apply custom middleware to the App router
for middleware_fn in self.custom_middleware { for middleware_fn in self.custom_middleware {
router = middleware_fn(router); app_router = middleware_fn(app_router);
} }
// Apply common middleware // Apply App-specific Middleware (Analytics, Template, Auth, Session)
router = router.layer(from_fn_with_state( app_router = app_router.layer(from_fn_with_state(
self.app_state.clone(), self.app_state.clone(),
analytics_middleware::<HtmlState>, analytics_middleware::<HtmlState>,
)); ));
router = router.layer(map_response_with_state( app_router = app_router.layer(from_fn_with_state(
self.app_state.clone(), self.app_state.clone(),
with_template_response::<HtmlState>, with_template_response::<HtmlState>,
)); ));
router = router.layer( app_router = app_router.layer(
AuthSessionLayer::<User, String, SessionSurrealPool<Any>, Surreal<Any>>::new(Some( AuthSessionLayer::<User, String, SessionSurrealPool<Any>, Surreal<Any>>::new(Some(
self.app_state.db.client.clone(), self.app_state.db.client.clone(),
)) ))
.with_config(AuthConfig::<String>::default()), .with_config(AuthConfig::<String>::default()),
); );
router = router.layer(SessionLayer::new((*self.app_state.session_store).clone())); app_router = app_router.layer(SessionLayer::new((*self.app_state.session_store).clone()));
if self.compression_enabled { // Build the Final router, starting with assets (bypassing app middleware)
router = router.layer(compression_layer()); let mut final_router = Router::new();
if let Some(assets_config) = self.public_assets_config {
// Call the macro using the stored relative directory path
let asset_service = create_asset_service!(&assets_config.directory);
// Nest the resulting service under the stored URL path
final_router = final_router.nest_service(&assets_config.path, asset_service);
} }
router // Merge the App router
final_router = final_router.merge(app_router);
// Apply Global Middleware (Compression)
if self.compression_enabled {
final_router = final_router.layer(compression_layer());
}
final_router
} }
} }

View File

@@ -9,33 +9,46 @@ use crate::{
}, },
AuthSessionType, AuthSessionType,
}; };
use common::storage::types::{conversation::Conversation, user::User}; use common::storage::types::user::{Theme, User};
use crate::html_state::HtmlState; use crate::html_state::HtmlState;
#[derive(Serialize)] #[derive(Serialize)]
pub struct AccountPageData { pub struct AccountPageData {
user: User,
timezones: Vec<String>, timezones: Vec<String>,
conversation_archive: Vec<Conversation>, theme_options: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
selected_timezone: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
selected_theme: Option<String>,
} }
pub async fn show_account_page( pub async fn show_account_page(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
State(state): State<HtmlState>, State(_state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let timezones = TZ_VARIANTS let timezones = TZ_VARIANTS
.iter() .iter()
.map(std::string::ToString::to_string) .map(std::string::ToString::to_string)
.collect(); .collect();
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?; let theme_options = vec![
Theme::Light.as_str().to_string(),
Theme::Dark.as_str().to_string(),
Theme::WarmPaper.as_str().to_string(),
Theme::ObsidianPrism.as_str().to_string(),
Theme::System.as_str().to_string(),
];
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"auth/account_settings.html", "auth/account_settings.html",
AccountPageData { AccountPageData {
user,
timezones, timezones,
conversation_archive, theme_options,
api_key: user.api_key,
selected_timezone: None,
selected_theme: None,
}, },
)) ))
} }
@@ -51,20 +64,16 @@ pub async fn set_api_key(
// Clear the cache so new requests have access to the user with api key // Clear the cache so new requests have access to the user with api key
auth.cache_clear_user(user.id.to_string()); auth.cache_clear_user(user.id.to_string());
// Update the user's API key
let updated_user = User {
api_key: Some(api_key),
..user.clone()
};
// Render the API key section block // Render the API key section block
Ok(TemplateResponse::new_partial( Ok(TemplateResponse::new_partial(
"auth/account_settings.html", "auth/account_settings.html",
"api_key_section", "api_key_section",
AccountPageData { AccountPageData {
user: updated_user,
timezones: vec![], timezones: vec![],
conversation_archive: vec![], theme_options: vec![],
api_key: Some(api_key),
selected_timezone: None,
selected_theme: None,
}, },
)) ))
} }
@@ -99,12 +108,6 @@ pub async fn update_timezone(
// Clear the cache // Clear the cache
auth.cache_clear_user(user.id.to_string()); auth.cache_clear_user(user.id.to_string());
// Update the user's API key
let updated_user = User {
timezone: form.timezone,
..user.clone()
};
let timezones = TZ_VARIANTS let timezones = TZ_VARIANTS
.iter() .iter()
.map(std::string::ToString::to_string) .map(std::string::ToString::to_string)
@@ -115,9 +118,48 @@ pub async fn update_timezone(
"auth/account_settings.html", "auth/account_settings.html",
"timezone_section", "timezone_section",
AccountPageData { AccountPageData {
user: updated_user,
timezones, timezones,
conversation_archive: vec![], theme_options: vec![],
api_key: None,
selected_timezone: Some(form.timezone),
selected_theme: None,
},
))
}
#[derive(Deserialize)]
pub struct UpdateThemeForm {
theme: String,
}
pub async fn update_theme(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
auth: AuthSessionType,
Form(form): Form<UpdateThemeForm>,
) -> Result<impl IntoResponse, HtmlError> {
User::update_theme(&user.id, &form.theme, &state.db).await?;
// Clear the cache
auth.cache_clear_user(user.id.to_string());
let theme_options = vec![
Theme::Light.as_str().to_string(),
Theme::Dark.as_str().to_string(),
Theme::WarmPaper.as_str().to_string(),
Theme::ObsidianPrism.as_str().to_string(),
Theme::System.as_str().to_string(),
];
Ok(TemplateResponse::new_partial(
"auth/account_settings.html",
"theme_section",
AccountPageData {
timezones: vec![],
theme_options,
api_key: None,
selected_timezone: None,
selected_theme: Some(form.theme),
}, },
)) ))
} }

View File

@@ -16,6 +16,7 @@ where
.route("/account", get(handlers::show_account_page)) .route("/account", get(handlers::show_account_page))
.route("/set-api-key", post(handlers::set_api_key)) .route("/set-api-key", post(handlers::set_api_key))
.route("/update-timezone", patch(handlers::update_timezone)) .route("/update-timezone", patch(handlers::update_timezone))
.route("/update-theme", patch(handlers::update_theme))
.route( .route(
"/change-password", "/change-password",
get(handlers::show_change_password).patch(handlers::change_password), get(handlers::show_change_password).patch(handlers::change_password),

View File

@@ -10,7 +10,6 @@ use common::{
error::AppError, error::AppError,
storage::types::{ storage::types::{
analytics::Analytics, analytics::Analytics,
conversation::Conversation,
knowledge_entity::KnowledgeEntity, knowledge_entity::KnowledgeEntity,
system_prompts::{ system_prompts::{
DEFAULT_IMAGE_PROCESSING_PROMPT, DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT, DEFAULT_IMAGE_PROCESSING_PROMPT, DEFAULT_INGRESS_ANALYSIS_SYSTEM_PROMPT,
@@ -18,28 +17,22 @@ use common::{
}, },
system_settings::SystemSettings, system_settings::SystemSettings,
text_chunk::TextChunk, text_chunk::TextChunk,
user::User,
}, },
}; };
use tracing::{error, info}; use tracing::{error, info};
use crate::{ use crate::{
html_state::HtmlState, html_state::HtmlState,
middlewares::{ middlewares::response_middleware::{HtmlError, TemplateResponse},
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
}; };
#[derive(Serialize)] #[derive(Serialize)]
pub struct AdminPanelData { pub struct AdminPanelData {
user: User,
settings: SystemSettings, settings: SystemSettings,
analytics: Option<Analytics>, analytics: Option<Analytics>,
users: Option<i64>, users: Option<i64>,
default_query_prompt: String, default_query_prompt: String,
default_image_prompt: String, default_image_prompt: String,
conversation_archive: Vec<Conversation>,
available_models: Option<ListModelResponse>, available_models: Option<ListModelResponse>,
current_section: AdminSection, current_section: AdminSection,
} }
@@ -64,7 +57,6 @@ pub struct AdminPanelQuery {
pub async fn show_admin_panel( pub async fn show_admin_panel(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Query(query): Query<AdminPanelQuery>, Query(query): Query<AdminPanelQuery>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let section = match query.section.as_deref() { let section = match query.section.as_deref() {
@@ -72,10 +64,7 @@ pub async fn show_admin_panel(
_ => AdminSection::Overview, _ => AdminSection::Overview,
}; };
let (settings, conversation_archive) = tokio::try_join!( let settings = SystemSettings::get_current(&state.db).await?;
SystemSettings::get_current(&state.db),
User::get_user_conversations(&user.id, &state.db)
)?;
let (analytics, users) = if section == AdminSection::Overview { let (analytics, users) = if section == AdminSection::Overview {
let (analytics, users) = tokio::try_join!( let (analytics, users) = tokio::try_join!(
@@ -103,14 +92,12 @@ pub async fn show_admin_panel(
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"admin/base.html", "admin/base.html",
AdminPanelData { AdminPanelData {
user,
settings, settings,
analytics, analytics,
available_models, available_models,
users, users,
default_query_prompt: DEFAULT_QUERY_SYSTEM_PROMPT.to_string(), default_query_prompt: DEFAULT_QUERY_SYSTEM_PROMPT.to_string(),
default_image_prompt: DEFAULT_IMAGE_PROCESSING_PROMPT.to_string(), default_image_prompt: DEFAULT_IMAGE_PROCESSING_PROMPT.to_string(),
conversation_archive,
current_section: section, current_section: section,
}, },
)) ))
@@ -140,14 +127,8 @@ pub struct RegistrationToggleData {
pub async fn toggle_registration_status( pub async fn toggle_registration_status(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<RegistrationToggleInput>, Form(input): Form<RegistrationToggleInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?; let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings { let new_settings = SystemSettings {
@@ -172,7 +153,7 @@ pub struct ModelSettingsInput {
processing_model: String, processing_model: String,
image_processing_model: String, image_processing_model: String,
voice_processing_model: String, voice_processing_model: String,
embedding_model: String, embedding_model: Option<String>,
embedding_dimensions: Option<u32>, embedding_dimensions: Option<u32>,
} }
@@ -184,14 +165,8 @@ pub struct ModelSettingsData {
pub async fn update_model_settings( pub async fn update_model_settings(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<ModelSettingsInput>, Form(input): Form<ModelSettingsInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?; let current_settings = SystemSettings::get_current(&state.db).await?;
// Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI // Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI
@@ -219,7 +194,9 @@ pub async fn update_model_settings(
.embedding_dimensions .embedding_dimensions
.is_some_and(|new_dims| new_dims != current_settings.embedding_dimensions); .is_some_and(|new_dims| new_dims != current_settings.embedding_dimensions);
( (
input.embedding_model, input
.embedding_model
.unwrap_or_else(|| current_settings.embedding_model.clone()),
input input
.embedding_dimensions .embedding_dimensions
.unwrap_or(current_settings.embedding_dimensions), .unwrap_or(current_settings.embedding_dimensions),
@@ -302,13 +279,7 @@ pub struct SystemPromptEditData {
pub async fn show_edit_system_prompt( pub async fn show_edit_system_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -332,14 +303,8 @@ pub struct SystemPromptSectionData {
pub async fn patch_query_prompt( pub async fn patch_query_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<SystemPromptUpdateInput>, Form(input): Form<SystemPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?; let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings { let new_settings = SystemSettings {
@@ -366,13 +331,7 @@ pub struct IngestionPromptEditData {
pub async fn show_edit_ingestion_prompt( pub async fn show_edit_ingestion_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -391,14 +350,8 @@ pub struct IngestionPromptUpdateInput {
pub async fn patch_ingestion_prompt( pub async fn patch_ingestion_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<IngestionPromptUpdateInput>, Form(input): Form<IngestionPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?; let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings { let new_settings = SystemSettings {
@@ -425,13 +378,7 @@ pub struct ImagePromptEditData {
pub async fn show_edit_image_prompt( pub async fn show_edit_image_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let settings = SystemSettings::get_current(&state.db).await?; let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
@@ -450,14 +397,8 @@ pub struct ImagePromptUpdateInput {
pub async fn patch_image_prompt( pub async fn patch_image_prompt(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<ImagePromptUpdateInput>, Form(input): Form<ImagePromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?; let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings { let new_settings = SystemSettings {

View File

@@ -1,6 +1,7 @@
mod handlers; mod handlers;
use axum::{ use axum::{
extract::FromRef, extract::FromRef,
middleware::from_fn,
routing::{get, patch}, routing::{get, patch},
Router, Router,
}; };
@@ -10,7 +11,7 @@ use handlers::{
toggle_registration_status, update_model_settings, toggle_registration_status, update_model_settings,
}; };
use crate::html_state::HtmlState; use crate::{html_state::HtmlState, middlewares::auth_middleware::require_admin};
pub fn router<S>() -> Router<S> pub fn router<S>() -> Router<S>
where where
@@ -27,4 +28,5 @@ where
.route("/update-ingestion-prompt", patch(patch_ingestion_prompt)) .route("/update-ingestion-prompt", patch(patch_ingestion_prompt))
.route("/edit-image-prompt", get(show_edit_image_prompt)) .route("/edit-image-prompt", get(show_edit_image_prompt))
.route("/update-image-prompt", patch(patch_image_prompt)) .route("/update-image-prompt", patch(patch_image_prompt))
.route_layer(from_fn(require_admin))
} }

View File

@@ -1,8 +1,4 @@
use axum::{ use axum::{extract::State, response::IntoResponse, Form};
extract::State,
response::{Html, IntoResponse},
Form,
};
use axum_htmx::HxBoosted; use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -46,7 +42,7 @@ pub async fn authenticate_user(
let user = match User::authenticate(&form.email, &form.password, &state.db).await { let user = match User::authenticate(&form.email, &form.password, &state.db).await {
Ok(user) => user, Ok(user) => user,
Err(_) => { Err(_) => {
return Ok(Html("<p>Incorrect email or password </p>").into_response()); return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
} }
}; };

View File

@@ -1,12 +1,8 @@
use axum::{ use axum::{extract::State, response::IntoResponse, Form};
extract::State,
response::{Html, IntoResponse},
Form,
};
use axum_htmx::HxBoosted; use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use common::storage::types::user::User; use common::storage::types::user::{Theme, User};
use crate::{ use crate::{
html_state::HtmlState, html_state::HtmlState,
@@ -45,11 +41,19 @@ pub async fn process_signup_and_show_verification(
auth: AuthSessionType, auth: AuthSessionType,
Form(form): Form<SignupParams>, Form(form): Form<SignupParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let user = match User::create_new(form.email, form.password, &state.db, form.timezone).await { let user = match User::create_new(
form.email,
form.password,
&state.db,
form.timezone,
Theme::System.as_str().to_string(),
)
.await
{
Ok(user) => user, Ok(user) => user,
Err(e) => { Err(e) => {
tracing::error!("{:?}", e); tracing::error!("{:?}", e);
return Ok(Html(format!("<p>{e}</p>")).into_response()); return Ok(TemplateResponse::bad_request(&e.to_string()).into_response());
} }
}; };

View File

@@ -45,10 +45,8 @@ where
#[derive(Serialize)] #[derive(Serialize)]
pub struct ChatPageData { pub struct ChatPageData {
user: User,
history: Vec<Message>, history: Vec<Message>,
conversation: Option<Conversation>, conversation: Option<Conversation>,
conversation_archive: Vec<Conversation>,
} }
pub async fn show_initialized_chat( pub async fn show_initialized_chat(
@@ -75,8 +73,7 @@ pub async fn show_initialized_chat(
state.db.store_item(conversation.clone()).await?; state.db.store_item(conversation.clone()).await?;
state.db.store_item(ai_message.clone()).await?; state.db.store_item(ai_message.clone()).await?;
state.db.store_item(user_message.clone()).await?; state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let messages = vec![user_message, ai_message]; let messages = vec![user_message, ai_message];
@@ -84,8 +81,6 @@ pub async fn show_initialized_chat(
"chat/base.html", "chat/base.html",
ChatPageData { ChatPageData {
history: messages, history: messages,
user,
conversation_archive,
conversation: Some(conversation.clone()), conversation: Some(conversation.clone()),
}, },
) )
@@ -100,17 +95,13 @@ pub async fn show_initialized_chat(
} }
pub async fn show_chat_base( pub async fn show_chat_base(
State(state): State<HtmlState>, State(_state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"chat/base.html", "chat/base.html",
ChatPageData { ChatPageData {
history: vec![], history: vec![],
user,
conversation_archive,
conversation: None, conversation: None,
}, },
)) ))
@@ -126,8 +117,6 @@ pub async fn show_existing_chat(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let (conversation, messages) = let (conversation, messages) =
Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db) Conversation::get_complete_conversation(conversation_id.as_str(), &user.id, &state.db)
.await?; .await?;
@@ -136,9 +125,7 @@ pub async fn show_existing_chat(
"chat/base.html", "chat/base.html",
ChatPageData { ChatPageData {
history: messages, history: messages,
user,
conversation: Some(conversation), conversation: Some(conversation),
conversation_archive,
}, },
)) ))
} }
@@ -192,7 +179,7 @@ pub async fn new_chat_user_message(
None => return Ok(Redirect::to("/").into_response()), None => return Ok(Redirect::to("/").into_response()),
}; };
let conversation = Conversation::new(user.id, "New chat".to_string()); let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
let user_message = Message::new( let user_message = Message::new(
conversation.id.clone(), conversation.id.clone(),
MessageRole::User, MessageRole::User,
@@ -202,6 +189,7 @@ pub async fn new_chat_user_message(
state.db.store_item(conversation.clone()).await?; state.db.store_item(conversation.clone()).await?;
state.db.store_item(user_message.clone()).await?; state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
#[derive(Serialize)] #[derive(Serialize)]
struct SSEResponseInitData { struct SSEResponseInitData {
@@ -232,8 +220,6 @@ pub struct PatchConversationTitle {
#[derive(Serialize)] #[derive(Serialize)]
pub struct DrawerContext { pub struct DrawerContext {
user: User,
conversation_archive: Vec<Conversation>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
edit_conversation_id: Option<String>, edit_conversation_id: Option<String>,
} }
@@ -242,20 +228,19 @@ pub async fn show_conversation_editing_title(
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(conversation_id): Path<String>, Path(conversation_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?; let conversation: Conversation = state
.db
.get_item(&conversation_id)
.await?
.ok_or_else(|| AppError::NotFound("Conversation not found".to_string()))?;
let owns = conversation_archive if conversation.user_id != user.id {
.iter()
.any(|c| c.id == conversation_id && c.user_id == user.id);
if !owns {
return Ok(TemplateResponse::unauthorized().into_response()); return Ok(TemplateResponse::unauthorized().into_response());
} }
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"sidebar.html", "sidebar.html",
DrawerContext { DrawerContext {
user,
conversation_archive,
edit_conversation_id: Some(conversation_id), edit_conversation_id: Some(conversation_id),
}, },
) )
@@ -269,14 +254,11 @@ pub async fn patch_conversation_title(
Form(form): Form<PatchConversationTitle>, Form(form): Form<PatchConversationTitle>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?; Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
let updated_conversations = User::get_user_conversations(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"sidebar.html", "sidebar.html",
DrawerContext { DrawerContext {
user,
conversation_archive: updated_conversations,
edit_conversation_id: None, edit_conversation_id: None,
}, },
) )
@@ -302,30 +284,23 @@ pub async fn delete_conversation(
.db .db
.delete_item::<Conversation>(&conversation_id) .delete_item::<Conversation>(&conversation_id)
.await?; .await?;
state.invalidate_conversation_archive_cache(&user.id).await;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"sidebar.html", "sidebar.html",
DrawerContext { DrawerContext {
user,
conversation_archive,
edit_conversation_id: None, edit_conversation_id: None,
}, },
) )
.into_response()) .into_response())
} }
pub async fn reload_sidebar( pub async fn reload_sidebar(
State(state): State<HtmlState>, State(_state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"sidebar.html", "sidebar.html",
DrawerContext { DrawerContext {
user,
conversation_archive,
edit_conversation_id: None, edit_conversation_id: None,
}, },
) )

View File

@@ -1,10 +1,12 @@
#![allow(clippy::missing_docs_in_private_items)]
use std::{pin::Pin, sync::Arc, time::Duration}; use std::{pin::Pin, sync::Arc, time::Duration};
use async_stream::stream; use async_stream::stream;
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
response::{ response::{
sse::{Event, KeepAlive}, sse::{Event, KeepAlive, KeepAliveStream},
Sse, Sse,
}, },
}; };
@@ -24,7 +26,7 @@ use retrieval_pipeline::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::from_str; use serde_json::from_str;
use tokio::sync::{mpsc::channel, Mutex}; use tokio::sync::{mpsc::channel, Mutex};
use tracing::{debug, error}; use tracing::{debug, error, info};
use common::storage::{ use common::storage::{
db::SurrealDbClient, db::SurrealDbClient,
@@ -38,10 +40,21 @@ use common::storage::{
use crate::{html_state::HtmlState, AuthSessionType}; use crate::{html_state::HtmlState, AuthSessionType};
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
type SseResponse = Sse<KeepAliveStream<EventStream>>;
fn sse_with_keep_alive(stream: EventStream) -> SseResponse {
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
// Error handling function // Error handling function
fn create_error_stream( fn create_error_stream(message: impl Into<String>) -> EventStream {
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
let message = message.into(); let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed() stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
} }
@@ -51,53 +64,125 @@ async fn get_message_and_user(
db: &SurrealDbClient, db: &SurrealDbClient,
current_user: Option<User>, current_user: Option<User>,
message_id: &str, message_id: &str,
) -> Result< ) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
(Message, User, Conversation, Vec<Message>),
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
> {
// Check authentication // Check authentication
let user = match current_user { let Some(user) = current_user else {
Some(user) => user, return Err(sse_with_keep_alive(create_error_stream(
None => { "You must be signed in to use this feature",
return Err(Sse::new(create_error_stream( )));
"You must be signed in to use this feature",
)))
}
}; };
// Retrieve message // Retrieve message
let message = match db.get_item::<Message>(message_id).await { let message = match db.get_item::<Message>(message_id).await {
Ok(Some(message)) => message, Ok(Some(message)) => message,
Ok(None) => { Ok(None) => {
return Err(Sse::new(create_error_stream( return Err(sse_with_keep_alive(create_error_stream(
"Message not found: the specified message does not exist", "Message not found: the specified message does not exist",
))) )))
} }
Err(e) => { Err(e) => {
error!("Database error retrieving message {}: {:?}", message_id, e); error!("Database error retrieving message {}: {:?}", message_id, e);
return Err(Sse::new(create_error_stream( return Err(sse_with_keep_alive(create_error_stream(
"Failed to retrieve message: database error", "Failed to retrieve message: database error",
))); )));
} }
}; };
// Get conversation history // Get conversation history
let (conversation, mut history) = let (conversation, history) =
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
{ {
Err(e) => { Err(e) => {
error!("Database error retrieving message {}: {:?}", message_id, e); error!("Database error retrieving message {}: {:?}", message_id, e);
return Err(Sse::new(create_error_stream( return Err(sse_with_keep_alive(create_error_stream(
"Failed to retrieve message: database error", "Failed to retrieve message: database error",
))); )));
} }
Ok((conversation, history)) => (conversation, history), Ok((conversation, history)) => (conversation, history),
}; };
// Remove the last message, its the same as the message let Some(message_index) = find_message_index(&history, message_id) else {
history.pop(); return Err(sse_with_keep_alive(create_error_stream(
"Message not found in conversation history",
)));
};
Ok((message, user, conversation, history)) let Some(message_from_history) = history.get(message_index) else {
return Err(sse_with_keep_alive(create_error_stream(
"Message not found in conversation history",
)));
};
if message_from_history.role != MessageRole::User {
return Err(sse_with_keep_alive(create_error_stream(
"Only user messages can be used to generate a response",
)));
}
let message = message_from_history.clone();
let history_before_message = history_before_message(&history, message_index);
let existing_ai_response = find_existing_ai_response(&history, message_index);
Ok((
message,
user,
conversation,
history_before_message,
existing_ai_response,
))
}
fn find_message_index(messages: &[Message], message_id: &str) -> Option<usize> {
messages.iter().position(|message| message.id == message_id)
}
fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option<Message> {
messages
.iter()
.skip(user_message_index + 1)
.take_while(|message| message.role != MessageRole::User)
.find(|message| message.role == MessageRole::AI)
.cloned()
}
fn history_before_message(messages: &[Message], message_index: usize) -> Vec<Message> {
messages.iter().take(message_index).cloned().collect()
}
fn create_replayed_response_stream(state: &HtmlState, existing_ai_message: Message) -> SseResponse {
let references_event = if existing_ai_message
.references
.as_ref()
.is_some_and(|references| !references.is_empty())
{
state
.templates
.render(
"chat/reference_list.html",
&Value::from_serialize(ReferenceData {
message: existing_ai_message.clone(),
}),
)
.ok()
.map(|html| Event::default().event("references").data(html))
} else {
None
};
let answer = existing_ai_message.content;
let event_stream = stream! {
yield Ok(Event::default().event("chat_message").data(answer));
if let Some(event) = references_event {
yield Ok(event);
}
yield Ok(Event::default().event("close_stream").data("Stream complete"));
};
sse_with_keep_alive(event_stream.boxed())
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -105,21 +190,42 @@ pub struct QueryParams {
message_id: String, message_id: String,
} }
#[derive(Serialize)]
struct ReferenceData {
message: Message,
}
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
response
.references
.iter()
.map(|reference| reference.reference.clone())
.collect()
}
#[allow(clippy::too_many_lines)]
pub async fn get_response_stream( pub async fn get_response_stream(
State(state): State<HtmlState>, State(state): State<HtmlState>,
auth: AuthSessionType, auth: AuthSessionType,
// auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Query(params): Query<QueryParams>, Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> { ) -> SseResponse {
// 1. Authentication and initial data validation // 1. Authentication and initial data validation
let (user_message, user, _conversation, history) = let (user_message, user, _conversation, history, existing_ai_response) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await { match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
Ok((user_message, user, conversation, history)) => { Ok((user_message, user, conversation, history, existing_ai_response)) => (
(user_message, user, conversation, history) user_message,
} user,
conversation,
history,
existing_ai_response,
),
Err(error_stream) => return error_stream, Err(error_stream) => return error_stream,
}; };
if let Some(existing_ai_message) = existing_ai_response {
return create_replayed_response_stream(&state, existing_ai_message);
}
// 2. Retrieve knowledge entities // 2. Retrieve knowledge entities
let rerank_lease = match state.reranker_pool.as_ref() { let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await), Some(pool) => Some(pool.checkout().await),
@@ -142,15 +248,17 @@ pub async fn get_response_stream(
{ {
Ok(result) => result, Ok(result) => result,
Err(_e) => { Err(_e) => {
return Sse::new(create_error_stream("Failed to retrieve knowledge")); return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge"));
} }
}; };
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
// 3. Create the OpenAI request with appropriate context format // 3. Create the OpenAI request with appropriate context format
let context_json = match retrieval_result { let context_json = match &retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks), retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => { retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(&entities) retrieved_entities_to_json(entities)
} }
retrieval_pipeline::StrategyOutput::Search(search_result) => { retrieval_pipeline::StrategyOutput::Search(search_result) => {
// For chat, use chunks from the search result // For chat, use chunks from the search result
@@ -159,24 +267,18 @@ pub async fn get_response_stream(
}; };
let formatted_user_message = let formatted_user_message =
create_user_message_with_history(&context_json, &history, &user_message.content); create_user_message_with_history(&context_json, &history, &user_message.content);
let settings = match SystemSettings::get_current(&state.db).await { let Ok(settings) = SystemSettings::get_current(&state.db).await else {
Ok(s) => s, return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings"));
Err(_) => {
return Sse::new(create_error_stream("Failed to retrieve system settings"));
}
}; };
let request = match create_chat_request(formatted_user_message, &settings) { let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
Ok(req) => req, return sse_with_keep_alive(create_error_stream("Failed to create chat request"));
Err(..) => {
return Sse::new(create_error_stream("Failed to create chat request"));
}
}; };
// 4. Set up the OpenAI stream // 4. Set up the OpenAI stream
let openai_stream = match state.openai_client.chat().create_stream(request).await { let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream, Ok(stream) => stream,
Err(_e) => { Err(_e) => {
return Sse::new(create_error_stream("Failed to create OpenAI stream")); return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
} }
}; };
@@ -186,7 +288,9 @@ pub async fn get_response_stream(
let (tx_final, mut rx_final) = channel::<Message>(1); let (tx_final, mut rx_final) = channel::<Message>(1);
// 6. Set up the collection task for DB storage // 6. Set up the collection task for DB storage
let db_client = state.db.clone(); let db_client = Arc::clone(&state.db);
let user_id = user.id.clone();
let allowed_reference_ids = allowed_reference_ids.clone();
tokio::spawn(async move { tokio::spawn(async move {
drop(tx); // Close sender when no longer needed drop(tx); // Close sender when no longer needed
@@ -198,17 +302,55 @@ pub async fn get_response_stream(
// Try to extract structured data // Try to extract structured data
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) { if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let references: Vec<String> = response let raw_references = extract_reference_strings(&response);
.references let answer = response.answer;
.into_iter()
.map(|r| r.reference) let initial_validation = match validate_references(
.collect(); &user_id,
raw_references,
&allowed_reference_ids,
&db_client,
)
.await
{
Ok(result) => result,
Err(err) => {
error!(error = %err, "Reference validation failed, storing answer without references");
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
answer,
Some(Vec::new()),
);
let _ = tx_final.send(ai_message.clone()).await;
if let Err(store_err) = db_client.store_item(ai_message).await {
error!(error = ?store_err, "Failed to store AI message after validation failure");
}
return;
}
};
info!(
total_refs = initial_validation.reason_stats.total,
valid_refs = initial_validation.valid_refs.len(),
invalid_refs = initial_validation.invalid_refs.len(),
invalid_empty = initial_validation.reason_stats.empty,
invalid_unsupported_prefix = initial_validation.reason_stats.unsupported_prefix,
invalid_malformed_uuid = initial_validation.reason_stats.malformed_uuid,
invalid_duplicate = initial_validation.reason_stats.duplicate,
invalid_not_in_context = initial_validation.reason_stats.not_in_context,
invalid_not_found = initial_validation.reason_stats.not_found,
invalid_wrong_user = initial_validation.reason_stats.wrong_user,
invalid_over_limit = initial_validation.reason_stats.over_limit,
"Post-LLM reference validation complete"
);
let ai_message = Message::new( let ai_message = Message::new(
user_message.conversation_id, user_message.conversation_id,
MessageRole::AI, MessageRole::AI,
response.answer, answer,
Some(references), Some(initial_validation.valid_refs),
); );
let _ = tx_final.send(ai_message.clone()).await; let _ = tx_final.send(ai_message.clone()).await;
@@ -240,7 +382,7 @@ pub async fn get_response_stream(
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>) .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| { .map(move |result| {
let tx_storage = tx_clone.clone(); let tx_storage = tx_clone.clone();
let json_state = json_state.clone(); let json_state = Arc::clone(&json_state);
stream! { stream! {
match result { match result {
@@ -288,12 +430,6 @@ pub async fn get_response_stream(
return Ok(Event::default().event("empty")); // This event won't be sent return Ok(Event::default().event("empty")); // This event won't be sent
} }
// Prepare data for template
#[derive(Serialize)]
struct ReferenceData {
message: Message,
}
// Render template with references // Render template with references
match state.templates.render( match state.templates.render(
"chat/reference_list.html", "chat/reference_list.html",
@@ -323,11 +459,7 @@ pub async fn get_response_stream(
.data("Stream complete")) .data("Stream complete"))
})); }));
Sse::new(event_stream.boxed()).keep_alive( sse_with_keep_alive(event_stream.boxed())
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
} }
struct StreamParserState { struct StreamParserState {
@@ -375,3 +507,195 @@ impl StreamParserState {
String::new() String::new()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Duration as ChronoDuration, Utc};
use common::storage::{
db::SurrealDbClient,
types::{conversation::Conversation, user::Theme},
};
use retrieval_pipeline::answer_retrieval::Reference;
use uuid::Uuid;
fn make_test_message(id: &str, role: MessageRole) -> Message {
let mut message = Message::new(
"conversation-1".to_string(),
role,
format!("content-{id}"),
None,
);
message.id = id.to_string();
message
}
fn make_test_user(id: &str) -> User {
User {
id: id.to_string(),
created_at: Utc::now(),
updated_at: Utc::now(),
email: "test@example.com".to_string(),
password: "password".to_string(),
anonymous: false,
api_key: None,
admin: false,
timezone: "UTC".to_string(),
theme: Theme::System,
}
}
#[test]
fn extracts_reference_strings_in_order() {
let response = LLMResponseFormat {
answer: "answer".to_string(),
references: vec![
Reference {
reference: "a".to_string(),
},
Reference {
reference: "b".to_string(),
},
],
};
let extracted = extract_reference_strings(&response);
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
}
#[test]
fn finds_message_index_for_existing_message() {
let messages = vec![
make_test_message("m1", MessageRole::User),
make_test_message("m2", MessageRole::AI),
make_test_message("m3", MessageRole::User),
];
assert_eq!(find_message_index(&messages, "m2"), Some(1));
assert_eq!(find_message_index(&messages, "missing"), None);
}
#[test]
fn finds_existing_ai_response_for_same_turn() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("system", MessageRole::System),
make_test_message("a1", MessageRole::AI),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
let ai_reply = find_existing_ai_response(&messages, 0).expect("expected AI response");
assert_eq!(ai_reply.id, "a1");
let ai_reply_second_turn =
find_existing_ai_response(&messages, 3).expect("expected AI response");
assert_eq!(ai_reply_second_turn.id, "a2");
}
#[test]
fn does_not_replay_ai_response_from_later_turn() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
assert!(find_existing_ai_response(&messages, 0).is_none());
let ai_reply = find_existing_ai_response(&messages, 1).expect("expected AI response");
assert_eq!(ai_reply.id, "a2");
}
#[test]
fn history_before_message_excludes_target_and_future_messages() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("a1", MessageRole::AI),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
let history_for_u2 = history_before_message(&messages, 2);
let history_ids: Vec<String> = history_for_u2
.into_iter()
.map(|message| message.id)
.collect();
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
}
#[tokio::test]
async fn get_message_and_user_reuses_existing_ai_response_for_same_turn() {
let namespace = "chat_stream_replay";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("failed to create in-memory db");
let user = make_test_user("user-1");
let conversation = Conversation::new(user.id.clone(), "Conversation".to_string());
let mut user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
"Question one".to_string(),
None,
);
user_message.id = "u1".to_string();
let mut ai_message = Message::new(
conversation.id.clone(),
MessageRole::AI,
"Answer one".to_string(),
Some(vec!["ref-1".to_string()]),
);
ai_message.id = "a1".to_string();
ai_message.created_at = user_message.created_at + ChronoDuration::seconds(1);
ai_message.updated_at = ai_message.created_at;
let mut second_user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
"Question two".to_string(),
None,
);
second_user_message.id = "u2".to_string();
second_user_message.created_at = ai_message.created_at + ChronoDuration::seconds(1);
second_user_message.updated_at = second_user_message.created_at;
db.store_item(conversation.clone())
.await
.expect("failed to store conversation");
db.store_item(user_message.clone())
.await
.expect("failed to store user message");
db.store_item(ai_message.clone())
.await
.expect("failed to store ai message");
db.store_item(second_user_message.clone())
.await
.expect("failed to store second user message");
let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) =
get_message_and_user(&db, Some(user.clone()), &user_message.id)
.await
.expect("expected first turn to load");
assert!(history_for_first_turn.is_empty());
let existing_ai_for_first_turn =
existing_ai_for_first_turn.expect("expected first-turn AI response");
assert_eq!(existing_ai_for_first_turn.id, ai_message.id);
let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) =
get_message_and_user(&db, Some(user), &second_user_message.id)
.await
.expect("expected second turn to load");
let history_ids: Vec<String> = history_for_second_turn
.into_iter()
.map(|message| message.id)
.collect();
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
assert!(existing_ai_for_second_turn.is_none());
}
}

View File

@@ -1,5 +1,6 @@
mod chat_handlers; mod chat_handlers;
mod message_response_stream; mod message_response_stream;
mod reference_validation;
mod references; mod references;
use axum::{ use axum::{

View File

@@ -0,0 +1,477 @@
#![allow(clippy::arithmetic_side_effects, clippy::missing_docs_in_private_items)]
use std::collections::HashSet;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
};
use retrieval_pipeline::StrategyOutput;
use uuid::Uuid;
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum InvalidReferenceReason {
Empty,
UnsupportedPrefix,
MalformedUuid,
Duplicate,
NotInContext,
NotFound,
WrongUser,
OverLimit,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct InvalidReference {
pub raw: String,
pub normalized: Option<String>,
pub reason: InvalidReferenceReason,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub(crate) struct ReferenceReasonStats {
pub total: usize,
pub empty: usize,
pub unsupported_prefix: usize,
pub malformed_uuid: usize,
pub duplicate: usize,
pub not_in_context: usize,
pub not_found: usize,
pub wrong_user: usize,
pub over_limit: usize,
}
impl ReferenceReasonStats {
fn record(&mut self, reason: &InvalidReferenceReason) {
match reason {
InvalidReferenceReason::Empty => self.empty += 1,
InvalidReferenceReason::UnsupportedPrefix => self.unsupported_prefix += 1,
InvalidReferenceReason::MalformedUuid => self.malformed_uuid += 1,
InvalidReferenceReason::Duplicate => self.duplicate += 1,
InvalidReferenceReason::NotInContext => self.not_in_context += 1,
InvalidReferenceReason::NotFound => self.not_found += 1,
InvalidReferenceReason::WrongUser => self.wrong_user += 1,
InvalidReferenceReason::OverLimit => self.over_limit += 1,
}
}
}
#[derive(Debug, Clone, Default)]
pub(crate) struct ReferenceValidationResult {
pub valid_refs: Vec<String>,
pub invalid_refs: Vec<InvalidReference>,
pub reason_stats: ReferenceReasonStats,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum ReferenceLookupTarget {
TextChunk,
KnowledgeEntity,
Any,
}
pub(crate) fn collect_reference_ids_from_retrieval(
retrieval_result: &StrategyOutput,
) -> Vec<String> {
let mut ids = Vec::new();
let mut seen = HashSet::new();
match retrieval_result {
StrategyOutput::Chunks(chunks) => {
for chunk in chunks {
let id = chunk.chunk.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
}
StrategyOutput::Entities(entities) => {
for entity in entities {
let id = entity.entity.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
}
StrategyOutput::Search(search) => {
for chunk in &search.chunks {
let id = chunk.chunk.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
for entity in &search.entities {
let id = entity.entity.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
}
}
ids
}
pub(crate) async fn validate_references(
user_id: &str,
refs: Vec<String>,
allowed_ids: &[String],
db: &SurrealDbClient,
) -> Result<ReferenceValidationResult, AppError> {
let mut result = ReferenceValidationResult::default();
result.reason_stats.total = refs.len();
let mut seen = HashSet::new();
let allowed_set: HashSet<&str> = allowed_ids.iter().map(String::as_str).collect();
let enforce_context = !allowed_set.is_empty();
for raw in refs {
let (normalized, target) = match normalize_reference(&raw) {
Ok(parsed) => parsed,
Err(reason) => {
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: None,
reason,
});
continue;
}
};
if !seen.insert(normalized.clone()) {
let reason = InvalidReferenceReason::Duplicate;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
continue;
}
if result.valid_refs.len() >= MAX_REFERENCE_COUNT {
let reason = InvalidReferenceReason::OverLimit;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
continue;
}
if enforce_context && !allowed_set.contains(normalized.as_str()) {
let reason = InvalidReferenceReason::NotInContext;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
continue;
}
match lookup_reference_for_user(&normalized, &target, user_id, db).await? {
LookupResult::Found => result.valid_refs.push(normalized),
LookupResult::WrongUser => {
let reason = InvalidReferenceReason::WrongUser;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
}
LookupResult::NotFound => {
let reason = InvalidReferenceReason::NotFound;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
}
}
}
Ok(result)
}
pub(crate) fn normalize_reference(
raw: &str,
) -> Result<(String, ReferenceLookupTarget), InvalidReferenceReason> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return Err(InvalidReferenceReason::Empty);
}
let (candidate, target) = if let Some((prefix, rest)) = trimmed.split_once(':') {
let lookup_target = if prefix.eq_ignore_ascii_case("knowledge_entity") {
ReferenceLookupTarget::KnowledgeEntity
} else if prefix.eq_ignore_ascii_case("text_chunk") {
ReferenceLookupTarget::TextChunk
} else {
return Err(InvalidReferenceReason::UnsupportedPrefix);
};
(rest.trim(), lookup_target)
} else {
(trimmed, ReferenceLookupTarget::Any)
};
if candidate.is_empty() {
return Err(InvalidReferenceReason::MalformedUuid);
}
Uuid::parse_str(candidate)
.map(|uuid| (uuid.to_string(), target))
.map_err(|_| InvalidReferenceReason::MalformedUuid)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LookupResult {
Found,
WrongUser,
NotFound,
}
async fn lookup_reference_for_user(
id: &str,
target: &ReferenceLookupTarget,
user_id: &str,
db: &SurrealDbClient,
) -> Result<LookupResult, AppError> {
match target {
ReferenceLookupTarget::TextChunk => lookup_single_type::<TextChunk>(id, user_id, db).await,
ReferenceLookupTarget::KnowledgeEntity => {
lookup_single_type::<KnowledgeEntity>(id, user_id, db).await
}
ReferenceLookupTarget::Any => {
let chunk_result = lookup_single_type::<TextChunk>(id, user_id, db).await?;
if chunk_result == LookupResult::Found {
return Ok(LookupResult::Found);
}
let entity_result = lookup_single_type::<KnowledgeEntity>(id, user_id, db).await?;
if entity_result == LookupResult::Found {
return Ok(LookupResult::Found);
}
if chunk_result == LookupResult::WrongUser || entity_result == LookupResult::WrongUser {
return Ok(LookupResult::WrongUser);
}
Ok(LookupResult::NotFound)
}
}
}
async fn lookup_single_type<T>(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<LookupResult, AppError>
where
T: StoredObject + for<'de> serde::Deserialize<'de> + HasUserId,
{
let item = db.get_item::<T>(id).await?;
Ok(match item {
Some(item) if item.user_id() == user_id => LookupResult::Found,
Some(_) => LookupResult::WrongUser,
None => LookupResult::NotFound,
})
}
trait HasUserId {
fn user_id(&self) -> &str;
}
impl HasUserId for TextChunk {
fn user_id(&self) -> &str {
&self.user_id
}
}
impl HasUserId for KnowledgeEntity {
fn user_id(&self) -> &str {
&self.user_id
}
}
#[cfg(test)]
#[allow(
clippy::cloned_ref_to_slice_refs,
clippy::expect_used,
clippy::indexing_slicing
)]
mod tests {
use super::*;
use common::storage::types::knowledge_entity::KnowledgeEntityType;
use surrealdb::engine::any::connect;
async fn setup_test_db() -> SurrealDbClient {
let client = connect("mem://")
.await
.expect("failed to create in-memory surrealdb client");
let namespace = format!("test_ns_{}", Uuid::new_v4());
let database = format!("test_db_{}", Uuid::new_v4());
client
.use_ns(namespace)
.use_db(database)
.await
.expect("failed to select namespace/db");
let db = SurrealDbClient { client };
db.apply_migrations()
.await
.expect("failed to apply migrations");
db
}
#[tokio::test]
async fn valid_uuid_exists_and_belongs_to_user() {
let db = setup_test_db().await;
let user_id = "user-a";
let entity = KnowledgeEntity::new(
"source-1".to_string(),
"Entity A".to_string(),
"Entity description".to_string(),
KnowledgeEntityType::Document,
None,
user_id.to_string(),
);
db.store_item(entity.clone())
.await
.expect("failed to store entity");
let result =
validate_references(user_id, vec![entity.id.clone()], &[entity.id.clone()], &db)
.await
.expect("validation should not fail");
assert_eq!(result.valid_refs, vec![entity.id]);
assert!(result.invalid_refs.is_empty());
}
#[tokio::test]
async fn valid_uuid_exists_but_wrong_user_is_rejected() {
let db = setup_test_db().await;
let entity = KnowledgeEntity::new(
"source-1".to_string(),
"Entity B".to_string(),
"Entity description".to_string(),
KnowledgeEntityType::Document,
None,
"other-user".to_string(),
);
db.store_item(entity.clone())
.await
.expect("failed to store entity");
let result =
validate_references("user-a", vec![entity.id.clone()], &[entity.id.clone()], &db)
.await
.expect("validation should not fail");
assert!(result.valid_refs.is_empty());
assert_eq!(result.invalid_refs.len(), 1);
assert_eq!(
result.invalid_refs[0].reason,
InvalidReferenceReason::WrongUser
);
}
#[tokio::test]
async fn malformed_uuid_is_rejected() {
let db = setup_test_db().await;
let result = validate_references(
"user-a",
vec!["not-a-uuid".to_string()],
&["not-a-uuid".to_string()],
&db,
)
.await
.expect("validation should not fail");
assert!(result.valid_refs.is_empty());
assert_eq!(result.invalid_refs.len(), 1);
assert_eq!(
result.invalid_refs[0].reason,
InvalidReferenceReason::MalformedUuid
);
}
#[tokio::test]
async fn mixed_duplicates_are_deduped() {
let db = setup_test_db().await;
let user_id = "user-a";
let first = KnowledgeEntity::new(
"source-1".to_string(),
"Entity 1".to_string(),
"Entity description".to_string(),
KnowledgeEntityType::Document,
None,
user_id.to_string(),
);
let second = KnowledgeEntity::new(
"source-2".to_string(),
"Entity 2".to_string(),
"Entity description".to_string(),
KnowledgeEntityType::Document,
None,
user_id.to_string(),
);
db.store_item(first.clone())
.await
.expect("failed to store first entity");
db.store_item(second.clone())
.await
.expect("failed to store second entity");
let refs = vec![
first.id.clone(),
format!("knowledge_entity:{}", first.id),
second.id.clone(),
second.id.clone(),
];
let allowed = vec![first.id.clone(), second.id.clone()];
let result = validate_references(user_id, refs, &allowed, &db)
.await
.expect("validation should not fail");
assert_eq!(result.valid_refs, vec![first.id, second.id]);
assert_eq!(result.invalid_refs.len(), 2);
assert!(result
.invalid_refs
.iter()
.all(|entry| entry.reason == InvalidReferenceReason::Duplicate));
}
#[tokio::test]
async fn bare_uuid_prefers_chunk_lookup_before_entity() {
let db = setup_test_db().await;
let user_id = "user-a";
let chunk = TextChunk::new(
"source-1".to_string(),
"Chunk body".to_string(),
user_id.to_string(),
);
db.store_item(chunk.clone())
.await
.expect("failed to store chunk");
let result = validate_references(user_id, vec![chunk.id.clone()], &[chunk.id.clone()], &db)
.await
.expect("validation should not fail");
assert_eq!(result.valid_refs, vec![chunk.id]);
}
}

View File

@@ -1,12 +1,15 @@
#![allow(clippy::missing_docs_in_private_items)]
use axum::{ use axum::{
extract::{Path, State}, extract::{Path, State},
response::IntoResponse, response::IntoResponse,
}; };
use chrono::{DateTime, Utc};
use chrono_tz::Tz;
use serde::Serialize; use serde::Serialize;
use common::{ use common::storage::types::{
error::AppError, knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, user::User,
storage::types::{knowledge_entity::KnowledgeEntity, user::User},
}; };
use crate::{ use crate::{
@@ -17,29 +20,101 @@ use crate::{
}, },
}; };
use super::reference_validation::{normalize_reference, ReferenceLookupTarget};
#[derive(Serialize)]
struct ReferenceTooltipData {
text_chunk: Option<TextChunk>,
text_chunk_updated_at: Option<String>,
entity: Option<KnowledgeEntity>,
entity_updated_at: Option<String>,
user: User,
}
fn format_datetime_for_user(datetime: DateTime<Utc>, timezone: &str) -> String {
match timezone.parse::<Tz>() {
Ok(tz) => datetime
.with_timezone(&tz)
.format("%Y-%m-%d %H:%M:%S")
.to_string(),
Err(_) => datetime.format("%Y-%m-%d %H:%M:%S").to_string(),
}
}
pub async fn show_reference_tooltip( pub async fn show_reference_tooltip(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
Path(reference_id): Path<String>, Path(reference_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let entity: KnowledgeEntity = state let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else {
.db return Ok(TemplateResponse::not_found());
.get_item(&reference_id) };
.await?
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
if entity.user_id != user.id { let lookup_order = match target {
return Ok(TemplateResponse::unauthorized()); ReferenceLookupTarget::TextChunk | ReferenceLookupTarget::Any => [
ReferenceLookupTarget::TextChunk,
ReferenceLookupTarget::KnowledgeEntity,
],
ReferenceLookupTarget::KnowledgeEntity => [
ReferenceLookupTarget::KnowledgeEntity,
ReferenceLookupTarget::TextChunk,
],
};
let mut text_chunk: Option<TextChunk> = None;
let mut knowledge_entity: Option<KnowledgeEntity> = None;
for lookup_target in lookup_order {
match lookup_target {
ReferenceLookupTarget::TextChunk => {
if let Some(chunk) = state
.db
.get_item::<TextChunk>(&normalized_reference_id)
.await?
{
if chunk.user_id != user.id {
return Ok(TemplateResponse::unauthorized());
}
text_chunk = Some(chunk);
break;
}
}
ReferenceLookupTarget::KnowledgeEntity => {
if let Some(entity) = state
.db
.get_item::<KnowledgeEntity>(&normalized_reference_id)
.await?
{
if entity.user_id != user.id {
return Ok(TemplateResponse::unauthorized());
}
knowledge_entity = Some(entity);
break;
}
}
ReferenceLookupTarget::Any => {}
}
} }
#[derive(Serialize)] if text_chunk.is_none() && knowledge_entity.is_none() {
struct ReferenceTooltipData { return Ok(TemplateResponse::not_found());
entity: KnowledgeEntity,
user: User,
} }
let text_chunk_updated_at = text_chunk
.as_ref()
.map(|chunk| format_datetime_for_user(chunk.updated_at, &user.timezone));
let entity_updated_at = knowledge_entity
.as_ref()
.map(|entity| format_datetime_for_user(entity.updated_at, &user.timezone));
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"chat/reference_tooltip.html", "chat/reference_tooltip.html",
ReferenceTooltipData { entity, user }, ReferenceTooltipData {
text_chunk,
text_chunk_updated_at,
entity: knowledge_entity,
entity_updated_at,
user,
},
)) ))
} }

View File

@@ -7,8 +7,8 @@ use axum_htmx::{HxBoosted, HxRequest, HxTarget};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use common::storage::types::{ use common::storage::types::{
conversation::Conversation, file_info::FileInfo, knowledge_entity::KnowledgeEntity, file_info::FileInfo, knowledge_entity::KnowledgeEntity, text_chunk::TextChunk,
text_chunk::TextChunk, text_content::TextContent, user::User, text_content::TextContent, user::User,
}; };
use crate::{ use crate::{
@@ -26,18 +26,15 @@ const CONTENTS_PER_PAGE: usize = 12;
#[derive(Serialize)] #[derive(Serialize)]
pub struct ContentPageData { pub struct ContentPageData {
user: User,
text_contents: Vec<TextContent>, text_contents: Vec<TextContent>,
categories: Vec<String>, categories: Vec<String>,
selected_category: Option<String>, selected_category: Option<String>,
conversation_archive: Vec<Conversation>,
pagination: Pagination, pagination: Pagination,
page_query: String, page_query: String,
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct RecentTextContentData { pub struct RecentTextContentData {
pub user: User,
pub text_contents: Vec<TextContent>, pub text_contents: Vec<TextContent>,
} }
@@ -81,13 +78,10 @@ pub async fn show_content_page(
}) })
.unwrap_or_default(); .unwrap_or_default();
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let data = ContentPageData { let data = ContentPageData {
user,
text_contents, text_contents,
categories, categories,
selected_category: params.category.clone(), selected_category: params.category.clone(),
conversation_archive,
pagination, pagination,
page_query, page_query,
}; };
@@ -112,13 +106,12 @@ pub async fn show_text_content_edit_form(
#[derive(Serialize)] #[derive(Serialize)]
pub struct TextContentEditModal { pub struct TextContentEditModal {
pub user: User,
pub text_content: TextContent, pub text_content: TextContent,
} }
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"content/edit_text_content_modal.html", "content/edit_text_content_modal.html",
TextContentEditModal { user, text_content }, TextContentEditModal { text_content },
)) ))
} }
@@ -145,10 +138,7 @@ pub async fn patch_text_content(
return Ok(TemplateResponse::new_template( return Ok(TemplateResponse::new_template(
"dashboard/recent_content.html", "dashboard/recent_content.html",
RecentTextContentData { RecentTextContentData { text_contents },
user,
text_contents,
},
)); ));
} }
@@ -159,17 +149,14 @@ pub async fn patch_text_content(
); );
let text_contents = truncate_text_contents(page_contents); let text_contents = truncate_text_contents(page_contents);
let categories = User::get_user_categories(&user.id, &state.db).await?; let categories = User::get_user_categories(&user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
Ok(TemplateResponse::new_partial( Ok(TemplateResponse::new_partial(
"content/base.html", "content/base.html",
"main", "main",
ContentPageData { ContentPageData {
user,
text_contents, text_contents,
categories, categories,
selected_category: None, selected_category: None,
conversation_archive,
pagination, pagination,
page_query: String::new(), page_query: String::new(),
}, },
@@ -209,16 +196,13 @@ pub async fn delete_text_content(
); );
let text_contents = truncate_text_contents(page_contents); let text_contents = truncate_text_contents(page_contents);
let categories = User::get_user_categories(&user.id, &state.db).await?; let categories = User::get_user_categories(&user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"content/content_list.html", "content/content_list.html",
ContentPageData { ContentPageData {
user,
text_contents, text_contents,
categories, categories,
selected_category: None, selected_category: None,
conversation_archive,
pagination, pagination,
page_query: String::new(), page_query: String::new(),
}, },
@@ -234,13 +218,12 @@ pub async fn show_content_read_modal(
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?; let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
#[derive(Serialize)] #[derive(Serialize)]
pub struct TextContentReadModalData { pub struct TextContentReadModalData {
pub user: User,
pub text_content: TextContent, pub text_content: TextContent,
} }
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"content/read_content_modal.html", "content/read_content_modal.html",
TextContentReadModalData { user, text_content }, TextContentReadModalData { text_content },
)) ))
} }
@@ -253,9 +236,6 @@ pub async fn show_recent_content(
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"dashboard/recent_content.html", "dashboard/recent_content.html",
RecentTextContentData { RecentTextContentData { text_contents },
user,
text_contents,
},
)) ))
} }

View File

@@ -21,19 +21,17 @@ use common::storage::types::user::DashboardStats;
use common::{ use common::{
error::AppError, error::AppError,
storage::types::{ storage::types::{
conversation::Conversation, file_info::FileInfo, ingestion_task::IngestionTask, file_info::FileInfo, ingestion_task::IngestionTask, knowledge_entity::KnowledgeEntity,
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship, knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
text_chunk::TextChunk, text_content::TextContent, user::User, text_content::TextContent, user::User,
}, },
}; };
#[derive(Serialize)] #[derive(Serialize)]
pub struct IndexPageData { pub struct IndexPageData {
user: Option<User>,
text_contents: Vec<TextContent>, text_contents: Vec<TextContent>,
stats: DashboardStats, stats: DashboardStats,
active_jobs: Vec<IngestionTask>, active_jobs: Vec<IngestionTask>,
conversation_archive: Vec<Conversation>,
} }
pub async fn index_handler( pub async fn index_handler(
@@ -44,9 +42,8 @@ pub async fn index_handler(
return Ok(TemplateResponse::redirect("/signin")); return Ok(TemplateResponse::redirect("/signin"));
}; };
let (text_contents, conversation_archive, stats, active_jobs) = try_join!( let (text_contents, stats, active_jobs) = try_join!(
User::get_latest_text_contents(&user.id, &state.db), User::get_latest_text_contents(&user.id, &state.db),
User::get_user_conversations(&user.id, &state.db),
User::get_dashboard_stats(&user.id, &state.db), User::get_dashboard_stats(&user.id, &state.db),
User::get_unfinished_ingestion_tasks(&user.id, &state.db) User::get_unfinished_ingestion_tasks(&user.id, &state.db)
)?; )?;
@@ -56,10 +53,8 @@ pub async fn index_handler(
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"dashboard/base.html", "dashboard/base.html",
IndexPageData { IndexPageData {
user: Some(user),
text_contents, text_contents,
stats, stats,
conversation_archive,
active_jobs, active_jobs,
}, },
)) ))
@@ -68,7 +63,6 @@ pub async fn index_handler(
#[derive(Serialize)] #[derive(Serialize)]
pub struct LatestTextContentData { pub struct LatestTextContentData {
text_contents: Vec<TextContent>, text_contents: Vec<TextContent>,
user: User,
} }
pub async fn delete_text_content( pub async fn delete_text_content(
@@ -105,10 +99,7 @@ pub async fn delete_text_content(
Ok(TemplateResponse::new_partial( Ok(TemplateResponse::new_partial(
"dashboard/recent_content.html", "dashboard/recent_content.html",
"latest_content_section", "latest_content_section",
LatestTextContentData { LatestTextContentData { text_contents },
user: user.clone(),
text_contents,
},
)) ))
} }
@@ -136,7 +127,6 @@ async fn get_and_validate_text_content(
#[derive(Serialize)] #[derive(Serialize)]
pub struct ActiveJobsData { pub struct ActiveJobsData {
pub active_jobs: Vec<IngestionTask>, pub active_jobs: Vec<IngestionTask>,
pub user: User,
} }
#[derive(Serialize)] #[derive(Serialize)]
@@ -161,7 +151,6 @@ struct TaskArchiveEntry {
#[derive(Serialize)] #[derive(Serialize)]
struct TaskArchiveData { struct TaskArchiveData {
user: User,
tasks: Vec<TaskArchiveEntry>, tasks: Vec<TaskArchiveEntry>,
} }
@@ -177,10 +166,7 @@ pub async fn delete_job(
Ok(TemplateResponse::new_partial( Ok(TemplateResponse::new_partial(
"dashboard/active_jobs.html", "dashboard/active_jobs.html",
"active_jobs_section", "active_jobs_section",
ActiveJobsData { ActiveJobsData { active_jobs },
user: user.clone(),
active_jobs,
},
)) ))
} }
@@ -192,10 +178,7 @@ pub async fn show_active_jobs(
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"dashboard/active_jobs.html", "dashboard/active_jobs.html",
ActiveJobsData { ActiveJobsData { active_jobs },
user: user.clone(),
active_jobs,
},
)) ))
} }
@@ -233,10 +216,7 @@ pub async fn show_task_archive(
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"dashboard/task_archive_modal.html", "dashboard/task_archive_modal.html",
TaskArchiveData { TaskArchiveData { tasks: entries },
user,
tasks: entries,
},
)) ))
} }

View File

@@ -2,9 +2,10 @@ use std::{pin::Pin, time::Duration};
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
http::StatusCode,
response::{ response::{
sse::{Event, KeepAlive}, sse::{Event, KeepAlive, KeepAliveStream},
Html, IntoResponse, Sse, IntoResponse, Response, Sse,
}, },
}; };
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
@@ -23,6 +24,7 @@ use common::{
ingestion_task::{IngestionTask, TaskState}, ingestion_task::{IngestionTask, TaskState},
user::User, user::User,
}, },
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
}; };
use crate::{ use crate::{
@@ -34,30 +36,41 @@ use crate::{
AuthSessionType, AuthSessionType,
}; };
pub async fn show_ingress_form( type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
type TaskSse = Sse<KeepAliveStream<EventStream>>;
fn sse_with_keep_alive(stream: EventStream) -> TaskSse {
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive-ping"),
)
}
pub async fn show_ingest_form(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let user_categories = User::get_user_categories(&user.id, &state.db).await?; let user_categories = User::get_user_categories(&user.id, &state.db).await?;
#[derive(Serialize)] #[derive(Serialize)]
pub struct ShowIngressFormData { pub struct ShowIngestFormData {
user_categories: Vec<String>, user_categories: Vec<String>,
} }
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"ingestion_modal.html", "ingestion_modal.html",
ShowIngressFormData { user_categories }, ShowIngestFormData { user_categories },
)) ))
} }
pub async fn hide_ingress_form( pub async fn hide_ingest_form(
RequireUser(_user): RequireUser, RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
Ok(Html( Ok(TemplateResponse::new_template(
"<a class='btn btn-primary' hx-get='/ingress-form' hx-swap='outerHTML'>Add Content</a>", "ingestion/add_content_button.html",
) (),
.into_response()) ))
} }
#[derive(Debug, TryFromMultipart)] #[derive(Debug, TryFromMultipart)]
@@ -65,37 +78,59 @@ pub struct IngestionParams {
pub content: Option<String>, pub content: Option<String>,
pub context: String, pub context: String,
pub category: String, pub category: String,
#[form_data(limit = "10000000")] // Adjust limit as needed #[form_data(limit = "20000000")]
#[form_data(default)] #[form_data(default)]
pub files: Vec<FieldData<NamedTempFile>>, pub files: Vec<FieldData<NamedTempFile>>,
} }
pub async fn process_ingress_form( pub async fn process_ingest_form(
State(state): State<HtmlState>, State(state): State<HtmlState>,
RequireUser(user): RequireUser, RequireUser(user): RequireUser,
TypedMultipart(input): TypedMultipart<IngestionParams>, TypedMultipart(input): TypedMultipart<IngestionParams>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<Response, HtmlError> {
#[derive(Serialize)]
pub struct IngressFormData {
context: String,
content: String,
category: String,
error: String,
}
if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() { if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() {
return Ok(TemplateResponse::new_template( return Ok(
"index/signed_in/ingress_form.html", TemplateResponse::bad_request("You need to either add files or content")
IngressFormData { .into_response(),
context: input.context.clone(), );
content: input.content.clone().unwrap_or_default(),
category: input.category.clone(),
error: "You need to either add files or content".to_string(),
},
));
} }
info!("{:?}", input); let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
match validate_ingest_input(
&state.config,
input.content.as_deref(),
&input.context,
&input.category,
file_count,
) {
Ok(()) => {}
Err(IngestValidationError::PayloadTooLarge(message)) => {
return Ok(TemplateResponse::error(
StatusCode::PAYLOAD_TOO_LARGE,
"Payload Too Large",
&message,
)
.into_response());
}
Err(IngestValidationError::BadRequest(message)) => {
return Ok(TemplateResponse::bad_request(&message).into_response());
}
}
info!(
user_id = %user.id,
has_content,
content_bytes,
context_bytes,
category_bytes,
file_count,
"Received ingest form submission"
);
let file_infos = try_join_all(input.files.into_iter().map(|file| { let file_infos = try_join_all(input.files.into_iter().map(|file| {
FileInfo::new_with_storage(file, &state.db, &user.id, &state.storage) FileInfo::new_with_storage(file, &state.db, &user.id, &state.storage)
@@ -120,14 +155,13 @@ pub async fn process_ingress_form(
#[derive(Serialize)] #[derive(Serialize)]
struct NewTasksData { struct NewTasksData {
user: User,
tasks: Vec<IngestionTask>, tasks: Vec<IngestionTask>,
} }
Ok(TemplateResponse::new_template( Ok(
"dashboard/current_task.html", TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks })
NewTasksData { user, tasks }, .into_response(),
)) )
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -135,9 +169,7 @@ pub struct QueryParams {
task_id: String, task_id: String,
} }
fn create_error_stream( fn create_error_stream(message: impl Into<String>) -> EventStream {
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
let message = message.into(); let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed() stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
} }
@@ -146,13 +178,13 @@ pub async fn get_task_updates_stream(
State(state): State<HtmlState>, State(state): State<HtmlState>,
auth: AuthSessionType, auth: AuthSessionType,
Query(params): Query<QueryParams>, Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> { ) -> TaskSse {
let task_id = params.task_id.clone(); let task_id = params.task_id.clone();
let db = state.db.clone(); let db = state.db.clone();
// 1. Check for authenticated user // 1. Check for authenticated user
let Some(current_user) = auth.current_user else { let Some(current_user) = auth.current_user else {
return Sse::new(create_error_stream("User not authenticated")); return sse_with_keep_alive(create_error_stream("User not authenticated"));
}; };
// 2. Fetch task for initial authorization and to ensure it exists // 2. Fetch task for initial authorization and to ensure it exists
@@ -160,7 +192,7 @@ pub async fn get_task_updates_stream(
Ok(Some(task)) => { Ok(Some(task)) => {
// 3. Validate user ownership // 3. Validate user ownership
if task.user_id != current_user.id { if task.user_id != current_user.id {
return Sse::new(create_error_stream( return sse_with_keep_alive(create_error_stream(
"Access denied: You do not have permission to view updates for this task.", "Access denied: You do not have permission to view updates for this task.",
)); ));
} }
@@ -246,18 +278,14 @@ pub async fn get_task_updates_stream(
} }
}; };
Sse::new(sse_stream.boxed()).keep_alive( sse_with_keep_alive(sse_stream.boxed())
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive-ping"),
)
} }
Ok(None) => Sse::new(create_error_stream(format!( Ok(None) => sse_with_keep_alive(create_error_stream(format!(
"Task with ID '{task_id}' not found." "Task with ID '{task_id}' not found."
))), ))),
Err(e) => { Err(e) => {
error!("Failed to fetch task '{task_id}' for authorization: {e:?}"); error!("Failed to fetch task '{task_id}' for authorization: {e:?}");
Sse::new(create_error_stream( sse_with_keep_alive(create_error_stream(
"An error occurred while retrieving task details. Please try again later.", "An error occurred while retrieving task details. Please try again later.",
)) ))
} }

View File

@@ -1,22 +1,22 @@
mod handlers; mod handlers;
use axum::{extract::FromRef, routing::get, Router}; use axum::{extract::DefaultBodyLimit, extract::FromRef, routing::get, Router};
use handlers::{ use handlers::{get_task_updates_stream, hide_ingest_form, process_ingest_form, show_ingest_form};
get_task_updates_stream, hide_ingress_form, process_ingress_form, show_ingress_form,
};
use crate::html_state::HtmlState; use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S> pub fn router<S>(max_body_bytes: usize) -> Router<S>
where where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>, HtmlState: FromRef<S>,
{ {
Router::new() Router::new()
.route( .route(
"/ingress-form", "/ingest-form",
get(show_ingress_form).post(process_ingress_form), get(show_ingest_form)
.post(process_ingest_form)
.layer(DefaultBodyLimit::max(max_body_bytes)),
) )
.route("/task/status-stream", get(get_task_updates_stream)) .route("/task/status-stream", get(get_task_updates_stream))
.route("/hide-ingress-form", get(hide_ingress_form)) .route("/hide-ingest-form", get(hide_ingest_form))
} }

View File

@@ -17,7 +17,6 @@ use serde::{
use common::{ use common::{
error::AppError, error::AppError,
storage::types::{ storage::types::{
conversation::Conversation,
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
knowledge_relationship::KnowledgeRelationship, knowledge_relationship::KnowledgeRelationship,
user::User, user::User,
@@ -333,12 +332,10 @@ pub struct KnowledgeBaseData {
entities: Vec<KnowledgeEntity>, entities: Vec<KnowledgeEntity>,
visible_entities: Vec<KnowledgeEntity>, visible_entities: Vec<KnowledgeEntity>,
relationships: Vec<RelationshipTableRow>, relationships: Vec<RelationshipTableRow>,
user: User,
entity_types: Vec<String>, entity_types: Vec<String>,
content_categories: Vec<String>, content_categories: Vec<String>,
selected_entity_type: Option<String>, selected_entity_type: Option<String>,
selected_content_category: Option<String>, selected_content_category: Option<String>,
conversation_archive: Vec<Conversation>,
pagination: Pagination, pagination: Pagination,
page_query: String, page_query: String,
relationship_type_options: Vec<String>, relationship_type_options: Vec<String>,
@@ -481,18 +478,15 @@ async fn build_knowledge_base_data(
relationship_type_options, relationship_type_options,
default_relationship_type, default_relationship_type,
} = build_relationship_table_data(entities.clone(), filtered_relationships); } = build_relationship_table_data(entities.clone(), filtered_relationships);
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
Ok(KnowledgeBaseData { Ok(KnowledgeBaseData {
entities, entities,
visible_entities, visible_entities,
relationships, relationships,
user: user.clone(),
entity_types, entity_types,
content_categories, content_categories,
selected_entity_type: params.entity_type.clone(), selected_entity_type: params.entity_type.clone(),
selected_content_category: params.content_category.clone(), selected_content_category: params.content_category.clone(),
conversation_archive,
pagination, pagination,
page_query, page_query,
relationship_type_options, relationship_type_options,
@@ -861,7 +855,6 @@ pub async fn show_edit_knowledge_entity_form(
pub struct EntityData { pub struct EntityData {
entity: KnowledgeEntity, entity: KnowledgeEntity,
entity_types: Vec<String>, entity_types: Vec<String>,
user: User,
} }
// Get entity types // Get entity types
@@ -878,7 +871,6 @@ pub async fn show_edit_knowledge_entity_form(
EntityData { EntityData {
entity, entity,
entity_types, entity_types,
user,
}, },
)) ))
} }
@@ -895,7 +887,6 @@ pub struct PatchKnowledgeEntityParams {
pub struct EntityListData { pub struct EntityListData {
visible_entities: Vec<KnowledgeEntity>, visible_entities: Vec<KnowledgeEntity>,
pagination: Pagination, pagination: Pagination,
user: User,
entity_types: Vec<String>, entity_types: Vec<String>,
content_categories: Vec<String>, content_categories: Vec<String>,
selected_entity_type: Option<String>, selected_entity_type: Option<String>,
@@ -943,7 +934,6 @@ pub async fn patch_knowledge_entity(
EntityListData { EntityListData {
visible_entities, visible_entities,
pagination, pagination,
user,
entity_types, entity_types,
content_categories, content_categories,
selected_entity_type: None, selected_entity_type: None,
@@ -982,7 +972,6 @@ pub async fn delete_knowledge_entity(
EntityListData { EntityListData {
visible_entities, visible_entities,
pagination, pagination,
user,
entity_types, entity_types,
content_categories, content_categories,
selected_entity_type: None, selected_entity_type: None,

View File

@@ -14,16 +14,13 @@ use crate::middlewares::{
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{HtmlError, TemplateResponse},
}; };
use common::storage::types::{ use common::storage::types::{
conversation::Conversation, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask, scratchpad::Scratchpad,
scratchpad::Scratchpad, user::User,
}; };
#[derive(Serialize)] #[derive(Serialize)]
pub struct ScratchpadPageData { pub struct ScratchpadPageData {
user: User,
scratchpads: Vec<ScratchpadListItem>, scratchpads: Vec<ScratchpadListItem>,
archived_scratchpads: Vec<ScratchpadArchiveItem>, archived_scratchpads: Vec<ScratchpadArchiveItem>,
conversation_archive: Vec<Conversation>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
new_scratchpad: Option<ScratchpadDetail>, new_scratchpad: Option<ScratchpadDetail>,
} }
@@ -38,9 +35,8 @@ pub struct ScratchpadListItem {
#[derive(Serialize)] #[derive(Serialize)]
pub struct ScratchpadDetailData { pub struct ScratchpadDetailData {
user: User,
scratchpad: ScratchpadDetail, scratchpad: ScratchpadDetail,
conversation_archive: Vec<Conversation>, is_editing_title: bool,
} }
#[derive(Serialize)] #[derive(Serialize)]
@@ -134,7 +130,6 @@ pub async fn show_scratchpad_page(
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?; let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let scratchpad_list: Vec<ScratchpadListItem> = let scratchpad_list: Vec<ScratchpadListItem> =
scratchpads.iter().map(ScratchpadListItem::from).collect(); scratchpads.iter().map(ScratchpadListItem::from).collect();
@@ -148,10 +143,8 @@ pub async fn show_scratchpad_page(
"scratchpad/base.html", "scratchpad/base.html",
"main", "main",
ScratchpadPageData { ScratchpadPageData {
user,
scratchpads: scratchpad_list, scratchpads: scratchpad_list,
archived_scratchpads: archived_list, archived_scratchpads: archived_list,
conversation_archive,
new_scratchpad: None, new_scratchpad: None,
}, },
)) ))
@@ -159,10 +152,8 @@ pub async fn show_scratchpad_page(
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"scratchpad/base.html", "scratchpad/base.html",
ScratchpadPageData { ScratchpadPageData {
user,
scratchpads: scratchpad_list, scratchpads: scratchpad_list,
archived_scratchpads: archived_list, archived_scratchpads: archived_list,
conversation_archive,
new_scratchpad: None, new_scratchpad: None,
}, },
)) ))
@@ -176,19 +167,17 @@ pub async fn show_scratchpad_modal(
Query(query): Query<EditTitleQuery>, Query(query): Query<EditTitleQuery>,
) -> Result<impl IntoResponse, HtmlError> { ) -> Result<impl IntoResponse, HtmlError> {
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?; let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let scratchpad_detail = ScratchpadDetail::from(&scratchpad); let scratchpad_detail = ScratchpadDetail::from(&scratchpad);
// Handle edit_title query parameter if needed in future // Handle edit_title query parameter
let _ = query.edit_title.unwrap_or(false); let is_editing_title = query.edit_title.unwrap_or(false);
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"scratchpad/editor_modal.html", "scratchpad/editor_modal.html",
ScratchpadDetailData { ScratchpadDetailData {
user,
scratchpad: scratchpad_detail, scratchpad: scratchpad_detail,
conversation_archive, is_editing_title,
}, },
)) ))
} }
@@ -204,7 +193,6 @@ pub async fn create_scratchpad(
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?; let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let scratchpad_list: Vec<ScratchpadListItem> = let scratchpad_list: Vec<ScratchpadListItem> =
scratchpads.iter().map(ScratchpadListItem::from).collect(); scratchpads.iter().map(ScratchpadListItem::from).collect();
@@ -217,10 +205,8 @@ pub async fn create_scratchpad(
"scratchpad/base.html", "scratchpad/base.html",
"main", "main",
ScratchpadPageData { ScratchpadPageData {
user,
scratchpads: scratchpad_list, scratchpads: scratchpad_list,
archived_scratchpads: archived_list, archived_scratchpads: archived_list,
conversation_archive,
new_scratchpad: Some(ScratchpadDetail::from(&scratchpad)), new_scratchpad: Some(ScratchpadDetail::from(&scratchpad)),
}, },
)) ))
@@ -255,14 +241,12 @@ pub async fn update_scratchpad_title(
Scratchpad::update_title(&scratchpad_id, &user.id, &form.title, &state.db).await?; Scratchpad::update_title(&scratchpad_id, &user.id, &form.title, &state.db).await?;
let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?; let scratchpad = Scratchpad::get_by_id(&scratchpad_id, &user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"scratchpad/editor_modal.html", "scratchpad/editor_modal.html",
ScratchpadDetailData { ScratchpadDetailData {
user,
scratchpad: ScratchpadDetail::from(&scratchpad), scratchpad: ScratchpadDetail::from(&scratchpad),
conversation_archive, is_editing_title: false,
}, },
)) ))
} }
@@ -276,7 +260,6 @@ pub async fn delete_scratchpad(
// Return the updated main section content // Return the updated main section content
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?; let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
let scratchpad_list: Vec<ScratchpadListItem> = let scratchpad_list: Vec<ScratchpadListItem> =
@@ -290,10 +273,8 @@ pub async fn delete_scratchpad(
"scratchpad/base.html", "scratchpad/base.html",
"main", "main",
ScratchpadPageData { ScratchpadPageData {
user,
scratchpads: scratchpad_list, scratchpads: scratchpad_list,
archived_scratchpads: archived_list, archived_scratchpads: archived_list,
conversation_archive,
new_scratchpad: None, new_scratchpad: None,
}, },
)) ))
@@ -347,7 +328,6 @@ pub async fn ingest_scratchpad(
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?; let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let scratchpad_list: Vec<ScratchpadListItem> = let scratchpad_list: Vec<ScratchpadListItem> =
scratchpads.iter().map(ScratchpadListItem::from).collect(); scratchpads.iter().map(ScratchpadListItem::from).collect();
@@ -371,10 +351,8 @@ pub async fn ingest_scratchpad(
"scratchpad/base.html", "scratchpad/base.html",
"main", "main",
ScratchpadPageData { ScratchpadPageData {
user,
scratchpads: scratchpad_list, scratchpads: scratchpad_list,
archived_scratchpads: archived_list, archived_scratchpads: archived_list,
conversation_archive,
new_scratchpad: None, new_scratchpad: None,
}, },
); );
@@ -396,7 +374,6 @@ pub async fn archive_scratchpad(
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?; let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?; let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let scratchpad_list: Vec<ScratchpadListItem> = let scratchpad_list: Vec<ScratchpadListItem> =
scratchpads.iter().map(ScratchpadListItem::from).collect(); scratchpads.iter().map(ScratchpadListItem::from).collect();
@@ -408,15 +385,59 @@ pub async fn archive_scratchpad(
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"scratchpad/base.html", "scratchpad/base.html",
ScratchpadPageData { ScratchpadPageData {
user,
scratchpads: scratchpad_list, scratchpads: scratchpad_list,
archived_scratchpads: archived_list, archived_scratchpads: archived_list,
conversation_archive,
new_scratchpad: None, new_scratchpad: None,
}, },
)) ))
} }
pub async fn restore_scratchpad(
RequireUser(user): RequireUser,
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
Scratchpad::restore(&scratchpad_id, &user.id, &state.db).await?;
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
let scratchpad_list: Vec<ScratchpadListItem> =
scratchpads.iter().map(ScratchpadListItem::from).collect();
let archived_list: Vec<ScratchpadArchiveItem> = archived_scratchpads
.iter()
.map(ScratchpadArchiveItem::from)
.collect();
let trigger_payload = serde_json::json!({
"toast": {
"title": "Scratchpad restored",
"description": "The scratchpad is back in your active list.",
"type": "info"
}
});
let trigger_value = serde_json::to_string(&trigger_payload).unwrap_or_else(|_| {
r#"{"toast":{"title":"Scratchpad restored","description":"The scratchpad is back in your active list.","type":"info"}}"#.to_string()
});
let template_response = TemplateResponse::new_partial(
"scratchpad/base.html",
"main",
ScratchpadPageData {
scratchpads: scratchpad_list,
archived_scratchpads: archived_list,
new_scratchpad: None,
},
);
let mut response = template_response.into_response();
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
response.headers_mut().insert(HX_TRIGGER, header_value);
}
Ok(response)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -506,52 +527,3 @@ mod tests {
assert_eq!(archive_item.ingested_at, None); assert_eq!(archive_item.ingested_at, None);
} }
} }
pub async fn restore_scratchpad(
RequireUser(user): RequireUser,
State(state): State<HtmlState>,
Path(scratchpad_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
Scratchpad::restore(&scratchpad_id, &user.id, &state.db).await?;
let scratchpads = Scratchpad::get_by_user(&user.id, &state.db).await?;
let archived_scratchpads = Scratchpad::get_archived_by_user(&user.id, &state.db).await?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let scratchpad_list: Vec<ScratchpadListItem> =
scratchpads.iter().map(ScratchpadListItem::from).collect();
let archived_list: Vec<ScratchpadArchiveItem> = archived_scratchpads
.iter()
.map(ScratchpadArchiveItem::from)
.collect();
let trigger_payload = serde_json::json!({
"toast": {
"title": "Scratchpad restored",
"description": "The scratchpad is back in your active list.",
"type": "info"
}
});
let trigger_value = serde_json::to_string(&trigger_payload).unwrap_or_else(|_| {
r#"{"toast":{"title":"Scratchpad restored","description":"The scratchpad is back in your active list.","type":"info"}}"#.to_string()
});
let template_response = TemplateResponse::new_partial(
"scratchpad/base.html",
"main",
ScratchpadPageData {
user,
scratchpads: scratchpad_list,
archived_scratchpads: archived_list,
conversation_archive,
new_scratchpad: None,
},
);
let mut response = template_response.into_response();
if let Ok(header_value) = HeaderValue::from_str(&trigger_value) {
response.headers_mut().insert(HX_TRIGGER, header_value);
}
Ok(response)
}

View File

@@ -1,6 +1,7 @@
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
fmt, str::FromStr, fmt,
str::FromStr,
}; };
use axum::{ use axum::{
@@ -8,9 +9,7 @@ use axum::{
response::IntoResponse, response::IntoResponse,
}; };
use common::storage::types::{ use common::storage::types::{
conversation::Conversation,
text_content::{deserialize_flexible_id, TextContent}, text_content::{deserialize_flexible_id, TextContent},
user::User,
StoredObject, StoredObject,
}; };
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput}; use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
@@ -193,64 +192,62 @@ pub async fn search_result_handler(
pub struct AnswerData { pub struct AnswerData {
search_result: Vec<SearchResultForTemplate>, search_result: Vec<SearchResultForTemplate>,
query_param: String, query_param: String,
user: User,
conversation_archive: Vec<Conversation>,
} }
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let (search_results_for_template, final_query_param_for_template) = let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
if let Some(actual_query) = params.query { params.query
let trimmed_query = actual_query.trim(); {
if trimmed_query.is_empty() { let trimmed_query = actual_query.trim();
(Vec::<SearchResultForTemplate>::new(), String::new()) if trimmed_query.is_empty() {
(Vec::<SearchResultForTemplate>::new(), String::new())
} else {
// Use retrieval pipeline Search strategy
let config = RetrievalConfig::for_search(SearchTarget::Both);
// Checkout a reranker lease if pool is available
let reranker_lease = match &state.reranker_pool {
Some(pool) => Some(pool.checkout().await),
None => None,
};
let result = retrieval_pipeline::pipeline::run_pipeline(
&state.db,
&state.openai_client,
Some(&state.embedding_provider),
trimmed_query,
&user.id,
config,
reranker_lease,
)
.await?;
let search_result = match result {
StrategyOutput::Search(sr) => sr,
_ => SearchResult::new(vec![], vec![]),
};
let mut source_ids = HashSet::new();
for chunk_result in &search_result.chunks {
source_ids.insert(chunk_result.chunk.source_id.clone());
}
for entity_result in &search_result.entities {
source_ids.insert(entity_result.entity.source_id.clone());
}
let source_label_map = if source_ids.is_empty() {
HashMap::new()
} else { } else {
// Use retrieval pipeline Search strategy let record_ids: Vec<RecordId> = source_ids
let config = RetrievalConfig::for_search(SearchTarget::Both); .iter()
.filter_map(|id| {
// Checkout a reranker lease if pool is available if id.contains(':') {
let reranker_lease = match &state.reranker_pool { RecordId::from_str(id).ok()
Some(pool) => Some(pool.checkout().await), } else {
None => None, Some(RecordId::from_table_key(TextContent::table_name(), id))
}; }
})
let result = retrieval_pipeline::pipeline::run_pipeline( .collect();
&state.db, let mut response = state
&state.openai_client,
Some(&state.embedding_provider),
trimmed_query,
&user.id,
config,
reranker_lease,
)
.await?;
let search_result = match result {
StrategyOutput::Search(sr) => sr,
_ => SearchResult::new(vec![], vec![]),
};
let mut source_ids = HashSet::new();
for chunk_result in &search_result.chunks {
source_ids.insert(chunk_result.chunk.source_id.clone());
}
for entity_result in &search_result.entities {
source_ids.insert(entity_result.entity.source_id.clone());
}
let source_label_map = if source_ids.is_empty() {
HashMap::new()
} else {
let record_ids: Vec<RecordId> = source_ids
.iter()
.filter_map(|id| {
if id.contains(':') {
RecordId::from_str(id).ok()
} else {
Some(RecordId::from_table_key(TextContent::table_name(), id))
}
})
.collect();
let mut response = state
.db .db
.client .client
.query( .query(
@@ -260,92 +257,90 @@ pub async fn search_result_handler(
.bind(("user_id", user.id.clone())) .bind(("user_id", user.id.clone()))
.bind(("record_ids", record_ids)) .bind(("record_ids", record_ids))
.await?; .await?;
let contents: Vec<SourceLabelRow> = response.take(0)?; let contents: Vec<SourceLabelRow> = response.take(0)?;
tracing::debug!( tracing::debug!(
source_id_count = source_ids.len(), source_id_count = source_ids.len(),
label_row_count = contents.len(), label_row_count = contents.len(),
"Resolved search source labels" "Resolved search source labels"
);
let mut labels = HashMap::new();
for content in contents {
let label = build_source_label(&content);
labels.insert(content.id.clone(), label.clone());
labels.insert(
format!("{}:{}", TextContent::table_name(), content.id),
label,
); );
}
let mut labels = HashMap::new(); labels
for content in contents { };
let label = build_source_label(&content);
labels.insert(content.id.clone(), label.clone());
labels.insert(
format!("{}:{}", TextContent::table_name(), content.id),
label,
);
}
labels let mut combined_results: Vec<SearchResultForTemplate> =
}; Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
let mut combined_results: Vec<SearchResultForTemplate> = // Add chunk results
Vec::with_capacity(search_result.chunks.len() + search_result.entities.len()); for chunk_result in search_result.chunks {
let source_label = source_label_map
// Add chunk results .get(&chunk_result.chunk.source_id)
for chunk_result in search_result.chunks { .cloned()
let source_label = source_label_map .unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
.get(&chunk_result.chunk.source_id) combined_results.push(SearchResultForTemplate {
.cloned() result_type: "text_chunk".to_string(),
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id)); score: chunk_result.score,
combined_results.push(SearchResultForTemplate { text_chunk: Some(TextChunkForTemplate {
result_type: "text_chunk".to_string(), id: chunk_result.chunk.id,
source_id: chunk_result.chunk.source_id,
source_label,
chunk: chunk_result.chunk.chunk,
score: chunk_result.score, score: chunk_result.score,
text_chunk: Some(TextChunkForTemplate { }),
id: chunk_result.chunk.id, knowledge_entity: None,
source_id: chunk_result.chunk.source_id, });
source_label,
chunk: chunk_result.chunk.chunk,
score: chunk_result.score,
}),
knowledge_entity: None,
});
}
// Add entity results
for entity_result in search_result.entities {
let source_label = source_label_map
.get(&entity_result.entity.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(),
score: entity_result.score,
text_chunk: None,
knowledge_entity: Some(KnowledgeEntityForTemplate {
id: entity_result.entity.id,
name: entity_result.entity.name,
description: entity_result.entity.description,
entity_type: format!("{:?}", entity_result.entity.entity_type),
source_id: entity_result.entity.source_id,
source_label,
score: entity_result.score,
}),
});
}
// Sort by score descending
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
// Limit results
const TOTAL_LIMIT: usize = 10;
combined_results.truncate(TOTAL_LIMIT);
(combined_results, trimmed_query.to_string())
} }
} else {
(Vec::<SearchResultForTemplate>::new(), String::new()) // Add entity results
}; for entity_result in search_result.entities {
let source_label = source_label_map
.get(&entity_result.entity.source_id)
.cloned()
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(),
score: entity_result.score,
text_chunk: None,
knowledge_entity: Some(KnowledgeEntityForTemplate {
id: entity_result.entity.id,
name: entity_result.entity.name,
description: entity_result.entity.description,
entity_type: format!("{:?}", entity_result.entity.entity_type),
source_id: entity_result.entity.source_id,
source_label,
score: entity_result.score,
}),
});
}
// Sort by score descending
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
// Limit results
const TOTAL_LIMIT: usize = 10;
combined_results.truncate(TOTAL_LIMIT);
(combined_results, trimmed_query.to_string())
}
} else {
(Vec::<SearchResultForTemplate>::new(), String::new())
};
Ok(TemplateResponse::new_template( Ok(TemplateResponse::new_template(
"search/base.html", "search/base.html",
AnswerData { AnswerData {
search_result: search_results_for_template, search_result: search_results_for_template,
query_param: final_query_param_for_template, query_param: final_query_param_for_template,
user,
conversation_archive,
}, },
)) ))
} }

View File

@@ -0,0 +1,20 @@
{% extends 'admin/_layout.html' %}
{% block admin_navigation %}
<a href="/admin?section=overview"
class="nb-btn btn-sm px-4 {% if current_section == 'overview' %}nb-cta{% else %}btn-ghost{% endif %}">
Overview
</a>
<a href="/admin?section=models"
class="nb-btn btn-sm px-4 {% if current_section == 'models' %}nb-cta{% else %}btn-ghost{% endif %}">
Models
</a>
{% endblock %}
{% block admin_content %}
{% if current_section == 'models' %}
{% include 'admin/sections/models.html' %}
{% else %}
{% include 'admin/sections/overview.html' %}
{% endif %}
{% endblock %}

View File

@@ -0,0 +1,29 @@
{% extends 'body_base.html' %}
{% block title %}Minne - Admin{% endblock %}
{% block main %}
<div id="admin-shell" class="flex justify-center grow mt-2 sm:mt-4 pb-4">
<div class="container flex flex-col gap-4">
<section class="nb-panel p-4 sm:p-5 flex flex-col gap-3 sm:flex-row sm:items-start sm:justify-between">
<div>
<h1 class="text-xl font-extrabold tracking-tight">Admin Controls</h1>
</div>
<div class="text-xs opacity-60 sm:text-right">
Signed in as <span class="font-medium">{{ user.email }}</span>
</div>
</section>
<nav class="nb-panel p-2 flex flex-wrap gap-2 text-sm" hx-boost="true" hx-target="#admin-shell"
hx-select="#admin-shell" hx-swap="outerHTML" hx-push-url="true">
{% block admin_navigation %}
{% endblock %}
</nav>
<div id="admin-content" class="flex flex-col gap-4">
{% block admin_content %}
{% endblock %}
</div>
</div>
</div>
{% endblock %}

View File

@@ -1,51 +1 @@
{% extends 'body_base.html' %} {% extends "admin/_base.html" %}
{% block title %}Minne - Admin{% endblock %}
{% block main %}
<div id="admin-shell" class="flex justify-center grow mt-2 sm:mt-4 pb-4">
<div class="container flex flex-col gap-4">
<section class="nb-panel p-4 sm:p-5 flex flex-col gap-3 sm:flex-row sm:items-start sm:justify-between">
<div>
<h1 class="text-xl font-extrabold tracking-tight">Admin Controls</h1>
<p class="text-sm opacity-70 max-w-2xl">
Stay on top of analytics and manage AI integrations without waiting on long-running model calls.
</p>
</div>
<div class="text-xs opacity-60 sm:text-right">
Signed in as <span class="font-medium">{{ user.email }}</span>
</div>
</section>
<nav
class="nb-panel p-2 flex flex-wrap gap-2 text-sm"
hx-boost="true"
hx-target="#admin-shell"
hx-select="#admin-shell"
hx-swap="outerHTML"
hx-push-url="true"
>
<a
href="/admin?section=overview"
class="nb-btn btn-sm px-4 {% if current_section == 'overview' %}nb-cta{% else %}btn-ghost{% endif %}"
>
Overview
</a>
<a
href="/admin?section=models"
class="nb-btn btn-sm px-4 {% if current_section == 'models' %}nb-cta{% else %}btn-ghost{% endif %}"
>
Models
</a>
</nav>
<div id="admin-content" class="flex flex-col gap-4">
{% if current_section == 'models' %}
{% include 'admin/sections/models.html' %}
{% else %}
{% include 'admin/sections/overview.html' %}
{% endif %}
</div>
</div>
</div>
{% endblock %}

View File

@@ -1,5 +1,7 @@
{% extends "modal_base.html" %} {% extends "modal_base.html" %}
{% block modal_class %}max-w-3xl{% endblock %}
{% block form_attributes %} {% block form_attributes %}
hx-patch="/update-image-prompt" hx-patch="/update-image-prompt"
hx-target="#system_prompt_section" hx-target="#system_prompt_section"

View File

@@ -1,5 +1,7 @@
{% extends "modal_base.html" %} {% extends "modal_base.html" %}
{% block modal_class %}max-w-3xl{% endblock %}
{% block form_attributes %} {% block form_attributes %}
hx-patch="/update-ingestion-prompt" hx-patch="/update-ingestion-prompt"
hx-target="#system_prompt_section" hx-target="#system_prompt_section"

View File

@@ -1,5 +1,7 @@
{% extends "modal_base.html" %} {% extends "modal_base.html" %}
{% block modal_class %}max-w-3xl{% endblock %}
{% block form_attributes %} {% block form_attributes %}
hx-patch="/update-query-prompt" hx-patch="/update-query-prompt"
hx-target="#system_prompt_section" hx-target="#system_prompt_section"

View File

@@ -0,0 +1,88 @@
{% extends "auth/_settings_layout.html" %}
{% block settings_header %}
<h1 class="text-xl font-extrabold tracking-tight">Account Settings</h1>
{% endblock %}
{% block settings_left_column %}
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div>
<input type="email" name="email" value="{{ user.email }}" class="nb-input w-full" disabled />
</label>
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">API Key</div>
{% block api_key_section %}
{% if api_key %}
<div class="relative">
<input id="api_key_input" type="text" name="api_key" value="{{ api_key }}"
class="nb-input w-full pr-14" disabled />
<button type="button" id="copy_api_key_btn" onclick="copy_api_key()"
class="absolute inset-y-0 right-0 flex items-center px-2 nb-btn btn-sm" aria-label="Copy API key"
title="Copy API key">
{% include "icons/clipboard_icon.html" %}
</button>
</div>
<a href="https://www.icloud.com/shortcuts/66985f7b98a74aaeac6ba29c3f1f0960"
class="nb-btn nb-cta mt-2 w-full">Download iOS shortcut</a>
{% else %}
<button hx-post="/set-api-key" class="nb-btn nb-cta w-full" hx-swap="outerHTML">Create API-Key</button>
{% endif %}
{% endblock %}
</label>
<script>
function copy_api_key() {
const input = document.getElementById('api_key_input');
if (!input) return;
if (navigator.clipboard && window.isSecureContext) {
navigator.clipboard.writeText(input.value)
.then(() => show_toast('API key copied!', 'success'))
.catch(() => show_toast('Copy failed', 'error'));
} else {
show_toast('Copy not supported', 'info');
}
}
</script>
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Timezone</div>
{% block timezone_section %}
{% set active_timezone = selected_timezone|default(user.timezone) %}
<select name="timezone" class="nb-select w-full" hx-patch="/update-timezone" hx-swap="outerHTML">
{% for tz in timezones %}
<option value="{{ tz }}" {% if tz==active_timezone %}selected{% endif %}>{{ tz }}</option>
{% endfor %}
</select>
{% endblock %}
</label>
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Theme</div>
{% block theme_section %}
{% set active_theme = selected_theme|default(user.theme) %}
<select name="theme" class="nb-select w-full" hx-patch="/update-theme" hx-swap="outerHTML">
{% for option in theme_options %}
<option value="{{ option }}" {% if option==active_theme %}selected{% endif %}>{{ option }}</option>
{% endfor %}
</select>
<script>
document.documentElement.setAttribute('data-theme-preference', '{{ active_theme }}');
</script>
{% endblock %}
</label>
{% endblock %}
{% block settings_right_column %}
<div>
{% block change_password_section %}
<button hx-get="/change-password" hx-swap="outerHTML" class="nb-btn w-full">Change Password</button>
{% endblock %}
</div>
<div>
<button hx-delete="/delete-account"
hx-confirm="This action will permanently delete your account and all data associated. Are you sure you want to continue?"
class="nb-btn btn-error w-full">Delete Account</button>
</div>
{% endblock %}

View File

@@ -0,0 +1,11 @@
{% extends "head_base.html" %}
{% block title %}Minne - Auth{% endblock %}
{% block body %}
<div class="min-h-[100dvh] flex flex-col items-center justify-center">
{% block auth_content %}
{% endblock %}
</div>
<div id="toast-container" class="fixed bottom-4 right-4 z-50 space-y-2"></div>
{% endblock %}

View File

@@ -0,0 +1,32 @@
{% extends "body_base.html" %}
{% block title %}Minne - Account{% endblock %}
{% block main %}
<div class="flex justify-center grow mt-2 sm:mt-4 pb-4">
<div class="container">
<section class="mb-4">
<div class="nb-panel p-3 flex items-center justify-between">
{% block settings_header %}
{% endblock %}
</div>
</section>
<section class="grid grid-cols-1 lg:grid-cols-2 gap-4 space-y-2">
<!-- Left column -->
<div class="nb-panel p-4 space-y-2 flex flex-col">
{% block settings_left_column %}
{% endblock %}
</div>
<!-- Right column -->
<div class="nb-panel p-4 space-y-2">
{% block settings_right_column %}
{% endblock %}
</div>
</section>
<div id="account-result" class="mt-4"></div>
</div>
</div>
{% endblock %}

View File

@@ -1,88 +1 @@
{% extends "body_base.html" %} {% extends "auth/_account_settings_core.html" %}
{% block title %}Minne - Account{% endblock %}
{% block main %}
<div class="flex justify-center grow mt-2 sm:mt-4 pb-4">
<div class="container">
<section class="mb-4">
<div class="nb-panel p-3 flex items-center justify-between">
<h1 class="text-xl font-extrabold tracking-tight">Account Settings</h1>
</div>
</section>
<section class="grid grid-cols-1 lg:grid-cols-2 gap-4 space-y-2">
<!-- Left column -->
<div class="nb-panel p-4 space-y-2 flex flex-col">
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div>
<input type="email" name="email" value="{{ user.email }}" class="nb-input w-full" disabled />
</label>
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">API Key</div>
{% block api_key_section %}
{% if user.api_key %}
<div class="relative">
<input id="api_key_input" type="text" name="api_key" value="{{ user.api_key }}"
class="nb-input w-full pr-14" disabled />
<button type="button" id="copy_api_key_btn" onclick="copy_api_key()"
class="absolute inset-y-0 right-0 flex items-center px-2 nb-btn btn-sm" aria-label="Copy API key"
title="Copy API key">
{% include "icons/clipboard_icon.html" %}
</button>
</div>
<a href="https://www.icloud.com/shortcuts/66985f7b98a74aaeac6ba29c3f1f0960"
class="nb-btn nb-cta mt-2 w-full">Download iOS shortcut</a>
{% else %}
<button hx-post="/set-api-key" class="nb-btn nb-cta w-full" hx-swap="outerHTML">Create API-Key</button>
{% endif %}
{% endblock %}
</label>
<script>
function copy_api_key() {
const input = document.getElementById('api_key_input');
if (!input) return;
if (navigator.clipboard && window.isSecureContext) {
navigator.clipboard.writeText(input.value)
.then(() => show_toast('API key copied!', 'success'))
.catch(() => show_toast('Copy failed', 'error'));
} else {
show_toast('Copy not supported', 'info');
}
}
</script>
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Timezone</div>
{% block timezone_section %}
<select name="timezone" class="nb-select w-full" hx-patch="/update-timezone" hx-swap="outerHTML">
{% for tz in timezones %}
<option value="{{ tz }}" {% if tz==user.timezone %}selected{% endif %}>{{ tz }}</option>
{% endfor %}
</select>
{% endblock %}
</label>
</div>
<!-- Right column -->
<div class="nb-panel p-4 space-y-2">
<div>
{% block change_password_section %}
<button hx-get="/change-password" hx-swap="outerHTML" class="nb-btn w-full">Change Password</button>
{% endblock %}
</div>
<div>
<button hx-delete="/delete-account"
hx-confirm="This action will permanently delete your account and all data associated. Are you sure you want to continue?"
class="nb-btn btn-error w-full">Delete Account</button>
</div>
</div>
</section>
<div id="account-result" class="mt-4"></div>
</div>
</div>
{% endblock %}

View File

@@ -1,9 +1,7 @@
{% extends "head_base.html" %} {% extends "auth/_layout.html" %}
{% block title %}Minne - Sign in{% endblock %} {% block title %}Minne - Sign in{% endblock %}
{% block body %} {% block auth_content %}
<div class="min-h-[100dvh] flex">
{% include "auth/signin_form.html" %} {% include "auth/signin_form.html" %}
</div> {% endblock %}
{% endblock %}

View File

@@ -6,7 +6,7 @@
</div> </div>
<div class="u-hairline mb-3"></div> <div class="u-hairline mb-3"></div>
<form hx-post="/signin" hx-target="#login-result" class="flex flex-col gap-2"> <form hx-post="/signin" hx-swap="none" class="flex flex-col gap-2">
<label class="w-full"> <label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div> <div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div>
<input name="email" type="email" placeholder="Email" class="nb-input w-full validator" required /> <input name="email" type="email" placeholder="Email" class="nb-input w-full validator" required />
@@ -19,8 +19,6 @@
minlength="8" /> minlength="8" />
</label> </label>
<div class="mt-1 text-error" id="login-result"></div>
<div class="form-control mt-1"> <div class="form-control mt-1">
<label class="label cursor-pointer justify-start gap-3"> <label class="label cursor-pointer justify-start gap-3">
<input type="checkbox" name="remember_me" class="nb-checkbox" /> <input type="checkbox" name="remember_me" class="nb-checkbox" />

View File

@@ -1,9 +1,8 @@
{% extends "head_base.html" %} {% extends "auth/_layout.html" %}
{% block title %}Minne - Sign up{% endblock %} {% block title %}Minne - Sign up{% endblock %}
{% block body %} {% block auth_content %}
<div class="min-h-[100dvh] flex items-center">
<div class="container mx-auto px-4 sm:max-w-md"> <div class="container mx-auto px-4 sm:max-w-md">
<div class="nb-card p-5"> <div class="nb-card p-5">
<div class="flex items-center justify-between mb-3"> <div class="flex items-center justify-between mb-3">
@@ -12,7 +11,7 @@
</div> </div>
<div class="u-hairline mb-3"></div> <div class="u-hairline mb-3"></div>
<form hx-post="/signup" hx-target="#signup-result" class="flex flex-col gap-4"> <form hx-post="/signup" hx-swap="none" class="flex flex-col gap-4">
<label class="w-full"> <label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div> <div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div>
<input type="email" placeholder="Email" name="email" required class="nb-input w-full validator" /> <input type="email" placeholder="Email" name="email" required class="nb-input w-full validator" />
@@ -32,7 +31,6 @@
</p> </p>
</label> </label>
<div class="mt-2 text-error" id="signup-result"></div>
<div class="form-control mt-1"> <div class="form-control mt-1">
<button id="submit-btn" class="nb-btn nb-cta w-full">Create Account</button> <button id="submit-btn" class="nb-btn nb-cta w-full">Create Account</button>
</div> </div>
@@ -46,10 +44,9 @@
</div> </div>
</div> </div>
</div> </div>
</div> <script>
<script> // Detect timezone and set hidden input
// Detect timezone and set hidden input const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone; document.getElementById("timezone").value = timezone;
document.getElementById("timezone").value = timezone; </script>
</script>
{% endblock %} {% endblock %}

View File

@@ -2,8 +2,8 @@
{% block body %} {% block body %}
<body class="relative" hx-ext="head-support"> <body class="relative">
<div class="drawer lg:drawer-open"> <div id="main-content-wrapper" class="drawer lg:drawer-open">
<input id="my-drawer" type="checkbox" class="drawer-toggle" /> <input id="my-drawer" type="checkbox" class="drawer-toggle" />
<!-- Page Content --> <!-- Page Content -->
<div class="drawer-content flex flex-col h-screen"> <div class="drawer-content flex flex-col h-screen">
@@ -14,6 +14,7 @@
{% block main %}{% endblock %} {% block main %}{% endblock %}
<div class="p32 min-h-[10px]"></div> <div class="p32 min-h-[10px]"></div>
</main> </main>
{% block overlay %}{% endblock %}
</div> </div>
<!-- Sidebar --> <!-- Sidebar -->
{% if user %} {% if user %}

View File

@@ -0,0 +1,78 @@
{% extends 'body_base.html' %}
{% block title %}Minne - Chat{% endblock %}
{% block main %}
<div class="flex grow relative justify-center mt-2 sm:mt-4">
<div class="container">
<section class="mb-3">
<div class="nb-panel p-3 flex items-center justify-between">
{% block chat_header_actions %}
{% endblock %}
</div>
</section>
<div id="chat-scroll-container" class="overflow-auto hide-scrollbar">
{% block chat_content %}
{% endblock %}
</div>
</div>
</div>
<script>
function doScrollChatToBottom() {
const mainScroll = document.querySelector('main');
if (mainScroll) mainScroll.scrollTop = mainScroll.scrollHeight;
const chatScroll = document.getElementById('chat-scroll-container');
if (chatScroll) chatScroll.scrollTop = chatScroll.scrollHeight;
const chatContainer = document.getElementById('chat_container');
if (chatContainer) chatContainer.scrollTop = chatContainer.scrollHeight;
window.scrollTo(0, document.body.scrollHeight);
}
function scrollChatToBottom() {
if (!window.location.pathname.startsWith('/chat')) return;
requestAnimationFrame(doScrollChatToBottom);
}
window.scrollChatToBottom = scrollChatToBottom;
// Delay initial scroll to avoid interfering with view transition
document.addEventListener('DOMContentLoaded', () => setTimeout(scrollChatToBottom, 350));
function handleChatSwap(e) {
if (!window.location.pathname.startsWith('/chat')) return;
// Full page swap: delay for view transition; partial swap: immediate
if (e.detail && e.detail.target && e.detail.target.tagName === 'BODY') {
setTimeout(scrollChatToBottom, 350);
} else {
scrollChatToBottom();
}
}
function cleanupChatListeners(e) {
if (e.detail && e.detail.target && e.detail.target.tagName === 'BODY') {
document.body.removeEventListener('htmx:afterSwap', window._chatEventHandlers.afterSwap);
document.body.removeEventListener('htmx:afterSettle', window._chatEventHandlers.afterSettle);
document.body.removeEventListener('htmx:beforeSwap', window._chatEventHandlers.beforeSwap);
delete window._chatEventHandlers;
window._chatListenersAttached = false;
}
}
window._chatEventHandlers = {
afterSwap: handleChatSwap,
afterSettle: handleChatSwap,
beforeSwap: cleanupChatListeners
};
if (!window._chatListenersAttached) {
document.body.addEventListener('htmx:afterSwap', window._chatEventHandlers.afterSwap);
document.body.addEventListener('htmx:afterSettle', window._chatEventHandlers.afterSettle);
document.body.addEventListener('htmx:beforeSwap', window._chatEventHandlers.beforeSwap);
window._chatListenersAttached = true;
}
</script>
{% endblock %}

View File

@@ -1,48 +1,14 @@
{% extends 'body_base.html' %} {% extends "chat/_layout.html" %}
{% block title %}Minne - Chat{% endblock %} {% block chat_header_actions %}
<h1 class="text-xl font-extrabold tracking-tight">Chat</h1>
{% block head %} <div class="text-xs opacity-70">Converse with your knowledge</div>
<script src="/assets/htmx-ext-sse.js" defer></script>
{% endblock %} {% endblock %}
{% block main %} {% block chat_content %}
<div class="flex grow relative justify-center mt-2 sm:mt-4"> {% include "chat/history.html" %}
<div class="container"> {% endblock %}
<section class="mb-3">
<div class="nb-panel p-3 flex items-center justify-between"> {% block overlay %}
<h1 class="text-xl font-extrabold tracking-tight">Chat</h1> {% include "chat/new_message_form.html" %}
<div class="text-xs opacity-70">Converse with your knowledge</div>
</div>
</section>
<div id="chat-scroll-container" class="overflow-auto hide-scrollbar">
{% include "chat/history.html" %}
{% include "chat/new_message_form.html" %}
</div>
</div>
</div>
<script>
function scrollChatToBottom() {
requestAnimationFrame(() => {
const mainScroll = document.querySelector('main');
if (mainScroll) mainScroll.scrollTop = mainScroll.scrollHeight;
const chatScroll = document.getElementById('chat-scroll-container');
if (chatScroll) chatScroll.scrollTop = chatScroll.scrollHeight;
const chatContainer = document.getElementById('chat_container');
if (chatContainer) chatContainer.scrollTop = chatContainer.scrollHeight;
window.scrollTo(0, document.body.scrollHeight);
});
}
window.scrollChatToBottom = scrollChatToBottom;
document.addEventListener('DOMContentLoaded', scrollChatToBottom);
document.body.addEventListener('htmx:afterSwap', scrollChatToBottom);
document.body.addEventListener('htmx:afterSettle', scrollChatToBottom);
</script>
{% endblock %} {% endblock %}

View File

@@ -12,16 +12,34 @@
</label> </label>
</form> </form>
<script> <script>
document.getElementById('chat-input').addEventListener('keydown', function (e) { (function () {
if (e.key === 'Enter' && !e.shiftKey) { const newChatStreamId = 'ai-stream-{{ user_message.id }}';
e.preventDefault();
htmx.trigger('#chat-form', 'submit'); document.getElementById('chat-input').addEventListener('keydown', function (e) {
} if (e.key === 'Enter' && !e.shiftKey) {
}); e.preventDefault();
// Clear textarea after successful submission htmx.trigger('#chat-form', 'submit');
document.getElementById('chat-form').addEventListener('htmx:afterRequest', function (e) { }
if (e.detail.successful) { // Check if the request was successful });
document.getElementById('chat-input').value = ''; // Clear the textarea // Clear textarea after successful submission
} document.getElementById('chat-form').addEventListener('htmx:afterRequest', function (e) {
}); if (e.detail.successful) { // Check if the request was successful
</script> document.getElementById('chat-input').value = ''; // Clear the textarea
}
});
const refreshSidebarAfterFirstResponse = function (e) {
const streamEl = document.getElementById(newChatStreamId);
if (!streamEl || e.target !== streamEl) return;
htmx.ajax('GET', '/chat/sidebar', {
target: '.drawer-side',
swap: 'outerHTML'
});
document.body.removeEventListener('htmx:sseClose', refreshSidebarAfterFirstResponse);
};
document.body.addEventListener('htmx:sseClose', refreshSidebarAfterFirstResponse);
})();
</script>

View File

@@ -1,6 +1,6 @@
<div class="fixed bottom-0 left-0 right-0 lg:left-72 z-20"> <div class="fixed bottom-0 left-0 right-0 lg:left-72 z-20">
<div class="mx-auto max-w-3xl px-4 pb-3"> <div class="mx-auto max-w-3xl px-4 pb-3">
<div class="nb-panel p-2"> <div class="nb-panel p-2 no-animation">
<form hx-post="{% if conversation %} /chat/{{conversation.id}} {% else %} /chat {% endif %}" <form hx-post="{% if conversation %} /chat/{{conversation.id}} {% else %} /chat {% endif %}"
hx-target="#chat_container" hx-swap="beforeend" class="relative flex gap-2 items-end" id="chat-form"> hx-target="#chat_container" hx-swap="beforeend" class="relative flex gap-2 items-end" id="chat-form">
<textarea autofocus required name="content" placeholder="Type your message…" rows="3" <textarea autofocus required name="content" placeholder="Type your message…" rows="3"

View File

@@ -111,12 +111,23 @@
// Load content if needed // Load content if needed
if (!tooltipContent) { if (!tooltipContent) {
fetch(`/chat/reference/${encodeURIComponent(reference)}`) fetch(`/chat/reference/${encodeURIComponent(reference)}`)
.then(response => response.text()) .then(response => {
if (!response.ok) {
throw new Error(`reference lookup failed with status ${response.status}`);
}
return response.text();
})
.then(html => { .then(html => {
tooltipContent = html; tooltipContent = html;
if (document.getElementById(tooltipId)) { if (document.getElementById(tooltipId)) {
document.getElementById(tooltipId).innerHTML = html; document.getElementById(tooltipId).innerHTML = html;
} }
})
.catch(() => {
tooltipContent = '<div class="text-xs opacity-70">Reference unavailable.</div>';
if (document.getElementById(tooltipId)) {
document.getElementById(tooltipId).innerHTML = tooltipContent;
}
}); });
} else if (tooltip) { } else if (tooltip) {
// Set content if already loaded // Set content if already loaded

View File

@@ -1,3 +1,11 @@
<div>{{entity.name}}</div> {% if text_chunk %}
<div>{{entity.description}}</div> <div class="font-semibold">Chunk Reference</div>
<div>{{entity.updated_at|datetimeformat(format="short", tz=user.timezone)}} </div> <div class="text-sm whitespace-pre-wrap">{{text_chunk.chunk}}</div>
<div class="text-xs opacity-70">{{text_chunk_updated_at}}</div>
{% elif entity %}
<div class="font-semibold">{{entity.name}}</div>
<div class="text-sm">{{entity.description}}</div>
<div class="text-xs opacity-70">{{entity_updated_at}}</div>
{% else %}
<div class="text-xs opacity-70">Reference unavailable.</div>
{% endif %}

View File

@@ -4,7 +4,8 @@
</div> </div>
</div> </div>
<div class="chat chat-start"> <div class="chat chat-start">
<div hx-ext="sse" sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream" <div id="ai-stream-{{user_message.id}}" hx-ext="sse"
sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
hx-swap="beforeend"> hx-swap="beforeend">
<div class="chat-bubble"> <div class="chat-bubble">
<span class="loading loading-dots loading-sm loading-id-{{user_message.id}}"></span> <span class="loading loading-dots loading-sm loading-id-{{user_message.id}}"></span>
@@ -27,13 +28,22 @@
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n')); el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom(); if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
}); });
document.body.addEventListener('htmx:sseClose', function () { document.body.addEventListener('htmx:sseClose', function (e) {
const msgId = '{{ user_message.id }}'; const msgId = '{{ user_message.id }}';
const streamEl = document.getElementById('ai-stream-' + msgId);
if (streamEl && e.target !== streamEl) return;
const el = document.getElementById('ai-message-content-' + msgId); const el = document.getElementById('ai-message-content-' + msgId);
if (el && window.markdownBuffer[msgId]) { if (el && window.markdownBuffer[msgId]) {
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n')); el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
delete window.markdownBuffer[msgId]; delete window.markdownBuffer[msgId];
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom(); if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
} }
if (streamEl) {
streamEl.removeAttribute('sse-connect');
streamEl.removeAttribute('sse-close');
streamEl.removeAttribute('hx-ext');
}
}); });
</script> </script>

View File

@@ -0,0 +1,15 @@
{% macro icon(name) %}
{% if name == "home" %}
{% include "icons/home_icon.html" %}
{% elif name == "book" %}
{% include "icons/book_icon.html" %}
{% elif name == "document" %}
{% include "icons/document_icon.html" %}
{% elif name == "chat" %}
{% include "icons/chat_icon.html" %}
{% elif name == "search" %}
{% include "icons/search_icon.html" %}
{% elif name == "scratchpad" %}
{% include "icons/scratchpad_icon.html" %}
{% endif %}
{% endmacro %}

Some files were not shown because too many files have changed in this diff Show More