Add/fix tests

This commit is contained in:
Jeremy Stretch
2026-03-30 10:02:38 -04:00
parent a45e8571da
commit 55daf4c52f
2 changed files with 120 additions and 10 deletions

View File

@@ -10,7 +10,8 @@ from dcim.choices import (
) )
from dcim.forms import * from dcim.forms import *
from dcim.models import * from dcim.models import *
from ipam.models import VLAN from ipam.models import ASN, RIR, VLAN
from utilities.forms.rendering import M2MAddRemoveFields
from utilities.testing import create_test_device from utilities.testing import create_test_device
from virtualization.models import Cluster, ClusterGroup, ClusterType from virtualization.models import Cluster, ClusterGroup, ClusterType
@@ -417,3 +418,111 @@ class InterfaceTestCase(TestCase):
self.assertNotIn('untagged_vlan', form.cleaned_data.keys()) self.assertNotIn('untagged_vlan', form.cleaned_data.keys())
self.assertNotIn('tagged_vlans', form.cleaned_data.keys()) self.assertNotIn('tagged_vlans', form.cleaned_data.keys())
self.assertNotIn('qinq_svlan', form.cleaned_data.keys()) self.assertNotIn('qinq_svlan', form.cleaned_data.keys())
class SiteFormTestCase(TestCase):
"""
Tests for M2MAddRemoveFields using Site ASN assignments as the test case.
Covers both simple mode (single multi-select field) and add/remove mode (dual fields).
"""
@classmethod
def setUpTestData(cls):
cls.rir = RIR.objects.create(name='RIR 1', slug='rir-1')
# Create 110 ASNs: 100 to pre-assign (triggering add/remove mode) plus 10 extras
ASN.objects.bulk_create([ASN(asn=i, rir=cls.rir) for i in range(1, 111)])
cls.asns = list(ASN.objects.order_by('asn'))
def _site_data(self, **kwargs):
data = {'name': 'Test Site', 'slug': 'test-site', 'status': 'active'}
data.update(kwargs)
return data
def test_new_site_uses_simple_mode(self):
"""A form for a new site uses the single 'asns' field (simple mode)."""
form = SiteForm(data=self._site_data())
self.assertIn('asns', form.fields)
self.assertNotIn('add_asns', form.fields)
self.assertNotIn('remove_asns', form.fields)
def test_existing_site_below_threshold_uses_simple_mode(self):
"""A form for an existing site with fewer than THRESHOLD ASNs uses simple mode."""
site = Site.objects.create(name='Site 1', slug='site-1')
site.asns.set(self.asns[:5])
form = SiteForm(instance=site)
self.assertIn('asns', form.fields)
self.assertNotIn('add_asns', form.fields)
self.assertNotIn('remove_asns', form.fields)
def test_existing_site_at_threshold_uses_add_remove_mode(self):
"""A form for an existing site with THRESHOLD or more ASNs uses add/remove mode."""
site = Site.objects.create(name='Site 2', slug='site-2')
site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD])
form = SiteForm(instance=site)
self.assertNotIn('asns', form.fields)
self.assertIn('add_asns', form.fields)
self.assertIn('remove_asns', form.fields)
def test_simple_mode_assigns_asns_on_create(self):
"""Saving a new site via simple mode assigns the selected ASNs."""
asn_pks = [asn.pk for asn in self.asns[:3]]
form = SiteForm(data=self._site_data(asns=asn_pks))
self.assertTrue(form.is_valid(), form.errors)
site = form.save()
self.assertEqual(set(site.asns.values_list('pk', flat=True)), set(asn_pks))
def test_simple_mode_replaces_asns_on_edit(self):
"""Saving an existing site via simple mode replaces the current ASN assignments."""
site = Site.objects.create(name='Site 3', slug='site-3')
site.asns.set(self.asns[:3])
new_asn_pks = [asn.pk for asn in self.asns[3:6]]
form = SiteForm(
data=self._site_data(name='Site 3', slug='site-3', asns=new_asn_pks),
instance=site
)
self.assertTrue(form.is_valid(), form.errors)
site = form.save()
self.assertEqual(set(site.asns.values_list('pk', flat=True)), set(new_asn_pks))
def test_add_remove_mode_adds_asns(self):
"""In add/remove mode, specifying 'add_asns' appends to current assignments."""
site = Site.objects.create(name='Site 4', slug='site-4')
site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD])
new_asn_pks = [asn.pk for asn in self.asns[M2MAddRemoveFields.THRESHOLD:]]
form = SiteForm(
data=self._site_data(name='Site 4', slug='site-4', add_asns=new_asn_pks),
instance=site
)
self.assertTrue(form.is_valid(), form.errors)
site = form.save()
self.assertEqual(site.asns.count(), len(self.asns))
def test_add_remove_mode_removes_asns(self):
"""In add/remove mode, specifying 'remove_asns' drops those assignments."""
site = Site.objects.create(name='Site 5', slug='site-5')
site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD])
remove_pks = [asn.pk for asn in self.asns[:5]]
form = SiteForm(
data=self._site_data(name='Site 5', slug='site-5', remove_asns=remove_pks),
instance=site
)
self.assertTrue(form.is_valid(), form.errors)
site = form.save()
self.assertEqual(site.asns.count(), M2MAddRemoveFields.THRESHOLD - 5)
self.assertFalse(site.asns.filter(pk__in=remove_pks).exists())
def test_add_remove_mode_simultaneous_add_and_remove(self):
"""In add/remove mode, add and remove operations are applied together."""
site = Site.objects.create(name='Site 6', slug='site-6')
site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD])
add_pks = [asn.pk for asn in self.asns[M2MAddRemoveFields.THRESHOLD:M2MAddRemoveFields.THRESHOLD + 3]]
remove_pks = [asn.pk for asn in self.asns[:3]]
form = SiteForm(
data=self._site_data(name='Site 6', slug='site-6', add_asns=add_pks, remove_asns=remove_pks),
instance=site
)
self.assertTrue(form.is_valid(), form.errors)
site = form.save()
self.assertEqual(site.asns.count(), M2MAddRemoveFields.THRESHOLD)
self.assertTrue(site.asns.filter(pk__in=add_pks).count() == 3)
self.assertFalse(site.asns.filter(pk__in=remove_pks).exists())

View File

@@ -2,7 +2,6 @@ import json
from django import forms from django import forms
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db.models.fields.related import ManyToManyRel from django.db.models.fields.related import ManyToManyRel
from extras.choices import * from extras.choices import *
@@ -77,15 +76,17 @@ class NetBoxModelForm(
and add/remove (dual field) modes. and add/remove (dual field) modes.
""" """
self.instance._m2m_values = {} self.instance._m2m_values = {}
for field in self.instance._meta.get_fields():
# Determine the accessor name for this M2M relationship
if isinstance(field, models.ManyToManyField):
name = field.name
elif isinstance(field, ManyToManyRel):
name = field.get_accessor_name()
else:
continue
# Collect names to process: local M2M fields (includes TaggableManager from django-taggit)
# plus reverse M2M relations (ManyToManyRel).
names = [field.name for field in self.instance._meta.local_many_to_many]
names += [
field.get_accessor_name()
for field in self.instance._meta.get_fields()
if isinstance(field, ManyToManyRel)
]
for name in names:
if name in self.cleaned_data: if name in self.cleaned_data:
# Simple mode: single multi-select field # Simple mode: single multi-select field
self.instance._m2m_values[name] = list(self.cleaned_data[name]) self.instance._m2m_values[name] = list(self.cleaned_data[name])