diff --git a/netbox/extras/api/serializers_/scripts.py b/netbox/extras/api/serializers_/scripts.py index 6995eb480..30b7afe0d 100644 --- a/netbox/extras/api/serializers_/scripts.py +++ b/netbox/extras/api/serializers_/scripts.py @@ -53,20 +53,21 @@ class ScriptModuleSerializer(ValidatedModelSerializer): return data + def _save_upload(self, upload_file, validated_data): + storage = storages.create_storage(storages.backends["scripts"]) + storage.save(upload_file.name, upload_file) + validated_data['file_path'] = upload_file.name + def create(self, validated_data): upload_file = validated_data.pop('upload_file', None) if upload_file: - storage = storages.create_storage(storages.backends["scripts"]) - storage.save(upload_file.name, upload_file) - validated_data['file_path'] = upload_file.name + self._save_upload(upload_file, validated_data) return super().create(validated_data) def update(self, instance, validated_data): upload_file = validated_data.pop('upload_file', None) if upload_file: - storage = storages.create_storage(storages.backends["scripts"]) - storage.save(upload_file.name, upload_file) - validated_data['file_path'] = upload_file.name + self._save_upload(upload_file, validated_data) return super().update(instance, validated_data) diff --git a/netbox/extras/tests/test_api.py b/netbox/extras/tests/test_api.py index d2002823f..5a8b63cc8 100644 --- a/netbox/extras/tests/test_api.py +++ b/netbox/extras/tests/test_api.py @@ -1465,6 +1465,29 @@ class ScriptModuleTest(APITestCase): response = self.client.post(self.url_list, {}, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) + def test_update_script_module_upload(self): + self.add_permissions('extras.change_scriptmodule', 'core.change_managedfile') + module = self.modules[0] + url = reverse('extras-api:scriptmodule-detail', kwargs={'pk': module.pk}) + script_content = b"from extras.scripts import Script\nclass UpdatedScript(Script):\n pass\n" + upload_file = SimpleUploadedFile('updated_script.py', script_content, content_type='text/plain') + + mock_storage = MagicMock() + + with patch('extras.api.serializers_.scripts.storages') as mock_storages: + mock_storages.create_storage.return_value = mock_storage + mock_storages.backends = {'scripts': {}} + response = self.client.patch( + url, + {'upload_file': upload_file}, + format='multipart', + **self.header, + ) + + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(response.data['file_path'], 'updated_script.py') + mock_storage.save.assert_called_once() + def test_delete_script_module(self): self.add_permissions('extras.delete_scriptmodule', 'core.delete_managedfile') module = self.modules[0]