diff --git a/netbox/dcim/tests/test_forms.py b/netbox/dcim/tests/test_forms.py index 118c347fd..2f01525fc 100644 --- a/netbox/dcim/tests/test_forms.py +++ b/netbox/dcim/tests/test_forms.py @@ -10,7 +10,8 @@ from dcim.choices import ( ) from dcim.forms import * from dcim.models import * -from ipam.models import VLAN +from ipam.models import ASN, RIR, VLAN +from utilities.forms.rendering import M2MAddRemoveFields from utilities.testing import create_test_device from virtualization.models import Cluster, ClusterGroup, ClusterType @@ -417,3 +418,111 @@ class InterfaceTestCase(TestCase): self.assertNotIn('untagged_vlan', form.cleaned_data.keys()) self.assertNotIn('tagged_vlans', form.cleaned_data.keys()) self.assertNotIn('qinq_svlan', form.cleaned_data.keys()) + + +class SiteFormTestCase(TestCase): + """ + Tests for M2MAddRemoveFields using Site ASN assignments as the test case. + Covers both simple mode (single multi-select field) and add/remove mode (dual fields). + """ + + @classmethod + def setUpTestData(cls): + cls.rir = RIR.objects.create(name='RIR 1', slug='rir-1') + # Create 110 ASNs: 100 to pre-assign (triggering add/remove mode) plus 10 extras + ASN.objects.bulk_create([ASN(asn=i, rir=cls.rir) for i in range(1, 111)]) + cls.asns = list(ASN.objects.order_by('asn')) + + def _site_data(self, **kwargs): + data = {'name': 'Test Site', 'slug': 'test-site', 'status': 'active'} + data.update(kwargs) + return data + + def test_new_site_uses_simple_mode(self): + """A form for a new site uses the single 'asns' field (simple mode).""" + form = SiteForm(data=self._site_data()) + self.assertIn('asns', form.fields) + self.assertNotIn('add_asns', form.fields) + self.assertNotIn('remove_asns', form.fields) + + def test_existing_site_below_threshold_uses_simple_mode(self): + """A form for an existing site with fewer than THRESHOLD ASNs uses simple mode.""" + site = Site.objects.create(name='Site 1', slug='site-1') + site.asns.set(self.asns[:5]) + form = SiteForm(instance=site) + self.assertIn('asns', form.fields) + self.assertNotIn('add_asns', form.fields) + self.assertNotIn('remove_asns', form.fields) + + def test_existing_site_at_threshold_uses_add_remove_mode(self): + """A form for an existing site with THRESHOLD or more ASNs uses add/remove mode.""" + site = Site.objects.create(name='Site 2', slug='site-2') + site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD]) + form = SiteForm(instance=site) + self.assertNotIn('asns', form.fields) + self.assertIn('add_asns', form.fields) + self.assertIn('remove_asns', form.fields) + + def test_simple_mode_assigns_asns_on_create(self): + """Saving a new site via simple mode assigns the selected ASNs.""" + asn_pks = [asn.pk for asn in self.asns[:3]] + form = SiteForm(data=self._site_data(asns=asn_pks)) + self.assertTrue(form.is_valid(), form.errors) + site = form.save() + self.assertEqual(set(site.asns.values_list('pk', flat=True)), set(asn_pks)) + + def test_simple_mode_replaces_asns_on_edit(self): + """Saving an existing site via simple mode replaces the current ASN assignments.""" + site = Site.objects.create(name='Site 3', slug='site-3') + site.asns.set(self.asns[:3]) + new_asn_pks = [asn.pk for asn in self.asns[3:6]] + form = SiteForm( + data=self._site_data(name='Site 3', slug='site-3', asns=new_asn_pks), + instance=site + ) + self.assertTrue(form.is_valid(), form.errors) + site = form.save() + self.assertEqual(set(site.asns.values_list('pk', flat=True)), set(new_asn_pks)) + + def test_add_remove_mode_adds_asns(self): + """In add/remove mode, specifying 'add_asns' appends to current assignments.""" + site = Site.objects.create(name='Site 4', slug='site-4') + site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD]) + new_asn_pks = [asn.pk for asn in self.asns[M2MAddRemoveFields.THRESHOLD:]] + form = SiteForm( + data=self._site_data(name='Site 4', slug='site-4', add_asns=new_asn_pks), + instance=site + ) + self.assertTrue(form.is_valid(), form.errors) + site = form.save() + self.assertEqual(site.asns.count(), len(self.asns)) + + def test_add_remove_mode_removes_asns(self): + """In add/remove mode, specifying 'remove_asns' drops those assignments.""" + site = Site.objects.create(name='Site 5', slug='site-5') + site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD]) + remove_pks = [asn.pk for asn in self.asns[:5]] + form = SiteForm( + data=self._site_data(name='Site 5', slug='site-5', remove_asns=remove_pks), + instance=site + ) + self.assertTrue(form.is_valid(), form.errors) + site = form.save() + self.assertEqual(site.asns.count(), M2MAddRemoveFields.THRESHOLD - 5) + self.assertFalse(site.asns.filter(pk__in=remove_pks).exists()) + + def test_add_remove_mode_simultaneous_add_and_remove(self): + """In add/remove mode, add and remove operations are applied together.""" + site = Site.objects.create(name='Site 6', slug='site-6') + site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD]) + add_pks = [asn.pk for asn in self.asns[M2MAddRemoveFields.THRESHOLD:M2MAddRemoveFields.THRESHOLD + 3]] + remove_pks = [asn.pk for asn in self.asns[:3]] + form = SiteForm( + data=self._site_data(name='Site 6', slug='site-6', add_asns=add_pks, remove_asns=remove_pks), + instance=site + ) + self.assertTrue(form.is_valid(), form.errors) + site = form.save() + self.assertEqual(site.asns.count(), M2MAddRemoveFields.THRESHOLD) + self.assertTrue(site.asns.filter(pk__in=add_pks).count() == 3) + self.assertFalse(site.asns.filter(pk__in=remove_pks).exists()) diff --git a/netbox/netbox/forms/model_forms.py b/netbox/netbox/forms/model_forms.py index 45948e174..cfca17255 100644 --- a/netbox/netbox/forms/model_forms.py +++ b/netbox/netbox/forms/model_forms.py @@ -2,7 +2,6 @@ import json from django import forms from django.contrib.contenttypes.models import ContentType -from django.db import models from django.db.models.fields.related import ManyToManyRel from extras.choices import * @@ -77,15 +76,17 @@ class NetBoxModelForm( and add/remove (dual field) modes. """ self.instance._m2m_values = {} - for field in self.instance._meta.get_fields(): - # Determine the accessor name for this M2M relationship - if isinstance(field, models.ManyToManyField): - name = field.name - elif isinstance(field, ManyToManyRel): - name = field.get_accessor_name() - else: - continue + # Collect names to process: local M2M fields (includes TaggableManager from django-taggit) + # plus reverse M2M relations (ManyToManyRel). + names = [field.name for field in self.instance._meta.local_many_to_many] + names += [ + field.get_accessor_name() + for field in self.instance._meta.get_fields() + if isinstance(field, ManyToManyRel) + ] + + for name in names: if name in self.cleaned_data: # Simple mode: single multi-select field self.instance._m2m_values[name] = list(self.cleaned_data[name])