change to ScriptModule

This commit is contained in:
Arthur
2026-03-31 08:09:43 -07:00
parent 37be38b582
commit 92f99910ee
6 changed files with 73 additions and 65 deletions

View File

@@ -23,7 +23,6 @@ __all__ = (
class ScriptModuleSerializer(ValidatedModelSerializer): class ScriptModuleSerializer(ValidatedModelSerializer):
url = None
data_source = DataSourceSerializer(nested=True, required=False, allow_null=True) data_source = DataSourceSerializer(nested=True, required=False, allow_null=True)
data_file = DataFileSerializer(nested=True, required=False, allow_null=True) data_file = DataFileSerializer(nested=True, required=False, allow_null=True)
upload_file = serializers.FileField(write_only=True, required=False, allow_null=True) upload_file = serializers.FileField(write_only=True, required=False, allow_null=True)
@@ -32,11 +31,11 @@ class ScriptModuleSerializer(ValidatedModelSerializer):
class Meta: class Meta:
model = ScriptModule model = ScriptModule
fields = [ fields = [
'id', 'display', 'file_path', 'upload_file', 'id', 'url', 'display', 'file_path', 'upload_file',
'data_source', 'data_file', 'auto_sync_enabled', 'data_source', 'data_file', 'auto_sync_enabled',
'created', 'last_updated', 'created', 'last_updated',
] ]
brief_fields = ('id', 'display') brief_fields = ('id', 'url', 'display')
def validate(self, data): def validate(self, data):
upload_file = data.pop('upload_file', None) upload_file = data.pop('upload_file', None)
@@ -119,6 +118,14 @@ class ScriptModuleSerializer(ValidatedModelSerializer):
except Exception: except Exception:
pass pass
def update(self, instance, validated_data):
upload_file = validated_data.pop('upload_file', None)
if upload_file:
self._save_upload(upload_file, validated_data)
elif data_file := validated_data.get('data_file'):
self._sync_data_file(data_file, validated_data)
return super().update(instance, validated_data)
class ScriptSerializer(ValidatedModelSerializer): class ScriptSerializer(ValidatedModelSerializer):
description = serializers.SerializerMethodField(read_only=True) description = serializers.SerializerMethodField(read_only=True)

View File

@@ -26,6 +26,7 @@ router.register('journal-entries', views.JournalEntryViewSet)
router.register('config-contexts', views.ConfigContextViewSet) router.register('config-contexts', views.ConfigContextViewSet)
router.register('config-context-profiles', views.ConfigContextProfileViewSet) router.register('config-context-profiles', views.ConfigContextProfileViewSet)
router.register('config-templates', views.ConfigTemplateViewSet) router.register('config-templates', views.ConfigTemplateViewSet)
router.register('script-modules', views.ScriptModuleViewSet)
router.register('scripts', views.ScriptViewSet, basename='script') router.register('scripts', views.ScriptViewSet, basename='script')
app_name = 'extras-api' app_name = 'extras-api'

View File

@@ -5,7 +5,7 @@ from django_rq.queues import get_connection
from drf_spectacular.utils import extend_schema, extend_schema_view from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status from rest_framework import status
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import MethodNotAllowed, PermissionDenied from rest_framework.exceptions import PermissionDenied
from rest_framework.generics import RetrieveUpdateDestroyAPIView from rest_framework.generics import RetrieveUpdateDestroyAPIView
from rest_framework.mixins import ListModelMixin, RetrieveModelMixin from rest_framework.mixins import ListModelMixin, RetrieveModelMixin
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
@@ -265,11 +265,12 @@ class ConfigTemplateViewSet(SyncedDataMixin, ConfigTemplateRenderMixin, NetBoxMo
# Scripts # Scripts
# #
@extend_schema_view( class ScriptModuleViewSet(SyncedDataMixin, NetBoxModelViewSet):
create=extend_schema(request=serializers.ScriptModuleSerializer), queryset = ScriptModule.objects.all()
update=extend_schema(exclude=True), serializer_class = serializers.ScriptModuleSerializer
partial_update=extend_schema(exclude=True), filterset_class = filtersets.ScriptModuleFilterSet
)
class ScriptViewSet(ModelViewSet): class ScriptViewSet(ModelViewSet):
permission_classes = [IsAuthenticatedOrLoginNotRequired] permission_classes = [IsAuthenticatedOrLoginNotRequired]
queryset = Script.objects.all() queryset = Script.objects.all()
@@ -283,43 +284,9 @@ class ScriptViewSet(ModelViewSet):
super().initial(request, *args, **kwargs) super().initial(request, *args, **kwargs)
# Restrict the view's QuerySet to allow only the permitted objects # Restrict the view's QuerySet to allow only the permitted objects
if request.user.is_authenticated and self.action != 'create': if request.user.is_authenticated:
perm_action = 'run' if request.method == 'POST' else 'view' action = 'run' if request.method == 'POST' else 'view'
self.queryset = self.queryset.restrict(request.user, perm_action) self.queryset = self.queryset.restrict(request.user, action)
def create(self, request, *args, **kwargs):
"""
Upload a new Script module (.py file) and return the created ScriptModule.
"""
if not request.user.has_perm('extras.add_scriptmodule'):
raise PermissionDenied(_("This user does not have permission to add script modules."))
if not request.user.has_perm('core.add_managedfile'):
raise PermissionDenied(_("This user does not have permission to add managed files."))
serializer = serializers.ScriptModuleSerializer(
data=request.data,
context={'request': request},
)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
return Response(serializer.data, status=status.HTTP_201_CREATED)
# PUT and PATCH are intentionally unsupported: ScriptSerializer has no writable fields
# and there is no implementation for replacing the underlying module file via these methods.
# They remain registered by ModelViewSet and return 405 rather than 404.
def update(self, request, *args, **kwargs):
raise MethodNotAllowed(request.method)
def partial_update(self, request, *args, **kwargs):
raise MethodNotAllowed(request.method)
def destroy(self, request, *args, **kwargs):
script = self._get_script(kwargs[self.lookup_field])
if not request.user.has_perm('extras.delete_scriptmodule', script.module):
raise PermissionDenied(_("This user does not have permission to delete script modules."))
script.module.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
def _get_script(self, pk): def _get_script(self, pk):
# If pk is numeric, retrieve script by ID # If pk is numeric, retrieve script by ID

View File

@@ -33,6 +33,7 @@ __all__ = (
'NotificationGroupFilterSet', 'NotificationGroupFilterSet',
'SavedFilterFilterSet', 'SavedFilterFilterSet',
'ScriptFilterSet', 'ScriptFilterSet',
'ScriptModuleFilterSet',
'TableConfigFilterSet', 'TableConfigFilterSet',
'TagFilterSet', 'TagFilterSet',
'TaggedItemFilterSet', 'TaggedItemFilterSet',
@@ -64,6 +65,24 @@ class ScriptFilterSet(BaseFilterSet):
) )
class ScriptModuleFilterSet(BaseFilterSet):
q = django_filters.CharFilter(
method='search',
label=_('Search'),
)
class Meta:
model = ScriptModule
fields = ('id', 'file_path')
def search(self, queryset, name, value):
if not value.strip():
return queryset
return queryset.filter(
Q(file_path__icontains=value)
)
@register_filterset @register_filterset
class WebhookFilterSet(OwnerFilterMixin, NetBoxModelFilterSet): class WebhookFilterSet(OwnerFilterMixin, NetBoxModelFilterSet):
q = django_filters.CharFilter( q = django_filters.CharFilter(

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.files.uploadedfile import SimpleUploadedFile from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import override_settings
from django.urls import reverse from django.urls import reverse
from django.utils.timezone import make_aware, now from django.utils.timezone import make_aware, now
from rest_framework import status from rest_framework import status
@@ -1388,7 +1389,7 @@ class NotificationTest(APIViewTestCases.APIViewTestCase):
] ]
class ScriptUploadTest(APITestCase): class ScriptModuleTest(APITestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@@ -1402,10 +1403,30 @@ class ScriptUploadTest(APITestCase):
hash=hashlib.sha256(script_content).hexdigest(), hash=hashlib.sha256(script_content).hexdigest(),
data=script_content, data=script_content,
) )
# Use bulk_create to bypass ScriptModule.save() which tries to sync classes from disk
cls.modules = ScriptModule.objects.bulk_create((
ScriptModule(file_root=ManagedFileRootPathChoices.SCRIPTS, file_path='module1.py'),
ScriptModule(file_root=ManagedFileRootPathChoices.SCRIPTS, file_path='module2.py'),
ScriptModule(file_root=ManagedFileRootPathChoices.SCRIPTS, file_path='module3.py'),
))
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.url_list = reverse('extras-api:script-list') self.url_list = reverse('extras-api:scriptmodule-list')
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_list_script_modules(self):
response = self.client.get(self.url_list, **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 3)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_get_script_module(self):
module = self.modules[0]
url = reverse('extras-api:scriptmodule-detail', kwargs={'pk': module.pk})
response = self.client.get(url, **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(response.data['file_path'], module.file_path)
def test_upload_script_module_without_permission(self): def test_upload_script_module_without_permission(self):
script_content = b"from extras.scripts import Script\nclass TestScript(Script):\n pass\n" script_content = b"from extras.scripts import Script\nclass TestScript(Script):\n pass\n"
@@ -1490,28 +1511,25 @@ class ScriptUploadTest(APITestCase):
self.assertEqual(response.data['file_path'], 'test_datasource.py') self.assertEqual(response.data['file_path'], 'test_datasource.py')
self.assertTrue(ScriptModule.objects.filter(file_path='test_datasource.py').exists()) self.assertTrue(ScriptModule.objects.filter(file_path='test_datasource.py').exists())
def test_destroy_script_module(self): def test_delete_script_module(self):
"""DELETE removes the ScriptModule and returns 204.""" """DELETE removes the ScriptModule and returns 204."""
self.add_permissions('extras.delete_scriptmodule', 'extras.view_script') self.add_permissions('extras.delete_scriptmodule', 'core.delete_managedfile',
from extras.models import Script 'extras.view_scriptmodule')
module = ScriptModule.objects.create( module = ScriptModule.objects.create(
file_root='scripts', file_path='to_delete.py', file_root=ManagedFileRootPathChoices.SCRIPTS, file_path='to_delete.py',
) )
script = Script.objects.create(module=module, name='ToDelete', is_executable=True) url = reverse('extras-api:scriptmodule-detail', kwargs={'pk': module.pk})
url = reverse('extras-api:script-detail', kwargs={'pk': script.pk})
response = self.client.delete(url, **self.header) response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
self.assertFalse(ScriptModule.objects.filter(pk=module.pk).exists()) self.assertFalse(ScriptModule.objects.filter(pk=module.pk).exists())
def test_destroy_script_module_without_permission(self): def test_delete_script_module_without_permission(self):
"""DELETE without delete_scriptmodule permission returns 403.""" """DELETE without delete_scriptmodule permission returns 403."""
self.add_permissions('extras.view_script') self.add_permissions('extras.view_scriptmodule')
from extras.models import Script
module = ScriptModule.objects.create( module = ScriptModule.objects.create(
file_root='scripts', file_path='no_delete.py', file_root=ManagedFileRootPathChoices.SCRIPTS, file_path='no_delete.py',
) )
script = Script.objects.create(module=module, name='NoDelete', is_executable=True) url = reverse('extras-api:scriptmodule-detail', kwargs={'pk': module.pk})
url = reverse('extras-api:script-detail', kwargs={'pk': script.pk})
response = self.client.delete(url, **self.header) response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
self.assertTrue(ScriptModule.objects.filter(pk=module.pk).exists()) self.assertTrue(ScriptModule.objects.filter(pk=module.pk).exists())

View File

@@ -142,11 +142,7 @@ class ObjectPermissionMixin:
# Also accept permissions for proxy models whose concrete model matches the object's. # Also accept permissions for proxy models whose concrete model matches the object's.
model = obj._meta.concrete_model model = obj._meta.concrete_model
if model._meta.label_lower != '.'.join((app_label, model_name)): if model._meta.label_lower != '.'.join((app_label, model_name)):
try: if apps.get_model(app_label, model_name)._meta.concrete_model != model:
perm_model = apps.get_model(app_label, model_name)
except LookupError:
perm_model = None
if not perm_model or perm_model._meta.concrete_model != model:
raise ValueError(_("Invalid permission {permission} for model {model}").format( raise ValueError(_("Invalid permission {permission} for model {model}").format(
permission=perm, model=model permission=perm, model=model
)) ))