mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-29 19:00:51 +02:00
chore: harden analytics, conversation access, and per-user file dedup
Use UPSERT for analytics counters, enforce message ownership in SQL, return NotFound when patch_title updates nothing, scope file dedup by user_id with a composite unique index, and expand tests for auth, ordering, and edge cases.
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
-- Per-user deduplication: same SHA256 may exist for different users.
|
||||
REMOVE INDEX IF EXISTS file_sha256_idx ON file;
|
||||
DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;
|
||||
@@ -0,0 +1 @@
|
||||
{"schemas":"--- original\n+++ modified\n@@ -45,9 +45,8 @@\n DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;\n DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;\n\n-# Indexes based on usage (get_by_sha, potentially user lookups)\n-# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates\n-DEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;\n+# Indexes based on usage (get_by_sha scoped by user_id, user lookups)\n+DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;\n DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;\n\n # Defines the schema for the 'ingestion_task' table (used by IngestionTask).\n","events":null}
|
||||
@@ -13,7 +13,6 @@ DEFINE FIELD IF NOT EXISTS file_name ON file TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS mime_type ON file TYPE string;
|
||||
DEFINE FIELD IF NOT EXISTS user_id ON file TYPE string;
|
||||
|
||||
# Indexes based on usage (get_by_sha, potentially user lookups)
|
||||
# Using UNIQUE based on the logic in FileInfo::new to prevent duplicates
|
||||
DEFINE INDEX IF NOT EXISTS file_sha256_idx ON file FIELDS sha256 UNIQUE;
|
||||
# Indexes based on usage (get_by_sha scoped by user_id, user lookups)
|
||||
DEFINE INDEX IF NOT EXISTS file_user_sha256_idx ON file FIELDS user_id, sha256 UNIQUE;
|
||||
DEFINE INDEX IF NOT EXISTS file_user_id_idx ON file FIELDS user_id;
|
||||
|
||||
@@ -22,24 +22,24 @@ impl StoredObject for Analytics {
|
||||
}
|
||||
|
||||
impl Analytics {
|
||||
const RECORD_ID: &'static str = "current";
|
||||
|
||||
/// Ensures the singleton analytics record exists (idempotent).
|
||||
///
|
||||
/// Production databases are also seeded by `20250503_215025_initial_setup.surql`;
|
||||
/// this uses an atomic `UPSERT` for tests and recovery.
|
||||
pub async fn ensure_initialized(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let analytics = db.get_item::<Self>("current").await?;
|
||||
|
||||
if analytics.is_none() {
|
||||
let created_analytics = Analytics {
|
||||
id: "current".to_string(),
|
||||
visitors: 0,
|
||||
page_loads: 0,
|
||||
};
|
||||
|
||||
let stored: Option<Self> = db.store_item(created_analytics).await?;
|
||||
return stored.ok_or(AppError::Validation(
|
||||
"failed to initialize analytics".into(),
|
||||
));
|
||||
}
|
||||
let analytics: Option<Self> = db
|
||||
.client
|
||||
.query(
|
||||
"UPSERT type::thing('analytics', $id) SET visitors = visitors ?? 0, page_loads = page_loads ?? 0 RETURN AFTER",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
analytics.ok_or(AppError::Validation(
|
||||
"Failed to initialize analytics".into(),
|
||||
"failed to initialize analytics".into(),
|
||||
))
|
||||
}
|
||||
pub async fn get_current(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
@@ -50,7 +50,10 @@ impl Analytics {
|
||||
pub async fn increment_visitors(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let updated: Option<Self> = db
|
||||
.client
|
||||
.query("UPDATE type::thing('analytics', 'current') SET visitors += 1 RETURN AFTER")
|
||||
.query(
|
||||
"UPSERT type::thing('analytics', $id) SET visitors = (visitors ?? 0) + 1, page_loads = page_loads ?? 0 RETURN AFTER",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
@@ -60,7 +63,10 @@ impl Analytics {
|
||||
pub async fn increment_page_loads(db: &SurrealDbClient) -> Result<Self, AppError> {
|
||||
let updated: Option<Self> = db
|
||||
.client
|
||||
.query("UPDATE type::thing('analytics', 'current') SET page_loads += 1 RETURN AFTER")
|
||||
.query(
|
||||
"UPSERT type::thing('analytics', $id) SET page_loads = (page_loads ?? 0) + 1, visitors = visitors ?? 0 RETURN AFTER",
|
||||
)
|
||||
.bind(("id", Self::RECORD_ID))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
@@ -227,6 +233,53 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_visitors_without_prior_init() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let analytics = Analytics::increment_visitors(&db).await?;
|
||||
assert_eq!(analytics.visitors, 1);
|
||||
assert_eq!(analytics.page_loads, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_increment_page_loads_without_prior_init() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let analytics = Analytics::increment_page_loads(&db).await?;
|
||||
assert_eq!(analytics.page_loads, 1);
|
||||
assert_eq!(analytics.visitors, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_visitor_and_page_load_increments_are_independent() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let after_visitors = Analytics::increment_visitors(&db).await?;
|
||||
assert_eq!(after_visitors.visitors, 1);
|
||||
assert_eq!(after_visitors.page_loads, 0);
|
||||
|
||||
let after_page_load = Analytics::increment_page_loads(&db).await?;
|
||||
assert_eq!(after_page_load.visitors, 1);
|
||||
assert_eq!(after_page_load.page_loads, 1);
|
||||
|
||||
let after_second_visitor = Analytics::increment_visitors(&db).await?;
|
||||
assert_eq!(after_second_visitor.visitors, 2);
|
||||
assert_eq!(after_second_visitor.page_loads, 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_current_nonexistent() -> anyhow::Result<()> {
|
||||
// Setup in-memory database for testing
|
||||
|
||||
@@ -88,10 +88,15 @@ impl Conversation {
|
||||
));
|
||||
}
|
||||
|
||||
let messages:Vec<Message> = db.client.
|
||||
query("SELECT * FROM type::table($table_name) WHERE conversation_id = $conversation_id ORDER BY updated_at").
|
||||
bind(("table_name", Message::table_name())).
|
||||
bind(("conversation_id", conversation_id.to_string()))
|
||||
let messages: Vec<Message> = db
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM type::table($message_table) WHERE conversation_id = $conversation_id AND type::thing($conversation_table, $conversation_id).user_id = $user_id ORDER BY updated_at",
|
||||
)
|
||||
.bind(("message_table", Message::table_name()))
|
||||
.bind(("conversation_table", Self::table_name()))
|
||||
.bind(("conversation_id", conversation_id.to_string()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
@@ -114,7 +119,7 @@ impl Conversation {
|
||||
));
|
||||
}
|
||||
|
||||
let _updated: Option<Self> = db
|
||||
let updated: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/title", new_title.to_string()))
|
||||
.patch(PatchOp::replace(
|
||||
@@ -123,6 +128,10 @@ impl Conversation {
|
||||
))
|
||||
.await?;
|
||||
|
||||
if updated.is_none() {
|
||||
return Err(AppError::NotFound("conversation not found".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -152,6 +161,24 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
const MESSAGE_QUERY_FOR_OWNER: &str = "SELECT * FROM type::table($message_table) WHERE conversation_id = $conversation_id AND type::thing($conversation_table, $conversation_id).user_id = $user_id ORDER BY updated_at";
|
||||
|
||||
async fn fetch_messages_for_owner(
|
||||
db: &SurrealDbClient,
|
||||
conversation_id: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Message>, AppError> {
|
||||
db.client
|
||||
.query(MESSAGE_QUERY_FOR_OWNER)
|
||||
.bind(("message_table", Message::table_name()))
|
||||
.bind(("conversation_table", Conversation::table_name()))
|
||||
.bind(("conversation_id", conversation_id.to_string()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await?
|
||||
.take(0)
|
||||
.map_err(AppError::from)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_conversation() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
@@ -488,4 +515,146 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sidebar_conversation_deserializes_plain_string_id() {
|
||||
let item: SidebarConversation =
|
||||
serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#).unwrap();
|
||||
assert_eq!(item.id, "conv-plain");
|
||||
assert_eq!(item.title, "My chat");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sidebar_conversation_deserializes_id_from_db_record() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let owner = "sidebar_owner";
|
||||
let conversation = Conversation::new(owner.to_string(), "Sidebar title".to_string());
|
||||
let expected_id = conversation.id.clone();
|
||||
db.store_item(conversation)
|
||||
.await
|
||||
.expect("Failed to store conversation");
|
||||
|
||||
let items = Conversation::get_user_sidebar_conversations(owner, &db)
|
||||
.await
|
||||
.expect("Failed to load sidebar");
|
||||
assert_eq!(items.len(), 1);
|
||||
assert_eq!(items[0].id, expected_id);
|
||||
assert_eq!(items[0].title, "Sidebar title");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_message_query_filters_by_owner_user_id_in_sql() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let owner = "owner_user";
|
||||
let intruder = "intruder_user";
|
||||
let conversation = Conversation::new(owner.to_string(), "Private".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
|
||||
db.store_item(conversation).await?;
|
||||
db.store_item(Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::User,
|
||||
"secret message".to_string(),
|
||||
None,
|
||||
))
|
||||
.await?;
|
||||
|
||||
let owner_messages =
|
||||
fetch_messages_for_owner(&db, &conversation_id, owner).await?;
|
||||
assert_eq!(owner_messages.len(), 1);
|
||||
assert_eq!(owner_messages[0].content, "secret message");
|
||||
|
||||
let intruder_messages =
|
||||
fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
|
||||
assert!(
|
||||
intruder_messages.is_empty(),
|
||||
"SQL owner filter must not return messages for a non-owner user_id"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_complete_conversation_orders_messages_by_updated_at() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let user_id = "order_user";
|
||||
let conversation = Conversation::new(user_id.to_string(), "Ordered".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
db.store_item(conversation).await?;
|
||||
|
||||
let base = Utc::now();
|
||||
let mut first = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::User,
|
||||
"first".to_string(),
|
||||
None,
|
||||
);
|
||||
first.updated_at = base - chrono::Duration::minutes(20);
|
||||
|
||||
let mut second = Message::new(
|
||||
conversation_id.clone(),
|
||||
MessageRole::AI,
|
||||
"second".to_string(),
|
||||
None,
|
||||
);
|
||||
second.updated_at = base - chrono::Duration::minutes(5);
|
||||
|
||||
db.store_item(first).await?;
|
||||
db.store_item(second).await?;
|
||||
|
||||
let (_, messages) =
|
||||
Conversation::get_complete_conversation(&conversation_id, user_id, &db).await?;
|
||||
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].content, "first");
|
||||
assert_eq!(messages[1].content, "second");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_patch_title_not_found_when_conversation_deleted() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let owner = "owner";
|
||||
let conversation = Conversation::new(owner.to_string(), "To delete".to_string());
|
||||
let conversation_id = conversation.id.clone();
|
||||
db.store_item(conversation).await?;
|
||||
db.delete_item::<Conversation>(&conversation_id).await?;
|
||||
|
||||
let result = Conversation::patch_title(&conversation_id, owner, "New title", &db).await;
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(AppError::NotFound(_)) => {}
|
||||
other => anyhow::bail!("expected NotFound, got {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_conversation_new_initializes_timestamps_and_id() {
|
||||
let before = Utc::now();
|
||||
let conversation = Conversation::new("user".to_string(), "Title".to_string());
|
||||
let after = Utc::now();
|
||||
|
||||
assert!(!conversation.id.is_empty());
|
||||
assert!(conversation.created_at >= before && conversation.created_at <= after);
|
||||
assert_eq!(conversation.created_at, conversation.updated_at);
|
||||
assert_eq!(conversation.user_id, "user");
|
||||
assert_eq!(conversation.title, "Title");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use mime_guess::from_path;
|
||||
use object_store::Error as ObjectStoreError;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
io::{BufReader, Read},
|
||||
io::{BufReader, Read, Seek, SeekFrom},
|
||||
path::Path,
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
@@ -73,21 +73,47 @@ impl FileInfo {
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Calculates the SHA256 hash of the given file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `file` - The file to hash.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, FileError>` - The SHA256 hash as a hex string or an error.
|
||||
fn sanitize_name_segment(segment: &str) -> String {
|
||||
segment
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores in
|
||||
/// both the stem and extension.
|
||||
fn sanitize_file_name(file_name: &str) -> String {
|
||||
if let Some(idx) = file_name.rfind('.') {
|
||||
let name = Self::sanitize_name_segment(&file_name[..idx]);
|
||||
let ext = Self::sanitize_name_segment(&file_name[idx + 1..]);
|
||||
if ext.is_empty() {
|
||||
name
|
||||
} else {
|
||||
format!("{name}.{ext}")
|
||||
}
|
||||
} else {
|
||||
Self::sanitize_name_segment(file_name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads the file once and returns its SHA-256 hex digest and raw bytes.
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
async fn get_sha(file: &NamedTempFile) -> Result<String, FileError> {
|
||||
async fn read_and_hash(file: &NamedTempFile) -> Result<(String, Vec<u8>), FileError> {
|
||||
let mut file_clone = file.as_file().try_clone()?;
|
||||
|
||||
let digest = task::spawn_blocking(move || -> Result<_, std::io::Error> {
|
||||
task::spawn_blocking(move || -> Result<_, std::io::Error> {
|
||||
file_clone.seek(SeekFrom::Start(0))?;
|
||||
let mut reader = BufReader::new(&mut file_clone);
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buffer = [0u8; 8192]; // 8KB buffer
|
||||
let mut buffer = [0u8; 8192];
|
||||
let mut bytes = Vec::new();
|
||||
|
||||
loop {
|
||||
let n = reader.read(&mut buffer)?;
|
||||
@@ -95,60 +121,27 @@ impl FileInfo {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buffer[..n]);
|
||||
bytes.extend_from_slice(&buffer[..n]);
|
||||
}
|
||||
|
||||
Ok::<_, std::io::Error>(hasher.finalize())
|
||||
Ok((format!("{:x}", hasher.finalize()), bytes))
|
||||
})
|
||||
.await
|
||||
.map_err(std::io::Error::other)??;
|
||||
|
||||
Ok(format!("{digest:x}"))
|
||||
.map_err(std::io::Error::other)?
|
||||
.map_err(FileError::from)
|
||||
}
|
||||
|
||||
/// Sanitizes the file name to prevent security vulnerabilities like directory traversal.
|
||||
/// Replaces any non-alphanumeric characters (excluding '.' and '_') with underscores.
|
||||
fn sanitize_file_name(file_name: &str) -> String {
|
||||
if let Some(idx) = file_name.rfind('.') {
|
||||
let (name, ext) = file_name.split_at(idx);
|
||||
let sanitized_name: String = name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
format!("{sanitized_name}{ext}")
|
||||
} else {
|
||||
// No extension
|
||||
file_name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a `FileInfo` by SHA256.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `sha256` - The SHA256 hash string.
|
||||
/// * `db_client` - Reference to the SurrealDbClient.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
|
||||
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
||||
/// Retrieves a `FileInfo` for this user by SHA256.
|
||||
async fn get_by_sha(
|
||||
sha256: &str,
|
||||
user_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<FileInfo, FileError> {
|
||||
let mut response = db_client
|
||||
.client
|
||||
.query("SELECT * FROM file WHERE sha256 = $sha256 LIMIT 1")
|
||||
.query("SELECT * FROM file WHERE sha256 = $sha256 AND user_id = $user_id LIMIT 1")
|
||||
.bind(("sha256", sha256.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
let response: Vec<FileInfo> = response.take(0)?;
|
||||
|
||||
@@ -197,33 +190,33 @@ impl FileInfo {
|
||||
.ok_or(FileError::MissingFileName)?;
|
||||
let original_file_name = file_name.clone();
|
||||
|
||||
// Calculate SHA256
|
||||
let sha256 = Self::get_sha(&file).await?;
|
||||
let (sha256, bytes) = Self::read_and_hash(&file).await?;
|
||||
|
||||
// Early return if file already exists
|
||||
match Self::get_by_sha(&sha256, db_client).await {
|
||||
Ok(existing_file) => {
|
||||
info!("File already exists with SHA256: {}", sha256);
|
||||
return Ok(existing_file);
|
||||
}
|
||||
Err(FileError::FileNotFound(_)) => (), // Expected case for new files
|
||||
Err(e) => return Err(e), // Propagate unexpected errors
|
||||
if let Ok(existing_file) = Self::get_by_sha(&sha256, user_id, db_client).await {
|
||||
info!(
|
||||
"File already exists for user {} with SHA256: {}",
|
||||
user_id, sha256
|
||||
);
|
||||
return Ok(existing_file);
|
||||
}
|
||||
|
||||
// Generate UUID and prepare paths
|
||||
let uuid = Uuid::new_v4();
|
||||
let sanitized_file_name = Self::sanitize_file_name(&file_name);
|
||||
let now = Utc::now();
|
||||
let storage_prefix = format!("{user_id}/{uuid}");
|
||||
|
||||
let path =
|
||||
Self::persist_file_with_storage(&uuid, file, &sanitized_file_name, user_id, storage)
|
||||
.await?;
|
||||
let path = Self::persist_bytes_with_storage(
|
||||
&storage_prefix,
|
||||
&sanitized_file_name,
|
||||
bytes,
|
||||
storage,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create FileInfo struct
|
||||
let file_info = FileInfo {
|
||||
id: uuid.to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
sha256,
|
||||
sha256: sha256.clone(),
|
||||
file_name: original_file_name,
|
||||
path,
|
||||
mime_type: Self::guess_mime_type(Path::new(&file_name)),
|
||||
@@ -231,13 +224,22 @@ impl FileInfo {
|
||||
updated_at: now,
|
||||
};
|
||||
|
||||
// Store in database
|
||||
db_client
|
||||
.store_item(file_info.clone())
|
||||
.await
|
||||
.map_err(FileError::SurrealError)?;
|
||||
|
||||
Ok(file_info)
|
||||
match db_client.store_item(file_info.clone()).await {
|
||||
Ok(_) => Ok(file_info),
|
||||
Err(e) => {
|
||||
if let Err(cleanup_err) = storage.delete_prefix(&storage_prefix).await {
|
||||
tracing::warn!(
|
||||
prefix = %storage_prefix,
|
||||
error = %cleanup_err,
|
||||
"failed to remove orphaned file after database error"
|
||||
);
|
||||
}
|
||||
if let Ok(existing) = Self::get_by_sha(&sha256, user_id, db_client).await {
|
||||
return Ok(existing);
|
||||
}
|
||||
Err(FileError::SurrealError(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete a FileInfo by ID using StorageManager for storage operations.
|
||||
@@ -294,29 +296,16 @@ impl FileInfo {
|
||||
.map_err(AppError::Storage)
|
||||
}
|
||||
|
||||
/// Persist file to storage using StorageManager.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `uuid` - The UUID for the file
|
||||
/// * `file` - The temporary file to persist
|
||||
/// * `file_name` - The name of the file
|
||||
/// * `user_id` - The user ID
|
||||
/// * `storage` - A StorageManager instance for storage operations
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<String, FileError>` - The logical object location or an error.
|
||||
async fn persist_file_with_storage(
|
||||
uuid: &Uuid,
|
||||
file: NamedTempFile,
|
||||
/// Persist bytes to storage using StorageManager.
|
||||
async fn persist_bytes_with_storage(
|
||||
storage_prefix: &str,
|
||||
file_name: &str,
|
||||
user_id: &str,
|
||||
bytes: Vec<u8>,
|
||||
storage: &StorageManager,
|
||||
) -> Result<String, FileError> {
|
||||
// Logical object location relative to the store root
|
||||
let location = format!("{user_id}/{uuid}/{file_name}");
|
||||
let location = format!("{storage_prefix}/{file_name}");
|
||||
info!("Persisting to object location: {}", location);
|
||||
|
||||
let bytes = tokio::fs::read(file.path()).await?;
|
||||
storage
|
||||
.put(&location, bytes.into())
|
||||
.await
|
||||
@@ -332,6 +321,7 @@ mod tests {
|
||||
use anyhow::{self, Context};
|
||||
|
||||
use super::*;
|
||||
use crate::error::AppError;
|
||||
use crate::storage::store::testing::TestStorageManager;
|
||||
use axum::http::HeaderMap;
|
||||
use axum_typed_multipart::FieldMetadata;
|
||||
@@ -362,6 +352,26 @@ mod tests {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_test_file_without_name(content: &[u8]) -> anyhow::Result<FieldData<NamedTempFile>> {
|
||||
let mut temp_file =
|
||||
NamedTempFile::new().with_context(|| "Failed to create temp file".to_string())?;
|
||||
temp_file
|
||||
.write_all(content)
|
||||
.with_context(|| "Failed to write to temp file".to_string())?;
|
||||
|
||||
let metadata = FieldMetadata {
|
||||
name: Some("file".to_string()),
|
||||
file_name: None,
|
||||
content_type: None,
|
||||
headers: HeaderMap::default(),
|
||||
};
|
||||
|
||||
Ok(FieldData {
|
||||
metadata,
|
||||
contents: temp_file,
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fileinfo_create_read_delete_with_storage_manager() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
@@ -669,6 +679,10 @@ mod tests {
|
||||
FileInfo::sanitize_file_name("../dangerous.txt"),
|
||||
"___dangerous.txt"
|
||||
);
|
||||
assert_eq!(
|
||||
FileInfo::sanitize_file_name("file.evil/../../x"),
|
||||
"file_evil_____._x"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -680,8 +694,7 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
|
||||
// Try to find a file with a SHA that doesn't exist
|
||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", &db).await;
|
||||
let result = FileInfo::get_by_sha("nonexistent_sha_hash", "user123", &db).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
@@ -716,11 +729,173 @@ mod tests {
|
||||
.expect("Failed to store test file info");
|
||||
|
||||
let malicious_sha = "known_sha_value' OR true --";
|
||||
let result = FileInfo::get_by_sha(malicious_sha, &db).await;
|
||||
let result = FileInfo::get_by_sha(malicious_sha, "user123", &db).await;
|
||||
|
||||
assert!(matches!(result, Err(FileError::FileNotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_duplicate_detection_is_per_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.with_context(|| "Failed to start in-memory surrealdb".to_string())?;
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
|
||||
let content = b"shared content across users";
|
||||
let test_storage = TestStorageManager::new_memory()
|
||||
.await
|
||||
.with_context(|| "create test storage manager".to_string())?;
|
||||
|
||||
let user_a = "user_a";
|
||||
let user_b = "user_b";
|
||||
|
||||
let file_a = FileInfo::new_with_storage(
|
||||
create_test_file(content, "a.txt")?,
|
||||
&db,
|
||||
user_a,
|
||||
test_storage.storage(),
|
||||
)
|
||||
.await?;
|
||||
let file_b = FileInfo::new_with_storage(
|
||||
create_test_file(content, "b.txt")?,
|
||||
&db,
|
||||
user_b,
|
||||
test_storage.storage(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(file_a.sha256, file_b.sha256);
|
||||
assert_ne!(file_a.id, file_b.id);
|
||||
assert_eq!(file_a.user_id, user_a);
|
||||
assert_eq!(file_b.user_id, user_b);
|
||||
assert_ne!(file_a.path, file_b.path);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_not_found_for_other_user() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
|
||||
let now = Utc::now();
|
||||
let sha = "abc123sha";
|
||||
let owner = "owner_user";
|
||||
let other = "other_user";
|
||||
|
||||
db.store_item(FileInfo {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id: owner.to_string(),
|
||||
sha256: sha.to_string(),
|
||||
path: format!("{owner}/id/file.txt"),
|
||||
file_name: "file.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let result = FileInfo::get_by_sha(sha, other, &db).await;
|
||||
assert!(matches!(result, Err(FileError::FileNotFound(_))));
|
||||
|
||||
let found = FileInfo::get_by_sha(sha, owner, &db).await?;
|
||||
assert_eq!(found.sha256, sha);
|
||||
assert_eq!(found.user_id, owner);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_with_storage_missing_file_name() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
|
||||
let field_data = create_test_file_without_name(b"data")?;
|
||||
let result =
|
||||
FileInfo::new_with_storage(field_data, &db, "user", test_storage.storage()).await;
|
||||
|
||||
assert!(matches!(result, Err(FileError::MissingFileName)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_with_storage_empty_file() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
|
||||
let file_info = FileInfo::new_with_storage(
|
||||
create_test_file(&[], "empty.bin")?,
|
||||
&db,
|
||||
"empty_user",
|
||||
test_storage.storage(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert!(file_info.sha256.len() == 64);
|
||||
let bytes = file_info
|
||||
.get_content_with_storage(test_storage.storage())
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
assert!(bytes.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_duplicate_upload_persists_single_row_per_user_sha() -> anyhow::Result<()> {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database).await?;
|
||||
db.apply_migrations().await?;
|
||||
let test_storage = TestStorageManager::new_memory().await?;
|
||||
let storage = test_storage.storage();
|
||||
let user_id = "dedup_user";
|
||||
let content = b"dedup content";
|
||||
|
||||
let first = FileInfo::new_with_storage(
|
||||
create_test_file(content, "first.txt")?,
|
||||
&db,
|
||||
user_id,
|
||||
storage,
|
||||
)
|
||||
.await?;
|
||||
let second = FileInfo::new_with_storage(
|
||||
create_test_file(content, "second.txt")?,
|
||||
&db,
|
||||
user_id,
|
||||
storage,
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(first.id, second.id);
|
||||
assert_eq!(first.sha256, second.sha256);
|
||||
|
||||
let rows: Vec<FileInfo> = db
|
||||
.client
|
||||
.query("SELECT * FROM file WHERE user_id = $user_id AND sha256 = $sha256")
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.bind(("sha256", first.sha256.clone()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
assert_eq!(
|
||||
rows.len(),
|
||||
1,
|
||||
"unique (user_id, sha256) index should keep a single row"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manual_file_info_creation() {
|
||||
let namespace = "test_ns";
|
||||
|
||||
Reference in New Issue
Block a user