diff --git a/netbox/dcim/api/serializers_/devices.py b/netbox/dcim/api/serializers_/devices.py index f860c7879..51cb6c3f2 100644 --- a/netbox/dcim/api/serializers_/devices.py +++ b/netbox/dcim/api/serializers_/devices.py @@ -58,10 +58,30 @@ class DeviceSerializer(PrimaryModelSerializer): ) status = ChoiceField(choices=DeviceStatusChoices, required=False) airflow = ChoiceField(choices=DeviceAirflowChoices, allow_blank=True, required=False) - primary_ip = IPAddressSerializer(nested=True, read_only=True, allow_null=True) - primary_ip4 = IPAddressSerializer(nested=True, required=False, allow_null=True) - primary_ip6 = IPAddressSerializer(nested=True, required=False, allow_null=True) - oob_ip = IPAddressSerializer(nested=True, required=False, allow_null=True) + primary_ip = IPAddressSerializer( + nested=True, + read_only=True, + allow_null=True, + fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'], + ) + primary_ip4 = IPAddressSerializer( + nested=True, + required=False, + allow_null=True, + fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'], + ) + primary_ip6 = IPAddressSerializer( + nested=True, + required=False, + allow_null=True, + fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'], + ) + oob_ip = IPAddressSerializer( + nested=True, + required=False, + allow_null=True, + fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'], + ) parent_device = serializers.SerializerMethodField() cluster = ClusterSerializer(nested=True, required=False, allow_null=True) virtual_chassis = VirtualChassisSerializer(nested=True, required=False, allow_null=True, default=None) diff --git a/netbox/dcim/tests/test_api.py b/netbox/dcim/tests/test_api.py index 1ae90ca8e..a87ef9e32 100644 --- a/netbox/dcim/tests/test_api.py +++ b/netbox/dcim/tests/test_api.py @@ -16,7 +16,13 @@ from netbox.api.serializers import GenericObjectSerializer from tenancy.models import Tenant from users.constants import TOKEN_PREFIX from users.models import ObjectPermission, Token, User -from utilities.testing import APITestCase, APIViewTestCases, create_test_device, disable_logging +from utilities.testing import ( + APITestCase, + APIViewTestCases, + create_test_device, + create_test_nat_ip_pair, + disable_logging, +) from virtualization.models import Cluster, ClusterType from wireless.choices import WirelessChannelChoices from wireless.models import WirelessLAN @@ -1902,6 +1908,81 @@ class DeviceTest(APIViewTestCases.APIViewTestCase): response = self.client.post(url, {}, format='json', HTTP_AUTHORIZATION=token_header) self.assertHttpStatus(response, status.HTTP_200_OK) + def test_list_object_includes_nat_inside_on_primary_ip(self): + device = create_test_device('natted-device') + interface = Interface.objects.create(device=device, name='eth0', type='other') + + real_ip, nat_ip = create_test_nat_ip_pair( + real_address='10.0.0.10/32', + nat_address='198.51.100.10/32', + inside_interface=interface, + ) + + device.primary_ip4 = nat_ip + device.save() + + self.add_permissions('dcim.view_device', 'ipam.view_ipaddress') + response = self.client.get(f'{self._get_list_url()}?id={device.pk}', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + + result = response.data['results'][0] + for field in ('primary_ip', 'primary_ip4'): + self.assertEqual(result[field]['address'], str(nat_ip.address)) + self.assertEqual(result[field]['nat_inside']['address'], str(real_ip.address)) + self.assertEqual(result[field]['nat_outside'], []) + + def test_get_object_includes_nat_outside_on_primary_ip(self): + device = create_test_device('real-ip-device') + interface = Interface.objects.create(device=device, name='eth0', type='other') + + real_ip, nat_ip = create_test_nat_ip_pair( + real_address='10.0.0.11/32', + nat_address='198.51.100.11/32', + inside_interface=interface, + ) + + device.primary_ip4 = real_ip + device.save() + + self.add_permissions('dcim.view_device', 'ipam.view_ipaddress') + response = self.client.get( + f'{self._get_detail_url(device)}?exclude=config_context', + **self.header, + ) + self.assertHttpStatus(response, status.HTTP_200_OK) + + for field in ('primary_ip', 'primary_ip4'): + self.assertEqual(response.data[field]['address'], str(real_ip.address)) + self.assertIsNone(response.data[field]['nat_inside']) + self.assertCountEqual( + [ip['address'] for ip in response.data[field]['nat_outside']], + [str(nat_ip.address)], + ) + + def test_get_object_includes_nat_on_oob_ip(self): + device = create_test_device('oob-nat-device') + interface = Interface.objects.create(device=device, name='oob0', type='other') + + real_ip, nat_ip = create_test_nat_ip_pair( + real_address='10.0.0.12/32', + nat_address='198.51.100.12/32', + inside_interface=interface, + ) + + device.oob_ip = nat_ip + device.save() + + self.add_permissions('dcim.view_device', 'ipam.view_ipaddress') + response = self.client.get( + f'{self._get_detail_url(device)}?exclude=config_context', + **self.header, + ) + self.assertHttpStatus(response, status.HTTP_200_OK) + + self.assertEqual(response.data['oob_ip']['address'], str(nat_ip.address)) + self.assertEqual(response.data['oob_ip']['nat_inside']['address'], str(real_ip.address)) + self.assertEqual(response.data['oob_ip']['nat_outside'], []) + class ModuleTest(APIViewTestCases.APIViewTestCase): model = Module diff --git a/netbox/utilities/api.py b/netbox/utilities/api.py index 076ce3b0c..9c6b1c08c 100644 --- a/netbox/utilities/api.py +++ b/netbox/utilities/api.py @@ -11,7 +11,7 @@ from django.urls import reverse from django.utils.module_loading import import_string from django.utils.translation import gettext_lazy as _ from rest_framework.permissions import BasePermission -from rest_framework.serializers import Serializer +from rest_framework.serializers import ListSerializer, Serializer from rest_framework.views import get_view_name as drf_get_view_name from extras.constants import HTTP_CONTENT_TYPE_JSON @@ -98,6 +98,30 @@ def get_view_name(view): return drf_get_view_name(view) +def _get_nested_serializer(serializer_field): + """ + Return the nested serializer instance for a declared serializer field. + """ + if isinstance(serializer_field, ListSerializer): + serializer_field = serializer_field.child + + if isinstance(serializer_field, Serializer) and hasattr(serializer_field, 'nested'): + return serializer_field + + return None + + +def _get_serializer_fields(serializer: Serializer): + """ + Return the effective field names for a serializer instance, honoring any + field-level fields=/omit= overrides. + """ + fields = getattr(serializer, '_include_fields', None) or serializer.Meta.fields + omit = getattr(serializer, '_omit_fields', []) or [] + + return [field_name for field_name in fields if field_name not in omit] + + def get_prefetches_for_serializer(serializer_class, fields=None, omit=None): """ Compile and return a list of fields which should be prefetched on the queryset for a serializer. @@ -119,7 +143,7 @@ def get_prefetches_for_serializer(serializer_class, fields=None, omit=None): # Determine the name of the model field referenced by the serializer field model_field_name = field_name - if serializer_field and serializer_field.source: + if serializer_field and getattr(serializer_field, 'source', None): model_field_name = serializer_field.source # If the serializer field does not map to a discrete model field, skip it. @@ -130,14 +154,13 @@ def get_prefetches_for_serializer(serializer_class, fields=None, omit=None): except FieldDoesNotExist: continue - # If this field is represented by a nested serializer, recurse to resolve prefetches - # for the related object. - if serializer_field: - if issubclass(type(serializer_field), Serializer): - # Determine which fields to prefetch for the nested object - subfields = serializer_field.Meta.brief_fields if serializer_field.nested else None - for subfield in get_prefetches_for_serializer(type(serializer_field), subfields): - prefetch_fields.append(f'{field_name}__{subfield}') + # If this field is represented by a nested serializer, recurse to resolve + # prefetches for the related object, honoring any field-level fields=/omit= + # constraints set on that serializer field instance. + if nested_serializer := _get_nested_serializer(serializer_field): + subfields = _get_serializer_fields(nested_serializer) + for subfield in get_prefetches_for_serializer(type(nested_serializer), fields=subfields): + prefetch_fields.append(f'{field.name}__{subfield}') return prefetch_fields diff --git a/netbox/utilities/testing/utils.py b/netbox/utilities/testing/utils.py index 7715daa39..34bf8a336 100644 --- a/netbox/utilities/testing/utils.py +++ b/netbox/utilities/testing/utils.py @@ -12,6 +12,7 @@ from core.models import ObjectType from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site from extras.choices import CustomFieldTypeChoices from extras.models import CustomField, Tag +from ipam.models import IPAddress from users.models import User from virtualization.models import Cluster, ClusterType, VirtualMachine @@ -65,6 +66,28 @@ def create_test_virtualmachine(name): return virtual_machine +def create_test_nat_ip_pair( + real_address='10.0.0.10/32', nat_address='198.51.100.10/32', inside_interface=None, outside_interface=None +): + """ + Convenience method for creating an inside IP and its NAT outside IP. + + Optionally, assign either address to an Interface or VMInterface. + Returns (real_ip, nat_ip). + """ + real_ip = IPAddress(address=real_address) + if inside_interface is not None: + real_ip.assigned_object = inside_interface + real_ip.save() + + nat_ip = IPAddress(address=nat_address, nat_inside=real_ip) + if outside_interface is not None: + nat_ip.assigned_object = outside_interface + nat_ip.save() + + return real_ip, nat_ip + + def create_test_user(username='testuser', permissions=None): """ Create a User with the given permissions. diff --git a/netbox/utilities/tests/test_api.py b/netbox/utilities/tests/test_api.py index 290e98f1e..f984aac3a 100644 --- a/netbox/utilities/tests/test_api.py +++ b/netbox/utilities/tests/test_api.py @@ -8,8 +8,9 @@ from dcim.models import Region, Site from extras.choices import CustomFieldTypeChoices from extras.models import CustomField from ipam.models import VLAN +from netbox.api.serializers import BaseModelSerializer from netbox.config import get_config -from utilities.api import get_view_name +from utilities.api import get_prefetches_for_serializer, get_view_name from utilities.testing import APITestCase, disable_warnings @@ -394,3 +395,82 @@ class GetViewNameTestCase(TestCase): name = get_view_name(view) self.assertEqual(name, 'Mock List') + + +class GetPrefetchesForSerializerTestCase(TestCase): + + def test_nested_serializer_honors_explicit_fields(self): + class RegionSerializer(BaseModelSerializer): + class Meta: + model = Region + fields = ('id', 'name', 'parent') + brief_fields = ('id', 'name') + + class SiteSerializer(BaseModelSerializer): + region = RegionSerializer(nested=True, fields=('id', 'parent')) + + class Meta: + model = Site + fields = ('id', 'name', 'region') + + self.assertListEqual( + get_prefetches_for_serializer(SiteSerializer), + ['region', 'region__parent'], + ) + + def test_nested_serializer_honors_explicit_omit(self): + class RegionSerializer(BaseModelSerializer): + class Meta: + model = Region + fields = ('id', 'name', 'parent') + brief_fields = ('id', 'name') + + class SiteSerializer(BaseModelSerializer): + region = RegionSerializer(nested=True, omit=('name',)) + + class Meta: + model = Site + fields = ('id', 'name', 'region') + + self.assertListEqual( + get_prefetches_for_serializer(SiteSerializer), + ['region', 'region__parent'], + ) + + def test_many_nested_serializer_honors_explicit_fields(self): + class SiteSerializer(BaseModelSerializer): + class Meta: + model = Site + fields = ('id', 'name', 'region') + brief_fields = ('id', 'name') + + class RegionSerializer(BaseModelSerializer): + sites = SiteSerializer(nested=True, many=True, fields=('id', 'region')) + + class Meta: + model = Region + fields = ('id', 'name', 'sites') + + self.assertListEqual( + get_prefetches_for_serializer(RegionSerializer), + ['sites', 'sites__region'], + ) + + def test_nested_serializer_uses_source_for_prefetch_path(self): + class RegionSerializer(BaseModelSerializer): + class Meta: + model = Region + fields = ('id', 'name', 'parent') + brief_fields = ('id', 'name') + + class SiteSerializer(BaseModelSerializer): + region_detail = RegionSerializer(source='region', nested=True, fields=('id', 'parent')) + + class Meta: + model = Site + fields = ('id', 'name', 'region_detail') + + self.assertListEqual( + get_prefetches_for_serializer(SiteSerializer), + ['region', 'region__parent'], + ) diff --git a/netbox/virtualization/api/serializers_/virtualmachines.py b/netbox/virtualization/api/serializers_/virtualmachines.py index ac2621511..cfb01e9c3 100644 --- a/netbox/virtualization/api/serializers_/virtualmachines.py +++ b/netbox/virtualization/api/serializers_/virtualmachines.py @@ -1,8 +1,7 @@ from drf_spectacular.utils import extend_schema_field from rest_framework import serializers -from dcim.api.serializers_.device_components import MACAddressSerializer -from dcim.api.serializers_.devices import DeviceSerializer +from dcim.api.serializers_.devices import DeviceSerializer, MACAddressSerializer from dcim.api.serializers_.platforms import PlatformSerializer from dcim.api.serializers_.roles import DeviceRoleSerializer from dcim.api.serializers_.sites import SiteSerializer @@ -58,9 +57,24 @@ class VirtualMachineSerializer(PrimaryModelSerializer): role = DeviceRoleSerializer(nested=True, required=False, allow_null=True) tenant = TenantSerializer(nested=True, required=False, allow_null=True, default=None) platform = PlatformSerializer(nested=True, required=False, allow_null=True) - primary_ip = IPAddressSerializer(nested=True, read_only=True, allow_null=True) - primary_ip4 = IPAddressSerializer(nested=True, required=False, allow_null=True) - primary_ip6 = IPAddressSerializer(nested=True, required=False, allow_null=True) + primary_ip = IPAddressSerializer( + nested=True, + read_only=True, + allow_null=True, + fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'], + ) + primary_ip4 = IPAddressSerializer( + nested=True, + required=False, + allow_null=True, + fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'], + ) + primary_ip6 = IPAddressSerializer( + nested=True, + required=False, + allow_null=True, + fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'], + ) config_template = ConfigTemplateSerializer(nested=True, required=False, allow_null=True, default=None) # Counter fields diff --git a/netbox/virtualization/tests/test_api.py b/netbox/virtualization/tests/test_api.py index e7137d27e..6cfa60963 100644 --- a/netbox/virtualization/tests/test_api.py +++ b/netbox/virtualization/tests/test_api.py @@ -18,6 +18,7 @@ from utilities.testing import ( APITestCase, APIViewTestCases, create_test_device, + create_test_nat_ip_pair, create_test_virtualmachine, disable_logging, ) @@ -505,6 +506,57 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase): response = self.client.post(url, {}, format='json', HTTP_AUTHORIZATION=token_header) self.assertHttpStatus(response, status.HTTP_200_OK) + def test_list_object_includes_nat_inside_on_primary_ip(self): + virtualmachine = create_test_virtualmachine('natted-vm') + interface = VMInterface.objects.create(virtual_machine=virtualmachine, name='eth0') + + real_ip, nat_ip = create_test_nat_ip_pair( + real_address='10.0.1.10/32', + nat_address='198.51.100.20/32', + inside_interface=interface, + ) + + virtualmachine.primary_ip4 = nat_ip + virtualmachine.save() + + self.add_permissions('virtualization.view_virtualmachine', 'ipam.view_ipaddress') + response = self.client.get(f'{self._get_list_url()}?id={virtualmachine.pk}', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + + result = response.data['results'][0] + for field in ('primary_ip', 'primary_ip4'): + self.assertEqual(result[field]['address'], str(nat_ip.address)) + self.assertEqual(result[field]['nat_inside']['address'], str(real_ip.address)) + self.assertEqual(result[field]['nat_outside'], []) + + def test_get_object_includes_nat_outside_on_primary_ip(self): + virtualmachine = create_test_virtualmachine('real-ip-vm') + interface = VMInterface.objects.create(virtual_machine=virtualmachine, name='eth0') + + real_ip, nat_ip = create_test_nat_ip_pair( + real_address='10.0.1.11/32', + nat_address='198.51.100.21/32', + inside_interface=interface, + ) + + virtualmachine.primary_ip4 = real_ip + virtualmachine.save() + + self.add_permissions('virtualization.view_virtualmachine', 'ipam.view_ipaddress') + response = self.client.get( + f'{self._get_detail_url(virtualmachine)}?exclude=config_context', + **self.header, + ) + self.assertHttpStatus(response, status.HTTP_200_OK) + + for field in ('primary_ip', 'primary_ip4'): + self.assertEqual(response.data[field]['address'], str(real_ip.address)) + self.assertIsNone(response.data[field]['nat_inside']) + self.assertCountEqual( + [ip['address'] for ip in response.data[field]['nat_outside']], + [str(nat_ip.address)], + ) + class VMInterfaceTest(APIViewTestCases.APIViewTestCase): model = VMInterface