fix: parameterize storage-layer queries and add injection tests

This commit is contained in:
Per Stark
2026-02-12 21:42:46 +01:00
parent 0133eead63
commit b89171d934
4 changed files with 254 additions and 40 deletions

View File

@@ -137,8 +137,12 @@ impl FileInfo {
/// # Returns
/// * `Result<Option<FileInfo>, FileError>` - The `FileInfo` or `None` if not found.
async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result<FileInfo, FileError> {
let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256);
let response: Vec<FileInfo> = db_client.client.query(query).await?.take(0)?;
let mut response = db_client
.client
.query("SELECT * FROM file WHERE sha256 = $sha256 LIMIT 1")
.bind(("sha256", sha256.to_owned()))
.await?;
let response: Vec<FileInfo> = response.take(0)?;
response
.into_iter()
@@ -665,6 +669,36 @@ mod tests {
}
}
#[tokio::test]
async fn test_get_by_sha_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
let now = Utc::now();
let file_info = FileInfo {
id: Uuid::new_v4().to_string(),
created_at: now,
updated_at: now,
user_id: "user123".to_string(),
sha256: "known_sha_value".to_string(),
path: "/path/to/file.txt".to_string(),
file_name: "file.txt".to_string(),
mime_type: "text/plain".to_string(),
};
db.store_item(file_info)
.await
.expect("Failed to store test file info");
let malicious_sha = "known_sha_value' OR true --";
let result = FileInfo::get_by_sha(malicious_sha, &db).await;
assert!(matches!(result, Err(FileError::FileNotFound(_))));
}
#[tokio::test]
async fn test_manual_file_info_creation() {
let namespace = "test_ns";

View File

@@ -174,12 +174,15 @@ impl KnowledgeEntity {
// Delete embeddings first, while we can still look them up via the entity's source_id
KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?;
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
db_client
.client
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
.bind(("table", Self::table_name()))
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
@@ -761,6 +764,69 @@ mod tests {
assert_eq!(different_remaining[0].id, different_entity.id);
}
#[tokio::test]
async fn test_delete_by_source_id_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
.await
.expect("Failed to redefine index length");
let user_id = "user123".to_string();
let entity1 = KnowledgeEntity::new(
"safe_source".to_string(),
"Entity 1".to_string(),
"Description 1".to_string(),
KnowledgeEntityType::Document,
None,
user_id.clone(),
);
let entity2 = KnowledgeEntity::new(
"other_source".to_string(),
"Entity 2".to_string(),
"Description 2".to_string(),
KnowledgeEntityType::Document,
None,
user_id,
);
KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db)
.await
.expect("store entity1");
KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db)
.await
.expect("store entity2");
let malicious_source = "safe_source' OR 1=1 --";
KnowledgeEntity::delete_by_source_id(malicious_source, &db)
.await
.expect("delete call should succeed");
let remaining: Vec<KnowledgeEntity> = db
.client
.query("SELECT * FROM type::table($table)")
.bind(("table", KnowledgeEntity::table_name()))
.await
.expect("query failed")
.take(0)
.expect("take failed");
assert_eq!(
remaining.len(),
2,
"malicious input must not delete unrelated entities"
);
}
#[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns";

View File

