From 95011821bb96137b01805cea5e7f1f856bebb820 Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Tue, 31 Mar 2026 16:01:40 -0400 Subject: [PATCH] Closes #21771: Add add_tags & remove_tags fields for taggable objects --- netbox/dcim/tests/test_api.py | 177 +++++++++++++++++++++- netbox/netbox/api/serializers/features.py | 52 +++++++ netbox/utilities/testing/api.py | 4 +- 3 files changed, 229 insertions(+), 4 deletions(-) diff --git a/netbox/dcim/tests/test_api.py b/netbox/dcim/tests/test_api.py index df12d0dec..b74ed1e42 100644 --- a/netbox/dcim/tests/test_api.py +++ b/netbox/dcim/tests/test_api.py @@ -5,16 +5,17 @@ from django.urls import reverse from django.utils.translation import gettext as _ from rest_framework import status +from core.models import ObjectType from dcim.choices import * from dcim.constants import * from dcim.models import * -from extras.models import ConfigTemplate +from extras.models import ConfigTemplate, Tag from ipam.choices import VLANQinQRoleChoices from ipam.models import ASN, RIR, VLAN, VRF from netbox.api.serializers import GenericObjectSerializer from tenancy.models import Tenant from users.constants import TOKEN_PREFIX -from users.models import Token, User +from users.models import ObjectPermission, Token, User from utilities.testing import APITestCase, APIViewTestCases, create_test_device, disable_logging from virtualization.models import Cluster, ClusterType from wireless.choices import WirelessChannelChoices @@ -195,6 +196,178 @@ class SiteTest(APIViewTestCases.APIViewTestCase): }, ] + def test_add_tags(self): + """ + Add tags to an existing object via the add_tags field. + """ + site = Site.objects.first() + tags = Tag.objects.bulk_create(( + Tag(name='Alpha', slug='alpha'), + Tag(name='Bravo', slug='bravo'), + Tag(name='Charlie', slug='charlie'), + )) + site.tags.set([tags[0], tags[1]]) + + # Grant change permission + obj_perm = ObjectPermission(name='Test permission', actions=['change']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + url = self._get_detail_url(site) + data = { + 'add_tags': [{'name': 'Charlie'}], + } + response = self.client.patch(url, data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + + # Verify all three tags are now assigned + tag_names = sorted(site.tags.values_list('name', flat=True)) + self.assertEqual(tag_names, ['Alpha', 'Bravo', 'Charlie']) + + # Verify add_tags and remove_tags are not in the response + self.assertNotIn('add_tags', response.data) + self.assertNotIn('remove_tags', response.data) + self.assertIn('tags', response.data) + + def test_remove_tags(self): + """ + Remove tags from an existing object via the remove_tags field. + """ + site = Site.objects.first() + tags = Tag.objects.bulk_create(( + Tag(name='Alpha', slug='alpha'), + Tag(name='Bravo', slug='bravo'), + Tag(name='Charlie', slug='charlie'), + )) + site.tags.set(tags) + + # Grant change permission + obj_perm = ObjectPermission(name='Test permission', actions=['change']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + url = self._get_detail_url(site) + data = { + 'remove_tags': [{'name': 'Charlie'}], + } + response = self.client.patch(url, data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + + # Verify only Alpha and Bravo remain + tag_names = sorted(site.tags.values_list('name', flat=True)) + self.assertEqual(tag_names, ['Alpha', 'Bravo']) + + def test_remove_tags_not_assigned(self): + """ + Removing a tag that is not assigned should not raise an error. + """ + site = Site.objects.first() + tags = Tag.objects.bulk_create(( + Tag(name='Alpha', slug='alpha'), + Tag(name='Bravo', slug='bravo'), + Tag(name='Charlie', slug='charlie'), + )) + site.tags.set([tags[0], tags[1]]) + + # Grant change permission + obj_perm = ObjectPermission(name='Test permission', actions=['change']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + url = self._get_detail_url(site) + data = { + 'remove_tags': [{'name': 'Charlie'}], + } + response = self.client.patch(url, data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + + # Tags should be unchanged + tag_names = sorted(site.tags.values_list('name', flat=True)) + self.assertEqual(tag_names, ['Alpha', 'Bravo']) + + def test_add_and_remove_tags(self): + """ + Add and remove tags in the same request. + """ + site = Site.objects.first() + tags = Tag.objects.bulk_create(( + Tag(name='Alpha', slug='alpha'), + Tag(name='Bravo', slug='bravo'), + Tag(name='Charlie', slug='charlie'), + )) + site.tags.set([tags[0], tags[1]]) + + # Grant change permission + obj_perm = ObjectPermission(name='Test permission', actions=['change']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + url = self._get_detail_url(site) + data = { + 'add_tags': [{'name': 'Charlie'}], + 'remove_tags': [{'name': 'Alpha'}], + } + response = self.client.patch(url, data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + + # Verify Bravo and Charlie remain + tag_names = sorted(site.tags.values_list('name', flat=True)) + self.assertEqual(tag_names, ['Bravo', 'Charlie']) + + def test_tags_with_add_tags_error(self): + """ + Specifying tags together with add_tags or remove_tags should raise a validation error. + """ + site = Site.objects.first() + Tag.objects.bulk_create(( + Tag(name='Alpha', slug='alpha'), + Tag(name='Bravo', slug='bravo'), + )) + + # Grant change permission + obj_perm = ObjectPermission(name='Test permission', actions=['change']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + url = self._get_detail_url(site) + data = { + 'tags': [{'name': 'Alpha'}], + 'add_tags': [{'name': 'Bravo'}], + } + response = self.client.patch(url, data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) + + def test_create_with_add_tags(self): + """ + Create a new object using add_tags. + """ + Tag.objects.bulk_create(( + Tag(name='Alpha', slug='alpha'), + Tag(name='Bravo', slug='bravo'), + )) + + obj_perm = ObjectPermission(name='Test permission', actions=['add']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + + data = { + 'name': 'Site 10', + 'slug': 'site-10', + 'add_tags': [{'name': 'Alpha'}, {'name': 'Bravo'}], + } + response = self.client.post(self._get_list_url(), data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_201_CREATED) + + site = Site.objects.get(pk=response.data['id']) + tag_names = sorted(site.tags.values_list('name', flat=True)) + self.assertEqual(tag_names, ['Alpha', 'Bravo']) + class LocationTest(APIViewTestCases.APIViewTestCase): model = Location diff --git a/netbox/netbox/api/serializers/features.py b/netbox/netbox/api/serializers/features.py index 6a89e7b1c..2ff1cc7f5 100644 --- a/netbox/netbox/api/serializers/features.py +++ b/netbox/netbox/api/serializers/features.py @@ -30,17 +30,62 @@ class TaggableModelSerializer(serializers.Serializer): on create() and update(). """ tags = NestedTagSerializer(many=True, required=False) + add_tags = NestedTagSerializer(many=True, required=False, write_only=True) + remove_tags = NestedTagSerializer(many=True, required=False, write_only=True) + + def to_internal_value(self, data): + ret = super().to_internal_value(data) + + # Workaround to bypass requirement to include add_tags/remove_tags in Meta.fields on every serializer + if type(data) is dict: + tag_serializer = NestedTagSerializer(many=True) + for field_name in ('add_tags', 'remove_tags'): + if field_name in data: + ret[field_name] = tag_serializer.to_internal_value(data[field_name]) + + return ret + + def validate(self, data): + # Skip validation for nested serializer representations (e.g. when used as a related field) + if type(data) is not dict: + return super().validate(data) + + if data.get('tags') and (data.get('add_tags') or data.get('remove_tags')): + raise serializers.ValidationError({ + 'tags': 'Cannot specify "tags" together with "add_tags" or "remove_tags".' + }) + + # Pop add_tags/remove_tags before calling super() to prevent them from being passed + # to the model constructor during ValidatedModelSerializer validation + add_tags = data.pop('add_tags', None) + remove_tags = data.pop('remove_tags', None) + + data = super().validate(data) + + # Restore for use in create()/update() + if add_tags is not None: + data['add_tags'] = add_tags + if remove_tags is not None: + data['remove_tags'] = remove_tags + + return data def create(self, validated_data): tags = validated_data.pop('tags', None) + add_tags = validated_data.pop('add_tags', None) + validated_data.pop('remove_tags', None) instance = super().create(validated_data) if tags is not None: return self._save_tags(instance, tags) + if add_tags is not None: + instance.tags.add(*[t.name for t in add_tags]) return instance def update(self, instance, validated_data): tags = validated_data.pop('tags', None) + add_tags = validated_data.pop('add_tags', None) + remove_tags = validated_data.pop('remove_tags', None) # Cache tags on instance for change logging instance._tags = tags or [] @@ -49,6 +94,13 @@ class TaggableModelSerializer(serializers.Serializer): if tags is not None: return self._save_tags(instance, tags) + if add_tags is not None: + instance.tags.add(*[t.name for t in add_tags]) + if remove_tags is not None: + instance.tags.remove(*[t.name for t in remove_tags]) + if add_tags is not None or remove_tags is not None: + instance._tags = instance.tags.all() + return instance def _save_tags(self, instance, tags): diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index c2a83a8fb..9899edd24 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -286,7 +286,7 @@ class APIViewTestCases: self.assertEqual(self._get_queryset().count(), initial_count + len(self.create_data)) for i, obj in enumerate(response.data): for field in self.create_data[i]: - if field == 'changelog_message': + if field in ('changelog_message', 'add_tags', 'remove_tags'): # Write-only field continue if field not in self.validation_excluded_fields: @@ -444,7 +444,7 @@ class APIViewTestCases: self.assertHttpStatus(response, status.HTTP_200_OK) for i, obj in enumerate(response.data): for field in self.bulk_update_data: - if field == 'changelog_data': + if field in ('changelog_message', 'add_tags', 'remove_tags'): # Write-only field continue self.assertIn(field, obj, f"Bulk update field '{field}' missing from object {i} in response")