mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-30 03:10:45 +02:00
chore: harden analytics, conversation access, and per-user file dedup
Use UPSERT for analytics counters, enforce message ownership in SQL, return NotFound when patch_title updates nothing, scope file dedup by user_id with a composite unique index, and expand tests for auth, ordering, and edge cases.
This commit is contained in:
@@ -0,0 +1,3 @@
|
|||||||
|
-- Per-user deduplication: same SHA256 may exist for different users.
|
||||||
|
REMOVE INDEX IF EXISTS file_sha256_idx ON file;
|
||||||
|
DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
{"schemas":"--- original\n+++ modified\n@@ -45,9 +45,8 @@\n DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;\n\n-# Indexes based on usage (get_by_sha, potentially user lookups)\n-# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates\n-DEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;\n+# Indexes based on usage (get_by_sha scoped by user_id, user lookups)\n+DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;\n DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;\n\n # Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n","events":null}
|
||||||
@@ -13,7 +13,6 @@ DEFINE FIELD IF NOT EXISTS file_name ON file TYPE string;
|
|||||||
DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;
|
DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;
|
||||||
DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;
|
DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;
|
||||||
|
|
||||||
# Indexes based on usage (get_by_sha, potentially user lookups)
|
# Indexes based on usage (get_by_sha scoped by user_id, user lookups)
|
||||||
# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates
|
DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;
|
||||||
DEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;
|
|
||||||
DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;
|
DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;
|
||||||
|
|||||||
@@ -22,24 +22,24 @@ impl StoredObject for Analytics {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Analytics {
|
impl Analytics {
|
||||||
|
const RECORD_ID: &'static str = "current";
|
||||||
|
|
||||||
|
/// Ensures the singleton analytics record exists (idempotent).
|
||||||
|
///
|
||||||
|
/// Production databases are also seeded by `20250503_215025_initial_setup.surql`;
|
||||||
|
/// this uses an atomic `UPSERT` for tests and recovery.
|
||||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let analytics = db.get_item::<Self>("current").await?;
|
let analytics: Option<Self> = db
|
||||||
|
.client
|
||||||
if analytics.is_none() {
|
.query(
|
||||||
let created_analytics = Analytics {
|
"UPSERT type::thing('analytics', $id) SET visitors = visitors ?? 0, page_loads = page_loads ?? 0 RETURN AFTER",
|
||||||
id: "current".to_string(),
|
)
|
||||||
visitors: 0,
|
.bind(("id", Self::RECORD_ID))
|
||||||
page_loads: 0,
|
.await?
|
||||||
};
|
.take(0)?;
|
||||||
|
|
||||||
let stored: Option<Self> = db.store_item(created_analytics).await?;
|
|
||||||
return stored.ok_or(AppError::Validation(
|
|
||||||
"failed to initialize analytics".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
analytics.ok_or(AppError::Validation(
|
analytics.ok_or(AppError::Validation(
|
||||||
"Failed to initialize analytics".into(),
|
"failed to initialize analytics".into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
@@ -50,7 +50,10 @@ impl Analytics {
|
|||||||
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.client
|
.client
|
||||||
.query("UPDATE type::thing('analytics', 'current') SET visitors += 1 RETURN AFTER")
|
.query(
|
||||||
|
"UPSERT type::thing('analytics', $id) SET visitors = (visitors ?? 0) + 1, page_loads = page_loads ?? 0 RETURN AFTER",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
@@ -60,7 +63,10 @@ impl Analytics {
|
|||||||
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
|
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||||
let updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.client
|
.client
|
||||||
.query("UPDATE type::thing('analytics', 'current') SET page_loads += 1 RETURN AFTER")
|
.query(
|
||||||
|
"UPSERT type::thing('analytics', $id) SET page_loads = (page_loads ?? 0) + 1, visitors = visitors ?? 0 RETURN AFTER",
|
||||||
|
)
|
||||||
|
.bind(("id", Self::RECORD_ID))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
@@ -227,6 +233,53 @@ mod tests {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_increment_visitors_without_prior_init() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let analytics = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(analytics.visitors, 1);
|
||||||
|
assert_eq!(analytics.page_loads, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_increment_page_loads_without_prior_init() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||||
|
assert_eq!(analytics.page_loads, 1);
|
||||||
|
assert_eq!(analytics.visitors, 0);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_visitor_and_page_load_increments_are_independent() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let after_visitors = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(after_visitors.visitors, 1);
|
||||||
|
assert_eq!(after_visitors.page_loads, 0);
|
||||||
|
|
||||||
|
let after_page_load = Analytics::increment_page_loads(&db).await?;
|
||||||
|
assert_eq!(after_page_load.visitors, 1);
|
||||||
|
assert_eq!(after_page_load.page_loads, 1);
|
||||||
|
|
||||||
|
let after_second_visitor = Analytics::increment_visitors(&db).await?;
|
||||||
|
assert_eq!(after_second_visitor.visitors, 2);
|
||||||
|
assert_eq!(after_second_visitor.page_loads, 1);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||||
// Setup in-memory database for testing
|
// Setup in-memory database for testing
|
||||||
|
|||||||
@@ -88,10 +88,15 @@ impl Conversation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let messages:Vec<Message> = db.client.
|
let messages: Vec<Message> = db
|
||||||
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
|
.client
|
||||||
bind(("table_name", Message::table_name())).
|
.query(
|
||||||
bind(("conversation_id", conversation_id.to_string()))
|
"SELECT * FROM type::table($message_table) WHERE conversation_id = $conversation_id AND type::thing($conversation_table, $conversation_id).user_id = $user_id ORDER BY updated_at",
|
||||||
|
)
|
||||||
|
.bind(("message_table", Message::table_name()))
|
||||||
|
.bind(("conversation_table", Self::table_name()))
|
||||||
|
.bind(("conversation_id", conversation_id.to_string()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
.await?
|
.await?
|
||||||
.take(0)?;
|
.take(0)?;
|
||||||
|
|
||||||
@@ -114,7 +119,7 @@ impl Conversation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let _updated: Option<Self> = db
|
let updated: Option<Self> = db
|
||||||
.update((Self::table_name(), id))
|
.update((Self::table_name(), id))
|
||||||
.patch(PatchOp::replace("/title", new_title.to_string()))
|
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||||
.patch(PatchOp::replace(
|
.patch(PatchOp::replace(
|
||||||
@@ -123,6 +128,10 @@ impl Conversation {
|
|||||||
))
|
))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
if updated.is_none() {
|
||||||
|
return Err(AppError::NotFound("conversation not found".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +161,24 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
const MESSAGE_QUERY_FOR_OWNER: &str = "SELECT * FROM type::table($message_table) WHERE conversation_id = $conversation_id AND type::thing($conversation_table, $conversation_id).user_id = $user_id ORDER BY updated_at";
|
||||||
|
|
||||||
|
async fn fetch_messages_for_owner(
|
||||||
|
db: &SurrealDbClient,
|
||||||
|
conversation_id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<Vec<Message>, AppError> {
|
||||||
|
db.client
|
||||||
|
.query(MESSAGE_QUERY_FOR_OWNER)
|
||||||
|
.bind(("message_table", Message::table_name()))
|
||||||
|
.bind(("conversation_table", Conversation::table_name()))
|
||||||
|
.bind(("conversation_id", conversation_id.to_string()))
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.await?
|
||||||
|
.take(0)
|
||||||
|
.map_err(AppError::from)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_conversation() -> anyhow::Result<()> {
|
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
@@ -488,4 +515,146 @@ mod tests {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sidebar_conversation_deserializes_plain_string_id() {
|
||||||
|
let item: SidebarConversation =
|
||||||
|
serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#).unwrap();
|
||||||
|
assert_eq!(item.id, "conv-plain");
|
||||||
|
assert_eq!(item.title, "My chat");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sidebar_conversation_deserializes_id_from_db_record() {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.expect("Failed to start in-memory surrealdb");
|
||||||
|
|
||||||
|
let owner = "sidebar_owner";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "Sidebar title".to_string());
|
||||||
|
let expected_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation)
|
||||||
|
.await
|
||||||
|
.expect("Failed to store conversation");
|
||||||
|
|
||||||
|
let items = Conversation::get_user_sidebar_conversations(owner, &db)
|
||||||
|
.await
|
||||||
|
.expect("Failed to load sidebar");
|
||||||
|
assert_eq!(items.len(), 1);
|
||||||
|
assert_eq!(items[0].id, expected_id);
|
||||||
|
assert_eq!(items[0].title, "Sidebar title");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_message_query_filters_by_owner_user_id_in_sql() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let owner = "owner_user";
|
||||||
|
let intruder = "intruder_user";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "Private".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
db.store_item(Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::User,
|
||||||
|
"secret message".to_string(),
|
||||||
|
None,
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let owner_messages =
|
||||||
|
fetch_messages_for_owner(&db, &conversation_id, owner).await?;
|
||||||
|
assert_eq!(owner_messages.len(), 1);
|
||||||
|
assert_eq!(owner_messages[0].content, "secret message");
|
||||||
|
|
||||||
|
let intruder_messages =
|
||||||
|
fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
|
||||||
|
assert!(
|
||||||
|
intruder_messages.is_empty(),
|
||||||
|
"SQL owner filter must not return messages for a non-owner user_id"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_complete_conversation_orders_messages_by_updated_at() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let user_id = "order_user";
|
||||||
|
let conversation = Conversation::new(user_id.to_string(), "Ordered".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
|
||||||
|
let base = Utc::now();
|
||||||
|
let mut first = Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::User,
|
||||||
|
"first".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
first.updated_at = base - chrono::Duration::minutes(20);
|
||||||
|
|
||||||
|
let mut second = Message::new(
|
||||||
|
conversation_id.clone(),
|
||||||
|
MessageRole::AI,
|
||||||
|
"second".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
second.updated_at = base - chrono::Duration::minutes(5);
|
||||||
|
|
||||||
|
db.store_item(first).await?;
|
||||||
|
db.store_item(second).await?;
|
||||||
|
|
||||||
|
let (_, messages) =
|
||||||
|
Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?;
|
||||||
|
|
||||||
|
assert_eq!(messages.len(), 2);
|
||||||
|
assert_eq!(messages[0].content, "first");
|
||||||
|
assert_eq!(messages[1].content, "second");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_patch_title_not_found_when_conversation_deleted() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let owner = "owner";
|
||||||
|
let conversation = Conversation::new(owner.to_string(), "To delete".to_string());
|
||||||
|
let conversation_id = conversation.id.clone();
|
||||||
|
db.store_item(conversation).await?;
|
||||||
|
db.delete_item::<Conversation>(&conversation_id).await?;
|
||||||
|
|
||||||
|
let result = Conversation::patch_title(&conversation_id, owner, "New title", &db).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
match result {
|
||||||
|
Err(AppError::NotFound(_)) => {}
|
||||||
|
other => anyhow::bail!("expected NotFound, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_conversation_new_initializes_timestamps_and_id() {
|
||||||
|
let before = Utc::now();
|
||||||
|
let conversation = Conversation::new("user".to_string(), "Title".to_string());
|
||||||
|
let after = Utc::now();
|
||||||
|
|
||||||
|
assert!(!conversation.id.is_empty());
|
||||||
|
assert!(conversation.created_at >= before && conversation.created_at <= after);
|
||||||
|
assert_eq!(conversation.created_at, conversation.updated_at);
|
||||||
|
assert_eq!(conversation.user_id, "user");
|
||||||
|
assert_eq!(conversation.title, "Title");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use mime_guess::from_path;
|
|||||||
use object_store::Error as ObjectStoreError;
|
use object_store::Error as ObjectStoreError;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::{
|
use std::{
|
||||||
io::{BufReader, Read},
|
io::{BufReader, Read, Seek, SeekFrom},
|
||||||
path::Path,
|
path::Path,
|
||||||
};
|
};
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
@@ -73,57 +73,8 @@ impl FileInfo {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculates the SHA256 hash of the given file.
|
fn sanitize_name_segment(segment: &str) -> String {
|
||||||
///
|
segment
|
||||||
/// # Arguments
|
|
||||||
/// * `file` - The file to hash.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
|
|
||||||
#[allow(clippy::indexing_slicing)]
|
|
||||||
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
|
|
||||||
let mut file_clone = file.as_file().try_clone()?;
|
|
||||||
|
|
||||||
let digest = task::spawn_blocking(move || -> Result<_, std::io::Error> {
|
|
||||||
let mut reader = BufReader::new(&mut file_clone);
|
|
||||||
let mut hasher = Sha256::new();
|
|
||||||
let mut buffer = [0u8; 8192]; // 8KB buffer
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let n = reader.read(&mut buffer)?;
|
|
||||||
if n == 0 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
hasher.update(&buffer[..n]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok::<_, std::io::Error>(hasher.finalize())
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.map_err(std::io::Error::other)??;
|
|
||||||
|
|
||||||
Ok(format!("{digest:x}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
|
||||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
|
|
||||||
fn sanitize_file_name(file_name: &str) -> String {
|
|
||||||
if let Some(idx) = file_name.rfind('.') {
|
|
||||||
let (name, ext) = file_name.split_at(idx);
|
|
||||||
let sanitized_name: String = name
|
|
||||||
.chars()
|
|
||||||
.map(|c| {
|
|
||||||
if c.is_ascii_alphanumeric() || c == '_' {
|
|
||||||
c
|
|
||||||
} else {
|
|
||||||
'_'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
format!("{sanitized_name}{ext}")
|
|
||||||
} else {
|
|
||||||
// No extension
|
|
||||||
file_name
|
|
||||||
.chars()
|
.chars()
|
||||||
.map(|c| {
|
.map(|c| {
|
||||||
if c.is_ascii_alphanumeric() || c == '_' {
|
if c.is_ascii_alphanumeric() || c == '_' {
|
||||||
@@ -134,21 +85,63 @@ impl FileInfo {
|
|||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
||||||
|
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores in
|
||||||
|
/// both the stem and extension.
|
||||||
|
fn sanitize_file_name(file_name: &str) -> String {
|
||||||
|
if let Some(idx) = file_name.rfind('.') {
|
||||||
|
let name = Self::sanitize_name_segment(&file_name[..idx]);
|
||||||
|
let ext = Self::sanitize_name_segment(&file_name[idx + 1..]);
|
||||||
|
if ext.is_empty() {
|
||||||
|
name
|
||||||
|
} else {
|
||||||
|
format!("{name}.{ext}")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Self::sanitize_name_segment(file_name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieves a `FileInfo` by SHA256.
|
/// Reads the file once and returns its SHA-256 hex digest and raw bytes.
|
||||||
///
|
#[allow(clippy::indexing_slicing)]
|
||||||
/// # Arguments
|
async fn read_and_hash(file: &NamedTempFile) -> Result<(String, Vec<u8>), FileError> {
|
||||||
/// * `sha256` - The SHA256 hash string.
|
let mut file_clone = file.as_file().try_clone()?;
|
||||||
/// * `db_client` - Reference to the SurrealDbClient.
|
|
||||||
///
|
task::spawn_blocking(move || -> Result<_, std::io::Error> {
|
||||||
/// # Returns
|
file_clone.seek(SeekFrom::Start(0))?;
|
||||||
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
|
let mut reader = BufReader::new(&mut file_clone);
|
||||||
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
let mut hasher = Sha256::new();
|
||||||
|
let mut buffer = [0u8; 8192];
|
||||||
|
let mut bytes = Vec::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let n = reader.read(&mut buffer)?;
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
hasher.update(&buffer[..n]);
|
||||||
|
bytes.extend_from_slice(&buffer[..n]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((format!("{:x}", hasher.finalize()), bytes))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(std::io::Error::other)?
|
||||||
|
.map_err(FileError::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieves a `FileInfo` for this user by SHA256.
|
||||||
|
async fn get_by_sha(
|
||||||
|
sha256: &str,
|
||||||
|
user_id: &str,
|
||||||
|
db_client: &SurrealDbClient,
|
||||||
|
) -> Result<FileInfo, FileError> {
|
||||||
let mut response = db_client
|
let mut response = db_client
|
||||||
.client
|
.client
|
||||||
.query("SELECT * FROM file WHERE sha256 = $sha256 LIMIT 1")
|
.query("SELECT * FROM file WHERE sha256 = $sha256 AND user_id = $user_id LIMIT 1")
|
||||||
.bind(("sha256", sha256.to_owned()))
|
.bind(("sha256", sha256.to_owned()))
|
||||||
|
.bind(("user_id", user_id.to_owned()))
|
||||||
.await?;
|
.await?;
|
||||||
let response: Vec<FileInfo> = response.take(0)?;
|
let response: Vec<FileInfo> = response.take(0)?;
|
||||||
|
|
||||||
@@ -197,33 +190,33 @@ impl FileInfo {
|
|||||||
.ok_or(FileError::MissingFileName)?;
|
.ok_or(FileError::MissingFileName)?;
|
||||||
let original_file_name = file_name.clone();
|
let original_file_name = file_name.clone();
|
||||||
|
|
||||||
// Calculate SHA256
|
let (sha256, bytes) = Self::read_and_hash(&file).await?;
|
||||||
let sha256 = Self::get_sha(&file).await?;
|
|
||||||
|
|
||||||
// Early return if file already exists
|
if let Ok(existing_file) = Self::get_by_sha(&sha256, user_id, db_client).await {
|
||||||
match Self::get_by_sha(&sha256, db_client).await {
|
info!(
|
||||||
Ok(existing_file) => {
|
"File already exists for user {} with SHA256: {}",
|
||||||
info!("File already exists with SHA256: {}", sha256);
|
user_id, sha256
|
||||||
|
);
|
||||||
return Ok(existing_file);
|
return Ok(existing_file);
|
||||||
}
|
}
|
||||||
Err(FileError::FileNotFound(_)) => (), // Expected case for new files
|
|
||||||
Err(e) => return Err(e), // Propagate unexpected errors
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate UUID and prepare paths
|
|
||||||
let uuid = Uuid::new_v4();
|
let uuid = Uuid::new_v4();
|
||||||
let sanitized_file_name = Self::sanitize_file_name(&file_name);
|
let sanitized_file_name = Self::sanitize_file_name(&file_name);
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
let storage_prefix = format!("{user_id}/{uuid}");
|
||||||
|
|
||||||
let path =
|
let path = Self::persist_bytes_with_storage(
|
||||||
Self::persist_file_with_storage(&uuid, file, &sanitized_file_name, user_id, storage)
|
&storage_prefix,
|
||||||
|
&sanitized_file_name,
|
||||||
|
bytes,
|
||||||
|
storage,
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Create FileInfo struct
|
|
||||||
let file_info = FileInfo {
|
let file_info = FileInfo {
|
||||||
id: uuid.to_string(),
|
id: uuid.to_string(),
|
||||||
user_id: user_id.to_string(),
|
user_id: user_id.to_string(),
|
||||||
sha256,
|
sha256: sha256.clone(),
|
||||||
file_name: original_file_name,
|
file_name: original_file_name,
|
||||||
path,
|
path,
|
||||||
mime_type: Self::guess_mime_type(Path::new(&file_name)),
|
mime_type: Self::guess_mime_type(Path::new(&file_name)),
|
||||||
@@ -231,13 +224,22 @@ impl FileInfo {
|
|||||||
updated_at: now,
|
updated_at: now,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Store in database
|
match db_client.store_item(file_info.clone()).await {
|
||||||
db_client
|
Ok(_) => Ok(file_info),
|
||||||
.store_item(file_info.clone())
|
Err(e) => {
|
||||||
.await
|
if let Err(cleanup_err) = storage.delete_prefix(&storage_prefix).await {
|
||||||
.map_err(FileError::SurrealError)?;
|
tracing::warn!(
|
||||||
|
prefix = %storage_prefix,
|
||||||
Ok(file_info)
|
error = %cleanup_err,
|
||||||
|
"failed to remove orphaned file after database error"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if let Ok(existing) = Self::get_by_sha(&sha256, user_id, db_client).await {
|
||||||
|
return Ok(existing);
|
||||||
|
}
|
||||||
|
Err(FileError::SurrealError(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Delete a FileInfo by ID using StorageManager for storage operations.
|
/// Delete a FileInfo by ID using StorageManager for storage operations.
|
||||||
@@ -294,29 +296,16 @@ impl FileInfo {
|
|||||||
.map_err(AppError::Storage)
|
.map_err(AppError::Storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Persist file to storage using StorageManager.
|
/// Persist bytes to storage using StorageManager.
|
||||||
///
|
async fn persist_bytes_with_storage(
|
||||||
/// # Arguments
|
storage_prefix: &str,
|
||||||
/// * `uuid` - The UUID for the file
|
|
||||||
/// * `file` - The temporary file to persist
|
|
||||||
/// * `file_name` - The name of the file
|
|
||||||
/// * `user_id` - The user ID
|
|
||||||
/// * `storage` - A StorageManager instance for storage operations
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// * `Result<String, FileError>` - The logical object location or an error.
|
|
||||||
async fn persist_file_with_storage(
|
|
||||||
uuid: &Uuid,
|
|
||||||
file: NamedTempFile,
|
|
||||||
file_name: &str,
|
file_name: &str,
|
||||||
user_id: &str,
|
bytes: Vec<u8>,
|
||||||
storage: &StorageManager,
|
storage: &StorageManager,
|
||||||
) -> Result<String, FileError> {
|
) -> Result<String, FileError> {
|
||||||
// Logical object location relative to the store root
|
let location = format!("{storage_prefix}/{file_name}");
|
||||||
let location = format!("{user_id}/{uuid}/{file_name}");
|
|
||||||
info!("Persisting to object location: {}", location);
|
info!("Persisting to object location: {}", location);
|
||||||
|
|
||||||
let bytes = tokio::fs::read(file.path()).await?;
|
|
||||||
storage
|
storage
|
||||||
.put(&location, bytes.into())
|
.put(&location, bytes.into())
|
||||||
.await
|
.await
|
||||||
@@ -332,6 +321,7 @@ mod tests {
|
|||||||
use anyhow::{self, Context};
|
use anyhow::{self, Context};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::error::AppError;
|
||||||
use crate::storage::store::testing::TestStorageManager;
|
use crate::storage::store::testing::TestStorageManager;
|
||||||
use axum::http::HeaderMap;
|
use axum::http::HeaderMap;
|
||||||
use axum_typed_multipart::FieldMetadata;
|
use axum_typed_multipart::FieldMetadata;
|
||||||
@@ -362,6 +352,26 @@ mod tests {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_test_file_without_name(content: &[u8]) -> anyhow::Result<FieldData<NamedTempFile>> {
|
||||||
|
let mut temp_file =
|
||||||
|
NamedTempFile::new().with_context(|| "Failed to create temp file".to_string())?;
|
||||||
|
temp_file
|
||||||
|
.write_all(content)
|
||||||
|
.with_context(|| "Failed to write to temp file".to_string())?;
|
||||||
|
|
||||||
|
let metadata = FieldMetadata {
|
||||||
|
name: Some("file".to_string()),
|
||||||
|
file_name: None,
|
||||||
|
content_type: None,
|
||||||
|
headers: HeaderMap::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(FieldData {
|
||||||
|
metadata,
|
||||||
|
contents: temp_file,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
@@ -669,6 +679,10 @@ mod tests {
|
|||||||
FileInfo::sanitize_file_name("../dangerous.txt"),
|
FileInfo::sanitize_file_name("../dangerous.txt"),
|
||||||
"___dangerous.txt"
|
"___dangerous.txt"
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
FileInfo::sanitize_file_name("file.evil/../../x"),
|
||||||
|
"file_evil_____._x"
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -680,8 +694,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
|
||||||
// Try to find a file with a SHA that doesn't exist
|
let result = FileInfo::get_by_sha("nonexistent_sha_hash", "user123", &db).await;
|
||||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await;
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -716,11 +729,173 @@ mod tests {
|
|||||||
.expect("Failed to store test file info");
|
.expect("Failed to store test file info");
|
||||||
|
|
||||||
let malicious_sha = "known_sha_value' OR true --";
|
let malicious_sha = "known_sha_value' OR true --";
|
||||||
let result = FileInfo::get_by_sha(malicious_sha, &db).await;
|
let result = FileInfo::get_by_sha(malicious_sha, "user123", &db).await;
|
||||||
|
|
||||||
assert!(matches!(result, Err(FileError::FileNotFound(_))));
|
assert!(matches!(result, Err(FileError::FileNotFound(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_duplicate_detection_is_per_user() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||||
|
db.apply_migrations()
|
||||||
|
.await
|
||||||
|
.with_context(|| "migrations".to_string())?;
|
||||||
|
|
||||||
|
let content = b"shared content across users";
|
||||||
|
let test_storage = TestStorageManager::new_memory()
|
||||||
|
.await
|
||||||
|
.with_context(|| "create test storage manager".to_string())?;
|
||||||
|
|
||||||
|
let user_a = "user_a";
|
||||||
|
let user_b = "user_b";
|
||||||
|
|
||||||
|
let file_a = FileInfo::new_with_storage(
|
||||||
|
create_test_file(content, "a.txt")?,
|
||||||
|
&db,
|
||||||
|
user_a,
|
||||||
|
test_storage.storage(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let file_b = FileInfo::new_with_storage(
|
||||||
|
create_test_file(content, "b.txt")?,
|
||||||
|
&db,
|
||||||
|
user_b,
|
||||||
|
test_storage.storage(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(file_a.sha256, file_b.sha256);
|
||||||
|
assert_ne!(file_a.id, file_b.id);
|
||||||
|
assert_eq!(file_a.user_id, user_a);
|
||||||
|
assert_eq!(file_b.user_id, user_b);
|
||||||
|
assert_ne!(file_a.path, file_b.path);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_by_sha_not_found_for_other_user() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
|
||||||
|
let now = Utc::now();
|
||||||
|
let sha = "abc123sha";
|
||||||
|
let owner = "owner_user";
|
||||||
|
let other = "other_user";
|
||||||
|
|
||||||
|
db.store_item(FileInfo {
|
||||||
|
id: Uuid::new_v4().to_string(),
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
user_id: owner.to_string(),
|
||||||
|
sha256: sha.to_string(),
|
||||||
|
path: format!("{owner}/id/file.txt"),
|
||||||
|
file_name: "file.txt".to_string(),
|
||||||
|
mime_type: "text/plain".to_string(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let result = FileInfo::get_by_sha(sha, other, &db).await;
|
||||||
|
assert!(matches!(result, Err(FileError::FileNotFound(_))));
|
||||||
|
|
||||||
|
let found = FileInfo::get_by_sha(sha, owner, &db).await?;
|
||||||
|
assert_eq!(found.sha256, sha);
|
||||||
|
assert_eq!(found.user_id, owner);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_new_with_storage_missing_file_name() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
let test_storage = TestStorageManager::new_memory().await?;
|
||||||
|
|
||||||
|
let field_data = create_test_file_without_name(b"data")?;
|
||||||
|
let result =
|
||||||
|
FileInfo::new_with_storage(field_data, &db, "user", test_storage.storage()).await;
|
||||||
|
|
||||||
|
assert!(matches!(result, Err(FileError::MissingFileName)));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_new_with_storage_empty_file() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
let test_storage = TestStorageManager::new_memory().await?;
|
||||||
|
|
||||||
|
let file_info = FileInfo::new_with_storage(
|
||||||
|
create_test_file(&[], "empty.bin")?,
|
||||||
|
&db,
|
||||||
|
"empty_user",
|
||||||
|
test_storage.storage(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert!(file_info.sha256.len() == 64);
|
||||||
|
let bytes = file_info
|
||||||
|
.get_content_with_storage(test_storage.storage())
|
||||||
|
.await
|
||||||
|
.map_err(AppError::from)?;
|
||||||
|
assert!(bytes.is_empty());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_duplicate_upload_persists_single_row_per_user_sha() -> anyhow::Result<()> {
|
||||||
|
let namespace = "test_ns";
|
||||||
|
let database = &Uuid::new_v4().to_string();
|
||||||
|
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||||
|
db.apply_migrations().await?;
|
||||||
|
let test_storage = TestStorageManager::new_memory().await?;
|
||||||
|
let storage = test_storage.storage();
|
||||||
|
let user_id = "dedup_user";
|
||||||
|
let content = b"dedup content";
|
||||||
|
|
||||||
|
let first = FileInfo::new_with_storage(
|
||||||
|
create_test_file(content, "first.txt")?,
|
||||||
|
&db,
|
||||||
|
user_id,
|
||||||
|
storage,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let second = FileInfo::new_with_storage(
|
||||||
|
create_test_file(content, "second.txt")?,
|
||||||
|
&db,
|
||||||
|
user_id,
|
||||||
|
storage,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(first.id, second.id);
|
||||||
|
assert_eq!(first.sha256, second.sha256);
|
||||||
|
|
||||||
|
let rows: Vec<FileInfo> = db
|
||||||
|
.client
|
||||||
|
.query("SELECT * FROM file WHERE user_id = $user_id AND sha256 = $sha256")
|
||||||
|
.bind(("user_id", user_id.to_string()))
|
||||||
|
.bind(("sha256", first.sha256.clone()))
|
||||||
|
.await?
|
||||||
|
.take(0)?;
|
||||||
|
assert_eq!(
|
||||||
|
rows.len(),
|
||||||
|
1,
|
||||||
|
"unique (user_id, sha256) index should keep a single row"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_manual_file_info_creation() {
|
async fn test_manual_file_info_creation() {
|
||||||
let namespace = "test_ns";
|
let namespace = "test_ns";
|
||||||
|
|||||||
Reference in New Issue
Block a user