mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-28 04:11:51 +01:00
fix: parameterize storage-layer queries and add injection tests
This commit is contained in:
@@ -137,8 +137,12 @@ impl FileInfo {
|
||||
/// # Returns
|
||||
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
|
||||
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
|
||||
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256);
|
||||
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
|
||||
let mut response = db_client
|
||||
.client
|
||||
.query("SELECT * FROM file WHERE sha256 = $sha256 LIMIT 1")
|
||||
.bind(("sha256", sha256.to_owned()))
|
||||
.await?;
|
||||
let response: Vec<FileInfo> = response.take(0)?;
|
||||
|
||||
response
|
||||
.into_iter()
|
||||
@@ -665,6 +669,36 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_by_sha_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
let now = Utc::now();
|
||||
let file_info = FileInfo {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id: "user123".to_string(),
|
||||
sha256: "known_sha_value".to_string(),
|
||||
path: "/path/to/file.txt".to_string(),
|
||||
file_name: "file.txt".to_string(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
};
|
||||
|
||||
db.store_item(file_info)
|
||||
.await
|
||||
.expect("Failed to store test file info");
|
||||
|
||||
let malicious_sha = "known_sha_value' OR true --";
|
||||
let result = FileInfo::get_by_sha(malicious_sha, &db).await;
|
||||
|
||||
assert!(matches!(result, Err(FileError::FileNotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_manual_file_info_creation() {
|
||||
let namespace = "test_ns";
|
||||
|
||||
@@ -174,12 +174,15 @@ impl KnowledgeEntity {
|
||||
// Delete embeddings first, while we can still look them up via the entity's source_id
|
||||
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
|
||||
|
||||
let query = format!(
|
||||
"DELETE {} WHERE source_id = '{}'",
|
||||
Self::table_name(),
|
||||
source_id
|
||||
);
|
||||
db_client.query(query).await?;
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -761,6 +764,69 @@ mod tests {
|
||||
assert_eq!(different_remaining[0].id, different_entity.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
let entity1 = KnowledgeEntity::new(
|
||||
"safe_source".to_string(),
|
||||
"Entity 1".to_string(),
|
||||
"Description 1".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
let entity2 = KnowledgeEntity::new(
|
||||
"other_source".to_string(),
|
||||
"Entity 2".to_string(),
|
||||
"Description 2".to_string(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
user_id,
|
||||
);
|
||||
|
||||
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store entity1");
|
||||
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db)
|
||||
.await
|
||||
.expect("store entity2");
|
||||
|
||||
let malicious_source = "safe_source' OR 1=1 --";
|
||||
KnowledgeEntity::delete_by_source_id(malicious_source, &db)
|
||||
.await
|
||||
.expect("delete call should succeed");
|
||||
|
||||
let remaining: Vec<KnowledgeEntity> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table)")
|
||||
.bind(("table", KnowledgeEntity::table_name()))
|
||||
.await
|
||||
.expect("query failed")
|
||||
.take(0)
|
||||
.expect("take failed");
|
||||
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
2,
|
||||
"malicious input must not delete unrelated entities"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
|
||||
@@ -40,22 +40,28 @@ impl KnowledgeRelationship {
|
||||
}
|
||||
}
|
||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
r#"DELETE relates_to:`{rel_id}`;
|
||||
RELATE knowledge_entity:`{in_id}`->relates_to:`{rel_id}`->knowledge_entity:`{out_id}`
|
||||
SET
|
||||
metadata.user_id = '{user_id}',
|
||||
metadata.source_id = '{source_id}',
|
||||
metadata.relationship_type = '{relationship_type}'"#,
|
||||
rel_id = self.id,
|
||||
in_id = self.in_,
|
||||
out_id = self.out,
|
||||
user_id = self.metadata.user_id.as_str(),
|
||||
source_id = self.metadata.source_id.as_str(),
|
||||
relationship_type = self.metadata.relationship_type.as_str()
|
||||
);
|
||||
|
||||
db_client.query(query).await?.check()?;
|
||||
db_client
|
||||
.client
|
||||
.query(
|
||||
r#"BEGIN TRANSACTION;
|
||||
LET $in_entity = type::thing('knowledge_entity', $in_id);
|
||||
LET $out_entity = type::thing('knowledge_entity', $out_id);
|
||||
LET $relation = type::thing('relates_to', $rel_id);
|
||||
DELETE type::thing('relates_to', $rel_id);
|
||||
RELATE $in_entity->$relation->$out_entity SET
|
||||
metadata.user_id = $user_id,
|
||||
metadata.source_id = $source_id,
|
||||
metadata.relationship_type = $relationship_type;
|
||||
COMMIT TRANSACTION;"#,
|
||||
)
|
||||
.bind(("rel_id", self.id.clone()))
|
||||
.bind(("in_id", self.in_.clone()))
|
||||
.bind(("out_id", self.out.clone()))
|
||||
.bind(("user_id", self.metadata.user_id.clone()))
|
||||
.bind(("source_id", self.metadata.source_id.clone()))
|
||||
.bind(("relationship_type", self.metadata.relationship_type.clone()))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -64,11 +70,12 @@ impl KnowledgeRelationship {
|
||||
source_id: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let query = format!(
|
||||
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'"
|
||||
);
|
||||
|
||||
db_client.query(query).await?;
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE knowledge_entity -> relates_to WHERE metadata.source_id = $source_id")
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -79,15 +86,20 @@ impl KnowledgeRelationship {
|
||||
db_client: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let mut authorized_result = db_client
|
||||
.query(format!(
|
||||
"SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'"
|
||||
))
|
||||
.client
|
||||
.query(
|
||||
"SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id",
|
||||
)
|
||||
.bind(("id", id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
|
||||
|
||||
if authorized.is_empty() {
|
||||
let mut exists_result = db_client
|
||||
.query(format!("SELECT * FROM relates_to:`{id}`"))
|
||||
.client
|
||||
.query("SELECT * FROM type::thing('relates_to', $id)")
|
||||
.bind(("id", id.to_owned()))
|
||||
.await?;
|
||||
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
|
||||
|
||||
@@ -99,7 +111,12 @@ impl KnowledgeRelationship {
|
||||
Err(AppError::NotFound(format!("Relationship {id} not found")))
|
||||
}
|
||||
} else {
|
||||
db_client.query(format!("DELETE relates_to:`{id}`")).await?;
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE type::thing('relates_to', $id)")
|
||||
.bind(("id", id.to_owned()))
|
||||
.await?
|
||||
.check()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -210,6 +227,49 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_relationship_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
let entity1_id = create_test_entity("Entity 1", &db).await;
|
||||
let entity2_id = create_test_entity("Entity 2", &db).await;
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
entity1_id,
|
||||
entity2_id,
|
||||
"user'123".to_string(),
|
||||
"source123'; DELETE FROM relates_to; --".to_string(),
|
||||
"references'; UPDATE user SET admin = true; --".to_string(),
|
||||
);
|
||||
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("store relationship should safely handle quote-containing values");
|
||||
|
||||
let mut res = db
|
||||
.client
|
||||
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
|
||||
.bind(("id", relationship.id.clone()))
|
||||
.await
|
||||
.expect("query relationship by id failed");
|
||||
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(
|
||||
rows[0].metadata.source_id,
|
||||
"source123'; DELETE FROM relates_to; --"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_and_delete_relationship() {
|
||||
// Setup in-memory database for testing
|
||||
|
||||
@@ -47,12 +47,15 @@ impl TextChunk {
|
||||
// Delete embeddings first
|
||||
TextChunkEmbedding::delete_by_source_id(source_id, db_client).await?;
|
||||
|
||||
let query = format!(
|
||||
"DELETE {} WHERE source_id = '{}'",
|
||||
Self::table_name(),
|
||||
source_id
|
||||
);
|
||||
db_client.query(query).await?;
|
||||
db_client
|
||||
.client
|
||||
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("source_id", source_id.to_owned()))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -617,6 +620,57 @@ mod tests {
|
||||
assert_eq!(remaining.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_by_source_id_resists_query_injection() {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
|
||||
let chunk1 = TextChunk::new(
|
||||
"safe_source".to_string(),
|
||||
"Safe chunk".to_string(),
|
||||
"user123".to_string(),
|
||||
);
|
||||
let chunk2 = TextChunk::new(
|
||||
"other_source".to_string(),
|
||||
"Other chunk".to_string(),
|
||||
"user123".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
|
||||
.await
|
||||
.expect("store chunk1");
|
||||
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db)
|
||||
.await
|
||||
.expect("store chunk2");
|
||||
|
||||
let malicious_source = "safe_source' OR 1=1 --";
|
||||
TextChunk::delete_by_source_id(malicious_source, &db)
|
||||
.await
|
||||
.expect("delete call should succeed");
|
||||
|
||||
let remaining: Vec<TextChunk> = db
|
||||
.client
|
||||
.query("SELECT * FROM type::table($table)")
|
||||
.bind(("table", TextChunk::table_name()))
|
||||
.await
|
||||
.expect("query failed")
|
||||
.take(0)
|
||||
.expect("take failed");
|
||||
|
||||
assert_eq!(
|
||||
remaining.len(),
|
||||
2,
|
||||
"malicious input must not delete unrelated rows"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_creates_both_records() {
|
||||
let namespace = "test_ns";
|
||||
|
||||
Reference in New Issue
Block a user