mirror of
https://github.com/perstarkse/minne.git
synced 2026-02-20 23:27:39 +01:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d237ff6d9 | ||
|
|
eb928cdb0e | ||
|
|
1490852a09 | ||
|
|
b0b01182d7 | ||
|
|
679308aa1d | ||
|
|
f93c06b347 | ||
|
|
a3f207beb1 | ||
|
|
e07199adfc | ||
|
|
f22cac891c | ||
|
|
b89171d934 | ||
|
|
0133eead63 | ||
|
|
e5d2b6605f | ||
|
|
bbad91d55b |
10
CHANGELOG.md
10
CHANGELOG.md
@@ -1,4 +1,14 @@
|
||||
# Changelog
|
||||
## Unreleased
|
||||
|
||||
## 1.0.2 (2026-02-15)
|
||||
- Fix: edge case where navigation back to a chat page could trigger a new response generation
|
||||
- Fix: chat references now validate and render more reliably
|
||||
- Fix: improved admin access checks for restricted routes
|
||||
- Performance: faster chat sidebar loads from cached conversation archive data
|
||||
- API: harmonized ingest endpoint naming and added configurable ingest safety limits
|
||||
- Security: hardened query handling and ingestion logging to reduce injection and data exposure risk
|
||||
|
||||
## 1.0.1 (2026-02-11)
|
||||
- Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments.
|
||||
- Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling.
|
||||
|
||||
3201
Cargo.lock
generated
3201
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -40,7 +40,7 @@ serde_json = "1.0.128"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
sha2 = "0.10.8"
|
||||
surrealdb-migrations = "2.2.2"
|
||||
surrealdb = { version = "2", features = ["kv-mem"] }
|
||||
surrealdb = { version = "2" }
|
||||
tempfile = "3.12.0"
|
||||
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
|
||||
tokenizers = { version = "0.20.4", features = ["http"] }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# === Builder ===
|
||||
FROM rust:1.86-bookworm AS builder
|
||||
FROM rust:1.89-bookworm AS builder
|
||||
WORKDIR /usr/src/minne
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
|
||||
@@ -30,8 +30,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgomp1 libstdc++6 curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ONNX Runtime (CPU). Change if you bump ort.
|
||||
ARG ORT_VERSION=1.22.0
|
||||
# ONNX Runtime (CPU). Keep in sync with ort crate requirements.
|
||||
ARG ORT_VERSION=1.23.2
|
||||
RUN mkdir -p /opt/onnxruntime && \
|
||||
curl -fsSL -o /tmp/ort.tgz \
|
||||
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \
|
||||
|
||||
@@ -20,6 +20,9 @@ pub enum ApiError {
|
||||
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(String),
|
||||
|
||||
#[error("Payload too large: {0}")]
|
||||
PayloadTooLarge(String),
|
||||
}
|
||||
|
||||
impl From<AppError> for ApiError {
|
||||
@@ -67,6 +70,13 @@ impl IntoResponse for ApiError {
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
Self::PayloadTooLarge(message) => (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
ErrorResponse {
|
||||
error: message,
|
||||
status: "error".to_string(),
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
(status, Json(error_response)).into_response()
|
||||
@@ -132,6 +142,10 @@ mod tests {
|
||||
// Test unauthorized status
|
||||
let error = ApiError::Unauthorized("not allowed".to_string());
|
||||
assert_status_code(error, StatusCode::UNAUTHORIZED);
|
||||
|
||||
// Test payload too large status
|
||||
let error = ApiError::PayloadTooLarge("too big".to_string());
|
||||
assert_status_code(error, StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
|
||||
// Alternative approach that doesn't try to parse the response body
|
||||
|
||||
@@ -6,7 +6,7 @@ use axum::{
|
||||
Router,
|
||||
};
|
||||
use middleware_api_auth::api_auth;
|
||||
use routes::{categories::get_categories, ingress::ingest_data, liveness::live, readiness::ready};
|
||||
use routes::{categories::get_categories, ingest::ingest_data, liveness::live, readiness::ready};
|
||||
|
||||
pub mod api_state;
|
||||
pub mod error;
|
||||
@@ -26,9 +26,13 @@ where
|
||||
|
||||
// Protected API endpoints (require auth)
|
||||
let protected = Router::new()
|
||||
.route("/ingress", post(ingest_data))
|
||||
.route(
|
||||
"/ingest",
|
||||
post(ingest_data).layer(DefaultBodyLimit::max(
|
||||
app_state.config.ingest_max_body_bytes,
|
||||
)),
|
||||
)
|
||||
.route("/categories", get(get_categories))
|
||||
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
|
||||
.route_layer(from_fn_with_state(app_state.clone(), api_auth));
|
||||
|
||||
public.merge(protected)
|
||||
|
||||
@@ -6,6 +6,7 @@ use common::{
|
||||
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
|
||||
user::User,
|
||||
},
|
||||
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
|
||||
};
|
||||
use futures::{future::try_join_all, TryFutureExt};
|
||||
use serde_json::json;
|
||||
@@ -19,7 +20,7 @@ pub struct IngestParams {
|
||||
pub content: Option<String>,
|
||||
pub context: String,
|
||||
pub category: String,
|
||||
#[form_data(limit = "10000000")] // Adjust limit as needed
|
||||
#[form_data(limit = "20000000")]
|
||||
#[form_data(default)]
|
||||
pub files: Vec<FieldData<NamedTempFile>>,
|
||||
}
|
||||
@@ -29,8 +30,38 @@ pub async fn ingest_data(
|
||||
Extension(user): Extension<User>,
|
||||
TypedMultipart(input): TypedMultipart<IngestParams>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
info!("Received input: {:?}", input);
|
||||
let user_id = user.id;
|
||||
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
|
||||
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
|
||||
let context_bytes = input.context.len();
|
||||
let category_bytes = input.category.len();
|
||||
let file_count = input.files.len();
|
||||
|
||||
match validate_ingest_input(
|
||||
&state.config,
|
||||
input.content.as_deref(),
|
||||
&input.context,
|
||||
&input.category,
|
||||
file_count,
|
||||
) {
|
||||
Ok(()) => {}
|
||||
Err(IngestValidationError::PayloadTooLarge(message)) => {
|
||||
return Err(ApiError::PayloadTooLarge(message));
|
||||
}
|
||||
Err(IngestValidationError::BadRequest(message)) => {
|
||||
return Err(ApiError::ValidationError(message));
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
user_id = %user_id,
|
||||
has_content,
|
||||
content_bytes,
|
||||
context_bytes,
|
||||
category_bytes,
|
||||
file_count,
|
||||
"Received ingest request"
|
||||
);
|
||||
|
||||
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
||||
FileInfo::new_with_storage(file, &state.db, &user_id, &state.storage)
|
||||
@@ -1,4 +1,4 @@
|
||||
pub mod categories;
|
||||
pub mod ingress;
|
||||
pub mod ingest;
|
||||
pub mod liveness;
|
||||
pub mod readiness;
|
||||
|
||||
@@ -16,7 +16,7 @@ tracing = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
surrealdb = { workspace = true, features = ["kv-mem"] }
|
||||
surrealdb = { workspace = true }
|
||||
async-openai = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
@@ -49,4 +49,7 @@ fastembed = { workspace = true }
|
||||
|
||||
|
||||
[features]
|
||||
test-utils = []
|
||||
test-utils = ["surrealdb/kv-mem"]
|
||||
|
||||
[dev-dependencies]
|
||||
surrealdb = { workspace = true, features = ["kv-mem"] }
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -28,6 +28,7 @@\n # Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)\n DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY\n+DEFINE INDEX IF NOT EXISTS conversation_user_updated_at_idx ON conversation FIELDS user_id, updated_at; # For sidebar conversation projection ORDER BY\n\n # Defines the schema for the 'file' table (used by FileInfo).\n\n","events":null}
|
||||
@@ -13,3 +13,4 @@ DEFINE FIELD IF NOT EXISTS title ON conversation TYPE string;
|
||||
# Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)
|
||||
DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;
|
||||
DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY
|
||||
DEFINE INDEX IF NOT EXISTS conversation_user_updated_at_idx ON conversation FIELDS user_id, updated_at; # For sidebar conversation projection ORDER BY
|
||||
|
||||
@@ -281,6 +281,33 @@ pub mod testing {
|
||||
use crate::utils::config::{AppConfig, PdfIngestMode};
|
||||
use uuid;
|
||||
|
||||
const DEFAULT_TEST_S3_BUCKET: &str = "minne-tests";
|
||||
const DEFAULT_TEST_S3_ENDPOINT: &str = "http://127.0.0.1:19000";
|
||||
|
||||
fn configured_test_s3_bucket() -> String {
|
||||
std::env::var("MINNE_TEST_S3_BUCKET")
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
.or_else(|| {
|
||||
std::env::var("S3_BUCKET")
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
})
|
||||
.unwrap_or_else(|| DEFAULT_TEST_S3_BUCKET.to_string())
|
||||
}
|
||||
|
||||
fn configured_test_s3_endpoint() -> String {
|
||||
std::env::var("MINNE_TEST_S3_ENDPOINT")
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
.or_else(|| {
|
||||
std::env::var("S3_ENDPOINT")
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
})
|
||||
.unwrap_or_else(|| DEFAULT_TEST_S3_ENDPOINT.to_string())
|
||||
}
|
||||
|
||||
/// Create a test configuration with memory storage.
|
||||
///
|
||||
/// This provides a ready-to-use configuration for testing scenarios
|
||||
@@ -326,7 +353,8 @@ pub mod testing {
|
||||
|
||||
/// Create a test configuration with S3 storage (MinIO).
|
||||
///
|
||||
/// This requires a running MinIO instance on localhost:9000.
|
||||
/// Uses `MINNE_TEST_S3_ENDPOINT` / `S3_ENDPOINT` and
|
||||
/// `MINNE_TEST_S3_BUCKET` / `S3_BUCKET` when provided.
|
||||
pub fn test_config_s3() -> AppConfig {
|
||||
AppConfig {
|
||||
openai_api_key: "test".into(),
|
||||
@@ -339,8 +367,8 @@ pub mod testing {
|
||||
http_port: 0,
|
||||
openai_base_url: "..".into(),
|
||||
storage: StorageKind::S3,
|
||||
s3_bucket: Some("minne-tests".into()),
|
||||
s3_endpoint: Some("http://localhost:9000".into()),
|
||||
s3_bucket: Some(configured_test_s3_bucket()),
|
||||
s3_endpoint: Some(configured_test_s3_endpoint()),
|
||||
s3_region: Some("us-east-1".into()),
|
||||
pdf_ingest_mode: PdfIngestMode::LlmFirst,
|
||||
..Default::default()
|
||||
@@ -391,8 +419,7 @@ pub mod testing {
|
||||
|
||||
/// Create a new TestStorageManager with S3 backend (MinIO).
|
||||
///
|
||||
/// This requires a running MinIO instance on localhost:9000 with
|
||||
/// default credentials (minioadmin/minioadmin) and a 'minne-tests' bucket.
|
||||
/// This requires a reachable MinIO endpoint and an existing test bucket.
|
||||
pub async fn new_s3() -> object_store::Result<Self> {
|
||||
// Ensure credentials are set for MinIO
|
||||
// We set these env vars for the process, which AmazonS3Builder will pick up
|
||||
@@ -403,6 +430,11 @@ pub mod testing {
|
||||
let cfg = test_config_s3();
|
||||
let storage = StorageManager::new(&cfg).await?;
|
||||
|
||||
// Probe the bucket so tests can cleanly skip when the endpoint is unreachable
|
||||
// or the test bucket is not provisioned.
|
||||
let probe_prefix = format!("__minne_s3_probe__/{}", uuid::Uuid::new_v4());
|
||||
storage.list(Some(&probe_prefix)).await?;
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
_temp_dir: None,
|
||||
@@ -923,8 +955,8 @@ mod tests {
|
||||
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
|
||||
}
|
||||
|
||||
// S3 Tests - Require MinIO on localhost:9000 with bucket 'minne-tests'
|
||||
// These tests will fail if MinIO is not running or bucket doesn't exist.
|
||||
// S3 Tests - Require a reachable MinIO endpoint and test bucket.
|
||||
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_manager_s3_basic_operations() {
|
||||
|
||||
@@ -10,6 +10,54 @@ stored_object!(Conversation, "conversation", {
|
||||
title: String
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
||||
pub struct SidebarConversation {
|
||||
#[serde(deserialize_with = "deserialize_sidebar_id")]
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
}
|
||||
|
||||
struct SidebarIdVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for SidebarIdVisitor {
|
||||
type Value = String;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a string id or a SurrealDB Thing")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(value.to_string())
|
||||
}
|
||||
|
||||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::MapAccess<'de>,
|
||||
{
|
||||
let thing = <surrealdb::sql::Thing as serde::Deserialize>::deserialize(
|
||||
serde::de::value::MapAccessDeserializer::new(map),
|
||||
)?;
|
||||
Ok(thing.id.to_raw())
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_sidebar_id<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_any(SidebarIdVisitor)
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
pub fn new(user_id: String, title: String) -> Self {
|
||||
let now = Utc::now();
|
||||
@@ -75,6 +123,23 @@ impl Conversation {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_user_sidebar_conversations(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<SidebarConversation>, AppError> {
|
||||
let conversations: Vec<SidebarConversation> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT id, title, updated_at FROM type::table($table_name) WHERE user_id = $user_id ORDER BY updated_at DESC",
|
||||
)
|
||||
.bind(("table_name", Self::table_name()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(conversations)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -249,6 +314,96 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_user_sidebar_conversations_filters_and_orders_by_updated_at_desc() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "sidebar_user";
|
||||
let other_user_id = "other_user";
|
||||
let base = Utc::now();
|
||||
|
||||
let mut oldest = Conversation::new(user_id.to_string(), "Oldest".to_string());
|
||||
oldest.updated_at = base - chrono::Duration::minutes(30);
|
||||
|
||||
let mut newest = Conversation::new(user_id.to_string(), "Newest".to_string());
|
||||
newest.updated_at = base - chrono::Duration::minutes(5);
|
||||
|
||||
let mut middle = Conversation::new(user_id.to_string(), "Middle".to_string());
|
||||
middle.updated_at = base - chrono::Duration::minutes(15);
|
||||
|
||||
let mut other_user = Conversation::new(other_user_id.to_string(), "Other".to_string());
|
||||
other_user.updated_at = base;
|
||||
|
||||
db.store_item(oldest.clone())
|
||||
.await
|
||||
.expect("Failed to store oldest conversation");
|
||||
db.store_item(newest.clone())
|
||||
.await
|
||||
.expect("Failed to store newest conversation");
|
||||
db.store_item(middle.clone())
|
||||
.await
|
||||
.expect("Failed to store middle conversation");
|
||||
db.store_item(other_user)
|
||||
.await
|
||||
.expect("Failed to store other-user conversation");
|
||||
|
||||
let sidebar_items = Conversation::get_user_sidebar_conversations(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to get sidebar conversations");
|
||||
|
||||
assert_eq!(sidebar_items.len(), 3);
|
||||
assert_eq!(sidebar_items[0].id, newest.id);
|
||||
assert_eq!(sidebar_items[0].title, "Newest");
|
||||
assert_eq!(sidebar_items[1].id, middle.id);
|
||||
assert_eq!(sidebar_items[1].title, "Middle");
|
||||
assert_eq!(sidebar_items[2].id, oldest.id);
|
||||
assert_eq!(sidebar_items[2].title, "Oldest");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sidebar_projection_reflects_patch_title_and_updated_at_reorder() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let user_id = "sidebar_patch_user";
|
||||
let base = Utc::now();
|
||||
|
||||
let mut first = Conversation::new(user_id.to_string(), "First".to_string());
|
||||
first.updated_at = base - chrono::Duration::minutes(20);
|
||||
|
||||
let mut second = Conversation::new(user_id.to_string(), "Second".to_string());
|
||||
second.updated_at = base - chrono::Duration::minutes(10);
|
||||
|
||||
db.store_item(first.clone())
|
||||
.await
|
||||
.expect("Failed to store first conversation");
|
||||
db.store_item(second.clone())
|
||||
.await
|
||||
.expect("Failed to store second conversation");
|
||||
|
||||
let before_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to get sidebar conversations before patch");
|
||||
assert_eq!(before_patch[0].id, second.id);
|
||||
|
||||
Conversation::patch_title(&first.id, user_id, "First (renamed)", &db)
|
||||
.await
|
||||
.expect("Failed to patch conversation title");
|
||||
|
||||
let after_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
|
||||
.await
|
||||
.expect("Failed to get sidebar conversations after patch");
|
||||
assert_eq!(after_patch[0].id, first.id);
|
||||
assert_eq!(after_patch[0].title, "First (renamed)");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_with_messages() {
|
||||
// Setup in-memory database for testing
|
||||
|
||||
@@ -137,8 +137,12 @@ impl FileInfo {
|
||||
/// # Returns
|
||||
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
|
||||
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
||||
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256);
|
||||
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
|
||||
let mut response = db_client
|
||||
.client
|
||||
.query("SELECT * FROM file WHERE sha256 = $sha256 LIMIT 1")
|
||||
.bind(("sha256", sha256.to_owned()))
|
||||
.await?;
|
||||
let response: Vec<FileInfo> = response.take(0)?;
|
||||
|
||||
response
|
||||
.into_iter()
|
||||
@@ -665,6 +669,36 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id: "user123".to_string(),
|
||||
sha256: "known_sha_value".to_string(),
|
||||
path: "/path/to/file.txt".to_string(),
|
||||
file_name: "file.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
};
|
||||
|
||||
db.store_item(file_info)
|
||||
.await
|
||||
.expect("Failed to store test file info");
|
||||
|
||||
let malicious_sha = "known_sha_value' OR true --";
|
||||
let result = FileInfo::get_by_sha(malicious_sha, &db).await;
|
||||
|
||||
assert!(matches!(result, Err(FileError::FileNotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manual_file_info_creation() {
|
||||
let namespace = "test_ns";
|
||||
|
||||
@@ -174,12 +174,15 @@ impl KnowledgeEntity {
|
||||
// Delete embeddings first, while we can still look them up via the entity's source_id
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
|
||||
|
||||
let query = format!(
|
||||
"DELETE {} WHERE source_id = '{}'",
|
||||
Self::table_name(),
|
||||
source_id
|
||||
);
|
||||
db_client.query(query).await?;
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -761,6 +764,69 @@ mod tests {
|
||||
assert_eq!(different_remaining[0].id, different_entity.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
"safe_source".to_string(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let entity2 = KnowledgeEntity::new(
|
||||
"other_source".to_string(),
|
||||
"Entity 2".to_string(),
|
||||
"Description 2".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id,
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store entity1");
|
||||
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db)
|
||||
.await
|
||||
.expect("store entity2");
|
||||
|
||||
let malicious_source = "safe_source' OR 1=1 --";
|
||||
KnowledgeEntity::delete_by_source_id(malicious_source, &db)
|
||||
.await
|
||||
.expect("delete call should succeed");
|
||||
|
||||
let remaining: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table)")
|
||||
.bind(("table", KnowledgeEntity::table_name()))
|
||||
.await
|
||||
.expect("query failed")
|
||||
.take(0)
|
||||
.expect("take failed");
|
||||
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
2,
|
||||
"malicious input must not delete unrelated entities"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
|
||||
@@ -40,22 +40,28 @@ impl KnowledgeRelationship {
|
||||
}
|
||||
}
|
||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
r#"DELETE relates_to:`{rel_id}`;
|
||||
RELATE knowledge_entity:`{in_id}`->relates_to:`{rel_id}`->knowledge_entity:`{out_id}`
|
||||
SET
|
||||
metadata.user_id = '{user_id}',
|
||||
metadata.source_id = '{source_id}',
|
||||
metadata.relationship_type = '{relationship_type}'"#,
|
||||
rel_id = self.id,
|
||||
in_id = self.in_,
|
||||
out_id = self.out,
|
||||
user_id = self.metadata.user_id.as_str(),
|
||||
source_id = self.metadata.source_id.as_str(),
|
||||
relationship_type = self.metadata.relationship_type.as_str()
|
||||
);
|
||||
|
||||
db_client.query(query).await?.check()?;
|
||||
db_client
|
||||
.client
|
||||
.query(
|
||||
r#"BEGIN TRANSACTION;
|
||||
LET $in_entity = type::thing('knowledge_entity', $in_id);
|
||||
LET $out_entity = type::thing('knowledge_entity', $out_id);
|
||||
LET $relation = type::thing('relates_to', $rel_id);
|
||||
DELETE type::thing('relates_to', $rel_id);
|
||||
RELATE $in_entity->$relation->$out_entity SET
|
||||
metadata.user_id = $user_id,
|
||||
metadata.source_id = $source_id,
|
||||
metadata.relationship_type = $relationship_type;
|
||||
COMMIT TRANSACTION;"#,
|
||||
)
|
||||
.bind(("rel_id", self.id.clone()))
|
||||
.bind(("in_id", self.in_.clone()))
|
||||
.bind(("out_id", self.out.clone()))
|
||||
.bind(("user_id", self.metadata.user_id.clone()))
|
||||
.bind(("source_id", self.metadata.source_id.clone()))
|
||||
.bind(("relationship_type", self.metadata.relationship_type.clone()))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -64,11 +70,12 @@ impl KnowledgeRelationship {
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'"
|
||||
);
|
||||
|
||||
db_client.query(query).await?;
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -79,15 +86,20 @@ impl KnowledgeRelationship {
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let mut authorized_result = db_client
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'"
|
||||
))
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id",
|
||||
)
|
||||
.bind(("id", id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
|
||||
|
||||
if authorized.is_empty() {
|
||||
let mut exists_result = db_client
|
||||
.query(format!("SELECT * FROM relates_to:`{id}`"))
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('relates_to', $id)")
|
||||
.bind(("id", id.to_owned()))
|
||||
.await?;
|
||||
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
|
||||
|
||||
@@ -99,7 +111,12 @@ impl KnowledgeRelationship {
|
||||
Err(AppError::NotFound(format!("Relationship {id} not found")))
|
||||
}
|
||||
} else {
|
||||
db_client.query(format!("DELETE relates_to:`{id}`")).await?;
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE type::thing('relates_to', $id)")
|
||||
.bind(("id", id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -110,6 +127,34 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
async fn get_relationship_by_id(
|
||||
relationship_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Option<KnowledgeRelationship> {
|
||||
let mut result = db_client
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('relates_to', $id)")
|
||||
.bind(("id", relationship_id.to_owned()))
|
||||
.await
|
||||
.expect("relationship query by id failed");
|
||||
|
||||
result.take(0).expect("failed to take relationship by id")
|
||||
}
|
||||
|
||||
// Helper function to create a test knowledge entity for the relationship tests
|
||||
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
|
||||
let source_id = "source123".to_string();
|
||||
@@ -161,15 +206,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_store_and_verify_by_source_id() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
@@ -194,30 +231,69 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
let persisted = get_relationship_by_id(&relationship.id, &db)
|
||||
.await
|
||||
.expect("Relationship should be retrievable by id");
|
||||
assert_eq!(persisted.in_, entity1_id);
|
||||
assert_eq!(persisted.out, entity2_id);
|
||||
assert_eq!(persisted.metadata.user_id, user_id);
|
||||
assert_eq!(persisted.metadata.source_id, source_id);
|
||||
|
||||
// Query to verify the relationship exists by checking for relationships with our source_id
|
||||
// This approach is more reliable than trying to look up by ID
|
||||
let check_query = format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
|
||||
source_id
|
||||
);
|
||||
let mut check_result = db.query(check_query).await.expect("Check query failed");
|
||||
let mut check_result = db
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
.await
|
||||
.expect("Check query failed");
|
||||
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
||||
|
||||
// Just verify that a relationship was created
|
||||
assert!(
|
||||
!check_results.is_empty(),
|
||||
"Relationship should exist in the database"
|
||||
assert_eq!(
|
||||
check_results.len(),
|
||||
1,
|
||||
"Expected one relationship for source_id"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_resists_query_injection() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id,
|
||||
entity2_id,
|
||||
"user'123".to_string(),
|
||||
"source123'; DELETE FROM relates_to; --".to_string(),
|
||||
"references'; UPDATE user SET admin = true; --".to_string(),
|
||||
);
|
||||
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("store relationship should safely handle quote-containing values");
|
||||
|
||||
let mut res = db
|
||||
.client
|
||||
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
|
||||
.bind(("id", relationship.id.clone()))
|
||||
.await
|
||||
.expect("query relationship by id failed");
|
||||
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(
|
||||
rows[0].metadata.source_id,
|
||||
"source123'; DELETE FROM relates_to; --"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_delete_relationship() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create two entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
@@ -278,11 +354,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationship_by_id_unauthorized() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
@@ -346,11 +418,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_exists() {
|
||||
// Setup in-memory database for testing
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
let db = setup_test_db().await;
|
||||
|
||||
// Create entities to relate
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
@@ -402,49 +470,87 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to store different relationship");
|
||||
|
||||
// Sanity-check setup: exactly two relationships use source_id and one uses different_source_id.
|
||||
let mut before_delete = db
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", source_id.clone()))
|
||||
.await
|
||||
.expect("before delete query failed");
|
||||
let before_delete_rows: Vec<KnowledgeRelationship> =
|
||||
before_delete.take(0).unwrap_or_default();
|
||||
assert_eq!(before_delete_rows.len(), 2);
|
||||
|
||||
let mut before_delete_different = db
|
||||
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", different_source_id.clone()))
|
||||
.await
|
||||
.expect("before delete different query failed");
|
||||
let before_delete_different_rows: Vec<KnowledgeRelationship> =
|
||||
before_delete_different.take(0).unwrap_or_default();
|
||||
assert_eq!(before_delete_different_rows.len(), 1);
|
||||
|
||||
// Delete relationships by source_id
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
|
||||
.await
|
||||
.expect("Failed to delete relationships by source_id");
|
||||
|
||||
// Query to verify the relationships with source_id were deleted
|
||||
let query1 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship1.id);
|
||||
let query2 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship2.id);
|
||||
let different_query = format!(
|
||||
"SELECT * FROM relates_to WHERE id = '{}'",
|
||||
different_relationship.id
|
||||
);
|
||||
|
||||
let mut result1 = db.query(query1).await.expect("Query 1 failed");
|
||||
let results1: Vec<KnowledgeRelationship> = result1.take(0).unwrap_or_default();
|
||||
|
||||
let mut result2 = db.query(query2).await.expect("Query 2 failed");
|
||||
let results2: Vec<KnowledgeRelationship> = result2.take(0).unwrap_or_default();
|
||||
|
||||
let mut different_result = db
|
||||
.query(different_query)
|
||||
.await
|
||||
.expect("Different query failed");
|
||||
let _different_results: Vec<KnowledgeRelationship> =
|
||||
different_result.take(0).unwrap_or_default();
|
||||
// Query to verify the specific relationships with source_id were deleted.
|
||||
let result1 = get_relationship_by_id(&relationship1.id, &db).await;
|
||||
let result2 = get_relationship_by_id(&relationship2.id, &db).await;
|
||||
let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
|
||||
|
||||
// Verify relationships with the source_id are deleted
|
||||
assert!(results1.is_empty(), "Relationship 1 should be deleted");
|
||||
assert!(results2.is_empty(), "Relationship 2 should be deleted");
|
||||
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
||||
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
||||
let remaining =
|
||||
different_result.expect("Relationship with different source_id should remain");
|
||||
assert_eq!(remaining.metadata.source_id, different_source_id);
|
||||
}
|
||||
|
||||
// For the relationship with different source ID, we need to check differently
|
||||
// Let's just verify we have a relationship where the source_id matches different_source_id
|
||||
let check_query = format!(
|
||||
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
|
||||
different_source_id
|
||||
#[tokio::test]
|
||||
async fn test_delete_relationships_by_source_id_resists_query_injection() {
|
||||
let db = setup_test_db().await;
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
let entity3_id = create_test_entity("Entity 3", &db).await;
|
||||
|
||||
let safe_relationship = KnowledgeRelationship::new(
|
||||
entity1_id.clone(),
|
||||
entity2_id.clone(),
|
||||
"user123".to_string(),
|
||||
"safe_source".to_string(),
|
||||
"references".to_string(),
|
||||
);
|
||||
let mut check_result = db.query(check_query).await.expect("Check query failed");
|
||||
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify the relationship with a different source_id still exists
|
||||
let other_relationship = KnowledgeRelationship::new(
|
||||
entity2_id,
|
||||
entity3_id,
|
||||
"user123".to_string(),
|
||||
"other_source".to_string(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
|
||||
safe_relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("store safe relationship");
|
||||
other_relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("store other relationship");
|
||||
|
||||
KnowledgeRelationship::delete_relationships_by_source_id("safe_source' OR 1=1 --", &db)
|
||||
.await
|
||||
.expect("delete call should succeed");
|
||||
|
||||
let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await;
|
||||
let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await;
|
||||
|
||||
assert!(remaining_safe.is_some(), "Safe relationship should remain");
|
||||
assert!(
|
||||
!check_results.is_empty(),
|
||||
"Relationship with different source_id should still exist"
|
||||
remaining_other.is_some(),
|
||||
"Other relationship should remain"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,12 +47,15 @@ impl TextChunk {
|
||||
// Delete embeddings first
|
||||
TextChunkEmbedding::delete_by_source_id(source_id, db_client).await?;
|
||||
|
||||
let query = format!(
|
||||
"DELETE {} WHERE source_id = '{}'",
|
||||
Self::table_name(),
|
||||
source_id
|
||||
);
|
||||
db_client.query(query).await?;
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -617,6 +620,57 @@ mod tests {
|
||||
assert_eq!(remaining.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
let chunk1 = TextChunk::new(
|
||||
"safe_source".to_string(),
|
||||
"Safe chunk".to_string(),
|
||||
"user123".to_string(),
|
||||
);
|
||||
let chunk2 = TextChunk::new(
|
||||
"other_source".to_string(),
|
||||
"Other chunk".to_string(),
|
||||
"user123".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
|
||||
let malicious_source = "safe_source' OR 1=1 --";
|
||||
TextChunk::delete_by_source_id(malicious_source, &db)
|
||||
.await
|
||||
.expect("delete call should succeed");
|
||||
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table)")
|
||||
.bind(("table", TextChunk::table_name()))
|
||||
.await
|
||||
.expect("query failed")
|
||||
.take(0)
|
||||
.expect("take failed");
|
||||
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
2,
|
||||
"malicious input must not delete unrelated rows"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_both_records() {
|
||||
let namespace = "test_ns";
|
||||
|
||||
@@ -86,6 +86,16 @@ pub struct AppConfig {
|
||||
pub retrieval_strategy: Option<String>,
|
||||
#[serde(default)]
|
||||
pub embedding_backend: EmbeddingBackend,
|
||||
#[serde(default = "default_ingest_max_body_bytes")]
|
||||
pub ingest_max_body_bytes: usize,
|
||||
#[serde(default = "default_ingest_max_files")]
|
||||
pub ingest_max_files: usize,
|
||||
#[serde(default = "default_ingest_max_content_bytes")]
|
||||
pub ingest_max_content_bytes: usize,
|
||||
#[serde(default = "default_ingest_max_context_bytes")]
|
||||
pub ingest_max_context_bytes: usize,
|
||||
#[serde(default = "default_ingest_max_category_bytes")]
|
||||
pub ingest_max_category_bytes: usize,
|
||||
}
|
||||
|
||||
/// Default data directory for persisted assets.
|
||||
@@ -103,6 +113,26 @@ fn default_reranking_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_ingest_max_body_bytes() -> usize {
|
||||
20_000_000
|
||||
}
|
||||
|
||||
fn default_ingest_max_files() -> usize {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_ingest_max_content_bytes() -> usize {
|
||||
262_144
|
||||
}
|
||||
|
||||
fn default_ingest_max_context_bytes() -> usize {
|
||||
16_384
|
||||
}
|
||||
|
||||
fn default_ingest_max_category_bytes() -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
pub fn ensure_ort_path() {
|
||||
if env::var_os("ORT_DYLIB_PATH").is_some() {
|
||||
return;
|
||||
@@ -157,6 +187,11 @@ impl Default for AppConfig {
|
||||
fastembed_max_length: None,
|
||||
retrieval_strategy: None,
|
||||
embedding_backend: EmbeddingBackend::default(),
|
||||
ingest_max_body_bytes: default_ingest_max_body_bytes(),
|
||||
ingest_max_files: default_ingest_max_files(),
|
||||
ingest_max_content_bytes: default_ingest_max_content_bytes(),
|
||||
ingest_max_context_bytes: default_ingest_max_context_bytes(),
|
||||
ingest_max_category_bytes: default_ingest_max_category_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
113
common/src/utils/ingest_limits.rs
Normal file
113
common/src/utils/ingest_limits.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use super::config::AppConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum IngestValidationError {
|
||||
PayloadTooLarge(String),
|
||||
BadRequest(String),
|
||||
}
|
||||
|
||||
pub fn validate_ingest_input(
|
||||
config: &AppConfig,
|
||||
content: Option<&str>,
|
||||
context: &str,
|
||||
category: &str,
|
||||
file_count: usize,
|
||||
) -> Result<(), IngestValidationError> {
|
||||
if file_count > config.ingest_max_files {
|
||||
return Err(IngestValidationError::BadRequest(format!(
|
||||
"Too many files. Maximum allowed is {}",
|
||||
config.ingest_max_files
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(content) = content {
|
||||
if content.len() > config.ingest_max_content_bytes {
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"Content is too large. Maximum allowed is {} bytes",
|
||||
config.ingest_max_content_bytes
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if context.len() > config.ingest_max_context_bytes {
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"Context is too large. Maximum allowed is {} bytes",
|
||||
config.ingest_max_context_bytes
|
||||
)));
|
||||
}
|
||||
|
||||
if category.len() > config.ingest_max_category_bytes {
|
||||
return Err(IngestValidationError::PayloadTooLarge(format!(
|
||||
"Category is too large. Maximum allowed is {} bytes",
|
||||
config.ingest_max_category_bytes
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn validate_ingest_input_rejects_too_many_files() {
|
||||
let config = AppConfig {
|
||||
ingest_max_files: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let result = validate_ingest_input(&config, Some("ok"), "ctx", "cat", 2);
|
||||
|
||||
assert!(matches!(result, Err(IngestValidationError::BadRequest(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_ingest_input_rejects_oversized_content() {
|
||||
let config = AppConfig {
|
||||
ingest_max_content_bytes: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let result = validate_ingest_input(&config, Some("12345"), "ctx", "cat", 0);
|
||||
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(IngestValidationError::PayloadTooLarge(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_ingest_input_rejects_oversized_context() {
|
||||
let config = AppConfig {
|
||||
ingest_max_context_bytes: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let result = validate_ingest_input(&config, None, "long", "cat", 0);
|
||||
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(IngestValidationError::PayloadTooLarge(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_ingest_input_rejects_oversized_category() {
|
||||
let config = AppConfig {
|
||||
ingest_max_category_bytes: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let result = validate_ingest_input(&config, None, "ok", "long", 0);
|
||||
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(IngestValidationError::PayloadTooLarge(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_ingest_input_accepts_valid_payload() {
|
||||
let config = AppConfig::default();
|
||||
let result = validate_ingest_input(&config, Some("ok"), "ctx", "cat", 1);
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod config;
|
||||
pub mod embedding;
|
||||
pub mod ingest_limits;
|
||||
pub mod template_engine;
|
||||
|
||||
40
devenv.lock
40
devenv.lock
@@ -3,10 +3,10 @@
|
||||
"devenv": {
|
||||
"locked": {
|
||||
"dir": "src/modules",
|
||||
"lastModified": 1761839147,
|
||||
"lastModified": 1771066302,
|
||||
"owner": "cachix",
|
||||
"repo": "devenv",
|
||||
"rev": "bb7849648b68035f6b910120252c22b28195cf54",
|
||||
"rev": "1b355dec9bddbaddbe4966d6fc30d7aa3af8575b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -22,10 +22,10 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761893049,
|
||||
"lastModified": 1771052630,
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"rev": "d0555da98576b8611c25df0c208e51e9a182d95f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -37,14 +37,14 @@
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761588595,
|
||||
"owner": "edolstra",
|
||||
"lastModified": 1767039857,
|
||||
"owner": "NixOS",
|
||||
"repo": "flake-compat",
|
||||
"rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5",
|
||||
"rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"owner": "NixOS",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
@@ -58,10 +58,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1760663237,
|
||||
"lastModified": 1770726378,
|
||||
"owner": "cachix",
|
||||
"repo": "git-hooks.nix",
|
||||
"rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37",
|
||||
"rev": "5eaaedde414f6eb1aea8b8525c466dc37bba95ae",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -78,10 +78,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1709087332,
|
||||
"lastModified": 1762808025,
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
|
||||
"rev": "cb5e3fdca1de58ccbc3ef53de65bd372b48f567c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -92,10 +92,10 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"lastModified": 1771008912,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"rev": "a82ccc39b39b621151d6732718e3e250109076fa",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -107,10 +107,10 @@
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1761880412,
|
||||
"lastModified": 1770843696,
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "a7fc11be66bdfb5cdde611ee5ce381c183da8386",
|
||||
"rev": "2343bbb58f99267223bc2aac4fc9ea301a155a16",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -135,10 +135,10 @@
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761849405,
|
||||
"lastModified": 1771007332,
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"rev": "bbc84d335fbbd9b3099d3e40c7469ee57dbd1873",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -155,10 +155,10 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761878277,
|
||||
"lastModified": 1771038269,
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "6604534e44090c917db714faa58d47861657690c",
|
||||
"rev": "d7a86c8a4df49002446737603a3e0d7ef91a9637",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
17
devenv.nix
17
devenv.nix
@@ -10,6 +10,7 @@
|
||||
packages = [
|
||||
pkgs.openssl
|
||||
pkgs.nodejs
|
||||
pkgs.watchman
|
||||
pkgs.vscode-langservers-extracted
|
||||
pkgs.cargo-dist
|
||||
pkgs.cargo-xwin
|
||||
@@ -29,11 +30,25 @@
|
||||
|
||||
env = {
|
||||
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
|
||||
S3_ENDPOINT = "http://127.0.0.1:19000";
|
||||
S3_BUCKET = "minne-tests";
|
||||
MINNE_TEST_S3_ENDPOINT = "http://127.0.0.1:19000";
|
||||
MINNE_TEST_S3_BUCKET = "minne-tests";
|
||||
};
|
||||
|
||||
services.minio = {
|
||||
enable = true;
|
||||
listenAddress = "127.0.0.1:19000";
|
||||
consoleAddress = "127.0.0.1:19001";
|
||||
buckets = ["minne-tests"];
|
||||
accessKey = "minioadmin";
|
||||
secretKey = "minioadmin";
|
||||
region = "us-east-1";
|
||||
};
|
||||
|
||||
processes = {
|
||||
surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --net=host --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password";
|
||||
server.exec = "cargo watch -x 'run --bin main'";
|
||||
tailwind.exec = "cd html-router && npm run tailwind";
|
||||
tailwind.exec = "tailwindcss --cwd html-router -i app.css -o assets/style.css --watch=always";
|
||||
};
|
||||
}
|
||||
|
||||
@@ -29,6 +29,11 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
|
||||
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
|
||||
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
|
||||
| `FASTEMBED_MAX_LENGTH` | Max sequence length for FastEmbed models | - |
|
||||
| `INGEST_MAX_BODY_BYTES` | Max request body size for ingest endpoints | `20000000` |
|
||||
| `INGEST_MAX_FILES` | Max files allowed per ingest request | `5` |
|
||||
| `INGEST_MAX_CONTENT_BYTES` | Max `content` field size for ingest requests | `262144` |
|
||||
| `INGEST_MAX_CONTEXT_BYTES` | Max `context` field size for ingest requests | `16384` |
|
||||
| `INGEST_MAX_CATEGORY_BYTES` | Max `category` field size for ingest requests | `128` |
|
||||
|
||||
### S3 Storage (Optional)
|
||||
|
||||
@@ -76,6 +81,13 @@ embedding_backend: "fastembed"
|
||||
# Optional reranking
|
||||
reranking_enabled: true
|
||||
reranking_pool_size: 2
|
||||
|
||||
# Ingest safety limits
|
||||
ingest_max_body_bytes: 20000000
|
||||
ingest_max_files: 5
|
||||
ingest_max_content_bytes: 262144
|
||||
ingest_max_context_bytes: 16384
|
||||
ingest_max_category_bytes: 128
|
||||
```
|
||||
|
||||
## AI Provider Setup
|
||||
|
||||
@@ -33,3 +33,4 @@ clap = { version = "4.4", features = ["derive", "env"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
@@ -9,6 +9,8 @@ use std::{
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_openai::Client;
|
||||
use chrono::Utc;
|
||||
#[cfg(not(test))]
|
||||
use common::utils::config::get_config;
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
@@ -421,11 +423,7 @@ async fn ingest_paragraph_batch(
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let namespace = format!("ingest_eval_{}", Uuid::new_v4());
|
||||
let db = Arc::new(
|
||||
SurrealDbClient::memory(&namespace, "corpus")
|
||||
.await
|
||||
.context("creating in-memory surrealdb for ingestion")?,
|
||||
);
|
||||
let db = create_ingest_db(&namespace).await?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.context("applying migrations for ingestion")?;
|
||||
@@ -487,6 +485,29 @@ async fn ingest_paragraph_batch(
|
||||
Ok(shards)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||
let db = SurrealDbClient::memory(namespace, "corpus")
|
||||
.await
|
||||
.context("creating in-memory surrealdb for ingestion")?;
|
||||
Ok(Arc::new(db))
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
|
||||
let config = get_config().context("loading app config for ingestion database")?;
|
||||
let db = SurrealDbClient::new(
|
||||
&config.surrealdb_address,
|
||||
&config.surrealdb_username,
|
||||
&config.surrealdb_password,
|
||||
namespace,
|
||||
"corpus",
|
||||
)
|
||||
.await
|
||||
.context("creating surrealdb database for ingestion")?;
|
||||
Ok(Arc::new(db))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn ingest_single_paragraph(
|
||||
pipeline: Arc<IngestionPipeline>,
|
||||
|
||||
6
flake.lock
generated
6
flake.lock
generated
@@ -35,11 +35,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
"lastModified": 1771008912,
|
||||
"narHash": "sha256-gf2AmWVTs8lEq7z/3ZAsgnZDhWIckkb+ZnAo5RzSxJg=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"rev": "a82ccc39b39b621151d6732718e3e250109076fa",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -41,5 +41,8 @@ common = { path = "../common" }
|
||||
retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||
json-stream-parser = { path = "../json-stream-parser" }
|
||||
|
||||
[dev-dependencies]
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
[build-dependencies]
|
||||
minijinja-embed = { version = "2.8.0" }
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -2,7 +2,7 @@
|
||||
"name": "html-router",
|
||||
"version": "1.0.0",
|
||||
"scripts": {
|
||||
"tailwind": "npx @tailwindcss/cli -i app.css -o assets/style.css -w -m"
|
||||
"tailwind": "tailwindcss -i app.css -o assets/style.css --watch=always"
|
||||
},
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
use common::storage::types::conversation::SidebarConversation;
|
||||
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
||||
use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy};
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{OpenAIClientType, SessionStoreType};
|
||||
@@ -18,8 +25,20 @@ pub struct HtmlState {
|
||||
pub storage: StorageManager,
|
||||
pub reranker_pool: Option<Arc<RerankerPool>>,
|
||||
pub embedding_provider: Arc<EmbeddingProvider>,
|
||||
conversation_archive_cache: Arc<RwLock<HashMap<String, ConversationArchiveCacheEntry>>>,
|
||||
conversation_archive_cache_writes: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConversationArchiveCacheEntry {
|
||||
conversations: Vec<SidebarConversation>,
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
const CONVERSATION_ARCHIVE_CACHE_TTL: Duration = Duration::from_secs(30);
|
||||
const CONVERSATION_ARCHIVE_CACHE_MAX_USERS: usize = 1024;
|
||||
const CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL: usize = 64;
|
||||
|
||||
impl HtmlState {
|
||||
pub async fn new_with_resources(
|
||||
db: Arc<SurrealDbClient>,
|
||||
@@ -44,6 +63,8 @@ impl HtmlState {
|
||||
storage,
|
||||
reranker_pool,
|
||||
embedding_provider,
|
||||
conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
conversation_archive_cache_writes: Arc::new(AtomicUsize::new(0)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -54,6 +75,86 @@ impl HtmlState {
|
||||
.and_then(|value| value.parse().ok())
|
||||
.unwrap_or(RetrievalStrategy::Default)
|
||||
}
|
||||
|
||||
pub async fn get_cached_conversation_archive(
|
||||
&self,
|
||||
user_id: &str,
|
||||
) -> Option<Vec<SidebarConversation>> {
|
||||
let now = Instant::now();
|
||||
let should_evict_expired = {
|
||||
let cache = self.conversation_archive_cache.read().await;
|
||||
if let Some(entry) = cache.get(user_id) {
|
||||
if entry.expires_at > now {
|
||||
return Some(entry.conversations.clone());
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if should_evict_expired {
|
||||
let mut cache = self.conversation_archive_cache.write().await;
|
||||
cache.remove(user_id);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub async fn set_cached_conversation_archive(
|
||||
&self,
|
||||
user_id: &str,
|
||||
conversations: Vec<SidebarConversation>,
|
||||
) {
|
||||
let now = Instant::now();
|
||||
let mut cache = self.conversation_archive_cache.write().await;
|
||||
cache.insert(
|
||||
user_id.to_string(),
|
||||
ConversationArchiveCacheEntry {
|
||||
conversations,
|
||||
expires_at: now + CONVERSATION_ARCHIVE_CACHE_TTL,
|
||||
},
|
||||
);
|
||||
|
||||
let writes = self
|
||||
.conversation_archive_cache_writes
|
||||
.fetch_add(1, Ordering::Relaxed)
|
||||
+ 1;
|
||||
if writes % CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL == 0 {
|
||||
Self::purge_expired_entries(&mut cache, now);
|
||||
}
|
||||
|
||||
Self::enforce_cache_capacity(&mut cache);
|
||||
}
|
||||
|
||||
pub async fn invalidate_conversation_archive_cache(&self, user_id: &str) {
|
||||
let mut cache = self.conversation_archive_cache.write().await;
|
||||
cache.remove(user_id);
|
||||
}
|
||||
|
||||
fn purge_expired_entries(
|
||||
cache: &mut HashMap<String, ConversationArchiveCacheEntry>,
|
||||
now: Instant,
|
||||
) {
|
||||
cache.retain(|_, entry| entry.expires_at > now);
|
||||
}
|
||||
|
||||
fn enforce_cache_capacity(cache: &mut HashMap<String, ConversationArchiveCacheEntry>) {
|
||||
if cache.len() <= CONVERSATION_ARCHIVE_CACHE_MAX_USERS {
|
||||
return;
|
||||
}
|
||||
|
||||
let overflow = cache.len() - CONVERSATION_ARCHIVE_CACHE_MAX_USERS;
|
||||
let mut by_expiry: Vec<(String, Instant)> = cache
|
||||
.iter()
|
||||
.map(|(user_id, entry)| (user_id.clone(), entry.expires_at))
|
||||
.collect();
|
||||
by_expiry.sort_by_key(|(_, expires_at)| *expires_at);
|
||||
|
||||
for (user_id, _) in by_expiry.into_iter().take(overflow) {
|
||||
cache.remove(&user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ProvidesDb for HtmlState {
|
||||
fn db(&self) -> &Arc<SurrealDbClient> {
|
||||
@@ -71,3 +172,87 @@ impl crate::middlewares::response_middleware::ProvidesHtmlState for HtmlState {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::{
|
||||
storage::types::conversation::SidebarConversation,
|
||||
utils::{
|
||||
config::{AppConfig, StorageKind},
|
||||
embedding::EmbeddingProvider,
|
||||
},
|
||||
};
|
||||
|
||||
async fn test_state() -> HtmlState {
|
||||
let namespace = "test_ns";
|
||||
let database = &uuid::Uuid::new_v4().to_string();
|
||||
let db = Arc::new(
|
||||
SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to create in-memory DB"),
|
||||
);
|
||||
|
||||
let session_store = Arc::new(
|
||||
db.create_session_store()
|
||||
.await
|
||||
.expect("Failed to create session store"),
|
||||
);
|
||||
|
||||
let mut config = AppConfig::default();
|
||||
config.storage = StorageKind::Memory;
|
||||
|
||||
let storage = StorageManager::new(&config)
|
||||
.await
|
||||
.expect("Failed to create storage manager");
|
||||
|
||||
let embedding_provider = Arc::new(
|
||||
EmbeddingProvider::new_hashed(8).expect("Failed to create embedding provider"),
|
||||
);
|
||||
|
||||
HtmlState::new_with_resources(
|
||||
db,
|
||||
Arc::new(async_openai::Client::new()),
|
||||
session_store,
|
||||
storage,
|
||||
config,
|
||||
None,
|
||||
embedding_provider,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create HtmlState")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_expired_conversation_archive_entry_is_evicted_on_read() {
|
||||
let state = test_state().await;
|
||||
let user_id = "expired-user";
|
||||
|
||||
{
|
||||
let mut cache = state.conversation_archive_cache.write().await;
|
||||
cache.insert(
|
||||
user_id.to_string(),
|
||||
ConversationArchiveCacheEntry {
|
||||
conversations: vec![SidebarConversation {
|
||||
id: "conv-1".to_string(),
|
||||
title: "A stale chat".to_string(),
|
||||
}],
|
||||
expires_at: Instant::now() - Duration::from_secs(1),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let cached = state.get_cached_conversation_archive(user_id).await;
|
||||
assert!(
|
||||
cached.is_none(),
|
||||
"Expired cache entry should not be returned"
|
||||
);
|
||||
|
||||
let cache = state.conversation_archive_cache.read().await;
|
||||
assert!(
|
||||
!cache.contains_key(user_id),
|
||||
"Expired cache entry should be evicted after read"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,9 @@ where
|
||||
.add_protected_routes(routes::chat::router())
|
||||
.add_protected_routes(routes::content::router())
|
||||
.add_protected_routes(routes::knowledge::router())
|
||||
.add_protected_routes(routes::ingestion::router())
|
||||
.add_protected_routes(routes::ingestion::router(
|
||||
app_state.config.ingest_max_body_bytes,
|
||||
))
|
||||
.add_protected_routes(routes::scratchpad::router())
|
||||
.with_compression()
|
||||
.build()
|
||||
|
||||
@@ -46,3 +46,14 @@ pub async fn require_auth(auth: AuthSessionType, mut request: Request, next: Nex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn require_admin(auth: AuthSessionType, mut request: Request, next: Next) -> Response {
|
||||
match auth.current_user {
|
||||
Some(user) if user.admin => {
|
||||
request.extensions_mut().insert(user);
|
||||
next.run(request).await
|
||||
}
|
||||
Some(_) => TemplateResponse::redirect("/").into_response(),
|
||||
None => TemplateResponse::redirect("/signin").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ use tracing::error;
|
||||
|
||||
use crate::{html_state::HtmlState, AuthSessionType};
|
||||
use common::storage::types::{
|
||||
conversation::Conversation,
|
||||
conversation::{Conversation, SidebarConversation},
|
||||
user::{Theme, User},
|
||||
};
|
||||
|
||||
@@ -27,7 +27,7 @@ pub trait ProvidesHtmlState {
|
||||
fn html_state(&self) -> &HtmlState;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum TemplateKind {
|
||||
Full(String),
|
||||
Partial(String, String),
|
||||
@@ -114,13 +114,34 @@ impl IntoResponse for TemplateResponse {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TemplateUser {
|
||||
id: String,
|
||||
email: String,
|
||||
admin: bool,
|
||||
timezone: String,
|
||||
theme: String,
|
||||
}
|
||||
|
||||
impl From<&User> for TemplateUser {
|
||||
fn from(user: &User) -> Self {
|
||||
Self {
|
||||
id: user.id.clone(),
|
||||
email: user.email.clone(),
|
||||
admin: user.admin,
|
||||
timezone: user.timezone.clone(),
|
||||
theme: user.theme.as_str().to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ContextWrapper<'a> {
|
||||
user_theme: &'a str,
|
||||
initial_theme: &'a str,
|
||||
is_authenticated: bool,
|
||||
user: Option<&'a User>,
|
||||
conversation_archive: Vec<Conversation>,
|
||||
user: Option<&'a TemplateUser>,
|
||||
conversation_archive: Vec<SidebarConversation>,
|
||||
#[serde(flatten)]
|
||||
context: HashMap<String, Value>,
|
||||
}
|
||||
@@ -138,6 +159,7 @@ where
|
||||
let mut initial_theme = Theme::System.initial_theme();
|
||||
let mut is_authenticated = false;
|
||||
let mut current_user_id = None;
|
||||
let mut current_user = None;
|
||||
|
||||
{
|
||||
if let Some(auth) = req.extensions().get::<AuthSessionType>() {
|
||||
@@ -146,6 +168,7 @@ where
|
||||
current_user_id = Some(user.id.clone());
|
||||
user_theme = user.theme.as_str();
|
||||
initial_theme = user.theme.initial_theme();
|
||||
current_user = Some(TemplateUser::from(user));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,17 +181,48 @@ where
|
||||
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
|
||||
let template_engine = state.template_engine();
|
||||
|
||||
let mut current_user = None;
|
||||
let mut conversation_archive = Vec::new();
|
||||
|
||||
if let Some(user_id) = current_user_id {
|
||||
let html_state = state.html_state();
|
||||
if let Ok(Some(user)) = html_state.db.get_item::<User>(&user_id).await {
|
||||
// Fetch conversation archive globally for authenticated users
|
||||
if let Ok(archive) = User::get_user_conversations(&user.id, &html_state.db).await {
|
||||
let should_load_conversation_archive =
|
||||
matches!(&template_response.template_kind, TemplateKind::Full(_));
|
||||
|
||||
if should_load_conversation_archive {
|
||||
if let Some(user_id) = current_user_id {
|
||||
let html_state = state.html_state();
|
||||
if let Some(cached_archive) =
|
||||
html_state.get_cached_conversation_archive(&user_id).await
|
||||
{
|
||||
conversation_archive = cached_archive;
|
||||
} else if let Ok(archive) =
|
||||
Conversation::get_user_sidebar_conversations(&user_id, &html_state.db).await
|
||||
{
|
||||
html_state
|
||||
.set_cached_conversation_archive(&user_id, archive.clone())
|
||||
.await;
|
||||
conversation_archive = archive;
|
||||
}
|
||||
current_user = Some(user);
|
||||
}
|
||||
}
|
||||
|
||||
fn context_to_map(
|
||||
value: &Value,
|
||||
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
|
||||
match value.kind() {
|
||||
minijinja::value::ValueKind::Map => {
|
||||
let mut map = HashMap::new();
|
||||
if let Ok(keys) = value.try_iter() {
|
||||
for key in keys {
|
||||
if let Ok(val) = value.get_item(&key) {
|
||||
map.insert(key.to_string(), val);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
minijinja::value::ValueKind::None | minijinja::value::ValueKind::Undefined => {
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
other => Err(other),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,19 +237,15 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// Convert minijinja::Value to HashMap if it's a map, otherwise use empty HashMap
|
||||
let context_map = if template_response.context.kind() == minijinja::value::ValueKind::Map {
|
||||
let mut map = HashMap::new();
|
||||
if let Ok(keys) = template_response.context.try_iter() {
|
||||
for key in keys {
|
||||
if let Ok(val) = template_response.context.get_item(&key) {
|
||||
map.insert(key.to_string(), val);
|
||||
}
|
||||
}
|
||||
let context_map = match context_to_map(&template_response.context) {
|
||||
Ok(map) => map,
|
||||
Err(kind) => {
|
||||
error!(
|
||||
"Template context must be a map or unit, got kind={:?} for template_kind={:?}",
|
||||
kind, template_response.template_kind
|
||||
);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, Html(fallback_error())).into_response();
|
||||
}
|
||||
map
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
let context = ContextWrapper {
|
||||
|
||||
@@ -17,10 +17,16 @@ use crate::html_state::HtmlState;
|
||||
pub struct AccountPageData {
|
||||
timezones: Vec<String>,
|
||||
theme_options: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
api_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
selected_timezone: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
selected_theme: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn show_account_page(
|
||||
RequireUser(_user): RequireUser,
|
||||
RequireUser(user): RequireUser,
|
||||
State(_state): State<HtmlState>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let timezones = TZ_VARIANTS
|
||||
@@ -40,6 +46,9 @@ pub async fn show_account_page(
|
||||
AccountPageData {
|
||||
timezones,
|
||||
theme_options,
|
||||
api_key: user.api_key,
|
||||
selected_timezone: None,
|
||||
selected_theme: None,
|
||||
},
|
||||
))
|
||||
}
|
||||
@@ -50,7 +59,7 @@ pub async fn set_api_key(
|
||||
auth: AuthSessionType,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Generate and set the API key
|
||||
User::set_api_key(&user.id, &state.db).await?;
|
||||
let api_key = User::set_api_key(&user.id, &state.db).await?;
|
||||
|
||||
// Clear the cache so new requests have access to the user with api key
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
@@ -62,6 +71,9 @@ pub async fn set_api_key(
|
||||
AccountPageData {
|
||||
timezones: vec![],
|
||||
theme_options: vec![],
|
||||
api_key: Some(api_key),
|
||||
selected_timezone: None,
|
||||
selected_theme: None,
|
||||
},
|
||||
))
|
||||
}
|
||||
@@ -108,6 +120,9 @@ pub async fn update_timezone(
|
||||
AccountPageData {
|
||||
timezones,
|
||||
theme_options: vec![],
|
||||
api_key: None,
|
||||
selected_timezone: Some(form.timezone),
|
||||
selected_theme: None,
|
||||
},
|
||||
))
|
||||
}
|
||||
@@ -142,6 +157,9 @@ pub async fn update_theme(
|
||||
AccountPageData {
|
||||
timezones: vec![],
|
||||
theme_options,
|
||||
api_key: None,
|
||||
selected_timezone: None,
|
||||
selected_theme: Some(form.theme),
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
@@ -23,10 +23,7 @@ use tracing::{error, info};
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
middlewares::{
|
||||
auth_middleware::RequireUser,
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
middlewares::response_middleware::{HtmlError, TemplateResponse},
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -60,7 +57,6 @@ pub struct AdminPanelQuery {
|
||||
|
||||
pub async fn show_admin_panel(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(_user): RequireUser,
|
||||
Query(query): Query<AdminPanelQuery>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let section = match query.section.as_deref() {
|
||||
@@ -131,14 +127,8 @@ pub struct RegistrationToggleData {
|
||||
|
||||
pub async fn toggle_registration_status(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(input): Form<RegistrationToggleInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
let new_settings = SystemSettings {
|
||||
@@ -175,14 +165,8 @@ pub struct ModelSettingsData {
|
||||
|
||||
pub async fn update_model_settings(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(input): Form<ModelSettingsInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
// Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI
|
||||
@@ -295,13 +279,7 @@ pub struct SystemPromptEditData {
|
||||
|
||||
pub async fn show_edit_system_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -325,14 +303,8 @@ pub struct SystemPromptSectionData {
|
||||
|
||||
pub async fn patch_query_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(input): Form<SystemPromptUpdateInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
let new_settings = SystemSettings {
|
||||
@@ -359,13 +331,7 @@ pub struct IngestionPromptEditData {
|
||||
|
||||
pub async fn show_edit_ingestion_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -384,14 +350,8 @@ pub struct IngestionPromptUpdateInput {
|
||||
|
||||
pub async fn patch_ingestion_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(input): Form<IngestionPromptUpdateInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
let new_settings = SystemSettings {
|
||||
@@ -418,13 +378,7 @@ pub struct ImagePromptEditData {
|
||||
|
||||
pub async fn show_edit_image_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
let settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
@@ -443,14 +397,8 @@ pub struct ImagePromptUpdateInput {
|
||||
|
||||
pub async fn patch_image_prompt(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(input): Form<ImagePromptUpdateInput>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Early return if the user is not admin
|
||||
if !user.admin {
|
||||
return Ok(TemplateResponse::redirect("/"));
|
||||
}
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
let new_settings = SystemSettings {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod handlers;
|
||||
use axum::{
|
||||
extract::FromRef,
|
||||
middleware::from_fn,
|
||||
routing::{get, patch},
|
||||
Router,
|
||||
};
|
||||
@@ -10,7 +11,7 @@ use handlers::{
|
||||
toggle_registration_status, update_model_settings,
|
||||
};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
use crate::{html_state::HtmlState, middlewares::auth_middleware::require_admin};
|
||||
|
||||
pub fn router<S>() -> Router<S>
|
||||
where
|
||||
@@ -27,4 +28,5 @@ where
|
||||
.route("/update-ingestion-prompt", patch(patch_ingestion_prompt))
|
||||
.route("/edit-image-prompt", get(show_edit_image_prompt))
|
||||
.route("/update-image-prompt", patch(patch_image_prompt))
|
||||
.route_layer(from_fn(require_admin))
|
||||
}
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
response::{Html, IntoResponse},
|
||||
Form,
|
||||
};
|
||||
use axum::{extract::State, response::IntoResponse, Form};
|
||||
use axum_htmx::HxBoosted;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -46,7 +42,7 @@ pub async fn authenticate_user(
|
||||
let user = match User::authenticate(&form.email, &form.password, &state.db).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => {
|
||||
return Ok(Html("<p>Incorrect email or password </p>").into_response());
|
||||
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
response::{Html, IntoResponse},
|
||||
Form,
|
||||
};
|
||||
use axum::{extract::State, response::IntoResponse, Form};
|
||||
use axum_htmx::HxBoosted;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -57,7 +53,7 @@ pub async fn process_signup_and_show_verification(
|
||||
Ok(user) => user,
|
||||
Err(e) => {
|
||||
tracing::error!("{:?}", e);
|
||||
return Ok(Html(format!("<p>{e}</p>")).into_response());
|
||||
return Ok(TemplateResponse::bad_request(&e.to_string()).into_response());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ pub async fn show_initialized_chat(
|
||||
state.db.store_item(conversation.clone()).await?;
|
||||
state.db.store_item(ai_message.clone()).await?;
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
let messages = vec![user_message, ai_message];
|
||||
|
||||
@@ -178,7 +179,7 @@ pub async fn new_chat_user_message(
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
};
|
||||
|
||||
let conversation = Conversation::new(user.id, "New chat".to_string());
|
||||
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
|
||||
let user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
@@ -188,6 +189,7 @@ pub async fn new_chat_user_message(
|
||||
|
||||
state.db.store_item(conversation.clone()).await?;
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
@@ -252,6 +254,7 @@ pub async fn patch_conversation_title(
|
||||
Form(form): Form<PatchConversationTitle>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"sidebar.html",
|
||||
@@ -281,6 +284,7 @@ pub async fn delete_conversation(
|
||||
.db
|
||||
.delete_item::<Conversation>(&conversation_id)
|
||||
.await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"sidebar.html",
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
|
||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use async_stream::stream;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::{
|
||||
sse::{Event, KeepAlive},
|
||||
sse::{Event, KeepAlive, KeepAliveStream},
|
||||
Sse,
|
||||
},
|
||||
};
|
||||
@@ -24,7 +26,7 @@ use retrieval_pipeline::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
use tokio::sync::{mpsc::channel, Mutex};
|
||||
use tracing::{debug, error};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
@@ -38,10 +40,21 @@ use common::storage::{
|
||||
|
||||
use crate::{html_state::HtmlState, AuthSessionType};
|
||||
|
||||
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
|
||||
|
||||
type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
|
||||
type SseResponse = Sse<KeepAliveStream<EventStream>>;
|
||||
|
||||
fn sse_with_keep_alive(stream: EventStream) -> SseResponse {
|
||||
Sse::new(stream).keep_alive(
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
}
|
||||
|
||||
// Error handling function
|
||||
fn create_error_stream(
|
||||
message: impl Into<String>,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
|
||||
fn create_error_stream(message: impl Into<String>) -> EventStream {
|
||||
let message = message.into();
|
||||
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
|
||||
}
|
||||
@@ -51,53 +64,125 @@ async fn get_message_and_user(
|
||||
db: &SurrealDbClient,
|
||||
current_user: Option<User>,
|
||||
message_id: &str,
|
||||
) -> Result<
|
||||
(Message, User, Conversation, Vec<Message>),
|
||||
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
||||
> {
|
||||
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
|
||||
// Check authentication
|
||||
let user = match current_user {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)))
|
||||
}
|
||||
let Some(user) = current_user else {
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)));
|
||||
};
|
||||
|
||||
// Retrieve message
|
||||
let message = match db.get_item::<Message>(message_id).await {
|
||||
Ok(Some(message)) => message,
|
||||
Ok(None) => {
|
||||
return Err(Sse::new(create_error_stream(
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"Message not found: the specified message does not exist",
|
||||
)))
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Database error retrieving message {}: {:?}", message_id, e);
|
||||
return Err(Sse::new(create_error_stream(
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"Failed to retrieve message: database error",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Get conversation history
|
||||
let (conversation, mut history) =
|
||||
let (conversation, history) =
|
||||
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
|
||||
{
|
||||
Err(e) => {
|
||||
error!("Database error retrieving message {}: {:?}", message_id, e);
|
||||
return Err(Sse::new(create_error_stream(
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"Failed to retrieve message: database error",
|
||||
)));
|
||||
}
|
||||
Ok((conversation, history)) => (conversation, history),
|
||||
};
|
||||
|
||||
// Remove the last message, its the same as the message
|
||||
history.pop();
|
||||
let Some(message_index) = find_message_index(&history, message_id) else {
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"Message not found in conversation history",
|
||||
)));
|
||||
};
|
||||
|
||||
Ok((message, user, conversation, history))
|
||||
let Some(message_from_history) = history.get(message_index) else {
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"Message not found in conversation history",
|
||||
)));
|
||||
};
|
||||
|
||||
if message_from_history.role != MessageRole::User {
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"Only user messages can be used to generate a response",
|
||||
)));
|
||||
}
|
||||
|
||||
let message = message_from_history.clone();
|
||||
|
||||
let history_before_message = history_before_message(&history, message_index);
|
||||
let existing_ai_response = find_existing_ai_response(&history, message_index);
|
||||
|
||||
Ok((
|
||||
message,
|
||||
user,
|
||||
conversation,
|
||||
history_before_message,
|
||||
existing_ai_response,
|
||||
))
|
||||
}
|
||||
|
||||
fn find_message_index(messages: &[Message], message_id: &str) -> Option<usize> {
|
||||
messages.iter().position(|message| message.id == message_id)
|
||||
}
|
||||
|
||||
fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option<Message> {
|
||||
messages
|
||||
.iter()
|
||||
.skip(user_message_index + 1)
|
||||
.take_while(|message| message.role != MessageRole::User)
|
||||
.find(|message| message.role == MessageRole::AI)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn history_before_message(messages: &[Message], message_index: usize) -> Vec<Message> {
|
||||
messages.iter().take(message_index).cloned().collect()
|
||||
}
|
||||
|
||||
fn create_replayed_response_stream(state: &HtmlState, existing_ai_message: Message) -> SseResponse {
|
||||
let references_event = if existing_ai_message
|
||||
.references
|
||||
.as_ref()
|
||||
.is_some_and(|references| !references.is_empty())
|
||||
{
|
||||
state
|
||||
.templates
|
||||
.render(
|
||||
"chat/reference_list.html",
|
||||
&Value::from_serialize(ReferenceData {
|
||||
message: existing_ai_message.clone(),
|
||||
}),
|
||||
)
|
||||
.ok()
|
||||
.map(|html| Event::default().event("references").data(html))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let answer = existing_ai_message.content;
|
||||
|
||||
let event_stream = stream! {
|
||||
yield Ok(Event::default().event("chat_message").data(answer));
|
||||
|
||||
if let Some(event) = references_event {
|
||||
yield Ok(event);
|
||||
}
|
||||
|
||||
yield Ok(Event::default().event("close_stream").data("Stream complete"));
|
||||
};
|
||||
|
||||
sse_with_keep_alive(event_stream.boxed())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -105,21 +190,42 @@ pub struct QueryParams {
|
||||
message_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
|
||||
response
|
||||
.references
|
||||
.iter()
|
||||
.map(|reference| reference.reference.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn get_response_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
// auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
|
||||
) -> SseResponse {
|
||||
// 1. Authentication and initial data validation
|
||||
let (user_message, user, _conversation, history) =
|
||||
let (user_message, user, _conversation, history, existing_ai_response) =
|
||||
match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await {
|
||||
Ok((user_message, user, conversation, history)) => {
|
||||
(user_message, user, conversation, history)
|
||||
}
|
||||
Ok((user_message, user, conversation, history, existing_ai_response)) => (
|
||||
user_message,
|
||||
user,
|
||||
conversation,
|
||||
history,
|
||||
existing_ai_response,
|
||||
),
|
||||
Err(error_stream) => return error_stream,
|
||||
};
|
||||
|
||||
if let Some(existing_ai_message) = existing_ai_response {
|
||||
return create_replayed_response_stream(&state, existing_ai_message);
|
||||
}
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
@@ -142,15 +248,17 @@ pub async fn get_response_stream(
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve knowledge"));
|
||||
return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge"));
|
||||
}
|
||||
};
|
||||
|
||||
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
||||
|
||||
// 3. Create the OpenAI request with appropriate context format
|
||||
let context_json = match retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
||||
let context_json = match &retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
|
||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
||||
retrieved_entities_to_json(&entities)
|
||||
retrieved_entities_to_json(entities)
|
||||
}
|
||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
||||
// For chat, use chunks from the search result
|
||||
@@ -159,24 +267,18 @@ pub async fn get_response_stream(
|
||||
};
|
||||
let formatted_user_message =
|
||||
create_user_message_with_history(&context_json, &history, &user_message.content);
|
||||
let settings = match SystemSettings::get_current(&state.db).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
return Sse::new(create_error_stream("Failed to retrieve system settings"));
|
||||
}
|
||||
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings"));
|
||||
};
|
||||
let request = match create_chat_request(formatted_user_message, &settings) {
|
||||
Ok(req) => req,
|
||||
Err(..) => {
|
||||
return Sse::new(create_error_stream("Failed to create chat request"));
|
||||
}
|
||||
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to create chat request"));
|
||||
};
|
||||
|
||||
// 4. Set up the OpenAI stream
|
||||
let openai_stream = match state.openai_client.chat().create_stream(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_e) => {
|
||||
return Sse::new(create_error_stream("Failed to create OpenAI stream"));
|
||||
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -186,7 +288,9 @@ pub async fn get_response_stream(
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = state.db.clone();
|
||||
let db_client = Arc::clone(&state.db);
|
||||
let user_id = user.id.clone();
|
||||
let allowed_reference_ids = allowed_reference_ids.clone();
|
||||
tokio::spawn(async move {
|
||||
drop(tx); // Close sender when no longer needed
|
||||
|
||||
@@ -198,17 +302,55 @@ pub async fn get_response_stream(
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||
let references: Vec<String> = response
|
||||
.references
|
||||
.into_iter()
|
||||
.map(|r| r.reference)
|
||||
.collect();
|
||||
let raw_references = extract_reference_strings(&response);
|
||||
let answer = response.answer;
|
||||
|
||||
let initial_validation = match validate_references(
|
||||
&user_id,
|
||||
raw_references,
|
||||
&allowed_reference_ids,
|
||||
&db_client,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
error!(error = %err, "Reference validation failed, storing answer without references");
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
answer,
|
||||
Some(Vec::new()),
|
||||
);
|
||||
|
||||
let _ = tx_final.send(ai_message.clone()).await;
|
||||
if let Err(store_err) = db_client.store_item(ai_message).await {
|
||||
error!(error = ?store_err, "Failed to store AI message after validation failure");
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
total_refs = initial_validation.reason_stats.total,
|
||||
valid_refs = initial_validation.valid_refs.len(),
|
||||
invalid_refs = initial_validation.invalid_refs.len(),
|
||||
invalid_empty = initial_validation.reason_stats.empty,
|
||||
invalid_unsupported_prefix = initial_validation.reason_stats.unsupported_prefix,
|
||||
invalid_malformed_uuid = initial_validation.reason_stats.malformed_uuid,
|
||||
invalid_duplicate = initial_validation.reason_stats.duplicate,
|
||||
invalid_not_in_context = initial_validation.reason_stats.not_in_context,
|
||||
invalid_not_found = initial_validation.reason_stats.not_found,
|
||||
invalid_wrong_user = initial_validation.reason_stats.wrong_user,
|
||||
invalid_over_limit = initial_validation.reason_stats.over_limit,
|
||||
"Post-LLM reference validation complete"
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
response.answer,
|
||||
Some(references),
|
||||
answer,
|
||||
Some(initial_validation.valid_refs),
|
||||
);
|
||||
|
||||
let _ = tx_final.send(ai_message.clone()).await;
|
||||
@@ -240,7 +382,7 @@ pub async fn get_response_stream(
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = json_state.clone();
|
||||
let json_state = Arc::clone(&json_state);
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
@@ -288,12 +430,6 @@ pub async fn get_response_stream(
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
// Prepare data for template
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
// Render template with references
|
||||
match state.templates.render(
|
||||
"chat/reference_list.html",
|
||||
@@ -323,11 +459,7 @@ pub async fn get_response_stream(
|
||||
.data("Stream complete"))
|
||||
}));
|
||||
|
||||
Sse::new(event_stream.boxed()).keep_alive(
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive"),
|
||||
)
|
||||
sse_with_keep_alive(event_stream.boxed())
|
||||
}
|
||||
|
||||
struct StreamParserState {
|
||||
@@ -375,3 +507,195 @@ impl StreamParserState {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::{Duration as ChronoDuration, Utc};
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{conversation::Conversation, user::Theme},
|
||||
};
|
||||
use retrieval_pipeline::answer_retrieval::Reference;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn make_test_message(id: &str, role: MessageRole) -> Message {
|
||||
let mut message = Message::new(
|
||||
"conversation-1".to_string(),
|
||||
role,
|
||||
format!("content-{id}"),
|
||||
None,
|
||||
);
|
||||
message.id = id.to_string();
|
||||
message
|
||||
}
|
||||
|
||||
fn make_test_user(id: &str) -> User {
|
||||
User {
|
||||
id: id.to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
email: "test@example.com".to_string(),
|
||||
password: "password".to_string(),
|
||||
anonymous: false,
|
||||
api_key: None,
|
||||
admin: false,
|
||||
timezone: "UTC".to_string(),
|
||||
theme: Theme::System,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_reference_strings_in_order() {
|
||||
let response = LLMResponseFormat {
|
||||
answer: "answer".to_string(),
|
||||
references: vec![
|
||||
Reference {
|
||||
reference: "a".to_string(),
|
||||
},
|
||||
Reference {
|
||||
reference: "b".to_string(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let extracted = extract_reference_strings(&response);
|
||||
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finds_message_index_for_existing_message() {
|
||||
let messages = vec![
|
||||
make_test_message("m1", MessageRole::User),
|
||||
make_test_message("m2", MessageRole::AI),
|
||||
make_test_message("m3", MessageRole::User),
|
||||
];
|
||||
|
||||
assert_eq!(find_message_index(&messages, "m2"), Some(1));
|
||||
assert_eq!(find_message_index(&messages, "missing"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finds_existing_ai_response_for_same_turn() {
|
||||
let messages = vec![
|
||||
make_test_message("u1", MessageRole::User),
|
||||
make_test_message("system", MessageRole::System),
|
||||
make_test_message("a1", MessageRole::AI),
|
||||
make_test_message("u2", MessageRole::User),
|
||||
make_test_message("a2", MessageRole::AI),
|
||||
];
|
||||
|
||||
let ai_reply = find_existing_ai_response(&messages, 0).expect("expected AI response");
|
||||
assert_eq!(ai_reply.id, "a1");
|
||||
|
||||
let ai_reply_second_turn =
|
||||
find_existing_ai_response(&messages, 3).expect("expected AI response");
|
||||
assert_eq!(ai_reply_second_turn.id, "a2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn does_not_replay_ai_response_from_later_turn() {
|
||||
let messages = vec![
|
||||
make_test_message("u1", MessageRole::User),
|
||||
make_test_message("u2", MessageRole::User),
|
||||
make_test_message("a2", MessageRole::AI),
|
||||
];
|
||||
|
||||
assert!(find_existing_ai_response(&messages, 0).is_none());
|
||||
|
||||
let ai_reply = find_existing_ai_response(&messages, 1).expect("expected AI response");
|
||||
assert_eq!(ai_reply.id, "a2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_before_message_excludes_target_and_future_messages() {
|
||||
let messages = vec![
|
||||
make_test_message("u1", MessageRole::User),
|
||||
make_test_message("a1", MessageRole::AI),
|
||||
make_test_message("u2", MessageRole::User),
|
||||
make_test_message("a2", MessageRole::AI),
|
||||
];
|
||||
|
||||
let history_for_u2 = history_before_message(&messages, 2);
|
||||
let history_ids: Vec<String> = history_for_u2
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect();
|
||||
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_message_and_user_reuses_existing_ai_response_for_same_turn() {
|
||||
let namespace = "chat_stream_replay";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("failed to create in-memory db");
|
||||
|
||||
let user = make_test_user("user-1");
|
||||
let conversation = Conversation::new(user.id.clone(), "Conversation".to_string());
|
||||
|
||||
let mut user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
"Question one".to_string(),
|
||||
None,
|
||||
);
|
||||
user_message.id = "u1".to_string();
|
||||
|
||||
let mut ai_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::AI,
|
||||
"Answer one".to_string(),
|
||||
Some(vec!["ref-1".to_string()]),
|
||||
);
|
||||
ai_message.id = "a1".to_string();
|
||||
ai_message.created_at = user_message.created_at + ChronoDuration::seconds(1);
|
||||
ai_message.updated_at = ai_message.created_at;
|
||||
|
||||
let mut second_user_message = Message::new(
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
"Question two".to_string(),
|
||||
None,
|
||||
);
|
||||
second_user_message.id = "u2".to_string();
|
||||
second_user_message.created_at = ai_message.created_at + ChronoDuration::seconds(1);
|
||||
second_user_message.updated_at = second_user_message.created_at;
|
||||
|
||||
db.store_item(conversation.clone())
|
||||
.await
|
||||
.expect("failed to store conversation");
|
||||
db.store_item(user_message.clone())
|
||||
.await
|
||||
.expect("failed to store user message");
|
||||
db.store_item(ai_message.clone())
|
||||
.await
|
||||
.expect("failed to store ai message");
|
||||
db.store_item(second_user_message.clone())
|
||||
.await
|
||||
.expect("failed to store second user message");
|
||||
|
||||
let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) =
|
||||
get_message_and_user(&db, Some(user.clone()), &user_message.id)
|
||||
.await
|
||||
.expect("expected first turn to load");
|
||||
|
||||
assert!(history_for_first_turn.is_empty());
|
||||
let existing_ai_for_first_turn =
|
||||
existing_ai_for_first_turn.expect("expected first-turn AI response");
|
||||
assert_eq!(existing_ai_for_first_turn.id, ai_message.id);
|
||||
|
||||
let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) =
|
||||
get_message_and_user(&db, Some(user), &second_user_message.id)
|
||||
.await
|
||||
.expect("expected second turn to load");
|
||||
|
||||
let history_ids: Vec<String> = history_for_second_turn
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect();
|
||||
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
|
||||
assert!(existing_ai_for_second_turn.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod chat_handlers;
|
||||
mod message_response_stream;
|
||||
mod reference_validation;
|
||||
mod references;
|
||||
|
||||
use axum::{
|
||||
|
||||
477
html-router/src/routes/chat/reference_validation.rs
Normal file
477
html-router/src/routes/chat/reference_validation.rs
Normal file
@@ -0,0 +1,477 @@
|
||||
#![allow(clippy::arithmetic_side_effects, clippy::missing_docs_in_private_items)]
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||
},
|
||||
};
|
||||
use retrieval_pipeline::StrategyOutput;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum InvalidReferenceReason {
|
||||
Empty,
|
||||
UnsupportedPrefix,
|
||||
MalformedUuid,
|
||||
Duplicate,
|
||||
NotInContext,
|
||||
NotFound,
|
||||
WrongUser,
|
||||
OverLimit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct InvalidReference {
|
||||
pub raw: String,
|
||||
pub normalized: Option<String>,
|
||||
pub reason: InvalidReferenceReason,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub(crate) struct ReferenceReasonStats {
|
||||
pub total: usize,
|
||||
pub empty: usize,
|
||||
pub unsupported_prefix: usize,
|
||||
pub malformed_uuid: usize,
|
||||
pub duplicate: usize,
|
||||
pub not_in_context: usize,
|
||||
pub not_found: usize,
|
||||
pub wrong_user: usize,
|
||||
pub over_limit: usize,
|
||||
}
|
||||
|
||||
impl ReferenceReasonStats {
|
||||
fn record(&mut self, reason: &InvalidReferenceReason) {
|
||||
match reason {
|
||||
InvalidReferenceReason::Empty => self.empty += 1,
|
||||
InvalidReferenceReason::UnsupportedPrefix => self.unsupported_prefix += 1,
|
||||
InvalidReferenceReason::MalformedUuid => self.malformed_uuid += 1,
|
||||
InvalidReferenceReason::Duplicate => self.duplicate += 1,
|
||||
InvalidReferenceReason::NotInContext => self.not_in_context += 1,
|
||||
InvalidReferenceReason::NotFound => self.not_found += 1,
|
||||
InvalidReferenceReason::WrongUser => self.wrong_user += 1,
|
||||
InvalidReferenceReason::OverLimit => self.over_limit += 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct ReferenceValidationResult {
|
||||
pub valid_refs: Vec<String>,
|
||||
pub invalid_refs: Vec<InvalidReference>,
|
||||
pub reason_stats: ReferenceReasonStats,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum ReferenceLookupTarget {
|
||||
TextChunk,
|
||||
KnowledgeEntity,
|
||||
Any,
|
||||
}
|
||||
|
||||
pub(crate) fn collect_reference_ids_from_retrieval(
|
||||
retrieval_result: &StrategyOutput,
|
||||
) -> Vec<String> {
|
||||
let mut ids = Vec::new();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
match retrieval_result {
|
||||
StrategyOutput::Chunks(chunks) => {
|
||||
for chunk in chunks {
|
||||
let id = chunk.chunk.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
StrategyOutput::Entities(entities) => {
|
||||
for entity in entities {
|
||||
let id = entity.entity.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
StrategyOutput::Search(search) => {
|
||||
for chunk in &search.chunks {
|
||||
let id = chunk.chunk.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
for entity in &search.entities {
|
||||
let id = entity.entity.id.clone();
|
||||
if seen.insert(id.clone()) {
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ids
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_references(
|
||||
user_id: &str,
|
||||
refs: Vec<String>,
|
||||
allowed_ids: &[String],
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<ReferenceValidationResult, AppError> {
|
||||
let mut result = ReferenceValidationResult::default();
|
||||
result.reason_stats.total = refs.len();
|
||||
|
||||
let mut seen = HashSet::new();
|
||||
let allowed_set: HashSet<&str> = allowed_ids.iter().map(String::as_str).collect();
|
||||
let enforce_context = !allowed_set.is_empty();
|
||||
|
||||
for raw in refs {
|
||||
let (normalized, target) = match normalize_reference(&raw) {
|
||||
Ok(parsed) => parsed,
|
||||
Err(reason) => {
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: None,
|
||||
reason,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !seen.insert(normalized.clone()) {
|
||||
let reason = InvalidReferenceReason::Duplicate;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if result.valid_refs.len() >= MAX_REFERENCE_COUNT {
|
||||
let reason = InvalidReferenceReason::OverLimit;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if enforce_context && !allowed_set.contains(normalized.as_str()) {
|
||||
let reason = InvalidReferenceReason::NotInContext;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
match lookup_reference_for_user(&normalized, &target, user_id, db).await? {
|
||||
LookupResult::Found => result.valid_refs.push(normalized),
|
||||
LookupResult::WrongUser => {
|
||||
let reason = InvalidReferenceReason::WrongUser;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
}
|
||||
LookupResult::NotFound => {
|
||||
let reason = InvalidReferenceReason::NotFound;
|
||||
result.reason_stats.record(&reason);
|
||||
result.invalid_refs.push(InvalidReference {
|
||||
raw,
|
||||
normalized: Some(normalized),
|
||||
reason,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn normalize_reference(
|
||||
raw: &str,
|
||||
) -> Result<(String, ReferenceLookupTarget), InvalidReferenceReason> {
|
||||
let trimmed = raw.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Err(InvalidReferenceReason::Empty);
|
||||
}
|
||||
|
||||
let (candidate, target) = if let Some((prefix, rest)) = trimmed.split_once(':') {
|
||||
let lookup_target = if prefix.eq_ignore_ascii_case("knowledge_entity") {
|
||||
ReferenceLookupTarget::KnowledgeEntity
|
||||
} else if prefix.eq_ignore_ascii_case("text_chunk") {
|
||||
ReferenceLookupTarget::TextChunk
|
||||
} else {
|
||||
return Err(InvalidReferenceReason::UnsupportedPrefix);
|
||||
};
|
||||
|
||||
(rest.trim(), lookup_target)
|
||||
} else {
|
||||
(trimmed, ReferenceLookupTarget::Any)
|
||||
};
|
||||
|
||||
if candidate.is_empty() {
|
||||
return Err(InvalidReferenceReason::MalformedUuid);
|
||||
}
|
||||
|
||||
Uuid::parse_str(candidate)
|
||||
.map(|uuid| (uuid.to_string(), target))
|
||||
.map_err(|_| InvalidReferenceReason::MalformedUuid)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum LookupResult {
|
||||
Found,
|
||||
WrongUser,
|
||||
NotFound,
|
||||
}
|
||||
|
||||
async fn lookup_reference_for_user(
|
||||
id: &str,
|
||||
target: &ReferenceLookupTarget,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<LookupResult, AppError> {
|
||||
match target {
|
||||
ReferenceLookupTarget::TextChunk => lookup_single_type::<TextChunk>(id, user_id, db).await,
|
||||
ReferenceLookupTarget::KnowledgeEntity => {
|
||||
lookup_single_type::<KnowledgeEntity>(id, user_id, db).await
|
||||
}
|
||||
ReferenceLookupTarget::Any => {
|
||||
let chunk_result = lookup_single_type::<TextChunk>(id, user_id, db).await?;
|
||||
if chunk_result == LookupResult::Found {
|
||||
return Ok(LookupResult::Found);
|
||||
}
|
||||
|
||||
let entity_result = lookup_single_type::<KnowledgeEntity>(id, user_id, db).await?;
|
||||
if entity_result == LookupResult::Found {
|
||||
return Ok(LookupResult::Found);
|
||||
}
|
||||
|
||||
if chunk_result == LookupResult::WrongUser || entity_result == LookupResult::WrongUser {
|
||||
return Ok(LookupResult::WrongUser);
|
||||
}
|
||||
|
||||
Ok(LookupResult::NotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn lookup_single_type<T>(
|
||||
id: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<LookupResult, AppError>
|
||||
where
|
||||
T: StoredObject + for<'de> serde::Deserialize<'de> + HasUserId,
|
||||
{
|
||||
let item = db.get_item::<T>(id).await?;
|
||||
Ok(match item {
|
||||
Some(item) if item.user_id() == user_id => LookupResult::Found,
|
||||
Some(_) => LookupResult::WrongUser,
|
||||
None => LookupResult::NotFound,
|
||||
})
|
||||
}
|
||||
|
||||
trait HasUserId {
|
||||
fn user_id(&self) -> &str;
|
||||
}
|
||||
|
||||
impl HasUserId for TextChunk {
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
impl HasUserId for KnowledgeEntity {
|
||||
fn user_id(&self) -> &str {
|
||||
&self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(
|
||||
clippy::cloned_ref_to_slice_refs,
|
||||
clippy::expect_used,
|
||||
clippy::indexing_slicing
|
||||
)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
||||
use surrealdb::engine::any::connect;
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let client = connect("mem://")
|
||||
.await
|
||||
.expect("failed to create in-memory surrealdb client");
|
||||
let namespace = format!("test_ns_{}", Uuid::new_v4());
|
||||
let database = format!("test_db_{}", Uuid::new_v4());
|
||||
client
|
||||
.use_ns(namespace)
|
||||
.use_db(database)
|
||||
.await
|
||||
.expect("failed to select namespace/db");
|
||||
|
||||
let db = SurrealDbClient { client };
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
db
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn valid_uuid_exists_and_belongs_to_user() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user-a";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source-1".to_string(),
|
||||
"Entity A".to_string(),
|
||||
"Entity description".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to store entity");
|
||||
|
||||
let result =
|
||||
validate_references(user_id, vec![entity.id.clone()], &[entity.id.clone()], &db)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert_eq!(result.valid_refs, vec![entity.id]);
|
||||
assert!(result.invalid_refs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn valid_uuid_exists_but_wrong_user_is_rejected() {
|
||||
let db = setup_test_db().await;
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source-1".to_string(),
|
||||
"Entity B".to_string(),
|
||||
"Entity description".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
"other-user".to_string(),
|
||||
);
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to store entity");
|
||||
|
||||
let result =
|
||||
validate_references("user-a", vec![entity.id.clone()], &[entity.id.clone()], &db)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert!(result.valid_refs.is_empty());
|
||||
assert_eq!(result.invalid_refs.len(), 1);
|
||||
assert_eq!(
|
||||
result.invalid_refs[0].reason,
|
||||
InvalidReferenceReason::WrongUser
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malformed_uuid_is_rejected() {
|
||||
let db = setup_test_db().await;
|
||||
let result = validate_references(
|
||||
"user-a",
|
||||
vec!["not-a-uuid".to_string()],
|
||||
&["not-a-uuid".to_string()],
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert!(result.valid_refs.is_empty());
|
||||
assert_eq!(result.invalid_refs.len(), 1);
|
||||
assert_eq!(
|
||||
result.invalid_refs[0].reason,
|
||||
InvalidReferenceReason::MalformedUuid
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mixed_duplicates_are_deduped() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user-a";
|
||||
|
||||
let first = KnowledgeEntity::new(
|
||||
"source-1".to_string(),
|
||||
"Entity 1".to_string(),
|
||||
"Entity description".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
let second = KnowledgeEntity::new(
|
||||
"source-2".to_string(),
|
||||
"Entity 2".to_string(),
|
||||
"Entity description".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(first.clone())
|
||||
.await
|
||||
.expect("failed to store first entity");
|
||||
db.store_item(second.clone())
|
||||
.await
|
||||
.expect("failed to store second entity");
|
||||
|
||||
let refs = vec![
|
||||
first.id.clone(),
|
||||
format!("knowledge_entity:{}", first.id),
|
||||
second.id.clone(),
|
||||
second.id.clone(),
|
||||
];
|
||||
|
||||
let allowed = vec![first.id.clone(), second.id.clone()];
|
||||
let result = validate_references(user_id, refs, &allowed, &db)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert_eq!(result.valid_refs, vec![first.id, second.id]);
|
||||
assert_eq!(result.invalid_refs.len(), 2);
|
||||
assert!(result
|
||||
.invalid_refs
|
||||
.iter()
|
||||
.all(|entry| entry.reason == InvalidReferenceReason::Duplicate));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bare_uuid_prefers_chunk_lookup_before_entity() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user-a";
|
||||
let chunk = TextChunk::new(
|
||||
"source-1".to_string(),
|
||||
"Chunk body".to_string(),
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.expect("failed to store chunk");
|
||||
|
||||
let result = validate_references(user_id, vec![chunk.id.clone()], &[chunk.id.clone()], &db)
|
||||
.await
|
||||
.expect("validation should not fail");
|
||||
|
||||
assert_eq!(result.valid_refs, vec![chunk.id]);
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use chrono_tz::Tz;
|
||||
use serde::Serialize;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, user::User},
|
||||
use common::storage::types::{
|
||||
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, user::User,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -17,29 +20,101 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::reference_validation::{normalize_reference, ReferenceLookupTarget};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceTooltipData {
|
||||
text_chunk: Option<TextChunk>,
|
||||
text_chunk_updated_at: Option<String>,
|
||||
entity: Option<KnowledgeEntity>,
|
||||
entity_updated_at: Option<String>,
|
||||
user: User,
|
||||
}
|
||||
|
||||
fn format_datetime_for_user(datetime: DateTime<Utc>, timezone: &str) -> String {
|
||||
match timezone.parse::<Tz>() {
|
||||
Ok(tz) => datetime
|
||||
.with_timezone(&tz)
|
||||
.format("%Y-%m-%d %H:%M:%S")
|
||||
.to_string(),
|
||||
Err(_) => datetime.format("%Y-%m-%d %H:%M:%S").to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn show_reference_tooltip(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Path(reference_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let entity: KnowledgeEntity = state
|
||||
.db
|
||||
.get_item(&reference_id)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
|
||||
let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else {
|
||||
return Ok(TemplateResponse::not_found());
|
||||
};
|
||||
|
||||
if entity.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
let lookup_order = match target {
|
||||
ReferenceLookupTarget::TextChunk | ReferenceLookupTarget::Any => [
|
||||
ReferenceLookupTarget::TextChunk,
|
||||
ReferenceLookupTarget::KnowledgeEntity,
|
||||
],
|
||||
ReferenceLookupTarget::KnowledgeEntity => [
|
||||
ReferenceLookupTarget::KnowledgeEntity,
|
||||
ReferenceLookupTarget::TextChunk,
|
||||
],
|
||||
};
|
||||
|
||||
let mut text_chunk: Option<TextChunk> = None;
|
||||
let mut knowledge_entity: Option<KnowledgeEntity> = None;
|
||||
|
||||
for lookup_target in lookup_order {
|
||||
match lookup_target {
|
||||
ReferenceLookupTarget::TextChunk => {
|
||||
if let Some(chunk) = state
|
||||
.db
|
||||
.get_item::<TextChunk>(&normalized_reference_id)
|
||||
.await?
|
||||
{
|
||||
if chunk.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
}
|
||||
text_chunk = Some(chunk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ReferenceLookupTarget::KnowledgeEntity => {
|
||||
if let Some(entity) = state
|
||||
.db
|
||||
.get_item::<KnowledgeEntity>(&normalized_reference_id)
|
||||
.await?
|
||||
{
|
||||
if entity.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized());
|
||||
}
|
||||
knowledge_entity = Some(entity);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ReferenceLookupTarget::Any => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ReferenceTooltipData {
|
||||
entity: KnowledgeEntity,
|
||||
user: User,
|
||||
if text_chunk.is_none() && knowledge_entity.is_none() {
|
||||
return Ok(TemplateResponse::not_found());
|
||||
}
|
||||
|
||||
let text_chunk_updated_at = text_chunk
|
||||
.as_ref()
|
||||
.map(|chunk| format_datetime_for_user(chunk.updated_at, &user.timezone));
|
||||
let entity_updated_at = knowledge_entity
|
||||
.as_ref()
|
||||
.map(|entity| format_datetime_for_user(entity.updated_at, &user.timezone));
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"chat/reference_tooltip.html",
|
||||
ReferenceTooltipData { entity, user },
|
||||
ReferenceTooltipData {
|
||||
text_chunk,
|
||||
text_chunk_updated_at,
|
||||
entity: knowledge_entity,
|
||||
entity_updated_at,
|
||||
user,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
@@ -42,9 +42,8 @@ pub async fn index_handler(
|
||||
return Ok(TemplateResponse::redirect("/signin"));
|
||||
};
|
||||
|
||||
let (text_contents, _conversation_archive, stats, active_jobs) = try_join!(
|
||||
let (text_contents, stats, active_jobs) = try_join!(
|
||||
User::get_latest_text_contents(&user.id, &state.db),
|
||||
User::get_user_conversations(&user.id, &state.db),
|
||||
User::get_dashboard_stats(&user.id, &state.db),
|
||||
User::get_unfinished_ingestion_tasks(&user.id, &state.db)
|
||||
)?;
|
||||
|
||||
@@ -2,9 +2,10 @@ use std::{pin::Pin, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::{
|
||||
sse::{Event, KeepAlive},
|
||||
Html, IntoResponse, Sse,
|
||||
sse::{Event, KeepAlive, KeepAliveStream},
|
||||
IntoResponse, Response, Sse,
|
||||
},
|
||||
};
|
||||
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
|
||||
@@ -23,6 +24,7 @@ use common::{
|
||||
ingestion_task::{IngestionTask, TaskState},
|
||||
user::User,
|
||||
},
|
||||
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -34,30 +36,41 @@ use crate::{
|
||||
AuthSessionType,
|
||||
};
|
||||
|
||||
pub async fn show_ingress_form(
|
||||
type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
|
||||
type TaskSse = Sse<KeepAliveStream<EventStream>>;
|
||||
|
||||
fn sse_with_keep_alive(stream: EventStream) -> TaskSse {
|
||||
Sse::new(stream).keep_alive(
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive-ping"),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn show_ingest_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ShowIngressFormData {
|
||||
pub struct ShowIngestFormData {
|
||||
user_categories: Vec<String>,
|
||||
}
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"ingestion_modal.html",
|
||||
ShowIngressFormData { user_categories },
|
||||
ShowIngestFormData { user_categories },
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn hide_ingress_form(
|
||||
pub async fn hide_ingest_form(
|
||||
RequireUser(_user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
Ok(Html(
|
||||
"<a class='btn btn-primary' hx-get='/ingress-form' hx-swap='outerHTML'>Add Content</a>",
|
||||
)
|
||||
.into_response())
|
||||
Ok(TemplateResponse::new_template(
|
||||
"ingestion/add_content_button.html",
|
||||
(),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Debug, TryFromMultipart)]
|
||||
@@ -65,37 +78,59 @@ pub struct IngestionParams {
|
||||
pub content: Option<String>,
|
||||
pub context: String,
|
||||
pub category: String,
|
||||
#[form_data(limit = "10000000")] // Adjust limit as needed
|
||||
#[form_data(limit = "20000000")]
|
||||
#[form_data(default)]
|
||||
pub files: Vec<FieldData<NamedTempFile>>,
|
||||
}
|
||||
|
||||
pub async fn process_ingress_form(
|
||||
pub async fn process_ingest_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
TypedMultipart(input): TypedMultipart<IngestionParams>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
#[derive(Serialize)]
|
||||
pub struct IngressFormData {
|
||||
context: String,
|
||||
content: String,
|
||||
category: String,
|
||||
error: String,
|
||||
}
|
||||
|
||||
) -> Result<Response, HtmlError> {
|
||||
if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() {
|
||||
return Ok(TemplateResponse::new_template(
|
||||
"index/signed_in/ingress_form.html",
|
||||
IngressFormData {
|
||||
context: input.context.clone(),
|
||||
content: input.content.clone().unwrap_or_default(),
|
||||
category: input.category.clone(),
|
||||
error: "You need to either add files or content".to_string(),
|
||||
},
|
||||
));
|
||||
return Ok(
|
||||
TemplateResponse::bad_request("You need to either add files or content")
|
||||
.into_response(),
|
||||
);
|
||||
}
|
||||
|
||||
info!("{:?}", input);
|
||||
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
|
||||
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
|
||||
let context_bytes = input.context.len();
|
||||
let category_bytes = input.category.len();
|
||||
let file_count = input.files.len();
|
||||
|
||||
match validate_ingest_input(
|
||||
&state.config,
|
||||
input.content.as_deref(),
|
||||
&input.context,
|
||||
&input.category,
|
||||
file_count,
|
||||
) {
|
||||
Ok(()) => {}
|
||||
Err(IngestValidationError::PayloadTooLarge(message)) => {
|
||||
return Ok(TemplateResponse::error(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"Payload Too Large",
|
||||
&message,
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
Err(IngestValidationError::BadRequest(message)) => {
|
||||
return Ok(TemplateResponse::bad_request(&message).into_response());
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
user_id = %user.id,
|
||||
has_content,
|
||||
content_bytes,
|
||||
context_bytes,
|
||||
category_bytes,
|
||||
file_count,
|
||||
"Received ingest form submission"
|
||||
);
|
||||
|
||||
let file_infos = try_join_all(input.files.into_iter().map(|file| {
|
||||
FileInfo::new_with_storage(file, &state.db, &user.id, &state.storage)
|
||||
@@ -123,10 +158,10 @@ pub async fn process_ingress_form(
|
||||
tasks: Vec<IngestionTask>,
|
||||
}
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"dashboard/current_task.html",
|
||||
NewTasksData { tasks },
|
||||
))
|
||||
Ok(
|
||||
TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks })
|
||||
.into_response(),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -134,9 +169,7 @@ pub struct QueryParams {
|
||||
task_id: String,
|
||||
}
|
||||
|
||||
fn create_error_stream(
|
||||
message: impl Into<String>,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
|
||||
fn create_error_stream(message: impl Into<String>) -> EventStream {
|
||||
let message = message.into();
|
||||
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
|
||||
}
|
||||
@@ -145,13 +178,13 @@ pub async fn get_task_updates_stream(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
|
||||
) -> TaskSse {
|
||||
let task_id = params.task_id.clone();
|
||||
let db = state.db.clone();
|
||||
|
||||
// 1. Check for authenticated user
|
||||
let Some(current_user) = auth.current_user else {
|
||||
return Sse::new(create_error_stream("User not authenticated"));
|
||||
return sse_with_keep_alive(create_error_stream("User not authenticated"));
|
||||
};
|
||||
|
||||
// 2. Fetch task for initial authorization and to ensure it exists
|
||||
@@ -159,7 +192,7 @@ pub async fn get_task_updates_stream(
|
||||
Ok(Some(task)) => {
|
||||
// 3. Validate user ownership
|
||||
if task.user_id != current_user.id {
|
||||
return Sse::new(create_error_stream(
|
||||
return sse_with_keep_alive(create_error_stream(
|
||||
"Access denied: You do not have permission to view updates for this task.",
|
||||
));
|
||||
}
|
||||
@@ -245,18 +278,14 @@ pub async fn get_task_updates_stream(
|
||||
}
|
||||
};
|
||||
|
||||
Sse::new(sse_stream.boxed()).keep_alive(
|
||||
KeepAlive::new()
|
||||
.interval(Duration::from_secs(15))
|
||||
.text("keep-alive-ping"),
|
||||
)
|
||||
sse_with_keep_alive(sse_stream.boxed())
|
||||
}
|
||||
Ok(None) => Sse::new(create_error_stream(format!(
|
||||
Ok(None) => sse_with_keep_alive(create_error_stream(format!(
|
||||
"Task with ID '{task_id}' not found."
|
||||
))),
|
||||
Err(e) => {
|
||||
error!("Failed to fetch task '{task_id}' for authorization: {e:?}");
|
||||
Sse::new(create_error_stream(
|
||||
sse_with_keep_alive(create_error_stream(
|
||||
"An error occurred while retrieving task details. Please try again later.",
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
mod handlers;
|
||||
|
||||
use axum::{extract::FromRef, routing::get, Router};
|
||||
use handlers::{
|
||||
get_task_updates_stream, hide_ingress_form, process_ingress_form, show_ingress_form,
|
||||
};
|
||||
use axum::{extract::DefaultBodyLimit, extract::FromRef, routing::get, Router};
|
||||
use handlers::{get_task_updates_stream, hide_ingest_form, process_ingest_form, show_ingest_form};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
pub fn router<S>() -> Router<S>
|
||||
pub fn router<S>(max_body_bytes: usize) -> Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route(
|
||||
"/ingress-form",
|
||||
get(show_ingress_form).post(process_ingress_form),
|
||||
"/ingest-form",
|
||||
get(show_ingest_form)
|
||||
.post(process_ingest_form)
|
||||
.layer(DefaultBodyLimit::max(max_body_bytes)),
|
||||
)
|
||||
.route("/task/status-stream", get(get_task_updates_stream))
|
||||
.route("/hide-ingress-form", get(hide_ingress_form))
|
||||
.route("/hide-ingest-form", get(hide_ingest_form))
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
<label class="w-full">
|
||||
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">API Key</div>
|
||||
{% block api_key_section %}
|
||||
{% if user.api_key %}
|
||||
{% if api_key %}
|
||||
<div class="relative">
|
||||
<input id="api_key_input" type="text" name="api_key" value="{{ user.api_key }}"
|
||||
<input id="api_key_input" type="text" name="api_key" value="{{ api_key }}"
|
||||
class="nb-input w-full pr-14" disabled />
|
||||
<button type="button" id="copy_api_key_btn" onclick="copy_api_key()"
|
||||
class="absolute inset-y-0 right-0 flex items-center px-2 nb-btn btn-sm" aria-label="Copy API key"
|
||||
@@ -48,9 +48,10 @@
|
||||
<label class="w-full">
|
||||
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Timezone</div>
|
||||
{% block timezone_section %}
|
||||
{% set active_timezone = selected_timezone|default(user.timezone) %}
|
||||
<select name="timezone" class="nb-select w-full" hx-patch="/update-timezone" hx-swap="outerHTML">
|
||||
{% for tz in timezones %}
|
||||
<option value="{{ tz }}" {% if tz==user.timezone %}selected{% endif %}>{{ tz }}</option>
|
||||
<option value="{{ tz }}" {% if tz==active_timezone %}selected{% endif %}>{{ tz }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
{% endblock %}
|
||||
@@ -59,13 +60,14 @@
|
||||
<label class="w-full">
|
||||
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Theme</div>
|
||||
{% block theme_section %}
|
||||
{% set active_theme = selected_theme|default(user.theme) %}
|
||||
<select name="theme" class="nb-select w-full" hx-patch="/update-theme" hx-swap="outerHTML">
|
||||
{% for option in theme_options %}
|
||||
<option value="{{ option }}" {% if option==user.theme %}selected{% endif %}>{{ option }}</option>
|
||||
<option value="{{ option }}" {% if option==active_theme %}selected{% endif %}>{{ option }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
<script>
|
||||
document.documentElement.setAttribute('data-theme-preference', '{{ user.theme }}');
|
||||
document.documentElement.setAttribute('data-theme-preference', '{{ active_theme }}');
|
||||
</script>
|
||||
{% endblock %}
|
||||
</label>
|
||||
|
||||
@@ -7,4 +7,5 @@
|
||||
{% block auth_content %}
|
||||
{% endblock %}
|
||||
</div>
|
||||
<div id="toast-container" class="fixed bottom-4 right-4 z-50 space-y-2"></div>
|
||||
{% endblock %}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
</div>
|
||||
<div class="u-hairline mb-3"></div>
|
||||
|
||||
<form hx-post="/signin" hx-target="#login-result" class="flex flex-col gap-2">
|
||||
<form hx-post="/signin" hx-swap="none" class="flex flex-col gap-2">
|
||||
<label class="w-full">
|
||||
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div>
|
||||
<input name="email" type="email" placeholder="Email" class="nb-input w-full validator" required />
|
||||
@@ -19,8 +19,6 @@
|
||||
minlength="8" />
|
||||
</label>
|
||||
|
||||
<div class="mt-1 text-error" id="login-result"></div>
|
||||
|
||||
<div class="form-control mt-1">
|
||||
<label class="label cursor-pointer justify-start gap-3">
|
||||
<input type="checkbox" name="remember_me" class="nb-checkbox" />
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
</div>
|
||||
<div class="u-hairline mb-3"></div>
|
||||
|
||||
<form hx-post="/signup" hx-target="#signup-result" class="flex flex-col gap-4">
|
||||
<form hx-post="/signup" hx-swap="none" class="flex flex-col gap-4">
|
||||
<label class="w-full">
|
||||
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div>
|
||||
<input type="email" placeholder="Email" name="email" required class="nb-input w-full validator" />
|
||||
@@ -31,7 +31,6 @@
|
||||
</p>
|
||||
</label>
|
||||
|
||||
<div class="mt-2 text-error" id="signup-result"></div>
|
||||
<div class="form-control mt-1">
|
||||
<button id="submit-btn" class="nb-btn nb-cta w-full">Create Account</button>
|
||||
</div>
|
||||
|
||||
@@ -12,16 +12,34 @@
|
||||
</label>
|
||||
</form>
|
||||
<script>
|
||||
document.getElementById('chat-input').addEventListener('keydown', function (e) {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
htmx.trigger('#chat-form', 'submit');
|
||||
}
|
||||
});
|
||||
// Clear textarea after successful submission
|
||||
document.getElementById('chat-form').addEventListener('htmx:afterRequest', function (e) {
|
||||
if (e.detail.successful) { // Check if the request was successful
|
||||
document.getElementById('chat-input').value = ''; // Clear the textarea
|
||||
}
|
||||
});
|
||||
</script>
|
||||
(function () {
|
||||
const newChatStreamId = 'ai-stream-{{ user_message.id }}';
|
||||
|
||||
document.getElementById('chat-input').addEventListener('keydown', function (e) {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
htmx.trigger('#chat-form', 'submit');
|
||||
}
|
||||
});
|
||||
// Clear textarea after successful submission
|
||||
document.getElementById('chat-form').addEventListener('htmx:afterRequest', function (e) {
|
||||
if (e.detail.successful) { // Check if the request was successful
|
||||
document.getElementById('chat-input').value = ''; // Clear the textarea
|
||||
}
|
||||
});
|
||||
|
||||
const refreshSidebarAfterFirstResponse = function (e) {
|
||||
const streamEl = document.getElementById(newChatStreamId);
|
||||
if (!streamEl || e.target !== streamEl) return;
|
||||
|
||||
htmx.ajax('GET', '/chat/sidebar', {
|
||||
target: '.drawer-side',
|
||||
swap: 'outerHTML'
|
||||
});
|
||||
|
||||
document.body.removeEventListener('htmx:sseClose', refreshSidebarAfterFirstResponse);
|
||||
};
|
||||
|
||||
document.body.addEventListener('htmx:sseClose', refreshSidebarAfterFirstResponse);
|
||||
})();
|
||||
</script>
|
||||
|
||||
@@ -111,12 +111,23 @@
|
||||
// Load content if needed
|
||||
if (!tooltipContent) {
|
||||
fetch(`/chat/reference/${encodeURIComponent(reference)}`)
|
||||
.then(response => response.text())
|
||||
.then(response => {
|
||||
if (!response.ok) {
|
||||
throw new Error(`reference lookup failed with status ${response.status}`);
|
||||
}
|
||||
return response.text();
|
||||
})
|
||||
.then(html => {
|
||||
tooltipContent = html;
|
||||
if (document.getElementById(tooltipId)) {
|
||||
document.getElementById(tooltipId).innerHTML = html;
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
tooltipContent = '<div class="text-xs opacity-70">Reference unavailable.</div>';
|
||||
if (document.getElementById(tooltipId)) {
|
||||
document.getElementById(tooltipId).innerHTML = tooltipContent;
|
||||
}
|
||||
});
|
||||
} else if (tooltip) {
|
||||
// Set content if already loaded
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
<div>{{entity.name}}</div>
|
||||
<div>{{entity.description}}</div>
|
||||
<div>{{entity.updated_at|datetimeformat(format="short", tz=user.timezone)}} </div>
|
||||
{% if text_chunk %}
|
||||
<div class="font-semibold">Chunk Reference</div>
|
||||
<div class="text-sm whitespace-pre-wrap">{{text_chunk.chunk}}</div>
|
||||
<div class="text-xs opacity-70">{{text_chunk_updated_at}}</div>
|
||||
{% elif entity %}
|
||||
<div class="font-semibold">{{entity.name}}</div>
|
||||
<div class="text-sm">{{entity.description}}</div>
|
||||
<div class="text-xs opacity-70">{{entity_updated_at}}</div>
|
||||
{% else %}
|
||||
<div class="text-xs opacity-70">Reference unavailable.</div>
|
||||
{% endif %}
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="chat chat-start">
|
||||
<div hx-ext="sse" sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
|
||||
<div id="ai-stream-{{user_message.id}}" hx-ext="sse"
|
||||
sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
|
||||
hx-swap="beforeend">
|
||||
<div class="chat-bubble">
|
||||
<span class="loading loading-dots loading-sm loading-id-{{user_message.id}}"></span>
|
||||
@@ -27,13 +28,22 @@
|
||||
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
|
||||
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
|
||||
});
|
||||
document.body.addEventListener('htmx:sseClose', function () {
|
||||
document.body.addEventListener('htmx:sseClose', function (e) {
|
||||
const msgId = '{{ user_message.id }}';
|
||||
const streamEl = document.getElementById('ai-stream-' + msgId);
|
||||
if (streamEl && e.target !== streamEl) return;
|
||||
|
||||
const el = document.getElementById('ai-message-content-' + msgId);
|
||||
if (el && window.markdownBuffer[msgId]) {
|
||||
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
|
||||
delete window.markdownBuffer[msgId];
|
||||
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
|
||||
}
|
||||
|
||||
if (streamEl) {
|
||||
streamEl.removeAttribute('sse-connect');
|
||||
streamEl.removeAttribute('sse-close');
|
||||
streamEl.removeAttribute('hx-ext');
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<nav class="sticky top-0 z-10 nb-panel nb-panel-canvas border-t-0 border-l-0">
|
||||
<nav class="sticky top-0 z-10 nb-panel nb-panel-canvas border-t-0" style="border-left: 0">
|
||||
<div class="container mx-auto navbar">
|
||||
<div class="mr-2 flex-1">
|
||||
{% block navbar_search %}
|
||||
@@ -11,4 +11,4 @@
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
</nav>
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
{% block dashboard_header %}
|
||||
<h1 class="text-xl font-extrabold tracking-tight">Dashboard</h1>
|
||||
<button class="nb-btn nb-cta" hx-get="/ingress-form" hx-target="#modal" hx-swap="innerHTML">
|
||||
<button class="nb-btn nb-cta" hx-get="/ingest-form" hx-target="#modal" hx-swap="innerHTML">
|
||||
{% include "icons/send_icon.html" %}
|
||||
<span class="ml-2">Add Content</span>
|
||||
</button>
|
||||
|
||||
1
html-router/templates/ingestion/add_content_button.html
Normal file
1
html-router/templates/ingestion/add_content_button.html
Normal file
@@ -0,0 +1 @@
|
||||
<a class="btn btn-primary" hx-get="/ingest-form" hx-target="#modal" hx-swap="innerHTML">Add Content</a>
|
||||
@@ -3,7 +3,7 @@
|
||||
{% block modal_class %}max-w-3xl{% endblock %}
|
||||
|
||||
{% block form_attributes %}
|
||||
hx-post="/ingress-form"
|
||||
hx-post="/ingest-form"
|
||||
enctype="multipart/form-data"
|
||||
{% endblock %}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
</li>
|
||||
{% endfor %}
|
||||
<li>
|
||||
<button class="nb-btn nb-cta w-full flex items-center gap-3 justify-start mt-2" hx-get="/ingress-form"
|
||||
<button class="nb-btn nb-cta w-full flex items-center gap-3 justify-start mt-2" hx-get="/ingest-form"
|
||||
hx-target="#modal" hx-swap="innerHTML">{% include "icons/send_icon.html" %} Add
|
||||
Content</button>
|
||||
</li>
|
||||
|
||||
@@ -38,3 +38,6 @@ retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||
|
||||
[features]
|
||||
docker = []
|
||||
|
||||
[dev-dependencies]
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "main"
|
||||
version = "1.0.1"
|
||||
version = "1.0.2"
|
||||
edition = "2021"
|
||||
repository = "https://github.com/perstarkse/minne"
|
||||
license = "AGPL-3.0-or-later"
|
||||
@@ -30,6 +30,7 @@ retrieval-pipeline = { path = "../retrieval-pipeline" }
|
||||
[dev-dependencies]
|
||||
tower = "0.5"
|
||||
uuid = { workspace = true }
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
[[bin]]
|
||||
name = "server"
|
||||
|
||||
254
main/src/main.rs
254
main/src/main.rs
@@ -217,8 +217,16 @@ struct AppState {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::{body::Body, http::Request, http::StatusCode, Router};
|
||||
use common::storage::store::StorageManager;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{header, Request, StatusCode},
|
||||
response::Response,
|
||||
Router,
|
||||
};
|
||||
use common::storage::{
|
||||
store::StorageManager,
|
||||
types::{system_settings::SystemSettings, user::User},
|
||||
};
|
||||
use common::utils::config::{AppConfig, PdfIngestMode, StorageKind};
|
||||
use std::{path::Path, sync::Arc};
|
||||
use tower::ServiceExt;
|
||||
@@ -241,11 +249,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn smoke_startup_with_in_memory_surrealdb() {
|
||||
async fn build_test_app() -> (Router, Arc<SurrealDbClient>, std::path::PathBuf) {
|
||||
let namespace = "test_ns";
|
||||
let database = format!("test_db_{}", Uuid::new_v4());
|
||||
let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4()));
|
||||
|
||||
tokio::fs::create_dir_all(&data_dir)
|
||||
.await
|
||||
.expect("failed to create temp data directory");
|
||||
@@ -304,6 +312,66 @@ mod tests {
|
||||
html_state,
|
||||
});
|
||||
|
||||
(app, db, data_dir)
|
||||
}
|
||||
|
||||
fn assert_redirect_to(response: &Response, expected_location: &str) {
|
||||
assert!(response.status().is_redirection());
|
||||
let location = response
|
||||
.headers()
|
||||
.get(header::LOCATION)
|
||||
.expect("redirect should contain a Location header")
|
||||
.to_str()
|
||||
.expect("location header must be valid utf-8");
|
||||
assert_eq!(location, expected_location);
|
||||
}
|
||||
|
||||
fn extract_session_cookie(response: &Response) -> String {
|
||||
let cookie_header = response
|
||||
.headers()
|
||||
.get_all(header::SET_COOKIE)
|
||||
.iter()
|
||||
.map(|value| {
|
||||
value
|
||||
.to_str()
|
||||
.expect("set-cookie header must be valid utf-8")
|
||||
.split(';')
|
||||
.next()
|
||||
.expect("set-cookie should include key=value pair")
|
||||
.to_string()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert!(
|
||||
!cookie_header.is_empty(),
|
||||
"login response should set at least one cookie"
|
||||
);
|
||||
|
||||
cookie_header.join("; ")
|
||||
}
|
||||
|
||||
async fn sign_in_and_get_cookie(app: &Router, email: &str, password: &str) -> String {
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/signin")
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.body(Body::from(format!("email={email}&password={password}")))
|
||||
.expect("signin request"),
|
||||
)
|
||||
.await
|
||||
.expect("signin response");
|
||||
|
||||
assert_redirect_to(&response, "/");
|
||||
extract_session_cookie(&response)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn smoke_startup_with_in_memory_surrealdb() {
|
||||
let (app, _db, data_dir) = build_test_app().await;
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
@@ -329,4 +397,182 @@ mod tests {
|
||||
|
||||
tokio::fs::remove_dir_all(&data_dir).await.ok();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn admin_route_enforces_unauth_non_admin_and_admin_access() {
|
||||
let (app, db, data_dir) = build_test_app().await;
|
||||
|
||||
let admin = User::create_new(
|
||||
"admin_user".to_string(),
|
||||
"admin_password".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("admin user should be created");
|
||||
let non_admin = User::create_new(
|
||||
"member_user".to_string(),
|
||||
"member_password".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("non-admin user should be created");
|
||||
|
||||
assert!(admin.admin, "first user should become admin");
|
||||
assert!(!non_admin.admin, "second user should not be admin");
|
||||
|
||||
let unauth_response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin")
|
||||
.body(Body::empty())
|
||||
.expect("unauth admin request"),
|
||||
)
|
||||
.await
|
||||
.expect("unauth admin response");
|
||||
assert_redirect_to(&unauth_response, "/signin");
|
||||
|
||||
let non_admin_cookie = sign_in_and_get_cookie(&app, "member_user", "member_password").await;
|
||||
let non_admin_response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin")
|
||||
.header(header::COOKIE, non_admin_cookie)
|
||||
.body(Body::empty())
|
||||
.expect("non-admin request"),
|
||||
)
|
||||
.await
|
||||
.expect("non-admin response");
|
||||
assert_redirect_to(&non_admin_response, "/");
|
||||
|
||||
let admin_cookie = sign_in_and_get_cookie(&app, "admin_user", "admin_password").await;
|
||||
let admin_response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin")
|
||||
.header(header::COOKIE, admin_cookie)
|
||||
.body(Body::empty())
|
||||
.expect("admin request"),
|
||||
)
|
||||
.await
|
||||
.expect("admin response");
|
||||
assert_eq!(admin_response.status(), StatusCode::OK);
|
||||
|
||||
tokio::fs::remove_dir_all(&data_dir).await.ok();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn admin_patch_blocks_non_admin_and_unauth_before_side_effects() {
|
||||
let (app, db, data_dir) = build_test_app().await;
|
||||
|
||||
User::create_new(
|
||||
"admin_user_patch".to_string(),
|
||||
"admin_password_patch".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("admin user should be created");
|
||||
User::create_new(
|
||||
"member_user_patch".to_string(),
|
||||
"member_password_patch".to_string(),
|
||||
&db,
|
||||
"UTC".to_string(),
|
||||
"system".to_string(),
|
||||
)
|
||||
.await
|
||||
.expect("non-admin user should be created");
|
||||
|
||||
let initial_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("settings should be available");
|
||||
|
||||
let patch_body = if initial_settings.registrations_enabled {
|
||||
String::new()
|
||||
} else {
|
||||
"registration_open=on".to_string()
|
||||
};
|
||||
let expected_after_admin_patch = !initial_settings.registrations_enabled;
|
||||
|
||||
let unauth_patch_response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("PATCH")
|
||||
.uri("/toggle-registrations")
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.body(Body::from(patch_body.clone()))
|
||||
.expect("unauth patch request"),
|
||||
)
|
||||
.await
|
||||
.expect("unauth patch response");
|
||||
assert_redirect_to(&unauth_patch_response, "/signin");
|
||||
|
||||
let settings_after_unauth = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("settings should still be available");
|
||||
assert_eq!(
|
||||
settings_after_unauth.registrations_enabled,
|
||||
initial_settings.registrations_enabled
|
||||
);
|
||||
|
||||
let non_admin_cookie =
|
||||
sign_in_and_get_cookie(&app, "member_user_patch", "member_password_patch").await;
|
||||
let non_admin_patch_response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("PATCH")
|
||||
.uri("/toggle-registrations")
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.header(header::COOKIE, non_admin_cookie)
|
||||
.body(Body::from(patch_body.clone()))
|
||||
.expect("non-admin patch request"),
|
||||
)
|
||||
.await
|
||||
.expect("non-admin patch response");
|
||||
assert_redirect_to(&non_admin_patch_response, "/");
|
||||
|
||||
let settings_after_non_admin = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("settings should still be available");
|
||||
assert_eq!(
|
||||
settings_after_non_admin.registrations_enabled,
|
||||
initial_settings.registrations_enabled
|
||||
);
|
||||
|
||||
let admin_cookie =
|
||||
sign_in_and_get_cookie(&app, "admin_user_patch", "admin_password_patch").await;
|
||||
let admin_patch_response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("PATCH")
|
||||
.uri("/toggle-registrations")
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.header(header::COOKIE, admin_cookie)
|
||||
.body(Body::from(patch_body))
|
||||
.expect("admin patch request"),
|
||||
)
|
||||
.await
|
||||
.expect("admin patch response");
|
||||
assert_eq!(admin_patch_response.status(), StatusCode::OK);
|
||||
|
||||
let settings_after_admin = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("settings should still be available");
|
||||
assert_eq!(
|
||||
settings_after_admin.registrations_enabled,
|
||||
expected_after_admin_patch
|
||||
);
|
||||
|
||||
tokio::fs::remove_dir_all(&data_dir).await.ok();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,4 +23,7 @@ uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
|
||||
common = { path = "../common" }
|
||||
|
||||
[dev-dependencies]
|
||||
common = { path = "../common", features = ["test-utils"] }
|
||||
|
||||
@@ -61,8 +61,8 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
|
||||
.iter()
|
||||
.map(|chunk| {
|
||||
serde_json::json!({
|
||||
"id": chunk.chunk.id,
|
||||
"content": chunk.chunk.chunk,
|
||||
"source_id": chunk.chunk.source_id,
|
||||
"score": round_score(chunk.score),
|
||||
})
|
||||
})
|
||||
@@ -117,7 +117,7 @@ pub fn create_chat_request(
|
||||
.build()
|
||||
}
|
||||
|
||||
pub async fn process_llm_response(
|
||||
pub fn process_llm_response(
|
||||
response: CreateChatCompletionResponse,
|
||||
) -> Result<LLMResponseFormat, AppError> {
|
||||
response
|
||||
|
||||
Reference in New Issue
Block a user