diff --git a/common/src/storage/types/file_info.rs b/common/src/storage/types/file_info.rs index f06ec68..37448d2 100644 --- a/common/src/storage/types/file_info.rs +++ b/common/src/storage/types/file_info.rs @@ -137,8 +137,12 @@ impl FileInfo { /// # Returns /// * `Result, FileError>` - The `FileInfo` or `None` if not found. async fn get_by_sha(sha256: &str, db_client: &SurrealDbClient) -> Result { - let query = format!("SELECT * FROM file WHERE sha256 = '{}'", &sha256); - let response: Vec = db_client.client.query(query).await?.take(0)?; + let mut response = db_client + .client + .query("SELECT * FROM file WHERE sha256 = $sha256 LIMIT 1") + .bind(("sha256", sha256.to_owned())) + .await?; + let response: Vec = response.take(0)?; response .into_iter() @@ -665,6 +669,36 @@ mod tests { } } + #[tokio::test] + async fn test_get_by_sha_resists_query_injection() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + let now = Utc::now(); + let file_info = FileInfo { + id: Uuid::new_v4().to_string(), + created_at: now, + updated_at: now, + user_id: "user123".to_string(), + sha256: "known_sha_value".to_string(), + path: "/path/to/file.txt".to_string(), + file_name: "file.txt".to_string(), + mime_type: "text/plain".to_string(), + }; + + db.store_item(file_info) + .await + .expect("Failed to store test file info"); + + let malicious_sha = "known_sha_value' OR true --"; + let result = FileInfo::get_by_sha(malicious_sha, &db).await; + + assert!(matches!(result, Err(FileError::FileNotFound(_)))); + } + #[tokio::test] async fn test_manual_file_info_creation() { let namespace = "test_ns"; diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 9acc36d..2c79444 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -174,12 +174,15 @@ impl KnowledgeEntity { // Delete embeddings first, while we can still look them up via the entity's source_id KnowledgeEntityEmbedding::delete_by_source_id(source_id, db_client).await?; - let query = format!( - "DELETE {} WHERE source_id = '{}'", - Self::table_name(), - source_id - ); - db_client.query(query).await?; + db_client + .client + .query("DELETE FROM type::table($table) WHERE source_id = $source_id") + .bind(("table", Self::table_name())) + .bind(("source_id", source_id.to_owned())) + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; Ok(()) } @@ -761,6 +764,69 @@ mod tests { assert_eq!(different_remaining[0].id, different_entity.id); } + #[tokio::test] + async fn test_delete_by_source_id_resists_query_injection() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + + KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3) + .await + .expect("Failed to redefine index length"); + + let user_id = "user123".to_string(); + + let entity1 = KnowledgeEntity::new( + "safe_source".to_string(), + "Entity 1".to_string(), + "Description 1".to_string(), + KnowledgeEntityType::Document, + None, + user_id.clone(), + ); + + let entity2 = KnowledgeEntity::new( + "other_source".to_string(), + "Entity 2".to_string(), + "Description 2".to_string(), + KnowledgeEntityType::Document, + None, + user_id, + ); + + KnowledgeEntity::store_with_embedding(entity1, vec![0.1, 0.2, 0.3], &db) + .await + .expect("store entity1"); + KnowledgeEntity::store_with_embedding(entity2, vec![0.3, 0.2, 0.1], &db) + .await + .expect("store entity2"); + + let malicious_source = "safe_source' OR 1=1 --"; + KnowledgeEntity::delete_by_source_id(malicious_source, &db) + .await + .expect("delete call should succeed"); + + let remaining: Vec = db + .client + .query("SELECT * FROM type::table($table)") + .bind(("table", KnowledgeEntity::table_name())) + .await + .expect("query failed") + .take(0) + .expect("take failed"); + + assert_eq!( + remaining.len(), + 2, + "malicious input must not delete unrelated entities" + ); + } + #[tokio::test] async fn test_vector_search_returns_empty_when_no_embeddings() { let namespace = "test_ns"; diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index 56a4603..664588b 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -40,22 +40,28 @@ impl KnowledgeRelationship { } } pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> { - let query = format!( - r#"DELETE relates_to:`{rel_id}`; - RELATE knowledge_entity:`{in_id}`->relates_to:`{rel_id}`->knowledge_entity:`{out_id}` - SET - metadata.user_id = '{user_id}', - metadata.source_id = '{source_id}', - metadata.relationship_type = '{relationship_type}'"#, - rel_id = self.id, - in_id = self.in_, - out_id = self.out, - user_id = self.metadata.user_id.as_str(), - source_id = self.metadata.source_id.as_str(), - relationship_type = self.metadata.relationship_type.as_str() - ); - - db_client.query(query).await?.check()?; + db_client + .client + .query( + r#"BEGIN TRANSACTION; + LET $in_entity = type::thing('knowledge_entity', $in_id); + LET $out_entity = type::thing('knowledge_entity', $out_id); + LET $relation = type::thing('relates_to', $rel_id); + DELETE type::thing('relates_to', $rel_id); + RELATE $in_entity->$relation->$out_entity SET + metadata.user_id = $user_id, + metadata.source_id = $source_id, + metadata.relationship_type = $relationship_type; + COMMIT TRANSACTION;"#, + ) + .bind(("rel_id", self.id.clone())) + .bind(("in_id", self.in_.clone())) + .bind(("out_id", self.out.clone())) + .bind(("user_id", self.metadata.user_id.clone())) + .bind(("source_id", self.metadata.source_id.clone())) + .bind(("relationship_type", self.metadata.relationship_type.clone())) + .await? + .check()?; Ok(()) } @@ -64,11 +70,12 @@ impl KnowledgeRelationship { source_id: &str, db_client: &SurrealDbClient, ) -> Result<(), AppError> { - let query = format!( - "DELETE knowledge_entity -> relates_to WHERE metadata.source_id = '{source_id}'" - ); - - db_client.query(query).await?; + db_client + .client + .query("DELETE knowledge_entity -> relates_to WHERE metadata.source_id = $source_id") + .bind(("source_id", source_id.to_owned())) + .await? + .check()?; Ok(()) } @@ -79,15 +86,20 @@ impl KnowledgeRelationship { db_client: &SurrealDbClient, ) -> Result<(), AppError> { let mut authorized_result = db_client - .query(format!( - "SELECT * FROM relates_to WHERE id = relates_to:`{id}` AND metadata.user_id = '{user_id}'" - )) + .client + .query( + "SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id) AND metadata.user_id = $user_id", + ) + .bind(("id", id.to_owned())) + .bind(("user_id", user_id.to_owned())) .await?; let authorized: Vec = authorized_result.take(0).unwrap_or_default(); if authorized.is_empty() { let mut exists_result = db_client - .query(format!("SELECT * FROM relates_to:`{id}`")) + .client + .query("SELECT * FROM type::thing('relates_to', $id)") + .bind(("id", id.to_owned())) .await?; let existing: Option = exists_result.take(0)?; @@ -99,7 +111,12 @@ impl KnowledgeRelationship { Err(AppError::NotFound(format!("Relationship {id} not found"))) } } else { - db_client.query(format!("DELETE relates_to:`{id}`")).await?; + db_client + .client + .query("DELETE type::thing('relates_to', $id)") + .bind(("id", id.to_owned())) + .await? + .check()?; Ok(()) } } @@ -210,6 +227,49 @@ mod tests { ); } + #[tokio::test] + async fn test_store_relationship_resists_query_injection() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + + db.apply_migrations() + .await + .expect("Failed to apply migrations"); + + let entity1_id = create_test_entity("Entity 1", &db).await; + let entity2_id = create_test_entity("Entity 2", &db).await; + + let relationship = KnowledgeRelationship::new( + entity1_id, + entity2_id, + "user'123".to_string(), + "source123'; DELETE FROM relates_to; --".to_string(), + "references'; UPDATE user SET admin = true; --".to_string(), + ); + + relationship + .store_relationship(&db) + .await + .expect("store relationship should safely handle quote-containing values"); + + let mut res = db + .client + .query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)") + .bind(("id", relationship.id.clone())) + .await + .expect("query relationship by id failed"); + let rows: Vec = res.take(0).expect("take rows"); + + assert_eq!(rows.len(), 1); + assert_eq!( + rows[0].metadata.source_id, + "source123'; DELETE FROM relates_to; --" + ); + } + #[tokio::test] async fn test_store_and_delete_relationship() { // Setup in-memory database for testing diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index e2d7ee3..805c931 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -47,12 +47,15 @@ impl TextChunk { // Delete embeddings first TextChunkEmbedding::delete_by_source_id(source_id, db_client).await?; - let query = format!( - "DELETE {} WHERE source_id = '{}'", - Self::table_name(), - source_id - ); - db_client.query(query).await?; + db_client + .client + .query("DELETE FROM type::table($table) WHERE source_id = $source_id") + .bind(("table", Self::table_name())) + .bind(("source_id", source_id.to_owned())) + .await + .map_err(AppError::Database)? + .check() + .map_err(AppError::Database)?; Ok(()) } @@ -617,6 +620,57 @@ mod tests { assert_eq!(remaining.len(), 1); } + #[tokio::test] + async fn test_delete_by_source_id_resists_query_injection() { + let namespace = "test_ns"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .expect("Failed to start in-memory surrealdb"); + db.apply_migrations().await.expect("migrations"); + TextChunkEmbedding::redefine_hnsw_index(&db, 5) + .await + .expect("redefine index"); + + let chunk1 = TextChunk::new( + "safe_source".to_string(), + "Safe chunk".to_string(), + "user123".to_string(), + ); + let chunk2 = TextChunk::new( + "other_source".to_string(), + "Other chunk".to_string(), + "user123".to_string(), + ); + + TextChunk::store_with_embedding(chunk1.clone(), vec![0.1, 0.2, 0.3, 0.4, 0.5], &db) + .await + .expect("store chunk1"); + TextChunk::store_with_embedding(chunk2.clone(), vec![0.5, 0.4, 0.3, 0.2, 0.1], &db) + .await + .expect("store chunk2"); + + let malicious_source = "safe_source' OR 1=1 --"; + TextChunk::delete_by_source_id(malicious_source, &db) + .await + .expect("delete call should succeed"); + + let remaining: Vec = db + .client + .query("SELECT * FROM type::table($table)") + .bind(("table", TextChunk::table_name())) + .await + .expect("query failed") + .take(0) + .expect("take failed"); + + assert_eq!( + remaining.len(), + 2, + "malicious input must not delete unrelated rows" + ); + } + #[tokio::test] async fn test_store_with_embedding_creates_both_records() { let namespace = "test_ns";