13 Commits

Author SHA1 Message Date
Per Stark
4d237ff6d9 release: 1.0.2 2026-02-15 11:57:04 +01:00
Per Stark
eb928cdb0e test: minio to devenv, improved testing s3 and relationships 2026-02-15 08:52:56 +01:00
Per Stark
1490852a09 chore: dep updates & kv-mem separation to test feature
docker builder update
2026-02-15 08:51:48 +01:00
Per Stark
b0b01182d7 test: add admin auth integration coverage 2026-02-14 23:11:35 +01:00
Per Stark
679308aa1d feat: caching chat history & dto 2026-02-14 19:43:34 +01:00
Per Stark
f93c06b347 fix: harden html responses and cache chat sidebar data
Use strict template response handling and sanitized template user context, then add an in-process conversation archive cache with mutation-driven invalidation for chat sidebar renders.
2026-02-14 17:47:14 +01:00
Per Stark
a3f207beb1 fix: simplified admin checking 2026-02-13 23:04:01 +01:00
Per Stark
e07199adfc fix: name harmonization of endpoints & ingestion security hardening 2026-02-13 22:36:00 +01:00
Per Stark
f22cac891c fix: redact ingestion payload logs and update changelog 2026-02-13 12:06:18 +01:00
Per Stark
b89171d934 fix: parameterize storage-layer queries and add injection tests 2026-02-12 21:42:46 +01:00
Per Stark
0133eead63 fix: border in navigation 2026-02-12 20:39:36 +01:00
Per Stark
e5d2b6605f fix: browser back navigation from chat windows
addenum
2026-02-12 20:32:06 +01:00
Per Stark
bbad91d55b fix: references bug
fix
2026-02-11 22:02:40 +01:00
64 changed files with 9220 additions and 2010 deletions

View File

@@ -1,4 +1,14 @@
# Changelog
## Unreleased
## 1.0.2 (2026-02-15)
- Fix: edge case where navigation back to a chat page could trigger a new response generation
- Fix: chat references now validate and render more reliably
- Fix: improved admin access checks for restricted routes
- Performance: faster chat sidebar loads from cached conversation archive data
- API: harmonized ingest endpoint naming and added configurable ingest safety limits
- Security: hardened query handling and ingestion logging to reduce injection and data exposure risk
## 1.0.1 (2026-02-11)
- Shipped an S3 storage backend so content can be stored in object storage instead of local disk, with configuration support for S3 deployments.
- Introduced user theme preferences with the new Obsidian Prism look and improved dark mode styling.

3201
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -40,7 +40,7 @@ serde_json = "1.0.128"
serde = { version = "1", features = ["derive"] }
sha2 = "0.10.8"
surrealdb-migrations = "2.2.2"
surrealdb = { version = "2", features = ["kv-mem"] }
surrealdb = { version = "2" }
tempfile = "3.12.0"
text-splitter = { version = "0.18.1", features = ["markdown", "tokenizers"] }
tokenizers = { version = "0.20.4", features = ["http"] }

View File

