Merge pull request #21806 from netbox-community/21771-rest-api-add-remove-tags

Closes #21771: Add `add_tags` & `remove_tags` fields for taggable objects
This commit is contained in:
bctiemann
2026-04-01 13:02:19 -04:00
committed by GitHub
3 changed files with 289 additions and 4 deletions

View File

@@ -5,16 +5,17 @@ from django.urls import reverse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from rest_framework import status from rest_framework import status
from core.models import ObjectType
from dcim.choices import * from dcim.choices import *
from dcim.constants import * from dcim.constants import *
from dcim.models import * from dcim.models import *
from extras.models import ConfigTemplate from extras.models import ConfigTemplate, Tag
from ipam.choices import VLANQinQRoleChoices from ipam.choices import VLANQinQRoleChoices
from ipam.models import ASN, RIR, VLAN, VRF from ipam.models import ASN, RIR, VLAN, VRF
from netbox.api.serializers import GenericObjectSerializer from netbox.api.serializers import GenericObjectSerializer
from tenancy.models import Tenant from tenancy.models import Tenant
from users.constants import TOKEN_PREFIX from users.constants import TOKEN_PREFIX
from users.models import Token, User from users.models import ObjectPermission, Token, User
from utilities.testing import APITestCase, APIViewTestCases, create_test_device, disable_logging from utilities.testing import APITestCase, APIViewTestCases, create_test_device, disable_logging
from virtualization.models import Cluster, ClusterType from virtualization.models import Cluster, ClusterType
from wireless.choices import WirelessChannelChoices from wireless.choices import WirelessChannelChoices
@@ -195,6 +196,222 @@ class SiteTest(APIViewTestCases.APIViewTestCase):
}, },
] ]
def test_add_tags(self):
"""
Add tags to an existing object via the add_tags field.
"""
site = Site.objects.first()
tags = Tag.objects.bulk_create((
Tag(name='Alpha', slug='alpha'),
Tag(name='Bravo', slug='bravo'),
Tag(name='Charlie', slug='charlie'),
))
site.tags.set([tags[0], tags[1]])
# Grant change permission
obj_perm = ObjectPermission(name='Test permission', actions=['change'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
url = self._get_detail_url(site)
data = {
'add_tags': [{'name': 'Charlie'}],
}
response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
# Verify all three tags are now assigned
tag_names = sorted(site.tags.values_list('name', flat=True))
self.assertEqual(tag_names, ['Alpha', 'Bravo', 'Charlie'])
# Verify add_tags and remove_tags are not in the response
self.assertNotIn('add_tags', response.data)
self.assertNotIn('remove_tags', response.data)
self.assertIn('tags', response.data)
def test_remove_tags(self):
"""
Remove tags from an existing object via the remove_tags field.
"""
site = Site.objects.first()
tags = Tag.objects.bulk_create((
Tag(name='Alpha', slug='alpha'),
Tag(name='Bravo', slug='bravo'),
Tag(name='Charlie', slug='charlie'),
))
site.tags.set(tags)
# Grant change permission
obj_perm = ObjectPermission(name='Test permission', actions=['change'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
url = self._get_detail_url(site)
data = {
'remove_tags': [{'name': 'Charlie'}],
}
response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
# Verify only Alpha and Bravo remain
tag_names = sorted(site.tags.values_list('name', flat=True))
self.assertEqual(tag_names, ['Alpha', 'Bravo'])
def test_remove_tags_not_assigned(self):
"""
Removing a tag that is not assigned should not raise an error.
"""
site = Site.objects.first()
tags = Tag.objects.bulk_create((
Tag(name='Alpha', slug='alpha'),
Tag(name='Bravo', slug='bravo'),
Tag(name='Charlie', slug='charlie'),
))
site.tags.set([tags[0], tags[1]])
# Grant change permission
obj_perm = ObjectPermission(name='Test permission', actions=['change'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
url = self._get_detail_url(site)
data = {
'remove_tags': [{'name': 'Charlie'}],
}
response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
# Tags should be unchanged
tag_names = sorted(site.tags.values_list('name', flat=True))
self.assertEqual(tag_names, ['Alpha', 'Bravo'])
def test_add_and_remove_tags(self):
"""
Add and remove tags in the same request.
"""
site = Site.objects.first()
tags = Tag.objects.bulk_create((
Tag(name='Alpha', slug='alpha'),
Tag(name='Bravo', slug='bravo'),
Tag(name='Charlie', slug='charlie'),
))
site.tags.set([tags[0], tags[1]])
# Grant change permission
obj_perm = ObjectPermission(name='Test permission', actions=['change'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
url = self._get_detail_url(site)
data = {
'add_tags': [{'name': 'Charlie'}],
'remove_tags': [{'name': 'Alpha'}],
}
response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
# Verify Bravo and Charlie remain
tag_names = sorted(site.tags.values_list('name', flat=True))
self.assertEqual(tag_names, ['Bravo', 'Charlie'])
def test_tags_with_add_tags_error(self):
"""
Specifying tags together with add_tags or remove_tags should raise a validation error.
"""
site = Site.objects.first()
Tag.objects.bulk_create((
Tag(name='Alpha', slug='alpha'),
Tag(name='Bravo', slug='bravo'),
))
# Grant change permission
obj_perm = ObjectPermission(name='Test permission', actions=['change'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
url = self._get_detail_url(site)
data = {
'tags': [{'name': 'Alpha'}],
'add_tags': [{'name': 'Bravo'}],
}
response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
def test_create_with_add_tags(self):
"""
Create a new object using add_tags.
"""
Tag.objects.bulk_create((
Tag(name='Alpha', slug='alpha'),
Tag(name='Bravo', slug='bravo'),
))
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
data = {
'name': 'Site 10',
'slug': 'site-10',
'add_tags': [{'name': 'Alpha'}, {'name': 'Bravo'}],
}
response = self.client.post(self._get_list_url(), data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED)
site = Site.objects.get(pk=response.data['id'])
tag_names = sorted(site.tags.values_list('name', flat=True))
self.assertEqual(tag_names, ['Alpha', 'Bravo'])
def test_create_with_remove_tags_error(self):
"""
Using remove_tags when creating a new object should raise a validation error.
"""
Tag.objects.bulk_create((
Tag(name='Alpha', slug='alpha'),
))
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
data = {
'name': 'Site 10',
'slug': 'site-10',
'remove_tags': [{'name': 'Alpha'}],
}
response = self.client.post(self._get_list_url(), data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
def test_add_and_remove_same_tag_error(self):
"""
Including the same tag in both add_tags and remove_tags should raise a validation error.
"""
site = Site.objects.first()
Tag.objects.bulk_create((
Tag(name='Alpha', slug='alpha'),
Tag(name='Bravo', slug='bravo'),
))
obj_perm = ObjectPermission(name='Test permission', actions=['change'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
url = self._get_detail_url(site)
data = {
'add_tags': [{'name': 'Alpha'}, {'name': 'Bravo'}],
'remove_tags': [{'name': 'Alpha'}],
}
response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
class LocationTest(APIViewTestCases.APIViewTestCase): class LocationTest(APIViewTestCases.APIViewTestCase):
model = Location model = Location

View File

@@ -30,17 +30,78 @@ class TaggableModelSerializer(serializers.Serializer):
on create() and update(). on create() and update().
""" """
tags = NestedTagSerializer(many=True, required=False) tags = NestedTagSerializer(many=True, required=False)
add_tags = NestedTagSerializer(many=True, required=False, write_only=True)
remove_tags = NestedTagSerializer(many=True, required=False, write_only=True)
def to_internal_value(self, data):
ret = super().to_internal_value(data)
# Workaround to bypass requirement to include add_tags/remove_tags in Meta.fields on every serializer
if type(data) is dict:
tag_serializer = NestedTagSerializer(many=True)
for field_name in ('add_tags', 'remove_tags'):
if field_name in data:
ret[field_name] = tag_serializer.to_internal_value(data[field_name])
return ret
def validate(self, data):
# Skip validation for nested serializer representations (e.g. when used as a related field)
if type(data) is not dict:
return super().validate(data)
if data.get('tags') and (data.get('add_tags') or data.get('remove_tags')):
raise serializers.ValidationError({
'tags': 'Cannot specify "tags" together with "add_tags" or "remove_tags".'
})
if self.instance is None and data.get('remove_tags'):
raise serializers.ValidationError({
'remove_tags': 'Cannot use "remove_tags" when creating a new object.'
})
if data.get('add_tags') and data.get('remove_tags'):
add_pks = {t.pk for t in data['add_tags']}
remove_pks = {t.pk for t in data['remove_tags']}
overlap = [t for t in data['add_tags'] if t.pk in (add_pks & remove_pks)]
if overlap:
raise serializers.ValidationError({
'remove_tags':
f'Tags may not be present in both "add_tags" and "remove_tags": '
f'{", ".join(t.name for t in overlap)}'
})
# Pop add_tags/remove_tags before calling super() to prevent them from being passed
# to the model constructor during ValidatedModelSerializer validation
add_tags = data.pop('add_tags', None)
remove_tags = data.pop('remove_tags', None)
data = super().validate(data)
# Restore for use in create()/update()
if add_tags is not None:
data['add_tags'] = add_tags
if remove_tags is not None:
data['remove_tags'] = remove_tags
return data
def create(self, validated_data): def create(self, validated_data):
tags = validated_data.pop('tags', None) tags = validated_data.pop('tags', None)
add_tags = validated_data.pop('add_tags', None)
validated_data.pop('remove_tags', None)
instance = super().create(validated_data) instance = super().create(validated_data)
if tags is not None: if tags is not None:
return self._save_tags(instance, tags) return self._save_tags(instance, tags)
if add_tags is not None:
instance.tags.add(*[t.name for t in add_tags])
return instance return instance
def update(self, instance, validated_data): def update(self, instance, validated_data):
tags = validated_data.pop('tags', None) tags = validated_data.pop('tags', None)
add_tags = validated_data.pop('add_tags', None)
remove_tags = validated_data.pop('remove_tags', None)
# Cache tags on instance for change logging # Cache tags on instance for change logging
instance._tags = tags or [] instance._tags = tags or []
@@ -49,6 +110,13 @@ class TaggableModelSerializer(serializers.Serializer):
if tags is not None: if tags is not None:
return self._save_tags(instance, tags) return self._save_tags(instance, tags)
if add_tags is not None:
instance.tags.add(*[t.name for t in add_tags])
if remove_tags is not None:
instance.tags.remove(*[t.name for t in remove_tags])
if add_tags is not None or remove_tags is not None:
instance._tags = instance.tags.all()
return instance return instance
def _save_tags(self, instance, tags): def _save_tags(self, instance, tags):

View File

@@ -286,7 +286,7 @@ class APIViewTestCases:
self.assertEqual(self._get_queryset().count(), initial_count + len(self.create_data)) self.assertEqual(self._get_queryset().count(), initial_count + len(self.create_data))
for i, obj in enumerate(response.data): for i, obj in enumerate(response.data):
for field in self.create_data[i]: for field in self.create_data[i]:
if field == 'changelog_message': if field in ('changelog_message', 'add_tags', 'remove_tags'):
# Write-only field # Write-only field
continue continue
if field not in self.validation_excluded_fields: if field not in self.validation_excluded_fields:
@@ -444,7 +444,7 @@ class APIViewTestCases:
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)
for i, obj in enumerate(response.data): for i, obj in enumerate(response.data):
for field in self.bulk_update_data: for field in self.bulk_update_data:
if field == 'changelog_data': if field in ('changelog_message', 'add_tags', 'remove_tags'):
# Write-only field # Write-only field
continue continue
self.assertIn(field, obj, f"Bulk update field '{field}' missing from object {i} in response") self.assertIn(field, obj, f"Bulk update field '{field}' missing from object {i} in response")