diff --git a/docs/customization/custom-scripts.md b/docs/customization/custom-scripts.md index add8fafb1..0947a2c94 100644 --- a/docs/customization/custom-scripts.md +++ b/docs/customization/custom-scripts.md @@ -384,6 +384,18 @@ A calendar date. Returns a `datetime.date` object. A complete date & time. Returns a `datetime.datetime` object. +## Uploading Scripts via the API + +Script modules can be uploaded to NetBox via the REST API by sending a `multipart/form-data` POST request to `/api/extras/scripts/upload/`. The caller must have the `extras.add_scriptmodule` and `core.add_managedfile` permissions. + +```no-highlight +curl -X POST \ +-H "Authorization: Token $TOKEN" \ +-H "Accept: application/json; indent=4" \ +-F "file=@/path/to/myscript.py" \ +http://netbox/api/extras/scripts/upload/ +``` + ## Running Custom Scripts !!! note diff --git a/netbox/extras/api/serializers_/scripts.py b/netbox/extras/api/serializers_/scripts.py index a7d5b9c2a..9f0afe3f1 100644 --- a/netbox/extras/api/serializers_/scripts.py +++ b/netbox/extras/api/serializers_/scripts.py @@ -1,19 +1,70 @@ -from django.utils.translation import gettext as _ +import logging + +from django.core.files.storage import storages +from django.db import IntegrityError +from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from core.api.serializers_.jobs import JobSerializer -from extras.models import Script +from core.choices import ManagedFileRootPathChoices +from extras.models import Script, ScriptModule from netbox.api.serializers import ValidatedModelSerializer from utilities.datetime import local_now +logger = logging.getLogger(__name__) + __all__ = ( 'ScriptDetailSerializer', 'ScriptInputSerializer', + 'ScriptModuleSerializer', 'ScriptSerializer', ) +class ScriptModuleSerializer(ValidatedModelSerializer): + file = serializers.FileField(write_only=True) + file_path = serializers.CharField(read_only=True) + + class Meta: + model = ScriptModule + fields = ['id', 'display', 'file_path', 'file', 'created', 'last_updated'] + brief_fields = ('id', 'display') + + def validate(self, data): + # ScriptModule.save() sets file_root; inject it here so full_clean() succeeds. + # Pop 'file' before model instantiation — ScriptModule has no such field. + file = data.pop('file', None) + data['file_root'] = ManagedFileRootPathChoices.SCRIPTS + data = super().validate(data) + data.pop('file_root', None) + if file is not None: + data['file'] = file + return data + + def create(self, validated_data): + file = validated_data.pop('file') + storage = storages.create_storage(storages.backends["scripts"]) + validated_data['file_path'] = storage.save(file.name, file) + created = False + try: + instance = super().create(validated_data) + created = True + return instance + except IntegrityError as e: + if 'file_path' in str(e): + raise serializers.ValidationError( + _("A script module with this file name already exists.") + ) + raise + finally: + if not created and (file_path := validated_data.get('file_path')): + try: + storage.delete(file_path) + except Exception: + logger.warning(f"Failed to delete orphaned script file '{file_path}' from storage.") + + class ScriptSerializer(ValidatedModelSerializer): description = serializers.SerializerMethodField(read_only=True) vars = serializers.SerializerMethodField(read_only=True) diff --git a/netbox/extras/api/urls.py b/netbox/extras/api/urls.py index 9478fbeb2..cd1a9f683 100644 --- a/netbox/extras/api/urls.py +++ b/netbox/extras/api/urls.py @@ -26,6 +26,7 @@ router.register('journal-entries', views.JournalEntryViewSet) router.register('config-contexts', views.ConfigContextViewSet) router.register('config-context-profiles', views.ConfigContextProfileViewSet) router.register('config-templates', views.ConfigTemplateViewSet) +router.register('scripts/upload', views.ScriptModuleViewSet) router.register('scripts', views.ScriptViewSet, basename='script') app_name = 'extras-api' diff --git a/netbox/extras/api/views.py b/netbox/extras/api/views.py index e72ad1ab5..5a2c03212 100644 --- a/netbox/extras/api/views.py +++ b/netbox/extras/api/views.py @@ -6,7 +6,7 @@ from rest_framework import status from rest_framework.decorators import action from rest_framework.exceptions import PermissionDenied from rest_framework.generics import RetrieveUpdateDestroyAPIView -from rest_framework.mixins import ListModelMixin, RetrieveModelMixin +from rest_framework.mixins import CreateModelMixin, ListModelMixin, RetrieveModelMixin from rest_framework.renderers import JSONRenderer from rest_framework.response import Response from rest_framework.routers import APIRootView @@ -21,6 +21,7 @@ from netbox.api.features import SyncedDataMixin from netbox.api.metadata import ContentTypeMetadata from netbox.api.renderers import TextRenderer from netbox.api.viewsets import BaseViewSet, NetBoxModelViewSet +from netbox.api.viewsets.mixins import ObjectValidationMixin from utilities.exceptions import RQWorkerNotRunningException from utilities.request import copy_safe_request @@ -264,6 +265,11 @@ class ConfigTemplateViewSet(SyncedDataMixin, ConfigTemplateRenderMixin, NetBoxMo # Scripts # +class ScriptModuleViewSet(ObjectValidationMixin, CreateModelMixin, BaseViewSet): + queryset = ScriptModule.objects.all() + serializer_class = serializers.ScriptModuleSerializer + + @extend_schema_view( update=extend_schema(request=serializers.ScriptInputSerializer), partial_update=extend_schema(request=serializers.ScriptInputSerializer), diff --git a/netbox/extras/tests/test_api.py b/netbox/extras/tests/test_api.py index 1c4996bcc..c85433982 100644 --- a/netbox/extras/tests/test_api.py +++ b/netbox/extras/tests/test_api.py @@ -1,7 +1,9 @@ import datetime import hashlib +from unittest.mock import MagicMock, patch from django.contrib.contenttypes.models import ContentType +from django.core.files.uploadedfile import SimpleUploadedFile from django.urls import reverse from django.utils.timezone import make_aware, now from rest_framework import status @@ -1384,3 +1386,54 @@ class NotificationTest(APIViewTestCases.APIViewTestCase): 'event_type': OBJECT_DELETED, }, ] + + +class ScriptModuleTest(APITestCase): + """ + Tests for the POST /api/extras/scripts/upload/ endpoint. + + ScriptModule is a proxy of core.ManagedFile (a different app) so the standard + APIViewTestCases mixins cannot be used directly. All tests use add_permissions() + with explicit Django model-level permissions. + """ + + def setUp(self): + super().setUp() + self.url = reverse('extras-api:scriptmodule-list') # /api/extras/scripts/upload/ + + def test_upload_script_module_without_permission(self): + script_content = b"from extras.scripts import Script\nclass TestScript(Script):\n pass\n" + upload_file = SimpleUploadedFile('test_upload.py', script_content, content_type='text/plain') + response = self.client.post( + self.url, + {'file': upload_file}, + format='multipart', + **self.header, + ) + self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) + + def test_upload_script_module(self): + # ScriptModule is a proxy of core.ManagedFile; both permissions required. + self.add_permissions('extras.add_scriptmodule', 'core.add_managedfile') + script_content = b"from extras.scripts import Script\nclass TestScript(Script):\n pass\n" + upload_file = SimpleUploadedFile('test_upload.py', script_content, content_type='text/plain') + mock_storage = MagicMock() + mock_storage.save.return_value = 'test_upload.py' + with patch('extras.api.serializers_.scripts.storages') as mock_storages: + mock_storages.create_storage.return_value = mock_storage + mock_storages.backends = {'scripts': {}} + response = self.client.post( + self.url, + {'file': upload_file}, + format='multipart', + **self.header, + ) + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(response.data['file_path'], 'test_upload.py') + mock_storage.save.assert_called_once() + self.assertTrue(ScriptModule.objects.filter(file_path='test_upload.py').exists()) + + def test_upload_script_module_without_file_fails(self): + self.add_permissions('extras.add_scriptmodule', 'core.add_managedfile') + response = self.client.post(self.url, {}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)