@@ -1,5 +1,5 @@
# === Builder ===
FROM rust:1.86-bookworm AS builder
FROM rust:1.89-bookworm AS builder
WORKDIR /usr/src/minne
RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config clang cmake git && rm -rf /var/lib/apt/lists/*
@@ -30,8 +30,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libgomp1 libstdc++6 curl \
&& rm -rf /var/lib/apt/lists/*
# ONNX Runtime (CPU). Change if you bump ort.
ARG ORT_VERSION=1.22.0
# ONNX Runtime (CPU). Keep in sync with ort crate requirements.
ARG ORT_VERSION=1.23.2
RUN mkdir -p /opt/onnxruntime && \
curl -fsSL -o /tmp/ort.tgz \
"https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-${ORT_VERSION}.tgz" && \

View File

@@ -20,6 +20,9 @@ pub enum ApiError {
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Payload too large: {0}")]
PayloadTooLarge(String),
}
impl From<AppError> for ApiError {
@@ -67,6 +70,13 @@ impl IntoResponse for ApiError {
status: "error".to_string(),
},
),
Self::PayloadTooLarge(message) => (
StatusCode::PAYLOAD_TOO_LARGE,
ErrorResponse {
error: message,
status: "error".to_string(),
},
),
};
(status, Json(error_response)).into_response()
@@ -132,6 +142,10 @@ mod tests {
// Test unauthorized status
let error = ApiError::Unauthorized("not allowed".to_string());
assert_status_code(error, StatusCode::UNAUTHORIZED);
// Test payload too large status
let error = ApiError::PayloadTooLarge("too big".to_string());
assert_status_code(error, StatusCode::PAYLOAD_TOO_LARGE);
}
// Alternative approach that doesn't try to parse the response body

View File

@@ -6,7 +6,7 @@ use axum::{
Router,
};
use middleware_api_auth::api_auth;
use routes::{categories::get_categories, ingress::ingest_data, liveness::live, readiness::ready};
use routes::{categories::get_categories, ingest::ingest_data, liveness::live, readiness::ready};
pub mod api_state;
pub mod error;
@@ -26,9 +26,13 @@ where
// Protected API endpoints (require auth)
let protected = Router::new()
.route("/ingress", post(ingest_data))
.route(
"/ingest",
post(ingest_data).layer(DefaultBodyLimit::max(
app_state.config.ingest_max_body_bytes,
)),
)
.route("/categories", get(get_categories))
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
.route_layer(from_fn_with_state(app_state.clone(), api_auth));
public.merge(protected)

View File

@@ -6,6 +6,7 @@ use common::{
file_info::FileInfo, ingestion_payload::IngestionPayload, ingestion_task::IngestionTask,
user::User,
},
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
};
use futures::{future::try_join_all, TryFutureExt};
use serde_json::json;
@@ -19,7 +20,7 @@ pub struct IngestParams {
pub content: Option<String>,
pub context: String,
pub category: String,
#[form_data(limit = "10000000")] // Adjust limit as needed
#[form_data(limit = "20000000")]
#[form_data(default)]
pub files: Vec<FieldData<NamedTempFile>>,
}
@@ -29,8 +30,38 @@ pub async fn ingest_data(
Extension(user): Extension<User>,
TypedMultipart(input): TypedMultipart<IngestParams>,
) -> Result<impl IntoResponse, ApiError> {
info!("Received input: {:?}", input);
let user_id = user.id;
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
match validate_ingest_input(
&state.config,
input.content.as_deref(),
&input.context,
&input.category,
file_count,
) {
Ok(()) => {}
Err(IngestValidationError::PayloadTooLarge(message)) => {
return Err(ApiError::PayloadTooLarge(message));
}
Err(IngestValidationError::BadRequest(message)) => {
return Err(ApiError::ValidationError(message));
}
}
info!(
user_id = %user_id,
has_content,
content_bytes,
context_bytes,
category_bytes,
file_count,
"Received ingest request"
);
let file_infos = try_join_all(input.files.into_iter().map(|file| {
FileInfo::new_with_storage(file, &state.db, &user_id, &state.storage)

View File

@@ -1,4 +1,4 @@
pub mod categories;
pub mod ingress;
pub mod ingest;
pub mod liveness;
pub mod readiness;

View File

@@ -16,7 +16,7 @@ tracing = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
serde_json = { workspace = true }
surrealdb = { workspace = true, features = ["kv-mem"] }
surrealdb = { workspace = true }
async-openai = { workspace = true }
futures = { workspace = true }
tempfile = { workspace = true }
@@ -49,4 +49,7 @@ fastembed = { workspace = true }
[features]
test-utils = []
test-utils = ["surrealdb/kv-mem"]
[dev-dependencies]
surrealdb = { workspace = true, features = ["kv-mem"] }

View File

@@ -0,0 +1 @@
{"schemas":"--- original\n+++ modified\n@@ -28,6 +28,7 @@\n # Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)\n DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;\n DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY\n+DEFINE INDEX IF NOT EXISTS conversation_user_updated_at_idx ON conversation FIELDS user_id, updated_at; # For sidebar conversation projection ORDER BY\n\n # Defines the schema for the 'file' table (used by FileInfo).\n\n","events":null}

View File

@@ -13,3 +13,4 @@ DEFINE FIELD IF NOT EXISTS title ON conversation TYPE string;
# Add indexes based on query patterns (get_complete_conversation ownership check, get_user_conversations)
DEFINE INDEX IF NOT EXISTS conversation_user_id_idx ON conversation FIELDS user_id;
DEFINE INDEX IF NOT EXISTS conversation_created_at_idx ON conversation FIELDS created_at; # For get_user_conversations ORDER BY
DEFINE INDEX IF NOT EXISTS conversation_user_updated_at_idx ON conversation FIELDS user_id, updated_at; # For sidebar conversation projection ORDER BY

View File

@@ -281,6 +281,33 @@ pub mod testing {
use crate::utils::config::{AppConfig, PdfIngestMode};
use uuid;
const DEFAULT_TEST_S3_BUCKET: &str = "minne-tests";
const DEFAULT_TEST_S3_ENDPOINT: &str = "http://127.0.0.1:19000";
fn configured_test_s3_bucket() -> String {
std::env::var("MINNE_TEST_S3_BUCKET")
.ok()
.filter(|value| !value.trim().is_empty())
.or_else(|| {
std::env::var("S3_BUCKET")
.ok()
.filter(|value| !value.trim().is_empty())
})
.unwrap_or_else(|| DEFAULT_TEST_S3_BUCKET.to_string())
}
fn configured_test_s3_endpoint() -> String {
std::env::var("MINNE_TEST_S3_ENDPOINT")
.ok()
.filter(|value| !value.trim().is_empty())
.or_else(|| {
std::env::var("S3_ENDPOINT")
.ok()
.filter(|value| !value.trim().is_empty())
})
.unwrap_or_else(|| DEFAULT_TEST_S3_ENDPOINT.to_string())
}
/// Create a test configuration with memory storage.
///
/// This provides a ready-to-use configuration for testing scenarios
@@ -326,7 +353,8 @@ pub mod testing {
/// Create a test configuration with S3 storage (MinIO).
///
/// This requires a running MinIO instance on localhost:9000.
/// Uses `MINNE_TEST_S3_ENDPOINT` / `S3_ENDPOINT` and
/// `MINNE_TEST_S3_BUCKET` / `S3_BUCKET` when provided.
pub fn test_config_s3() -> AppConfig {
AppConfig {
openai_api_key: "test".into(),
@@ -339,8 +367,8 @@ pub mod testing {
http_port: 0,
openai_base_url: "..".into(),
storage: StorageKind::S3,
s3_bucket: Some("minne-tests".into()),
s3_endpoint: Some("http://localhost:9000".into()),
s3_bucket: Some(configured_test_s3_bucket()),
s3_endpoint: Some(configured_test_s3_endpoint()),
s3_region: Some("us-east-1".into()),
pdf_ingest_mode: PdfIngestMode::LlmFirst,
..Default::default()
@@ -391,8 +419,7 @@ pub mod testing {
/// Create a new TestStorageManager with S3 backend (MinIO).
///
/// This requires a running MinIO instance on localhost:9000 with
/// default credentials (minioadmin/minioadmin) and a 'minne-tests' bucket.
/// This requires a reachable MinIO endpoint and an existing test bucket.
pub async fn new_s3() -> object_store::Result<Self> {
// Ensure credentials are set for MinIO
// We set these env vars for the process, which AmazonS3Builder will pick up
@@ -403,6 +430,11 @@ pub mod testing {
let cfg = test_config_s3();
let storage = StorageManager::new(&cfg).await?;
// Probe the bucket so tests can cleanly skip when the endpoint is unreachable
// or the test bucket is not provisioned.
let probe_prefix = format!("__minne_s3_probe__/{}", uuid::Uuid::new_v4());
storage.list(Some(&probe_prefix)).await?;
Ok(Self {
storage,
_temp_dir: None,
@@ -923,8 +955,8 @@ mod tests {
assert_eq!(*test_storage.storage().backend_kind(), StorageKind::Memory);
}
// S3 Tests - Require MinIO on localhost:9000 with bucket 'minne-tests'
// These tests will fail if MinIO is not running or bucket doesn't exist.
// S3 Tests - Require a reachable MinIO endpoint and test bucket.
// `TestStorageManager::new_s3()` probes connectivity and these tests auto-skip when unavailable.
#[tokio::test]
async fn test_storage_manager_s3_basic_operations() {

View File

@@ -10,6 +10,54 @@ stored_object!(Conversation, "conversation", {
title: String
});
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
pub struct SidebarConversation {
#[serde(deserialize_with = "deserialize_sidebar_id")]
pub id: String,
pub title: String,
}
struct SidebarIdVisitor;
impl<'de> serde::de::Visitor<'de> for SidebarIdVisitor {
type Value = String;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string id or a SurrealDB Thing")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(value.to_string())
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(value)
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let thing = <surrealdb::sql::Thing as serde::Deserialize>::deserialize(
serde::de::value::MapAccessDeserializer::new(map),
)?;
Ok(thing.id.to_raw())
}
}
fn deserialize_sidebar_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(SidebarIdVisitor)
}
impl Conversation {
pub fn new(user_id: String, title: String) -> Self {
let now = Utc::now();
@@ -75,6 +123,23 @@ impl Conversation {
Ok(())
}
pub async fn get_user_sidebar_conversations(
user_id: &str,
db: &SurrealDbClient,
) -> Result<Vec<SidebarConversation>, AppError> {
let conversations: Vec<SidebarConversation> = db
.client
.query(
"SELECT id, title, updated_at FROM type::table($table_name) WHERE user_id = $user_id ORDER BY updated_at DESC",
)
.bind(("table_name", Self::table_name()))
.bind(("user_id", user_id.to_string()))
.await?
.take(0)?;
Ok(conversations)
}
}
#[cfg(test)]
@@ -249,6 +314,96 @@ mod tests {
}
}
#[tokio::test]
async fn test_get_user_sidebar_conversations_filters_and_orders_by_updated_at_desc() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let user_id = "sidebar_user";
let other_user_id = "other_user";
let base = Utc::now();
let mut oldest = Conversation::new(user_id.to_string(), "Oldest".to_string());
oldest.updated_at = base - chrono::Duration::minutes(30);
let mut newest = Conversation::new(user_id.to_string(), "Newest".to_string());
newest.updated_at = base - chrono::Duration::minutes(5);
let mut middle = Conversation::new(user_id.to_string(), "Middle".to_string());
middle.updated_at = base - chrono::Duration::minutes(15);
let mut other_user = Conversation::new(other_user_id.to_string(), "Other".to_string());
other_user.updated_at = base;
db.store_item(oldest.clone())
.await
.expect("Failed to store oldest conversation");
db.store_item(newest.clone())
.await
.expect("Failed to store newest conversation");
db.store_item(middle.clone())
.await
.expect("Failed to store middle conversation");
db.store_item(other_user)
.await
.expect("Failed to store other-user conversation");
let sidebar_items = Conversation::get_user_sidebar_conversations(user_id, &db)
.await
.expect("Failed to get sidebar conversations");
assert_eq!(sidebar_items.len(), 3);
assert_eq!(sidebar_items[0].id, newest.id);
assert_eq!(sidebar_items[0].title, "Newest");
assert_eq!(sidebar_items[1].id, middle.id);
assert_eq!(sidebar_items[1].title, "Middle");
assert_eq!(sidebar_items[2].id, oldest.id);
assert_eq!(sidebar_items[2].title, "Oldest");
}
#[tokio::test]
async fn test_sidebar_projection_reflects_patch_title_and_updated_at_reorder() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let user_id = "sidebar_patch_user";
let base = Utc::now();
let mut first = Conversation::new(user_id.to_string(), "First".to_string());
first.updated_at = base - chrono::Duration::minutes(20);
let mut second = Conversation::new(user_id.to_string(), "Second".to_string());
second.updated_at = base - chrono::Duration::minutes(10);
db.store_item(first.clone())
.await
.expect("Failed to store first conversation");
db.store_item(second.clone())
.await
.expect("Failed to store second conversation");
let before_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
.await
.expect("Failed to get sidebar conversations before patch");
assert_eq!(before_patch[0].id, second.id);
Conversation::patch_title(&first.id, user_id, "First (renamed)", &db)
.await
.expect("Failed to patch conversation title");
let after_patch = Conversation::get_user_sidebar_conversations(user_id, &db)
.await
.expect("Failed to get sidebar conversations after patch");
assert_eq!(after_patch[0].id, first.id);
assert_eq!(after_patch[0].title, "First (renamed)");
}
#[tokio::test]
async fn test_get_complete_conversation_with_messages() {
// Setup in-memory database for testing

View File

@@ -137,8 +137,12 @@ impl FileInfo {
/// # Returns
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256);
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
let mut response = db_client
.client
.query("SELECT * FROM file WHERE sha256 = $sha256 LIMIT 1")
.bind(("sha256", sha256.to_owned()))
.await?;
let response: Vec<FileInfo> = response.take(0)?;
response
.into_iter()
@@ -665,6 +669,36 @@ mod tests {
}
}
#[tokio::test]
async fn test_get_by_sha_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let now = Utc::now();
let file_info = FileInfo {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
user_id: "user123".to_string(),
sha256: "known_sha_value".to_string(),
path: "/path/to/file.txt".to_string(),
file_name: "file.txt".to_string(),
mime_type: "text/plain".to_string(),
};
db.store_item(file_info)
.await
.expect("Failed to store test file info");
let malicious_sha = "known_sha_value' OR true --";
let result = FileInfo::get_by_sha(malicious_sha, &db).await;
assert!(matches!(result, Err(FileError::FileNotFound(_))));
}
#[tokio::test]
async fn test_manual_file_info_creation() {
let namespace = "test_ns";

View File

@@ -174,12 +174,15 @@ impl KnowledgeEntity {
// Delete embeddings first, while we can still look them up via the entity's source_id
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
db_client
.client
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
.bind(("table", Self::table_name()))
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
@@ -761,6 +764,69 @@ mod tests {
assert_eq!(different_remaining[0].id, different_entity.id);
}
#[tokio::test]
async fn test_delete_by_source_id_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let user_id = "user123".to_string();
let entity1 = KnowledgeEntity::new(
"safe_source".to_string(),
"Entity 1".to_string(),
"Description 1".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
let entity2 = KnowledgeEntity::new(
"other_source".to_string(),
"Entity 2".to_string(),
"Description 2".to_string(),
KnowledgeEntityType::Document,
None,
user_id,
);
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db)
.await
.expect("store entity1");
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db)
.await
.expect("store entity2");
let malicious_source = "safe_source' OR 1=1 --";
KnowledgeEntity::delete_by_source_id(malicious_source, &db)
.await
.expect("delete call should succeed");
let remaining: Vec<KnowledgeEntity> = db
.client
.query("SELECT * FROM type::table($table)")
.bind(("table", KnowledgeEntity::table_name()))
.await
.expect("query failed")
.take(0)
.expect("take failed");
assert_eq!(
remaining.len(),
2,
"malicious input must not delete unrelated entities"
);
}
#[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns";

View File

@@ -40,22 +40,28 @@ impl KnowledgeRelationship {
}
}
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
let query = format!(
r#"DELETE relates_to:`{rel_id}`;
RELATE knowledge_entity:`{in_id}`->relates_to:`{rel_id}`->knowledge_entity:`{out_id}`
SET
metadata.user_id = '{user_id}',
metadata.source_id = '{source_id}',
metadata.relationship_type = '{relationship_type}'"#,
rel_id = self.id,
in_id = self.in_,
out_id = self.out,
user_id = self.metadata.user_id.as_str(),
source_id = self.metadata.source_id.as_str(),
relationship_type = self.metadata.relationship_type.as_str()
);
db_client.query(query).await?.check()?;
db_client
.client
.query(
r#"BEGIN TRANSACTION;
LET $in_entity = type::thing('knowledge_entity', $in_id);
LET $out_entity = type::thing('knowledge_entity', $out_id);
LET $relation = type::thing('relates_to', $rel_id);
DELETE type::thing('relates_to', $rel_id);
RELATE $in_entity->$relation->$out_entity SET
metadata.user_id = $user_id,
metadata.source_id = $source_id,
metadata.relationship_type = $relationship_type;
COMMIT TRANSACTION;"#,
)
.bind(("rel_id", self.id.clone()))
.bind(("in_id", self.in_.clone()))
.bind(("out_id", self.out.clone()))
.bind(("user_id", self.metadata.user_id.clone()))
.bind(("source_id", self.metadata.source_id.clone()))
.bind(("relationship_type", self.metadata.relationship_type.clone()))
.await?
.check()?;
Ok(())
}
@@ -64,11 +70,12 @@ impl KnowledgeRelationship {
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'"
);
db_client.query(query).await?;
db_client
.client
.query("DELETE FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.to_owned()))
.await?
.check()?;
Ok(())
}
@@ -79,15 +86,20 @@ impl KnowledgeRelationship {
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let mut authorized_result = db_client
.query(format!(
"SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'"
))
.client
.query(
"SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id",
)
.bind(("id", id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
if authorized.is_empty() {
let mut exists_result = db_client
.query(format!("SELECT * FROM relates_to:`{id}`"))
.client
.query("SELECT * FROM type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?;
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
@@ -99,7 +111,12 @@ impl KnowledgeRelationship {
Err(AppError::NotFound(format!("Relationship {id} not found")))
}
} else {
db_client.query(format!("DELETE relates_to:`{id}`")).await?;
db_client
.client
.query("DELETE type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?
.check()?;
Ok(())
}
}
@@ -110,6 +127,34 @@ mod tests {
use super::*;
use crate::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
async fn setup_test_db() -> SurrealDbClient {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
db
}
async fn get_relationship_by_id(
relationship_id: &str,
db_client: &SurrealDbClient,
) -> Option<KnowledgeRelationship> {
let mut result = db_client
.client
.query("SELECT * FROM type::thing('relates_to', $id)")
.bind(("id", relationship_id.to_owned()))
.await
.expect("relationship query by id failed");
result.take(0).expect("failed to take relationship by id")
}
// Helper function to create a test knowledge entity for the relationship tests
async fn create_test_entity(name: &str, db_client: &SurrealDbClient) -> String {
let source_id = "source123".to_string();
@@ -161,15 +206,7 @@ mod tests {
#[tokio::test]
async fn test_store_and_verify_by_source_id() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
let db = setup_test_db().await;
// Create two entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await;
@@ -194,30 +231,69 @@ mod tests {
.await
.expect("Failed to store relationship");
let persisted = get_relationship_by_id(&relationship.id, &db)
.await
.expect("Relationship should be retrievable by id");
assert_eq!(persisted.in_, entity1_id);
assert_eq!(persisted.out, entity2_id);
assert_eq!(persisted.metadata.user_id, user_id);
assert_eq!(persisted.metadata.source_id, source_id);
// Query to verify the relationship exists by checking for relationships with our source_id
// This approach is more reliable than trying to look up by ID
let check_query = format!(
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
source_id
);
let mut check_result = db.query(check_query).await.expect("Check query failed");
let mut check_result = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.clone()))
.await
.expect("Check query failed");
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
// Just verify that a relationship was created
assert!(
!check_results.is_empty(),
"Relationship should exist in the database"
assert_eq!(
check_results.len(),
1,
"Expected one relationship for source_id"
);
}
#[tokio::test]
async fn test_store_relationship_resists_query_injection() {
let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let relationship = KnowledgeRelationship::new(
entity1_id,
entity2_id,
"user'123".to_string(),
"source123'; DELETE FROM relates_to; --".to_string(),
"references'; UPDATE user SET admin = true; --".to_string(),
);
relationship
.store_relationship(&db)
.await
.expect("store relationship should safely handle quote-containing values");
let mut res = db
.client
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
.bind(("id", relationship.id.clone()))
.await
.expect("query relationship by id failed");
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
assert_eq!(rows.len(), 1);
assert_eq!(
rows[0].metadata.source_id,
"source123'; DELETE FROM relates_to; --"
);
}
#[tokio::test]
async fn test_store_and_delete_relationship() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = setup_test_db().await;
// Create two entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await;
@@ -278,11 +354,7 @@ mod tests {
#[tokio::test]
async fn test_delete_relationship_by_id_unauthorized() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
@@ -346,11 +418,7 @@ mod tests {
#[tokio::test]
async fn test_store_relationship_exists() {
// Setup in-memory database for testing
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let db = setup_test_db().await;
// Create entities to relate
let entity1_id = create_test_entity("Entity 1", &db).await;
@@ -402,49 +470,87 @@ mod tests {
.await
.expect("Failed to store different relationship");
// Sanity-check setup: exactly two relationships use source_id and one uses different_source_id.
let mut before_delete = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.clone()))
.await
.expect("before delete query failed");
let before_delete_rows: Vec<KnowledgeRelationship> =
before_delete.take(0).unwrap_or_default();
assert_eq!(before_delete_rows.len(), 2);
let mut before_delete_different = db
.query("SELECT * FROM relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", different_source_id.clone()))
.await
.expect("before delete different query failed");
let before_delete_different_rows: Vec<KnowledgeRelationship> =
before_delete_different.take(0).unwrap_or_default();
assert_eq!(before_delete_different_rows.len(), 1);
// Delete relationships by source_id
KnowledgeRelationship::delete_relationships_by_source_id(&source_id, &db)
.await
.expect("Failed to delete relationships by source_id");
// Query to verify the relationships with source_id were deleted
let query1 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship1.id);
let query2 = format!("SELECT * FROM relates_to WHERE id = '{}'", relationship2.id);
let different_query = format!(
"SELECT * FROM relates_to WHERE id = '{}'",
different_relationship.id
);
let mut result1 = db.query(query1).await.expect("Query 1 failed");
let results1: Vec<KnowledgeRelationship> = result1.take(0).unwrap_or_default();
let mut result2 = db.query(query2).await.expect("Query 2 failed");
let results2: Vec<KnowledgeRelationship> = result2.take(0).unwrap_or_default();
let mut different_result = db
.query(different_query)
.await
.expect("Different query failed");
let _different_results: Vec<KnowledgeRelationship> =
different_result.take(0).unwrap_or_default();
// Query to verify the specific relationships with source_id were deleted.
let result1 = get_relationship_by_id(&relationship1.id, &db).await;
let result2 = get_relationship_by_id(&relationship2.id, &db).await;
let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
// Verify relationships with the source_id are deleted
assert!(results1.is_empty(), "Relationship 1 should be deleted");
assert!(results2.is_empty(), "Relationship 2 should be deleted");
assert!(result1.is_none(), "Relationship 1 should be deleted");
assert!(result2.is_none(), "Relationship 2 should be deleted");
let remaining =
different_result.expect("Relationship with different source_id should remain");
assert_eq!(remaining.metadata.source_id, different_source_id);
}
// For the relationship with different source ID, we need to check differently
// Let's just verify we have a relationship where the source_id matches different_source_id
let check_query = format!(
"SELECT * FROM relates_to WHERE metadata.source_id = '{}'",
different_source_id
#[tokio::test]
async fn test_delete_relationships_by_source_id_resists_query_injection() {
let db = setup_test_db().await;
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let entity3_id = create_test_entity("Entity 3", &db).await;
let safe_relationship = KnowledgeRelationship::new(
entity1_id.clone(),
entity2_id.clone(),
"user123".to_string(),
"safe_source".to_string(),
"references".to_string(),
);
let mut check_result = db.query(check_query).await.expect("Check query failed");
let check_results: Vec<KnowledgeRelationship> = check_result.take(0).unwrap_or_default();
// Verify the relationship with a different source_id still exists
let other_relationship = KnowledgeRelationship::new(
entity2_id,
entity3_id,
"user123".to_string(),
"other_source".to_string(),
"contains".to_string(),
);
safe_relationship
.store_relationship(&db)
.await
.expect("store safe relationship");
other_relationship
.store_relationship(&db)
.await
.expect("store other relationship");
KnowledgeRelationship::delete_relationships_by_source_id("safe_source' OR 1=1 --", &db)
.await
.expect("delete call should succeed");
let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await;
let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await;
assert!(remaining_safe.is_some(), "Safe relationship should remain");
assert!(
!check_results.is_empty(),
"Relationship with different source_id should still exist"
remaining_other.is_some(),
"Other relationship should remain"
);
}
}

View File

@@ -47,12 +47,15 @@ impl TextChunk {
// Delete embeddings first
TextChunkEmbedding::delete_by_source_id(source_id, db_client).await?;
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
db_client
.client
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
.bind(("table", Self::table_name()))
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
@@ -617,6 +620,57 @@ mod tests {
assert_eq!(remaining.len(), 1);
}
#[tokio::test]
async fn test_delete_by_source_id_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
let chunk1 = TextChunk::new(
"safe_source".to_string(),
"Safe chunk".to_string(),
"user123".to_string(),
);
let chunk2 = TextChunk::new(
"other_source".to_string(),
"Other chunk".to_string(),
"user123".to_string(),
);
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("store chunk1");
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db)
.await
.expect("store chunk2");
let malicious_source = "safe_source' OR 1=1 --";
TextChunk::delete_by_source_id(malicious_source, &db)
.await
.expect("delete call should succeed");
let remaining: Vec<TextChunk> = db
.client
.query("SELECT * FROM type::table($table)")
.bind(("table", TextChunk::table_name()))
.await
.expect("query failed")
.take(0)
.expect("take failed");
assert_eq!(
remaining.len(),
2,
"malicious input must not delete unrelated rows"
);
}
#[tokio::test]
async fn test_store_with_embedding_creates_both_records() {
let namespace = "test_ns";

View File

@@ -86,6 +86,16 @@ pub struct AppConfig {
pub retrieval_strategy: Option<String>,
#[serde(default)]
pub embedding_backend: EmbeddingBackend,
#[serde(default = "default_ingest_max_body_bytes")]
pub ingest_max_body_bytes: usize,
#[serde(default = "default_ingest_max_files")]
pub ingest_max_files: usize,
#[serde(default = "default_ingest_max_content_bytes")]
pub ingest_max_content_bytes: usize,
#[serde(default = "default_ingest_max_context_bytes")]
pub ingest_max_context_bytes: usize,
#[serde(default = "default_ingest_max_category_bytes")]
pub ingest_max_category_bytes: usize,
}
/// Default data directory for persisted assets.
@@ -103,6 +113,26 @@ fn default_reranking_enabled() -> bool {
false
}
fn default_ingest_max_body_bytes() -> usize {
20_000_000
}
fn default_ingest_max_files() -> usize {
5
}
fn default_ingest_max_content_bytes() -> usize {
262_144
}
fn default_ingest_max_context_bytes() -> usize {
16_384
}
fn default_ingest_max_category_bytes() -> usize {
128
}
pub fn ensure_ort_path() {
if env::var_os("ORT_DYLIB_PATH").is_some() {
return;
@@ -157,6 +187,11 @@ impl Default for AppConfig {
fastembed_max_length: None,
retrieval_strategy: None,
embedding_backend: EmbeddingBackend::default(),
ingest_max_body_bytes: default_ingest_max_body_bytes(),
ingest_max_files: default_ingest_max_files(),
ingest_max_content_bytes: default_ingest_max_content_bytes(),
ingest_max_context_bytes: default_ingest_max_context_bytes(),
ingest_max_category_bytes: default_ingest_max_category_bytes(),
}
}
}

View File

@@ -0,0 +1,113 @@
use super::config::AppConfig;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IngestValidationError {
PayloadTooLarge(String),
BadRequest(String),
}
pub fn validate_ingest_input(
config: &AppConfig,
content: Option<&str>,
context: &str,
category: &str,
file_count: usize,
) -> Result<(), IngestValidationError> {
if file_count > config.ingest_max_files {
return Err(IngestValidationError::BadRequest(format!(
"Too many files. Maximum allowed is {}",
config.ingest_max_files
)));
}
if let Some(content) = content {
if content.len() > config.ingest_max_content_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Content is too large. Maximum allowed is {} bytes",
config.ingest_max_content_bytes
)));
}
}
if context.len() > config.ingest_max_context_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Context is too large. Maximum allowed is {} bytes",
config.ingest_max_context_bytes
)));
}
if category.len() > config.ingest_max_category_bytes {
return Err(IngestValidationError::PayloadTooLarge(format!(
"Category is too large. Maximum allowed is {} bytes",
config.ingest_max_category_bytes
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_ingest_input_rejects_too_many_files() {
let config = AppConfig {
ingest_max_files: 1,
..Default::default()
};
let result = validate_ingest_input(&config, Some("ok"), "ctx", "cat", 2);
assert!(matches!(result, Err(IngestValidationError::BadRequest(_))));
}
#[test]
fn validate_ingest_input_rejects_oversized_content() {
let config = AppConfig {
ingest_max_content_bytes: 4,
..Default::default()
};
let result = validate_ingest_input(&config, Some("12345"), "ctx", "cat", 0);
assert!(matches!(
result,
Err(IngestValidationError::PayloadTooLarge(_))
));
}
#[test]
fn validate_ingest_input_rejects_oversized_context() {
let config = AppConfig {
ingest_max_context_bytes: 2,
..Default::default()
};
let result = validate_ingest_input(&config, None, "long", "cat", 0);
assert!(matches!(
result,
Err(IngestValidationError::PayloadTooLarge(_))
));
}
#[test]
fn validate_ingest_input_rejects_oversized_category() {
let config = AppConfig {
ingest_max_category_bytes: 2,
..Default::default()
};
let result = validate_ingest_input(&config, None, "ok", "long", 0);
assert!(matches!(
result,
Err(IngestValidationError::PayloadTooLarge(_))
));
}
#[test]
fn validate_ingest_input_accepts_valid_payload() {
let config = AppConfig::default();
let result = validate_ingest_input(&config, Some("ok"), "ctx", "cat", 1);
assert!(result.is_ok());
}
}

View File

@@ -1,3 +1,4 @@
pub mod config;
pub mod embedding;
pub mod ingest_limits;
pub mod template_engine;

View File

@@ -3,10 +3,10 @@
"devenv": {
"locked": {
"dir": "src/modules",
"lastModified": 1761839147,
"lastModified": 1771066302,
"owner": "cachix",
"repo": "devenv",
"rev": "bb7849648b68035f6b910120252c22b28195cf54",
"rev": "1b355dec9bddbaddbe4966d6fc30d7aa3af8575b",
"type": "github"
},
"original": {
@@ -22,10 +22,10 @@
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1761893049,
"lastModified": 1771052630,
"owner": "nix-community",
"repo": "fenix",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
"rev": "d0555da98576b8611c25df0c208e51e9a182d95f",
"type": "github"
},
"original": {
@@ -37,14 +37,14 @@
"flake-compat": {
"flake": false,
"locked": {
"lastModified": 1761588595,
"owner": "edolstra",
"lastModified": 1767039857,
"owner": "NixOS",
"repo": "flake-compat",
"rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5",
"rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab",
"type": "github"
},
"original": {
"owner": "edolstra",
"owner": "NixOS",
"repo": "flake-compat",
"type": "github"
}
@@ -58,10 +58,10 @@
]
},
"locked": {
"lastModified": 1760663237,
"lastModified": 1770726378,
"owner": "cachix",
"repo": "git-hooks.nix",
"rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37",
"rev": "5eaaedde414f6eb1aea8b8525c466dc37bba95ae",
"type": "github"
},
"original": {
@@ -78,10 +78,10 @@
]
},
"locked": {
"lastModified": 1709087332,
"lastModified": 1762808025,
"owner": "hercules-ci",
"repo": "gitignore.nix",
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
"rev": "cb5e3fdca1de58ccbc3ef53de65bd372b48f567c",
"type": "github"
},
"original": {
@@ -92,10 +92,10 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1761672384,
"lastModified": 1771008912,
"owner": "nixos",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"rev": "a82ccc39b39b621151d6732718e3e250109076fa",
"type": "github"
},
"original": {
@@ -107,10 +107,10 @@
},
"nixpkgs_2": {
"locked": {
"lastModified": 1761880412,
"lastModified": 1770843696,
"owner": "nixos",
"repo": "nixpkgs",
"rev": "a7fc11be66bdfb5cdde611ee5ce381c183da8386",
"rev": "2343bbb58f99267223bc2aac4fc9ea301a155a16",
"type": "github"
},
"original": {
@@ -135,10 +135,10 @@
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1761849405,
"lastModified": 1771007332,
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
"rev": "bbc84d335fbbd9b3099d3e40c7469ee57dbd1873",
"type": "github"
},
"original": {
@@ -155,10 +155,10 @@
]
},
"locked": {
"lastModified": 1761878277,
"lastModified": 1771038269,
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "6604534e44090c917db714faa58d47861657690c",
"rev": "d7a86c8a4df49002446737603a3e0d7ef91a9637",
"type": "github"
},
"original": {

View File

@@ -10,6 +10,7 @@
packages = [
pkgs.openssl
pkgs.nodejs
pkgs.watchman
pkgs.vscode-langservers-extracted
pkgs.cargo-dist
pkgs.cargo-xwin
@@ -29,11 +30,25 @@
env = {
ORT_DYLIB_PATH = "${pkgs.onnxruntime}/lib/libonnxruntime.so";
S3_ENDPOINT = "http://127.0.0.1:19000";
S3_BUCKET = "minne-tests";
MINNE_TEST_S3_ENDPOINT = "http://127.0.0.1:19000";
MINNE_TEST_S3_BUCKET = "minne-tests";
};
services.minio = {
enable = true;
listenAddress = "127.0.0.1:19000";
consoleAddress = "127.0.0.1:19001";
buckets = ["minne-tests"];
accessKey = "minioadmin";
secretKey = "minioadmin";
region = "us-east-1";
};
processes = {
surreal_db.exec = "docker run --rm --pull always -p 8000:8000 --net=host --user $(id -u) -v $(pwd)/database:/database surrealdb/surrealdb:latest-dev start rocksdb:/database/database.db --user root_user --pass root_password";
server.exec = "cargo watch -x 'run --bin main'";
tailwind.exec = "cd html-router && npm run tailwind";
tailwind.exec = "tailwindcss --cwd html-router -i app.css -o assets/style.css --watch=always";
};
}

View File

@@ -29,6 +29,11 @@ Minne can be configured via environment variables or a `config.yaml` file. Envir
| `FASTEMBED_CACHE_DIR` | Model cache directory | `<data_dir>/fastembed` |
| `FASTEMBED_SHOW_DOWNLOAD_PROGRESS` | Show progress bar for model downloads | `false` |
| `FASTEMBED_MAX_LENGTH` | Max sequence length for FastEmbed models | - |
| `INGEST_MAX_BODY_BYTES` | Max request body size for ingest endpoints | `20000000` |
| `INGEST_MAX_FILES` | Max files allowed per ingest request | `5` |
| `INGEST_MAX_CONTENT_BYTES` | Max `content` field size for ingest requests | `262144` |
| `INGEST_MAX_CONTEXT_BYTES` | Max `context` field size for ingest requests | `16384` |
| `INGEST_MAX_CATEGORY_BYTES` | Max `category` field size for ingest requests | `128` |
### S3 Storage (Optional)
@@ -76,6 +81,13 @@ embedding_backend: "fastembed"
# Optional reranking
reranking_enabled: true
reranking_pool_size: 2
# Ingest safety limits
ingest_max_body_bytes: 20000000
ingest_max_files: 5
ingest_max_content_bytes: 262144
ingest_max_context_bytes: 16384
ingest_max_category_bytes: 128
```
## AI Provider Setup

View File

@@ -33,3 +33,4 @@ clap = { version = "4.4", features = ["derive", "env"] }
[dev-dependencies]
tempfile = { workspace = true }
common = { path = "../common", features = ["test-utils"] }

View File

@@ -9,6 +9,8 @@ use std::{
use anyhow::{anyhow, Context, Result};
use async_openai::Client;
use chrono::Utc;
#[cfg(not(test))]
use common::utils::config::get_config;
use common::{
storage::{
db::SurrealDbClient,
@@ -421,11 +423,7 @@ async fn ingest_paragraph_batch(
return Ok(Vec::new());
}
let namespace = format!("ingest_eval_{}", Uuid::new_v4());
let db = Arc::new(
SurrealDbClient::memory(&namespace, "corpus")
.await
.context("creating in-memory surrealdb for ingestion")?,
);
let db = create_ingest_db(&namespace).await?;
db.apply_migrations()
.await
.context("applying migrations for ingestion")?;
@@ -487,6 +485,29 @@ async fn ingest_paragraph_batch(
Ok(shards)
}
#[cfg(test)]
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
let db = SurrealDbClient::memory(namespace, "corpus")
.await
.context("creating in-memory surrealdb for ingestion")?;
Ok(Arc::new(db))
}
#[cfg(not(test))]
async fn create_ingest_db(namespace: &str) -> Result<Arc<SurrealDbClient>> {
let config = get_config().context("loading app config for ingestion database")?;
let db = SurrealDbClient::new(
&config.surrealdb_address,
&config.surrealdb_username,
&config.surrealdb_password,
namespace,
"corpus",
)
.await
.context("creating surrealdb database for ingestion")?;
Ok(Arc::new(db))
}
#[allow(clippy::too_many_arguments)]
async fn ingest_single_paragraph(
pipeline: Arc<IngestionPipeline>,

6
flake.lock generated
View File

@@ -35,11 +35,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
"lastModified": 1771008912,
"narHash": "sha256-gf2AmWVTs8lEq7z/3ZAsgnZDhWIckkb+ZnAo5RzSxJg=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"rev": "a82ccc39b39b621151d6732718e3e250109076fa",
"type": "github"
},
"original": {

View File

@@ -41,5 +41,8 @@ common = { path = "../common" }
retrieval-pipeline = { path = "../retrieval-pipeline" }
json-stream-parser = { path = "../json-stream-parser" }
[dev-dependencies]
common = { path = "../common", features = ["test-utils"] }
[build-dependencies]
minijinja-embed = { version = "2.8.0" }

File diff suppressed because one or more lines are too long

View File

@@ -2,7 +2,7 @@
"name": "html-router",
"version": "1.0.0",
"scripts": {
"tailwind": "npx @tailwindcss/cli -i app.css -o assets/style.css -w -m"
"tailwind": "tailwindcss -i app.css -o assets/style.css --watch=always"
},
"author": "",
"license": "ISC",

View File

@@ -1,9 +1,16 @@
use common::storage::types::conversation::SidebarConversation;
use common::storage::{db::SurrealDbClient, store::StorageManager};
use common::utils::embedding::EmbeddingProvider;
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy};
use std::sync::Arc;
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::debug;
use crate::{OpenAIClientType, SessionStoreType};
@@ -18,8 +25,20 @@ pub struct HtmlState {
pub storage: StorageManager,
pub reranker_pool: Option<Arc<RerankerPool>>,
pub embedding_provider: Arc<EmbeddingProvider>,
conversation_archive_cache: Arc<RwLock<HashMap<String, ConversationArchiveCacheEntry>>>,
conversation_archive_cache_writes: Arc<AtomicUsize>,
}
#[derive(Clone)]
struct ConversationArchiveCacheEntry {
conversations: Vec<SidebarConversation>,
expires_at: Instant,
}
const CONVERSATION_ARCHIVE_CACHE_TTL: Duration = Duration::from_secs(30);
const CONVERSATION_ARCHIVE_CACHE_MAX_USERS: usize = 1024;
const CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL: usize = 64;
impl HtmlState {
pub async fn new_with_resources(
db: Arc<SurrealDbClient>,
@@ -44,6 +63,8 @@ impl HtmlState {
storage,
reranker_pool,
embedding_provider,
conversation_archive_cache: Arc::new(RwLock::new(HashMap::new())),
conversation_archive_cache_writes: Arc::new(AtomicUsize::new(0)),
})
}
@@ -54,6 +75,86 @@ impl HtmlState {
.and_then(|value| value.parse().ok())
.unwrap_or(RetrievalStrategy::Default)
}
pub async fn get_cached_conversation_archive(
&self,
user_id: &str,
) -> Option<Vec<SidebarConversation>> {
let now = Instant::now();
let should_evict_expired = {
let cache = self.conversation_archive_cache.read().await;
if let Some(entry) = cache.get(user_id) {
if entry.expires_at > now {
return Some(entry.conversations.clone());
}
true
} else {
false
}
};
if should_evict_expired {
let mut cache = self.conversation_archive_cache.write().await;
cache.remove(user_id);
}
None
}
pub async fn set_cached_conversation_archive(
&self,
user_id: &str,
conversations: Vec<SidebarConversation>,
) {
let now = Instant::now();
let mut cache = self.conversation_archive_cache.write().await;
cache.insert(
user_id.to_string(),
ConversationArchiveCacheEntry {
conversations,
expires_at: now + CONVERSATION_ARCHIVE_CACHE_TTL,
},
);
let writes = self
.conversation_archive_cache_writes
.fetch_add(1, Ordering::Relaxed)
+ 1;
if writes % CONVERSATION_ARCHIVE_CACHE_CLEANUP_WRITE_INTERVAL == 0 {
Self::purge_expired_entries(&mut cache, now);
}
Self::enforce_cache_capacity(&mut cache);
}
pub async fn invalidate_conversation_archive_cache(&self, user_id: &str) {
let mut cache = self.conversation_archive_cache.write().await;
cache.remove(user_id);
}
fn purge_expired_entries(
cache: &mut HashMap<String, ConversationArchiveCacheEntry>,
now: Instant,
) {
cache.retain(|_, entry| entry.expires_at > now);
}
fn enforce_cache_capacity(cache: &mut HashMap<String, ConversationArchiveCacheEntry>) {
if cache.len() <= CONVERSATION_ARCHIVE_CACHE_MAX_USERS {
return;
}
let overflow = cache.len() - CONVERSATION_ARCHIVE_CACHE_MAX_USERS;
let mut by_expiry: Vec<(String, Instant)> = cache
.iter()
.map(|(user_id, entry)| (user_id.clone(), entry.expires_at))
.collect();
by_expiry.sort_by_key(|(_, expires_at)| *expires_at);
for (user_id, _) in by_expiry.into_iter().take(overflow) {
cache.remove(&user_id);
}
}
}
impl ProvidesDb for HtmlState {
fn db(&self) -> &Arc<SurrealDbClient> {
@@ -71,3 +172,87 @@ impl crate::middlewares::response_middleware::ProvidesHtmlState for HtmlState {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::{
storage::types::conversation::SidebarConversation,
utils::{
config::{AppConfig, StorageKind},
embedding::EmbeddingProvider,
},
};
async fn test_state() -> HtmlState {
let namespace = "test_ns";
let database = &uuid::Uuid::new_v4().to_string();
let db = Arc::new(
SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to create in-memory DB"),
);
let session_store = Arc::new(
db.create_session_store()
.await
.expect("Failed to create session store"),
);
let mut config = AppConfig::default();
config.storage = StorageKind::Memory;
let storage = StorageManager::new(&config)
.await
.expect("Failed to create storage manager");
let embedding_provider = Arc::new(
EmbeddingProvider::new_hashed(8).expect("Failed to create embedding provider"),
);
HtmlState::new_with_resources(
db,
Arc::new(async_openai::Client::new()),
session_store,
storage,
config,
None,
embedding_provider,
None,
)
.await
.expect("Failed to create HtmlState")
}
#[tokio::test]
async fn test_expired_conversation_archive_entry_is_evicted_on_read() {
let state = test_state().await;
let user_id = "expired-user";
{
let mut cache = state.conversation_archive_cache.write().await;
cache.insert(
user_id.to_string(),
ConversationArchiveCacheEntry {
conversations: vec![SidebarConversation {
id: "conv-1".to_string(),
title: "A stale chat".to_string(),
}],
expires_at: Instant::now() - Duration::from_secs(1),
},
);
}
let cached = state.get_cached_conversation_archive(user_id).await;
assert!(
cached.is_none(),
"Expired cache entry should not be returned"
);
let cache = state.conversation_archive_cache.read().await;
assert!(
!cache.contains_key(user_id),
"Expired cache entry should be evicted after read"
);
}
}

View File

@@ -35,7 +35,9 @@ where
.add_protected_routes(routes::chat::router())
.add_protected_routes(routes::content::router())
.add_protected_routes(routes::knowledge::router())
.add_protected_routes(routes::ingestion::router())
.add_protected_routes(routes::ingestion::router(
app_state.config.ingest_max_body_bytes,
))
.add_protected_routes(routes::scratchpad::router())
.with_compression()
.build()

View File

@@ -46,3 +46,14 @@ pub async fn require_auth(auth: AuthSessionType, mut request: Request, next: Nex
}
}
}
pub async fn require_admin(auth: AuthSessionType, mut request: Request, next: Next) -> Response {
match auth.current_user {
Some(user) if user.admin => {
request.extensions_mut().insert(user);
next.run(request).await
}
Some(_) => TemplateResponse::redirect("/").into_response(),
None => TemplateResponse::redirect("/signin").into_response(),
}
}

View File

@@ -19,7 +19,7 @@ use tracing::error;
use crate::{html_state::HtmlState, AuthSessionType};
use common::storage::types::{
conversation::Conversation,
conversation::{Conversation, SidebarConversation},
user::{Theme, User},
};
@@ -27,7 +27,7 @@ pub trait ProvidesHtmlState {
fn html_state(&self) -> &HtmlState;
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum TemplateKind {
Full(String),
Partial(String, String),
@@ -114,13 +114,34 @@ impl IntoResponse for TemplateResponse {
}
}
#[derive(Serialize)]
struct TemplateUser {
id: String,
email: String,
admin: bool,
timezone: String,
theme: String,
}
impl From<&User> for TemplateUser {
fn from(user: &User) -> Self {
Self {
id: user.id.clone(),
email: user.email.clone(),
admin: user.admin,
timezone: user.timezone.clone(),
theme: user.theme.as_str().to_string(),
}
}
}
#[derive(Serialize)]
struct ContextWrapper<'a> {
user_theme: &'a str,
initial_theme: &'a str,
is_authenticated: bool,
user: Option<&'a User>,
conversation_archive: Vec<Conversation>,
user: Option<&'a TemplateUser>,
conversation_archive: Vec<SidebarConversation>,
#[serde(flatten)]
context: HashMap<String, Value>,
}
@@ -138,6 +159,7 @@ where
let mut initial_theme = Theme::System.initial_theme();
let mut is_authenticated = false;
let mut current_user_id = None;
let mut current_user = None;
{
if let Some(auth) = req.extensions().get::<AuthSessionType>() {
@@ -146,6 +168,7 @@ where
current_user_id = Some(user.id.clone());
user_theme = user.theme.as_str();
initial_theme = user.theme.initial_theme();
current_user = Some(TemplateUser::from(user));
}
}
}
@@ -158,17 +181,48 @@ where
if let Some(template_response) = response.extensions().get::<TemplateResponse>().cloned() {
let template_engine = state.template_engine();
let mut current_user = None;
let mut conversation_archive = Vec::new();
if let Some(user_id) = current_user_id {
let html_state = state.html_state();
if let Ok(Some(user)) = html_state.db.get_item::<User>(&user_id).await {
// Fetch conversation archive globally for authenticated users
if let Ok(archive) = User::get_user_conversations(&user.id, &html_state.db).await {
let should_load_conversation_archive =
matches!(&template_response.template_kind, TemplateKind::Full(_));
if should_load_conversation_archive {
if let Some(user_id) = current_user_id {
let html_state = state.html_state();
if let Some(cached_archive) =
html_state.get_cached_conversation_archive(&user_id).await
{
conversation_archive = cached_archive;
} else if let Ok(archive) =
Conversation::get_user_sidebar_conversations(&user_id, &html_state.db).await
{
html_state
.set_cached_conversation_archive(&user_id, archive.clone())
.await;
conversation_archive = archive;
}
current_user = Some(user);
}
}
fn context_to_map(
value: &Value,
) -> Result<HashMap<String, Value>, minijinja::value::ValueKind> {
match value.kind() {
minijinja::value::ValueKind::Map => {
let mut map = HashMap::new();
if let Ok(keys) = value.try_iter() {
for key in keys {
if let Ok(val) = value.get_item(&key) {
map.insert(key.to_string(), val);
}
}
}
Ok(map)
}
minijinja::value::ValueKind::None | minijinja::value::ValueKind::Undefined => {
Ok(HashMap::new())
}
other => Err(other),
}
}
@@ -183,19 +237,15 @@ where
}
}
// Convert minijinja::Value to HashMap if it's a map, otherwise use empty HashMap
let context_map = if template_response.context.kind() == minijinja::value::ValueKind::Map {
let mut map = HashMap::new();
if let Ok(keys) = template_response.context.try_iter() {
for key in keys {
if let Ok(val) = template_response.context.get_item(&key) {
map.insert(key.to_string(), val);
}
}
let context_map = match context_to_map(&template_response.context) {
Ok(map) => map,
Err(kind) => {
error!(
"Template context must be a map or unit, got kind={:?} for template_kind={:?}",
kind, template_response.template_kind
);
return (StatusCode::INTERNAL_SERVER_ERROR, Html(fallback_error())).into_response();
}
map
} else {
HashMap::new()
};
let context = ContextWrapper {

View File

@@ -17,10 +17,16 @@ use crate::html_state::HtmlState;
pub struct AccountPageData {
timezones: Vec<String>,
theme_options: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
selected_timezone: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
selected_theme: Option<String>,
}
pub async fn show_account_page(
RequireUser(_user): RequireUser,
RequireUser(user): RequireUser,
State(_state): State<HtmlState>,
) -> Result<impl IntoResponse, HtmlError> {
let timezones = TZ_VARIANTS
@@ -40,6 +46,9 @@ pub async fn show_account_page(
AccountPageData {
timezones,
theme_options,
api_key: user.api_key,
selected_timezone: None,
selected_theme: None,
},
))
}
@@ -50,7 +59,7 @@ pub async fn set_api_key(
auth: AuthSessionType,
) -> Result<impl IntoResponse, HtmlError> {
// Generate and set the API key
User::set_api_key(&user.id, &state.db).await?;
let api_key = User::set_api_key(&user.id, &state.db).await?;
// Clear the cache so new requests have access to the user with api key
auth.cache_clear_user(user.id.to_string());
@@ -62,6 +71,9 @@ pub async fn set_api_key(
AccountPageData {
timezones: vec![],
theme_options: vec![],
api_key: Some(api_key),
selected_timezone: None,
selected_theme: None,
},
))
}
@@ -108,6 +120,9 @@ pub async fn update_timezone(
AccountPageData {
timezones,
theme_options: vec![],
api_key: None,
selected_timezone: Some(form.timezone),
selected_theme: None,
},
))
}
@@ -142,6 +157,9 @@ pub async fn update_theme(
AccountPageData {
timezones: vec![],
theme_options,
api_key: None,
selected_timezone: None,
selected_theme: Some(form.theme),
},
))
}

View File

@@ -23,10 +23,7 @@ use tracing::{error, info};
use crate::{
html_state::HtmlState,
middlewares::{
auth_middleware::RequireUser,
response_middleware::{HtmlError, TemplateResponse},
},
middlewares::response_middleware::{HtmlError, TemplateResponse},
};
#[derive(Serialize)]
@@ -60,7 +57,6 @@ pub struct AdminPanelQuery {
pub async fn show_admin_panel(
State(state): State<HtmlState>,
RequireUser(_user): RequireUser,
Query(query): Query<AdminPanelQuery>,
) -> Result<impl IntoResponse, HtmlError> {
let section = match query.section.as_deref() {
@@ -131,14 +127,8 @@ pub struct RegistrationToggleData {
pub async fn toggle_registration_status(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<RegistrationToggleInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {
@@ -175,14 +165,8 @@ pub struct ModelSettingsData {
pub async fn update_model_settings(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<ModelSettingsInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?;
// Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI
@@ -295,13 +279,7 @@ pub struct SystemPromptEditData {
pub async fn show_edit_system_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template(
@@ -325,14 +303,8 @@ pub struct SystemPromptSectionData {
pub async fn patch_query_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<SystemPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {
@@ -359,13 +331,7 @@ pub struct IngestionPromptEditData {
pub async fn show_edit_ingestion_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template(
@@ -384,14 +350,8 @@ pub struct IngestionPromptUpdateInput {
pub async fn patch_ingestion_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<IngestionPromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {
@@ -418,13 +378,7 @@ pub struct ImagePromptEditData {
pub async fn show_edit_image_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let settings = SystemSettings::get_current(&state.db).await?;
Ok(TemplateResponse::new_template(
@@ -443,14 +397,8 @@ pub struct ImagePromptUpdateInput {
pub async fn patch_image_prompt(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Form(input): Form<ImagePromptUpdateInput>,
) -> Result<impl IntoResponse, HtmlError> {
// Early return if the user is not admin
if !user.admin {
return Ok(TemplateResponse::redirect("/"));
}
let current_settings = SystemSettings::get_current(&state.db).await?;
let new_settings = SystemSettings {

View File

@@ -1,6 +1,7 @@
mod handlers;
use axum::{
extract::FromRef,
middleware::from_fn,
routing::{get, patch},
Router,
};
@@ -10,7 +11,7 @@ use handlers::{
toggle_registration_status, update_model_settings,
};
use crate::html_state::HtmlState;
use crate::{html_state::HtmlState, middlewares::auth_middleware::require_admin};
pub fn router<S>() -> Router<S>
where
@@ -27,4 +28,5 @@ where
.route("/update-ingestion-prompt", patch(patch_ingestion_prompt))
.route("/edit-image-prompt", get(show_edit_image_prompt))
.route("/update-image-prompt", patch(patch_image_prompt))
.route_layer(from_fn(require_admin))
}

View File

@@ -1,8 +1,4 @@
use axum::{
extract::State,
response::{Html, IntoResponse},
Form,
};
use axum::{extract::State, response::IntoResponse, Form};
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
@@ -46,7 +42,7 @@ pub async fn authenticate_user(
let user = match User::authenticate(&form.email, &form.password, &state.db).await {
Ok(user) => user,
Err(_) => {
return Ok(Html("<p>Incorrect email or password </p>").into_response());
return Ok(TemplateResponse::bad_request("Incorrect email or password").into_response());
}
};

View File

@@ -1,8 +1,4 @@
use axum::{
extract::State,
response::{Html, IntoResponse},
Form,
};
use axum::{extract::State, response::IntoResponse, Form};
use axum_htmx::HxBoosted;
use serde::{Deserialize, Serialize};
@@ -57,7 +53,7 @@ pub async fn process_signup_and_show_verification(
Ok(user) => user,
Err(e) => {
tracing::error!("{:?}", e);
return Ok(Html(format!("<p>{e}</p>")).into_response());
return Ok(TemplateResponse::bad_request(&e.to_string()).into_response());
}
};

View File

@@ -73,6 +73,7 @@ pub async fn show_initialized_chat(
state.db.store_item(conversation.clone()).await?;
state.db.store_item(ai_message.clone()).await?;
state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
let messages = vec![user_message, ai_message];
@@ -178,7 +179,7 @@ pub async fn new_chat_user_message(
None => return Ok(Redirect::to("/").into_response()),
};
let conversation = Conversation::new(user.id, "New chat".to_string());
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
let user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
@@ -188,6 +189,7 @@ pub async fn new_chat_user_message(
state.db.store_item(conversation.clone()).await?;
state.db.store_item(user_message.clone()).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
#[derive(Serialize)]
struct SSEResponseInitData {
@@ -252,6 +254,7 @@ pub async fn patch_conversation_title(
Form(form): Form<PatchConversationTitle>,
) -> Result<impl IntoResponse, HtmlError> {
Conversation::patch_title(&conversation_id, &user.id, &form.title, &state.db).await?;
state.invalidate_conversation_archive_cache(&user.id).await;
Ok(TemplateResponse::new_template(
"sidebar.html",
@@ -281,6 +284,7 @@ pub async fn delete_conversation(
.db
.delete_item::<Conversation>(&conversation_id)
.await?;
state.invalidate_conversation_archive_cache(&user.id).await;
Ok(TemplateResponse::new_template(
"sidebar.html",

View File

@@ -1,10 +1,12 @@
#![allow(clippy::missing_docs_in_private_items)]
use std::{pin::Pin, sync::Arc, time::Duration};
use async_stream::stream;
use axum::{
extract::{Query, State},
response::{
sse::{Event, KeepAlive},
sse::{Event, KeepAlive, KeepAliveStream},
Sse,
},
};
@@ -24,7 +26,7 @@ use retrieval_pipeline::{
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use tokio::sync::{mpsc::channel, Mutex};
use tracing::{debug, error};
use tracing::{debug, error, info};
use common::storage::{
db::SurrealDbClient,
@@ -38,10 +40,21 @@ use common::storage::{
use crate::{html_state::HtmlState, AuthSessionType};
use super::reference_validation::{collect_reference_ids_from_retrieval, validate_references};
type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
type SseResponse = Sse<KeepAliveStream<EventStream>>;
fn sse_with_keep_alive(stream: EventStream) -> SseResponse {
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
// Error handling function
fn create_error_stream(
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
fn create_error_stream(message: impl Into<String>) -> EventStream {
let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
}
@@ -51,53 +64,125 @@ async fn get_message_and_user(
db: &SurrealDbClient,
current_user: Option<User>,
message_id: &str,
) -> Result<
(Message, User, Conversation, Vec<Message>),
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
> {
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
// Check authentication
let user = match current_user {
Some(user) => user,
None => {
return Err(Sse::new(create_error_stream(
"You must be signed in to use this feature",
)))
}
let Some(user) = current_user else {
return Err(sse_with_keep_alive(create_error_stream(
"You must be signed in to use this feature",
)));
};
// Retrieve message
let message = match db.get_item::<Message>(message_id).await {
Ok(Some(message)) => message,
Ok(None) => {
return Err(Sse::new(create_error_stream(
return Err(sse_with_keep_alive(create_error_stream(
"Message not found: the specified message does not exist",
)))
}
Err(e) => {
error!("Database error retrieving message {}: {:?}", message_id, e);
return Err(Sse::new(create_error_stream(
return Err(sse_with_keep_alive(create_error_stream(
"Failed to retrieve message: database error",
)));
}
};
// Get conversation history
let (conversation, mut history) =
let (conversation, history) =
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
{
Err(e) => {
error!("Database error retrieving message {}: {:?}", message_id, e);
return Err(Sse::new(create_error_stream(
return Err(sse_with_keep_alive(create_error_stream(
"Failed to retrieve message: database error",
)));
}
Ok((conversation, history)) => (conversation, history),
};
// Remove the last message, its the same as the message
history.pop();
let Some(message_index) = find_message_index(&history, message_id) else {
return Err(sse_with_keep_alive(create_error_stream(
"Message not found in conversation history",
)));
};
Ok((message, user, conversation, history))
let Some(message_from_history) = history.get(message_index) else {
return Err(sse_with_keep_alive(create_error_stream(
"Message not found in conversation history",
)));
};
if message_from_history.role != MessageRole::User {
return Err(sse_with_keep_alive(create_error_stream(
"Only user messages can be used to generate a response",
)));
}
let message = message_from_history.clone();
let history_before_message = history_before_message(&history, message_index);
let existing_ai_response = find_existing_ai_response(&history, message_index);
Ok((
message,
user,
conversation,
history_before_message,
existing_ai_response,
))
}
fn find_message_index(messages: &[Message], message_id: &str) -> Option<usize> {
messages.iter().position(|message| message.id == message_id)
}
fn find_existing_ai_response(messages: &[Message], user_message_index: usize) -> Option<Message> {
messages
.iter()
.skip(user_message_index + 1)
.take_while(|message| message.role != MessageRole::User)
.find(|message| message.role == MessageRole::AI)
.cloned()
}
fn history_before_message(messages: &[Message], message_index: usize) -> Vec<Message> {
messages.iter().take(message_index).cloned().collect()
}
fn create_replayed_response_stream(state: &HtmlState, existing_ai_message: Message) -> SseResponse {
let references_event = if existing_ai_message
.references
.as_ref()
.is_some_and(|references| !references.is_empty())
{
state
.templates
.render(
"chat/reference_list.html",
&Value::from_serialize(ReferenceData {
message: existing_ai_message.clone(),
}),
)
.ok()
.map(|html| Event::default().event("references").data(html))
} else {
None
};
let answer = existing_ai_message.content;
let event_stream = stream! {
yield Ok(Event::default().event("chat_message").data(answer));
if let Some(event) = references_event {
yield Ok(event);
}
yield Ok(Event::default().event("close_stream").data("Stream complete"));
};
sse_with_keep_alive(event_stream.boxed())
}
#[derive(Deserialize)]
@@ -105,21 +190,42 @@ pub struct QueryParams {
message_id: String,
}
#[derive(Serialize)]
struct ReferenceData {
message: Message,
}
fn extract_reference_strings(response: &LLMResponseFormat) -> Vec<String> {
response
.references
.iter()
.map(|reference| reference.reference.clone())
.collect()
}
#[allow(clippy::too_many_lines)]
pub async fn get_response_stream(
State(state): State<HtmlState>,
auth: AuthSessionType,
// auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
) -> SseResponse {
// 1. Authentication and initial data validation
let (user_message, user, _conversation, history) =
let (user_message, user, _conversation, history, existing_ai_response) =
match get_message_and_user(&state.db, auth.current_user, &params.message_id).await {
Ok((user_message, user, conversation, history)) => {
(user_message, user, conversation, history)
}
Ok((user_message, user, conversation, history, existing_ai_response)) => (
user_message,
user,
conversation,
history,
existing_ai_response,
),
Err(error_stream) => return error_stream,
};
if let Some(existing_ai_message) = existing_ai_response {
return create_replayed_response_stream(&state, existing_ai_message);
}
// 2. Retrieve knowledge entities
let rerank_lease = match state.reranker_pool.as_ref() {
Some(pool) => Some(pool.checkout().await),
@@ -142,15 +248,17 @@ pub async fn get_response_stream(
{
Ok(result) => result,
Err(_e) => {
return Sse::new(create_error_stream("Failed to retrieve knowledge"));
return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge"));
}
};
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
// 3. Create the OpenAI request with appropriate context format
let context_json = match retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
let context_json = match &retrieval_result {
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(&entities)
retrieved_entities_to_json(entities)
}
retrieval_pipeline::StrategyOutput::Search(search_result) => {
// For chat, use chunks from the search result
@@ -159,24 +267,18 @@ pub async fn get_response_stream(
};
let formatted_user_message =
create_user_message_with_history(&context_json, &history, &user_message.content);
let settings = match SystemSettings::get_current(&state.db).await {
Ok(s) => s,
Err(_) => {
return Sse::new(create_error_stream("Failed to retrieve system settings"));
}
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings"));
};
let request = match create_chat_request(formatted_user_message, &settings) {
Ok(req) => req,
Err(..) => {
return Sse::new(create_error_stream("Failed to create chat request"));
}
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
return sse_with_keep_alive(create_error_stream("Failed to create chat request"));
};
// 4. Set up the OpenAI stream
let openai_stream = match state.openai_client.chat().create_stream(request).await {
Ok(stream) => stream,
Err(_e) => {
return Sse::new(create_error_stream("Failed to create OpenAI stream"));
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
}
};
@@ -186,7 +288,9 @@ pub async fn get_response_stream(
let (tx_final, mut rx_final) = channel::<Message>(1);
// 6. Set up the collection task for DB storage
let db_client = state.db.clone();
let db_client = Arc::clone(&state.db);
let user_id = user.id.clone();
let allowed_reference_ids = allowed_reference_ids.clone();
tokio::spawn(async move {
drop(tx); // Close sender when no longer needed
@@ -198,17 +302,55 @@ pub async fn get_response_stream(
// Try to extract structured data
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
let references: Vec<String> = response
.references
.into_iter()
.map(|r| r.reference)
.collect();
let raw_references = extract_reference_strings(&response);
let answer = response.answer;
let initial_validation = match validate_references(
&user_id,
raw_references,
&allowed_reference_ids,
&db_client,
)
.await
{
Ok(result) => result,
Err(err) => {
error!(error = %err, "Reference validation failed, storing answer without references");
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
answer,
Some(Vec::new()),
);
let _ = tx_final.send(ai_message.clone()).await;
if let Err(store_err) = db_client.store_item(ai_message).await {
error!(error = ?store_err, "Failed to store AI message after validation failure");
}
return;
}
};
info!(
total_refs = initial_validation.reason_stats.total,
valid_refs = initial_validation.valid_refs.len(),
invalid_refs = initial_validation.invalid_refs.len(),
invalid_empty = initial_validation.reason_stats.empty,
invalid_unsupported_prefix = initial_validation.reason_stats.unsupported_prefix,
invalid_malformed_uuid = initial_validation.reason_stats.malformed_uuid,
invalid_duplicate = initial_validation.reason_stats.duplicate,
invalid_not_in_context = initial_validation.reason_stats.not_in_context,
invalid_not_found = initial_validation.reason_stats.not_found,
invalid_wrong_user = initial_validation.reason_stats.wrong_user,
invalid_over_limit = initial_validation.reason_stats.over_limit,
"Post-LLM reference validation complete"
);
let ai_message = Message::new(
user_message.conversation_id,
MessageRole::AI,
response.answer,
Some(references),
answer,
Some(initial_validation.valid_refs),
);
let _ = tx_final.send(ai_message.clone()).await;
@@ -240,7 +382,7 @@ pub async fn get_response_stream(
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.map(move |result| {
let tx_storage = tx_clone.clone();
let json_state = json_state.clone();
let json_state = Arc::clone(&json_state);
stream! {
match result {
@@ -288,12 +430,6 @@ pub async fn get_response_stream(
return Ok(Event::default().event("empty")); // This event won't be sent
}
// Prepare data for template
#[derive(Serialize)]
struct ReferenceData {
message: Message,
}
// Render template with references
match state.templates.render(
"chat/reference_list.html",
@@ -323,11 +459,7 @@ pub async fn get_response_stream(
.data("Stream complete"))
}));
Sse::new(event_stream.boxed()).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
sse_with_keep_alive(event_stream.boxed())
}
struct StreamParserState {
@@ -375,3 +507,195 @@ impl StreamParserState {
String::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Duration as ChronoDuration, Utc};
use common::storage::{
db::SurrealDbClient,
types::{conversation::Conversation, user::Theme},
};
use retrieval_pipeline::answer_retrieval::Reference;
use uuid::Uuid;
fn make_test_message(id: &str, role: MessageRole) -> Message {
let mut message = Message::new(
"conversation-1".to_string(),
role,
format!("content-{id}"),
None,
);
message.id = id.to_string();
message
}
fn make_test_user(id: &str) -> User {
User {
id: id.to_string(),
created_at: Utc::now(),
updated_at: Utc::now(),
email: "test@example.com".to_string(),
password: "password".to_string(),
anonymous: false,
api_key: None,
admin: false,
timezone: "UTC".to_string(),
theme: Theme::System,
}
}
#[test]
fn extracts_reference_strings_in_order() {
let response = LLMResponseFormat {
answer: "answer".to_string(),
references: vec![
Reference {
reference: "a".to_string(),
},
Reference {
reference: "b".to_string(),
},
],
};
let extracted = extract_reference_strings(&response);
assert_eq!(extracted, vec!["a".to_string(), "b".to_string()]);
}
#[test]
fn finds_message_index_for_existing_message() {
let messages = vec![
make_test_message("m1", MessageRole::User),
make_test_message("m2", MessageRole::AI),
make_test_message("m3", MessageRole::User),
];
assert_eq!(find_message_index(&messages, "m2"), Some(1));
assert_eq!(find_message_index(&messages, "missing"), None);
}
#[test]
fn finds_existing_ai_response_for_same_turn() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("system", MessageRole::System),
make_test_message("a1", MessageRole::AI),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
let ai_reply = find_existing_ai_response(&messages, 0).expect("expected AI response");
assert_eq!(ai_reply.id, "a1");
let ai_reply_second_turn =
find_existing_ai_response(&messages, 3).expect("expected AI response");
assert_eq!(ai_reply_second_turn.id, "a2");
}
#[test]
fn does_not_replay_ai_response_from_later_turn() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
assert!(find_existing_ai_response(&messages, 0).is_none());
let ai_reply = find_existing_ai_response(&messages, 1).expect("expected AI response");
assert_eq!(ai_reply.id, "a2");
}
#[test]
fn history_before_message_excludes_target_and_future_messages() {
let messages = vec![
make_test_message("u1", MessageRole::User),
make_test_message("a1", MessageRole::AI),
make_test_message("u2", MessageRole::User),
make_test_message("a2", MessageRole::AI),
];
let history_for_u2 = history_before_message(&messages, 2);
let history_ids: Vec<String> = history_for_u2
.into_iter()
.map(|message| message.id)
.collect();
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
}
#[tokio::test]
async fn get_message_and_user_reuses_existing_ai_response_for_same_turn() {
let namespace = "chat_stream_replay";
let database = Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, &database)
.await
.expect("failed to create in-memory db");
let user = make_test_user("user-1");
let conversation = Conversation::new(user.id.clone(), "Conversation".to_string());
let mut user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
"Question one".to_string(),
None,
);
user_message.id = "u1".to_string();
let mut ai_message = Message::new(
conversation.id.clone(),
MessageRole::AI,
"Answer one".to_string(),
Some(vec!["ref-1".to_string()]),
);
ai_message.id = "a1".to_string();
ai_message.created_at = user_message.created_at + ChronoDuration::seconds(1);
ai_message.updated_at = ai_message.created_at;
let mut second_user_message = Message::new(
conversation.id.clone(),
MessageRole::User,
"Question two".to_string(),
None,
);
second_user_message.id = "u2".to_string();
second_user_message.created_at = ai_message.created_at + ChronoDuration::seconds(1);
second_user_message.updated_at = second_user_message.created_at;
db.store_item(conversation.clone())
.await
.expect("failed to store conversation");
db.store_item(user_message.clone())
.await
.expect("failed to store user message");
db.store_item(ai_message.clone())
.await
.expect("failed to store ai message");
db.store_item(second_user_message.clone())
.await
.expect("failed to store second user message");
let (_, _, _, history_for_first_turn, existing_ai_for_first_turn) =
get_message_and_user(&db, Some(user.clone()), &user_message.id)
.await
.expect("expected first turn to load");
assert!(history_for_first_turn.is_empty());
let existing_ai_for_first_turn =
existing_ai_for_first_turn.expect("expected first-turn AI response");
assert_eq!(existing_ai_for_first_turn.id, ai_message.id);
let (_, _, _, history_for_second_turn, existing_ai_for_second_turn) =
get_message_and_user(&db, Some(user), &second_user_message.id)
.await
.expect("expected second turn to load");
let history_ids: Vec<String> = history_for_second_turn
.into_iter()
.map(|message| message.id)
.collect();
assert_eq!(history_ids, vec!["u1".to_string(), "a1".to_string()]);
assert!(existing_ai_for_second_turn.is_none());
}
}

View File

@@ -1,5 +1,6 @@
mod chat_handlers;
mod message_response_stream;
mod reference_validation;
mod references;
use axum::{

View File

@@ -0,0 +1,477 @@
#![allow(clippy::arithmetic_side_effects, clippy::missing_docs_in_private_items)]
use std::collections::HashSet;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
};
use retrieval_pipeline::StrategyOutput;
use uuid::Uuid;
pub(crate) const MAX_REFERENCE_COUNT: usize = 10;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum InvalidReferenceReason {
Empty,
UnsupportedPrefix,
MalformedUuid,
Duplicate,
NotInContext,
NotFound,
WrongUser,
OverLimit,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct InvalidReference {
pub raw: String,
pub normalized: Option<String>,
pub reason: InvalidReferenceReason,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub(crate) struct ReferenceReasonStats {
pub total: usize,
pub empty: usize,
pub unsupported_prefix: usize,
pub malformed_uuid: usize,
pub duplicate: usize,
pub not_in_context: usize,
pub not_found: usize,
pub wrong_user: usize,
pub over_limit: usize,
}
impl ReferenceReasonStats {
fn record(&mut self, reason: &InvalidReferenceReason) {
match reason {
InvalidReferenceReason::Empty => self.empty += 1,
InvalidReferenceReason::UnsupportedPrefix => self.unsupported_prefix += 1,
InvalidReferenceReason::MalformedUuid => self.malformed_uuid += 1,
InvalidReferenceReason::Duplicate => self.duplicate += 1,
InvalidReferenceReason::NotInContext => self.not_in_context += 1,
InvalidReferenceReason::NotFound => self.not_found += 1,
InvalidReferenceReason::WrongUser => self.wrong_user += 1,
InvalidReferenceReason::OverLimit => self.over_limit += 1,
}
}
}
#[derive(Debug, Clone, Default)]
pub(crate) struct ReferenceValidationResult {
pub valid_refs: Vec<String>,
pub invalid_refs: Vec<InvalidReference>,
pub reason_stats: ReferenceReasonStats,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum ReferenceLookupTarget {
TextChunk,
KnowledgeEntity,
Any,
}
pub(crate) fn collect_reference_ids_from_retrieval(
retrieval_result: &StrategyOutput,
) -> Vec<String> {
let mut ids = Vec::new();
let mut seen = HashSet::new();
match retrieval_result {
StrategyOutput::Chunks(chunks) => {
for chunk in chunks {
let id = chunk.chunk.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
}
StrategyOutput::Entities(entities) => {
for entity in entities {
let id = entity.entity.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
}
StrategyOutput::Search(search) => {
for chunk in &search.chunks {
let id = chunk.chunk.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
for entity in &search.entities {
let id = entity.entity.id.clone();
if seen.insert(id.clone()) {
ids.push(id);
}
}
}
}
ids
}
pub(crate) async fn validate_references(
user_id: &str,
refs: Vec<String>,
allowed_ids: &[String],
db: &SurrealDbClient,
) -> Result<ReferenceValidationResult, AppError> {
let mut result = ReferenceValidationResult::default();
result.reason_stats.total = refs.len();
let mut seen = HashSet::new();
let allowed_set: HashSet<&str> = allowed_ids.iter().map(String::as_str).collect();
let enforce_context = !allowed_set.is_empty();
for raw in refs {
let (normalized, target) = match normalize_reference(&raw) {
Ok(parsed) => parsed,
Err(reason) => {
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: None,
reason,
});
continue;
}
};
if !seen.insert(normalized.clone()) {
let reason = InvalidReferenceReason::Duplicate;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
continue;
}
if result.valid_refs.len() >= MAX_REFERENCE_COUNT {
let reason = InvalidReferenceReason::OverLimit;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
continue;
}
if enforce_context && !allowed_set.contains(normalized.as_str()) {
let reason = InvalidReferenceReason::NotInContext;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
continue;
}
match lookup_reference_for_user(&normalized, &target, user_id, db).await? {
LookupResult::Found => result.valid_refs.push(normalized),
LookupResult::WrongUser => {
let reason = InvalidReferenceReason::WrongUser;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
}
LookupResult::NotFound => {
let reason = InvalidReferenceReason::NotFound;
result.reason_stats.record(&reason);
result.invalid_refs.push(InvalidReference {
raw,
normalized: Some(normalized),
reason,
});
}
}
}
Ok(result)
}
pub(crate) fn normalize_reference(
raw: &str,
) -> Result<(String, ReferenceLookupTarget), InvalidReferenceReason> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return Err(InvalidReferenceReason::Empty);
}
let (candidate, target) = if let Some((prefix, rest)) = trimmed.split_once(':') {
let lookup_target = if prefix.eq_ignore_ascii_case("knowledge_entity") {
ReferenceLookupTarget::KnowledgeEntity
} else if prefix.eq_ignore_ascii_case("text_chunk") {
ReferenceLookupTarget::TextChunk
} else {
return Err(InvalidReferenceReason::UnsupportedPrefix);
};
(rest.trim(), lookup_target)
} else {
(trimmed, ReferenceLookupTarget::Any)
};
if candidate.is_empty() {
return Err(InvalidReferenceReason::MalformedUuid);
}
Uuid::parse_str(candidate)
.map(|uuid| (uuid.to_string(), target))
.map_err(|_| InvalidReferenceReason::MalformedUuid)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LookupResult {
Found,
WrongUser,
NotFound,
}
async fn lookup_reference_for_user(
id: &str,
target: &ReferenceLookupTarget,
user_id: &str,
db: &SurrealDbClient,
) -> Result<LookupResult, AppError> {
match target {
ReferenceLookupTarget::TextChunk => lookup_single_type::<TextChunk>(id, user_id, db).await,
ReferenceLookupTarget::KnowledgeEntity => {
lookup_single_type::<KnowledgeEntity>(id, user_id, db).await
}
ReferenceLookupTarget::Any => {
let chunk_result = lookup_single_type::<TextChunk>(id, user_id, db).await?;
if chunk_result == LookupResult::Found {
return Ok(LookupResult::Found);
}
let entity_result = lookup_single_type::<KnowledgeEntity>(id, user_id, db).await?;
if entity_result == LookupResult::Found {
return Ok(LookupResult::Found);
}
if chunk_result == LookupResult::WrongUser || entity_result == LookupResult::WrongUser {
return Ok(LookupResult::WrongUser);
}
Ok(LookupResult::NotFound)
}
}
}
async fn lookup_single_type<T>(
id: &str,
user_id: &str,
db: &SurrealDbClient,
) -> Result<LookupResult, AppError>
where
T: StoredObject + for<'de> serde::Deserialize<'de> + HasUserId,
{
let item = db.get_item::<T>(id).await?;
Ok(match item {
Some(item) if item.user_id() == user_id => LookupResult::Found,
Some(_) => LookupResult::WrongUser,
None => LookupResult::NotFound,
})
}
trait HasUserId {
fn user_id(&self) -> &str;
}
impl HasUserId for TextChunk {
fn user_id(&self) -> &str {
&self.user_id
}
}
impl HasUserId for KnowledgeEntity {
fn user_id(&self) -> &str {
&self.user_id
}
}
#[cfg(test)]
#[allow(
clippy::cloned_ref_to_slice_refs,
clippy::expect_used,
clippy::indexing_slicing
)]
mod tests {
use super::*;
use common::storage::types::knowledge_entity::KnowledgeEntityType;
use surrealdb::engine::any::connect;
async fn setup_test_db() -> SurrealDbClient {
let client = connect("mem://")
.await
.expect("failed to create in-memory surrealdb client");
let namespace = format!("test_ns_{}", Uuid::new_v4());
let database = format!("test_db_{}", Uuid::new_v4());
client
.use_ns(namespace)
.use_db(database)
.await
.expect("failed to select namespace/db");
let db = SurrealDbClient { client };
db.apply_migrations()
.await
.expect("failed to apply migrations");
db
}
#[tokio::test]
async fn valid_uuid_exists_and_belongs_to_user() {
let db = setup_test_db().await;
let user_id = "user-a";
let entity = KnowledgeEntity::new(
"source-1".to_string(),
"Entity A".to_string(),
"Entity description".to_string(),
KnowledgeEntityType::Document,
None,
user_id.to_string(),
);
db.store_item(entity.clone())
.await
.expect("failed to store entity");
let result =
validate_references(user_id, vec![entity.id.clone()], &[entity.id.clone()], &db)
.await
.expect("validation should not fail");
assert_eq!(result.valid_refs, vec![entity.id]);
assert!(result.invalid_refs.is_empty());
}
#[tokio::test]
async fn valid_uuid_exists_but_wrong_user_is_rejected() {
let db = setup_test_db().await;
let entity = KnowledgeEntity::new(
"source-1".to_string(),
"Entity B".to_string(),
"Entity description".to_string(),
KnowledgeEntityType::Document,
None,
"other-user".to_string(),
);
db.store_item(entity.clone())
.await
.expect("failed to store entity");
let result =
validate_references("user-a", vec![entity.id.clone()], &[entity.id.clone()], &db)
.await
.expect("validation should not fail");
assert!(result.valid_refs.is_empty());
assert_eq!(result.invalid_refs.len(), 1);
assert_eq!(
result.invalid_refs[0].reason,
InvalidReferenceReason::WrongUser
);
}
#[tokio::test]
async fn malformed_uuid_is_rejected() {
let db = setup_test_db().await;
let result = validate_references(
"user-a",
vec!["not-a-uuid".to_string()],
&["not-a-uuid".to_string()],
&db,
)
.await
.expect("validation should not fail");
assert!(result.valid_refs.is_empty());
assert_eq!(result.invalid_refs.len(), 1);
assert_eq!(
result.invalid_refs[0].reason,
InvalidReferenceReason::MalformedUuid
);
}
#[tokio::test]
async fn mixed_duplicates_are_deduped() {
let db = setup_test_db().await;
let user_id = "user-a";
let first = KnowledgeEntity::new(
"source-1".to_string(),
"Entity 1".to_string(),
"Entity description".to_string(),
KnowledgeEntityType::Document,
None,
user_id.to_string(),
);
let second = KnowledgeEntity::new(
"source-2".to_string(),
"Entity 2".to_string(),
"Entity description".to_string(),
KnowledgeEntityType::Document,
None,
user_id.to_string(),
);
db.store_item(first.clone())
.await
.expect("failed to store first entity");
db.store_item(second.clone())
.await
.expect("failed to store second entity");
let refs = vec![
first.id.clone(),
format!("knowledge_entity:{}", first.id),
second.id.clone(),
second.id.clone(),
];
let allowed = vec![first.id.clone(), second.id.clone()];
let result = validate_references(user_id, refs, &allowed, &db)
.await
.expect("validation should not fail");
assert_eq!(result.valid_refs, vec![first.id, second.id]);
assert_eq!(result.invalid_refs.len(), 2);
assert!(result
.invalid_refs
.iter()
.all(|entry| entry.reason == InvalidReferenceReason::Duplicate));
}
#[tokio::test]
async fn bare_uuid_prefers_chunk_lookup_before_entity() {
let db = setup_test_db().await;
let user_id = "user-a";
let chunk = TextChunk::new(
"source-1".to_string(),
"Chunk body".to_string(),
user_id.to_string(),
);
db.store_item(chunk.clone())
.await
.expect("failed to store chunk");
let result = validate_references(user_id, vec![chunk.id.clone()], &[chunk.id.clone()], &db)
.await
.expect("validation should not fail");
assert_eq!(result.valid_refs, vec![chunk.id]);
}
}

View File

@@ -1,12 +1,15 @@
#![allow(clippy::missing_docs_in_private_items)]
use axum::{
extract::{Path, State},
response::IntoResponse,
};
use chrono::{DateTime, Utc};
use chrono_tz::Tz;
use serde::Serialize;
use common::{
error::AppError,
storage::types::{knowledge_entity::KnowledgeEntity, user::User},
use common::storage::types::{
knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, user::User,
};
use crate::{
@@ -17,29 +20,101 @@ use crate::{
},
};
use super::reference_validation::{normalize_reference, ReferenceLookupTarget};
#[derive(Serialize)]
struct ReferenceTooltipData {
text_chunk: Option<TextChunk>,
text_chunk_updated_at: Option<String>,
entity: Option<KnowledgeEntity>,
entity_updated_at: Option<String>,
user: User,
}
fn format_datetime_for_user(datetime: DateTime<Utc>, timezone: &str) -> String {
match timezone.parse::<Tz>() {
Ok(tz) => datetime
.with_timezone(&tz)
.format("%Y-%m-%d %H:%M:%S")
.to_string(),
Err(_) => datetime.format("%Y-%m-%d %H:%M:%S").to_string(),
}
}
pub async fn show_reference_tooltip(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Path(reference_id): Path<String>,
) -> Result<impl IntoResponse, HtmlError> {
let entity: KnowledgeEntity = state
.db
.get_item(&reference_id)
.await?
.ok_or_else(|| AppError::NotFound("Item was not found".to_string()))?;
let Ok((normalized_reference_id, target)) = normalize_reference(&reference_id) else {
return Ok(TemplateResponse::not_found());
};
if entity.user_id != user.id {
return Ok(TemplateResponse::unauthorized());
let lookup_order = match target {
ReferenceLookupTarget::TextChunk | ReferenceLookupTarget::Any => [
ReferenceLookupTarget::TextChunk,
ReferenceLookupTarget::KnowledgeEntity,
],
ReferenceLookupTarget::KnowledgeEntity => [
ReferenceLookupTarget::KnowledgeEntity,
ReferenceLookupTarget::TextChunk,
],
};
let mut text_chunk: Option<TextChunk> = None;
let mut knowledge_entity: Option<KnowledgeEntity> = None;
for lookup_target in lookup_order {
match lookup_target {
ReferenceLookupTarget::TextChunk => {
if let Some(chunk) = state
.db
.get_item::<TextChunk>(&normalized_reference_id)
.await?
{
if chunk.user_id != user.id {
return Ok(TemplateResponse::unauthorized());
}
text_chunk = Some(chunk);
break;
}
}
ReferenceLookupTarget::KnowledgeEntity => {
if let Some(entity) = state
.db
.get_item::<KnowledgeEntity>(&normalized_reference_id)
.await?
{
if entity.user_id != user.id {
return Ok(TemplateResponse::unauthorized());
}
knowledge_entity = Some(entity);
break;
}
}
ReferenceLookupTarget::Any => {}
}
}
#[derive(Serialize)]
struct ReferenceTooltipData {
entity: KnowledgeEntity,
user: User,
if text_chunk.is_none() && knowledge_entity.is_none() {
return Ok(TemplateResponse::not_found());
}
let text_chunk_updated_at = text_chunk
.as_ref()
.map(|chunk| format_datetime_for_user(chunk.updated_at, &user.timezone));
let entity_updated_at = knowledge_entity
.as_ref()
.map(|entity| format_datetime_for_user(entity.updated_at, &user.timezone));
Ok(TemplateResponse::new_template(
"chat/reference_tooltip.html",
ReferenceTooltipData { entity, user },
ReferenceTooltipData {
text_chunk,
text_chunk_updated_at,
entity: knowledge_entity,
entity_updated_at,
user,
},
))
}

View File

@@ -42,9 +42,8 @@ pub async fn index_handler(
return Ok(TemplateResponse::redirect("/signin"));
};
let (text_contents, _conversation_archive, stats, active_jobs) = try_join!(
let (text_contents, stats, active_jobs) = try_join!(
User::get_latest_text_contents(&user.id, &state.db),
User::get_user_conversations(&user.id, &state.db),
User::get_dashboard_stats(&user.id, &state.db),
User::get_unfinished_ingestion_tasks(&user.id, &state.db)
)?;

View File

@@ -2,9 +2,10 @@ use std::{pin::Pin, time::Duration};
use axum::{
extract::{Query, State},
http::StatusCode,
response::{
sse::{Event, KeepAlive},
Html, IntoResponse, Sse,
sse::{Event, KeepAlive, KeepAliveStream},
IntoResponse, Response, Sse,
},
};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
@@ -23,6 +24,7 @@ use common::{
ingestion_task::{IngestionTask, TaskState},
user::User,
},
utils::ingest_limits::{validate_ingest_input, IngestValidationError},
};
use crate::{
@@ -34,30 +36,41 @@ use crate::{
AuthSessionType,
};
pub async fn show_ingress_form(
type EventStream = Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>;
type TaskSse = Sse<KeepAliveStream<EventStream>>;
fn sse_with_keep_alive(stream: EventStream) -> TaskSse {
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive-ping"),
)
}
pub async fn show_ingest_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
#[derive(Serialize)]
pub struct ShowIngressFormData {
pub struct ShowIngestFormData {
user_categories: Vec<String>,
}
Ok(TemplateResponse::new_template(
"ingestion_modal.html",
ShowIngressFormData { user_categories },
ShowIngestFormData { user_categories },
))
}
pub async fn hide_ingress_form(
pub async fn hide_ingest_form(
RequireUser(_user): RequireUser,
) -> Result<impl IntoResponse, HtmlError> {
Ok(Html(
"<a class='btn btn-primary' hx-get='/ingress-form' hx-swap='outerHTML'>Add Content</a>",
)
.into_response())
Ok(TemplateResponse::new_template(
"ingestion/add_content_button.html",
(),
))
}
#[derive(Debug, TryFromMultipart)]
@@ -65,37 +78,59 @@ pub struct IngestionParams {
pub content: Option<String>,
pub context: String,
pub category: String,
#[form_data(limit = "10000000")] // Adjust limit as needed
#[form_data(limit = "20000000")]
#[form_data(default)]
pub files: Vec<FieldData<NamedTempFile>>,
}
pub async fn process_ingress_form(
pub async fn process_ingest_form(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
TypedMultipart(input): TypedMultipart<IngestionParams>,
) -> Result<impl IntoResponse, HtmlError> {
#[derive(Serialize)]
pub struct IngressFormData {
context: String,
content: String,
category: String,
error: String,
}
) -> Result<Response, HtmlError> {
if input.content.as_ref().is_none_or(|c| c.len() < 2) && input.files.is_empty() {
return Ok(TemplateResponse::new_template(
"index/signed_in/ingress_form.html",
IngressFormData {
context: input.context.clone(),
content: input.content.clone().unwrap_or_default(),
category: input.category.clone(),
error: "You need to either add files or content".to_string(),
},
));
return Ok(
TemplateResponse::bad_request("You need to either add files or content")
.into_response(),
);
}
info!("{:?}", input);
let content_bytes = input.content.as_ref().map_or(0, |c| c.len());
let has_content = input.content.as_ref().is_some_and(|c| !c.trim().is_empty());
let context_bytes = input.context.len();
let category_bytes = input.category.len();
let file_count = input.files.len();
match validate_ingest_input(
&state.config,
input.content.as_deref(),
&input.context,
&input.category,
file_count,
) {
Ok(()) => {}
Err(IngestValidationError::PayloadTooLarge(message)) => {
return Ok(TemplateResponse::error(
StatusCode::PAYLOAD_TOO_LARGE,
"Payload Too Large",
&message,
)
.into_response());
}
Err(IngestValidationError::BadRequest(message)) => {
return Ok(TemplateResponse::bad_request(&message).into_response());
}
}
info!(
user_id = %user.id,
has_content,
content_bytes,
context_bytes,
category_bytes,
file_count,
"Received ingest form submission"
);
let file_infos = try_join_all(input.files.into_iter().map(|file| {
FileInfo::new_with_storage(file, &state.db, &user.id, &state.storage)
@@ -123,10 +158,10 @@ pub async fn process_ingress_form(
tasks: Vec<IngestionTask>,
}
Ok(TemplateResponse::new_template(
"dashboard/current_task.html",
NewTasksData { tasks },
))
Ok(
TemplateResponse::new_template("dashboard/current_task.html", NewTasksData { tasks })
.into_response(),
)
}
#[derive(Deserialize)]
@@ -134,9 +169,7 @@ pub struct QueryParams {
task_id: String,
}
fn create_error_stream(
message: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>> {
fn create_error_stream(message: impl Into<String>) -> EventStream {
let message = message.into();
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
}
@@ -145,13 +178,13 @@ pub async fn get_task_updates_stream(
State(state): State<HtmlState>,
auth: AuthSessionType,
Query(params): Query<QueryParams>,
) -> Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>> {
) -> TaskSse {
let task_id = params.task_id.clone();
let db = state.db.clone();
// 1. Check for authenticated user
let Some(current_user) = auth.current_user else {
return Sse::new(create_error_stream("User not authenticated"));
return sse_with_keep_alive(create_error_stream("User not authenticated"));
};
// 2. Fetch task for initial authorization and to ensure it exists
@@ -159,7 +192,7 @@ pub async fn get_task_updates_stream(
Ok(Some(task)) => {
// 3. Validate user ownership
if task.user_id != current_user.id {
return Sse::new(create_error_stream(
return sse_with_keep_alive(create_error_stream(
"Access denied: You do not have permission to view updates for this task.",
));
}
@@ -245,18 +278,14 @@ pub async fn get_task_updates_stream(
}
};
Sse::new(sse_stream.boxed()).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive-ping"),
)
sse_with_keep_alive(sse_stream.boxed())
}
Ok(None) => Sse::new(create_error_stream(format!(
Ok(None) => sse_with_keep_alive(create_error_stream(format!(
"Task with ID '{task_id}' not found."
))),
Err(e) => {
error!("Failed to fetch task '{task_id}' for authorization: {e:?}");
Sse::new(create_error_stream(
sse_with_keep_alive(create_error_stream(
"An error occurred while retrieving task details. Please try again later.",
))
}

View File

@@ -1,22 +1,22 @@
mod handlers;
use axum::{extract::FromRef, routing::get, Router};
use handlers::{
get_task_updates_stream, hide_ingress_form, process_ingress_form, show_ingress_form,
};
use axum::{extract::DefaultBodyLimit, extract::FromRef, routing::get, Router};
use handlers::{get_task_updates_stream, hide_ingest_form, process_ingest_form, show_ingest_form};
use crate::html_state::HtmlState;
pub fn router<S>() -> Router<S>
pub fn router<S>(max_body_bytes: usize) -> Router<S>
where
S: Clone + Send + Sync + 'static,
HtmlState: FromRef<S>,
{
Router::new()
.route(
"/ingress-form",
get(show_ingress_form).post(process_ingress_form),
"/ingest-form",
get(show_ingest_form)
.post(process_ingest_form)
.layer(DefaultBodyLimit::max(max_body_bytes)),
)
.route("/task/status-stream", get(get_task_updates_stream))
.route("/hide-ingress-form", get(hide_ingress_form))
.route("/hide-ingest-form", get(hide_ingest_form))
}

View File

@@ -13,9 +13,9 @@
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">API Key</div>
{% block api_key_section %}
{% if user.api_key %}
{% if api_key %}
<div class="relative">
<input id="api_key_input" type="text" name="api_key" value="{{ user.api_key }}"
<input id="api_key_input" type="text" name="api_key" value="{{ api_key }}"
class="nb-input w-full pr-14" disabled />
<button type="button" id="copy_api_key_btn" onclick="copy_api_key()"
class="absolute inset-y-0 right-0 flex items-center px-2 nb-btn btn-sm" aria-label="Copy API key"
@@ -48,9 +48,10 @@
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Timezone</div>
{% block timezone_section %}
{% set active_timezone = selected_timezone|default(user.timezone) %}
<select name="timezone" class="nb-select w-full" hx-patch="/update-timezone" hx-swap="outerHTML">
{% for tz in timezones %}
<option value="{{ tz }}" {% if tz==user.timezone %}selected{% endif %}>{{ tz }}</option>
<option value="{{ tz }}" {% if tz==active_timezone %}selected{% endif %}>{{ tz }}</option>
{% endfor %}
</select>
{% endblock %}
@@ -59,13 +60,14 @@
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Theme</div>
{% block theme_section %}
{% set active_theme = selected_theme|default(user.theme) %}
<select name="theme" class="nb-select w-full" hx-patch="/update-theme" hx-swap="outerHTML">
{% for option in theme_options %}
<option value="{{ option }}" {% if option==user.theme %}selected{% endif %}>{{ option }}</option>
<option value="{{ option }}" {% if option==active_theme %}selected{% endif %}>{{ option }}</option>
{% endfor %}
</select>
<script>
document.documentElement.setAttribute('data-theme-preference', '{{ user.theme }}');
document.documentElement.setAttribute('data-theme-preference', '{{ active_theme }}');
</script>
{% endblock %}
</label>

View File

@@ -7,4 +7,5 @@
{% block auth_content %}
{% endblock %}
</div>
<div id="toast-container" class="fixed bottom-4 right-4 z-50 space-y-2"></div>
{% endblock %}

View File

@@ -6,7 +6,7 @@
</div>
<div class="u-hairline mb-3"></div>
<form hx-post="/signin" hx-target="#login-result" class="flex flex-col gap-2">
<form hx-post="/signin" hx-swap="none" class="flex flex-col gap-2">
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div>
<input name="email" type="email" placeholder="Email" class="nb-input w-full validator" required />
@@ -19,8 +19,6 @@
minlength="8" />
</label>
<div class="mt-1 text-error" id="login-result"></div>
<div class="form-control mt-1">
<label class="label cursor-pointer justify-start gap-3">
<input type="checkbox" name="remember_me" class="nb-checkbox" />

View File

@@ -11,7 +11,7 @@
</div>
<div class="u-hairline mb-3"></div>
<form hx-post="/signup" hx-target="#signup-result" class="flex flex-col gap-4">
<form hx-post="/signup" hx-swap="none" class="flex flex-col gap-4">
<label class="w-full">
<div class="text-xs uppercase tracking-wide opacity-70 mb-1">Email</div>
<input type="email" placeholder="Email" name="email" required class="nb-input w-full validator" />
@@ -31,7 +31,6 @@
</p>
</label>
<div class="mt-2 text-error" id="signup-result"></div>
<div class="form-control mt-1">
<button id="submit-btn" class="nb-btn nb-cta w-full">Create Account</button>
</div>

View File

@@ -12,16 +12,34 @@
</label>
</form>
<script>
document.getElementById('chat-input').addEventListener('keydown', function (e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
htmx.trigger('#chat-form', 'submit');
}
});
// Clear textarea after successful submission
document.getElementById('chat-form').addEventListener('htmx:afterRequest', function (e) {
if (e.detail.successful) { // Check if the request was successful
document.getElementById('chat-input').value = ''; // Clear the textarea
}
});
</script>
(function () {
const newChatStreamId = 'ai-stream-{{ user_message.id }}';
document.getElementById('chat-input').addEventListener('keydown', function (e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
htmx.trigger('#chat-form', 'submit');
}
});
// Clear textarea after successful submission
document.getElementById('chat-form').addEventListener('htmx:afterRequest', function (e) {
if (e.detail.successful) { // Check if the request was successful
document.getElementById('chat-input').value = ''; // Clear the textarea
}
});
const refreshSidebarAfterFirstResponse = function (e) {
const streamEl = document.getElementById(newChatStreamId);
if (!streamEl || e.target !== streamEl) return;
htmx.ajax('GET', '/chat/sidebar', {
target: '.drawer-side',
swap: 'outerHTML'
});
document.body.removeEventListener('htmx:sseClose', refreshSidebarAfterFirstResponse);
};
document.body.addEventListener('htmx:sseClose', refreshSidebarAfterFirstResponse);
})();
</script>

View File

@@ -111,12 +111,23 @@
// Load content if needed
if (!tooltipContent) {
fetch(`/chat/reference/${encodeURIComponent(reference)}`)
.then(response => response.text())
.then(response => {
if (!response.ok) {
throw new Error(`reference lookup failed with status ${response.status}`);
}
return response.text();
})
.then(html => {
tooltipContent = html;
if (document.getElementById(tooltipId)) {
document.getElementById(tooltipId).innerHTML = html;
}
})
.catch(() => {
tooltipContent = '<div class="text-xs opacity-70">Reference unavailable.</div>';
if (document.getElementById(tooltipId)) {
document.getElementById(tooltipId).innerHTML = tooltipContent;
}
});
} else if (tooltip) {
// Set content if already loaded

View File

@@ -1,3 +1,11 @@
<div>{{entity.name}}</div>
<div>{{entity.description}}</div>
<div>{{entity.updated_at|datetimeformat(format="short", tz=user.timezone)}} </div>
{% if text_chunk %}
<div class="font-semibold">Chunk Reference</div>
<div class="text-sm whitespace-pre-wrap">{{text_chunk.chunk}}</div>
<div class="text-xs opacity-70">{{text_chunk_updated_at}}</div>
{% elif entity %}
<div class="font-semibold">{{entity.name}}</div>
<div class="text-sm">{{entity.description}}</div>
<div class="text-xs opacity-70">{{entity_updated_at}}</div>
{% else %}
<div class="text-xs opacity-70">Reference unavailable.</div>
{% endif %}

View File

@@ -4,7 +4,8 @@
</div>
</div>
<div class="chat chat-start">
<div hx-ext="sse" sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
<div id="ai-stream-{{user_message.id}}" hx-ext="sse"
sse-connect="/chat/response-stream?message_id={{user_message.id}}" sse-close="close_stream"
hx-swap="beforeend">
<div class="chat-bubble">
<span class="loading loading-dots loading-sm loading-id-{{user_message.id}}"></span>
@@ -27,13 +28,22 @@
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
});
document.body.addEventListener('htmx:sseClose', function () {
document.body.addEventListener('htmx:sseClose', function (e) {
const msgId = '{{ user_message.id }}';
const streamEl = document.getElementById('ai-stream-' + msgId);
if (streamEl && e.target !== streamEl) return;
const el = document.getElementById('ai-message-content-' + msgId);
if (el && window.markdownBuffer[msgId]) {
el.innerHTML = marked.parse(window.markdownBuffer[msgId].replace(/\\n/g, '\n'));
delete window.markdownBuffer[msgId];
if (typeof window.scrollChatToBottom === "function") window.scrollChatToBottom();
}
if (streamEl) {
streamEl.removeAttribute('sse-connect');
streamEl.removeAttribute('sse-close');
streamEl.removeAttribute('hx-ext');
}
});
</script>

View File

@@ -1,4 +1,4 @@
<nav class="sticky top-0 z-10 nb-panel nb-panel-canvas border-t-0 border-l-0">
<nav class="sticky top-0 z-10 nb-panel nb-panel-canvas border-t-0" style="border-left: 0">
<div class="container mx-auto navbar">
<div class="mr-2 flex-1">
{% block navbar_search %}
@@ -11,4 +11,4 @@
</ul>
</div>
</div>
</nav>
</nav>

View File

@@ -2,7 +2,7 @@
{% block dashboard_header %}
<h1 class="text-xl font-extrabold tracking-tight">Dashboard</h1>
<button class="nb-btn nb-cta" hx-get="/ingress-form" hx-target="#modal" hx-swap="innerHTML">
<button class="nb-btn nb-cta" hx-get="/ingest-form" hx-target="#modal" hx-swap="innerHTML">
{% include "icons/send_icon.html" %}
<span class="ml-2">Add Content</span>
</button>

View File

@@ -0,0 +1 @@
<a class="btn btn-primary" hx-get="/ingest-form" hx-target="#modal" hx-swap="innerHTML">Add Content</a>

View File

@@ -3,7 +3,7 @@
{% block modal_class %}max-w-3xl{% endblock %}
{% block form_attributes %}
hx-post="/ingress-form"
hx-post="/ingest-form"
enctype="multipart/form-data"
{% endblock %}

View File

@@ -18,7 +18,7 @@
</li>
{% endfor %}
<li>
<button class="nb-btn nb-cta w-full flex items-center gap-3 justify-start mt-2" hx-get="/ingress-form"
<button class="nb-btn nb-cta w-full flex items-center gap-3 justify-start mt-2" hx-get="/ingest-form"
hx-target="#modal" hx-swap="innerHTML">{% include "icons/send_icon.html" %} Add
Content</button>
</li>

View File

@@ -38,3 +38,6 @@ retrieval-pipeline = { path = "../retrieval-pipeline" }
[features]
docker = []
[dev-dependencies]
common = { path = "../common", features = ["test-utils"] }

View File

@@ -1,6 +1,6 @@
[package]
name = "main"
version = "1.0.1"
version = "1.0.2"
edition = "2021"
repository = "https://github.com/perstarkse/minne"
license = "AGPL-3.0-or-later"
@@ -30,6 +30,7 @@ retrieval-pipeline = { path = "../retrieval-pipeline" }
[dev-dependencies]
tower = "0.5"
uuid = { workspace = true }
common = { path = "../common", features = ["test-utils"] }
[[bin]]
name = "server"

View File

@@ -217,8 +217,16 @@ struct AppState {
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, http::StatusCode, Router};
use common::storage::store::StorageManager;
use axum::{
body::Body,
http::{header, Request, StatusCode},
response::Response,
Router,
};
use common::storage::{
store::StorageManager,
types::{system_settings::SystemSettings, user::User},
};
use common::utils::config::{AppConfig, PdfIngestMode, StorageKind};
use std::{path::Path, sync::Arc};
use tower::ServiceExt;
@@ -241,11 +249,11 @@ mod tests {
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn smoke_startup_with_in_memory_surrealdb() {
async fn build_test_app() -> (Router, Arc<SurrealDbClient>, std::path::PathBuf) {
let namespace = "test_ns";
let database = format!("test_db_{}", Uuid::new_v4());
let data_dir = std::env::temp_dir().join(format!("minne_smoke_{}", Uuid::new_v4()));
tokio::fs::create_dir_all(&data_dir)
.await
.expect("failed to create temp data directory");
@@ -304,6 +312,66 @@ mod tests {
html_state,
});
(app, db, data_dir)
}
fn assert_redirect_to(response: &Response, expected_location: &str) {
assert!(response.status().is_redirection());
let location = response
.headers()
.get(header::LOCATION)
.expect("redirect should contain a Location header")
.to_str()
.expect("location header must be valid utf-8");
assert_eq!(location, expected_location);
}
fn extract_session_cookie(response: &Response) -> String {
let cookie_header = response
.headers()
.get_all(header::SET_COOKIE)
.iter()
.map(|value| {
value
.to_str()
.expect("set-cookie header must be valid utf-8")
.split(';')
.next()
.expect("set-cookie should include key=value pair")
.to_string()
})
.collect::<Vec<_>>();
assert!(
!cookie_header.is_empty(),
"login response should set at least one cookie"
);
cookie_header.join("; ")
}
async fn sign_in_and_get_cookie(app: &Router, email: &str, password: &str) -> String {
let response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/signin")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(format!("email={email}&password={password}")))
.expect("signin request"),
)
.await
.expect("signin response");
assert_redirect_to(&response, "/");
extract_session_cookie(&response)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn smoke_startup_with_in_memory_surrealdb() {
let (app, _db, data_dir) = build_test_app().await;
let response = app
.clone()
.oneshot(
@@ -329,4 +397,182 @@ mod tests {
tokio::fs::remove_dir_all(&data_dir).await.ok();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn admin_route_enforces_unauth_non_admin_and_admin_access() {
let (app, db, data_dir) = build_test_app().await;
let admin = User::create_new(
"admin_user".to_string(),
"admin_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("admin user should be created");
let non_admin = User::create_new(
"member_user".to_string(),
"member_password".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("non-admin user should be created");
assert!(admin.admin, "first user should become admin");
assert!(!non_admin.admin, "second user should not be admin");
let unauth_response = app
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.body(Body::empty())
.expect("unauth admin request"),
)
.await
.expect("unauth admin response");
assert_redirect_to(&unauth_response, "/signin");
let non_admin_cookie = sign_in_and_get_cookie(&app, "member_user", "member_password").await;
let non_admin_response = app
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.header(header::COOKIE, non_admin_cookie)
.body(Body::empty())
.expect("non-admin request"),
)
.await
.expect("non-admin response");
assert_redirect_to(&non_admin_response, "/");
let admin_cookie = sign_in_and_get_cookie(&app, "admin_user", "admin_password").await;
let admin_response = app
.clone()
.oneshot(
Request::builder()
.uri("/admin")
.header(header::COOKIE, admin_cookie)
.body(Body::empty())
.expect("admin request"),
)
.await
.expect("admin response");
assert_eq!(admin_response.status(), StatusCode::OK);
tokio::fs::remove_dir_all(&data_dir).await.ok();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn admin_patch_blocks_non_admin_and_unauth_before_side_effects() {
let (app, db, data_dir) = build_test_app().await;
User::create_new(
"admin_user_patch".to_string(),
"admin_password_patch".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("admin user should be created");
User::create_new(
"member_user_patch".to_string(),
"member_password_patch".to_string(),
&db,
"UTC".to_string(),
"system".to_string(),
)
.await
.expect("non-admin user should be created");
let initial_settings = SystemSettings::get_current(&db)
.await
.expect("settings should be available");
let patch_body = if initial_settings.registrations_enabled {
String::new()
} else {
"registration_open=on".to_string()
};
let expected_after_admin_patch = !initial_settings.registrations_enabled;
let unauth_patch_response = app
.clone()
.oneshot(
Request::builder()
.method("PATCH")
.uri("/toggle-registrations")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(patch_body.clone()))
.expect("unauth patch request"),
)
.await
.expect("unauth patch response");
assert_redirect_to(&unauth_patch_response, "/signin");
let settings_after_unauth = SystemSettings::get_current(&db)
.await
.expect("settings should still be available");
assert_eq!(
settings_after_unauth.registrations_enabled,
initial_settings.registrations_enabled
);
let non_admin_cookie =
sign_in_and_get_cookie(&app, "member_user_patch", "member_password_patch").await;
let non_admin_patch_response = app
.clone()
.oneshot(
Request::builder()
.method("PATCH")
.uri("/toggle-registrations")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(header::COOKIE, non_admin_cookie)
.body(Body::from(patch_body.clone()))
.expect("non-admin patch request"),
)
.await
.expect("non-admin patch response");
assert_redirect_to(&non_admin_patch_response, "/");
let settings_after_non_admin = SystemSettings::get_current(&db)
.await
.expect("settings should still be available");
assert_eq!(
settings_after_non_admin.registrations_enabled,
initial_settings.registrations_enabled
);
let admin_cookie =
sign_in_and_get_cookie(&app, "admin_user_patch", "admin_password_patch").await;
let admin_patch_response = app
.clone()
.oneshot(
Request::builder()
.method("PATCH")
.uri("/toggle-registrations")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(header::COOKIE, admin_cookie)
.body(Body::from(patch_body))
.expect("admin patch request"),
)
.await
.expect("admin patch response");
assert_eq!(admin_patch_response.status(), StatusCode::OK);
let settings_after_admin = SystemSettings::get_current(&db)
.await
.expect("settings should still be available");
assert_eq!(
settings_after_admin.registrations_enabled,
expected_after_admin_patch
);
tokio::fs::remove_dir_all(&data_dir).await.ok();
}
}

View File

@@ -23,4 +23,7 @@ uuid = { workspace = true }
fastembed = { workspace = true }
clap = { version = "4.4", features = ["derive"] }
common = { path = "../common" }
[dev-dependencies]
common = { path = "../common", features = ["test-utils"] }

View File

@@ -61,8 +61,8 @@ pub fn chunks_to_chat_context(chunks: &[crate::RetrievedChunk]) -> Value {
.iter()
.map(|chunk| {
serde_json::json!({
"id": chunk.chunk.id,
"content": chunk.chunk.chunk,
"source_id": chunk.chunk.source_id,
"score": round_score(chunk.score),
})
})
@@ -117,7 +117,7 @@ pub fn create_chat_request(
.build()
}
pub async fn process_llm_response(
pub fn process_llm_response(
response: CreateChatCompletionResponse,
) -> Result<LLMResponseFormat, AppError> {
response