@@ -40,22 +40,28 @@ impl KnowledgeRelationship {
}
}
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
let query = format!(
r#"DELETE relates_to:`{rel_id}`;
RELATE knowledge_entity:`{in_id}`->relates_to:`{rel_id}`->knowledge_entity:`{out_id}`
SET
metadata.user_id = '{user_id}',
metadata.source_id = '{source_id}',
metadata.relationship_type = '{relationship_type}'"#,
rel_id = self.id,
in_id = self.in_,
out_id = self.out,
user_id = self.metadata.user_id.as_str(),
source_id = self.metadata.source_id.as_str(),
relationship_type = self.metadata.relationship_type.as_str()
);
db_client.query(query).await?.check()?;
db_client
.client
.query(
r#"BEGIN TRANSACTION;
LET $in_entity = type::thing('knowledge_entity', $in_id);
LET $out_entity = type::thing('knowledge_entity', $out_id);
LET $relation = type::thing('relates_to', $rel_id);
DELETE type::thing('relates_to', $rel_id);
RELATE $in_entity->$relation->$out_entity SET
metadata.user_id = $user_id,
metadata.source_id = $source_id,
metadata.relationship_type = $relationship_type;
COMMIT TRANSACTION;"#,
)
.bind(("rel_id", self.id.clone()))
.bind(("in_id", self.in_.clone()))
.bind(("out_id", self.out.clone()))
.bind(("user_id", self.metadata.user_id.clone()))
.bind(("source_id", self.metadata.source_id.clone()))
.bind(("relationship_type", self.metadata.relationship_type.clone()))
.await?
.check()?;
Ok(())
}
@@ -64,11 +70,12 @@ impl KnowledgeRelationship {
source_id: &str,
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let query = format!(
"DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'"
);
db_client.query(query).await?;
db_client
.client
.query("DELETE knowledge_entity -> relates_to WHERE metadata.source_id = $source_id")
.bind(("source_id", source_id.to_owned()))
.await?
.check()?;
Ok(())
}
@@ -79,15 +86,20 @@ impl KnowledgeRelationship {
db_client: &SurrealDbClient,
) -> Result<(), AppError> {
let mut authorized_result = db_client
.query(format!(
"SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'"
))
.client
.query(
"SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id",
)
.bind(("id", id.to_owned()))
.bind(("user_id", user_id.to_owned()))
.await?;
let authorized: Vec<KnowledgeRelationship> = authorized_result.take(0).unwrap_or_default();
if authorized.is_empty() {
let mut exists_result = db_client
.query(format!("SELECT * FROM relates_to:`{id}`"))
.client
.query("SELECT * FROM type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?;
let existing: Option<KnowledgeRelationship> = exists_result.take(0)?;
@@ -99,7 +111,12 @@ impl KnowledgeRelationship {
Err(AppError::NotFound(format!("Relationship {id} not found")))
}
} else {
db_client.query(format!("DELETE relates_to:`{id}`")).await?;
db_client
.client
.query("DELETE type::thing('relates_to', $id)")
.bind(("id", id.to_owned()))
.await?
.check()?;
Ok(())
}
}
@@ -210,6 +227,49 @@ mod tests {
);
}
#[tokio::test]
async fn test_store_relationship_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations()
.await
.expect("Failed to apply migrations");
let entity1_id = create_test_entity("Entity 1", &db).await;
let entity2_id = create_test_entity("Entity 2", &db).await;
let relationship = KnowledgeRelationship::new(
entity1_id,
entity2_id,
"user'123".to_string(),
"source123'; DELETE FROM relates_to; --".to_string(),
"references'; UPDATE user SET admin = true; --".to_string(),
);
relationship
.store_relationship(&db)
.await
.expect("store relationship should safely handle quote-containing values");
let mut res = db
.client
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
.bind(("id", relationship.id.clone()))
.await
.expect("query relationship by id failed");
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
assert_eq!(rows.len(), 1);
assert_eq!(
rows[0].metadata.source_id,
"source123'; DELETE FROM relates_to; --"
);
}
#[tokio::test]
async fn test_store_and_delete_relationship() {
// Setup in-memory database for testing

View File

@@ -47,12 +47,15 @@ impl TextChunk {
// Delete embeddings first
TextChunkEmbedding::delete_by_source_id(source_id, db_client).await?;
let query = format!(
"DELETE {} WHERE source_id = '{}'",
Self::table_name(),
source_id
);
db_client.query(query).await?;
db_client
.client
.query("DELETE FROM type::table($table) WHERE source_id = $source_id")
.bind(("table", Self::table_name()))
.bind(("source_id", source_id.to_owned()))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
Ok(())
}
@@ -617,6 +620,57 @@ mod tests {
assert_eq!(remaining.len(), 1);
}
#[tokio::test]
async fn test_delete_by_source_id_resists_query_injection() {
let namespace = "test_ns";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
.await
.expect("redefine index");
let chunk1 = TextChunk::new(
"safe_source".to_string(),
"Safe chunk".to_string(),
"user123".to_string(),
);
let chunk2 = TextChunk::new(
"other_source".to_string(),
"Other chunk".to_string(),
"user123".to_string(),
);
TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db)
.await
.expect("store chunk1");
TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db)
.await
.expect("store chunk2");
let malicious_source = "safe_source' OR 1=1 --";
TextChunk::delete_by_source_id(malicious_source, &db)
.await
.expect("delete call should succeed");
let remaining: Vec<TextChunk> = db
.client
.query("SELECT * FROM type::table($table)")
.bind(("table", TextChunk::table_name()))
.await
.expect("query failed")
.take(0)
.expect("take failed");
assert_eq!(
remaining.len(),
2,
"malicious input must not delete unrelated rows"
);
}
#[tokio::test]
async fn test_store_with_embedding_creates_both_records() {
let namespace = "test_ns";