Compare commits

...

2 Commits

Author SHA1 Message Date
Jeremy Stretch
3880f2d416 Fixes #21538: Fix annotated count for contacts assigned to multiple contact groups 2026-04-15 10:04:28 -04:00
Sergio López
660ca42149 Closes #21875: Allow subclasses of dict for API_TOKEN_PEPPERS 2026-04-14 16:59:49 -04:00
5 changed files with 103 additions and 29 deletions

View File

@@ -42,13 +42,7 @@ class TenantViewSet(NetBoxModelViewSet):
#
class ContactGroupViewSet(MPTTLockedMixin, NetBoxModelViewSet):
queryset = ContactGroup.objects.add_related_count(
ContactGroup.objects.all(),
Contact,
'groups',
'contact_count',
cumulative=True
)
queryset = ContactGroup.objects.annotate_contacts()
serializer_class = serializers.ContactGroupSerializer
filterset_class = filtersets.ContactGroupFilterSet

View File

@@ -1,12 +1,14 @@
from django.contrib.contenttypes.fields import GenericForeignKey
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.expressions import RawSQL
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from netbox.models import ChangeLoggedModel, NestedGroupModel, OrganizationalModel, PrimaryModel
from netbox.models.features import CustomFieldsMixin, ExportTemplatesMixin, TagsMixin, has_feature
from tenancy.choices import *
from utilities.mptt import TreeManager
__all__ = (
'Contact',
@@ -16,10 +18,34 @@ __all__ = (
)
class ContactGroupManager(TreeManager):
def annotate_contacts(self):
"""
Annotate the total number of Contacts belonging to each ContactGroup.
This returns both direct children and children of child groups. Raw SQL is used here to avoid double-counting
contacts which are assigned to multiple child groups of the parent.
"""
return self.annotate(
contact_count=RawSQL(
"SELECT COUNT(DISTINCT m2m.contact_id)"
" FROM tenancy_contact_groups m2m"
" INNER JOIN tenancy_contactgroup cg ON m2m.contactgroup_id = cg.id"
" WHERE cg.tree_id = tenancy_contactgroup.tree_id"
" AND cg.lft >= tenancy_contactgroup.lft"
" AND cg.lft <= tenancy_contactgroup.rght",
()
)
)
class ContactGroup(NestedGroupModel):
"""
An arbitrary collection of Contacts.
"""
objects = ContactGroupManager()
class Meta:
ordering = ['name']
# Empty tuple triggers Django migration detection for MPTT indexes

View File

@@ -0,0 +1,72 @@
from django.test import TestCase
from tenancy.models import Contact, ContactGroup
class ContactGroupTestCase(TestCase):
@classmethod
def setUpTestData(cls):
# Create a tree of contact groups:
# - Group A
# - Group A1
# - Group A2
# - Group B
cls.group_a = ContactGroup.objects.create(name='Group A', slug='group-a')
cls.group_a1 = ContactGroup.objects.create(name='Group A1', slug='group-a1', parent=cls.group_a)
cls.group_a2 = ContactGroup.objects.create(name='Group A2', slug='group-a2', parent=cls.group_a)
cls.group_b = ContactGroup.objects.create(name='Group B', slug='group-b')
# Create contacts
cls.contact1 = Contact.objects.create(name='Contact 1')
cls.contact2 = Contact.objects.create(name='Contact 2')
cls.contact3 = Contact.objects.create(name='Contact 3')
cls.contact4 = Contact.objects.create(name='Contact 4')
def test_annotate_contacts_direct(self):
"""Contacts assigned directly to a group should be counted."""
self.contact1.groups.set([self.group_a])
self.contact2.groups.set([self.group_a])
queryset = ContactGroup.objects.annotate_contacts()
self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 2)
def test_annotate_contacts_cumulative(self):
"""Contacts assigned to child groups should be included in the parent's count."""
self.contact1.groups.set([self.group_a1])
self.contact2.groups.set([self.group_a2])
queryset = ContactGroup.objects.annotate_contacts()
self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 2)
self.assertEqual(queryset.get(pk=self.group_a1.pk).contact_count, 1)
self.assertEqual(queryset.get(pk=self.group_a2.pk).contact_count, 1)
def test_annotate_contacts_no_double_counting(self):
"""A contact assigned to multiple child groups must be counted only once for the parent."""
self.contact1.groups.set([self.group_a1, self.group_a2])
queryset = ContactGroup.objects.annotate_contacts()
self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 1)
def test_annotate_contacts_mixed(self):
"""Test a mix of direct and inherited contacts with overlap."""
self.contact1.groups.set([self.group_a])
self.contact2.groups.set([self.group_a1])
self.contact3.groups.set([self.group_a1, self.group_a2])
self.contact4.groups.set([self.group_b])
queryset = ContactGroup.objects.annotate_contacts()
# Group A: contact1 (direct) + contact2 (via A1) + contact3 (via A1 & A2) = 3
self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 3)
# Group A1: contact2 + contact3 = 2
self.assertEqual(queryset.get(pk=self.group_a1.pk).contact_count, 2)
# Group A2: contact3 = 1
self.assertEqual(queryset.get(pk=self.group_a2.pk).contact_count, 1)
# Group B: contact4 = 1
self.assertEqual(queryset.get(pk=self.group_b.pk).contact_count, 1)
def test_annotate_contacts_empty(self):
"""Groups with no contacts should return a count of zero."""
queryset = ContactGroup.objects.annotate_contacts()
self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 0)
self.assertEqual(queryset.get(pk=self.group_b.pk).contact_count, 0)

View File

@@ -205,13 +205,7 @@ class TenantBulkDeleteView(generic.BulkDeleteView):
@register_model_view(ContactGroup, 'list', path='', detail=False)
class ContactGroupListView(generic.ObjectListView):
queryset = ContactGroup.objects.add_related_count(
ContactGroup.objects.all(),
Contact,
'groups',
'contact_count',
cumulative=True
)
queryset = ContactGroup.objects.annotate_contacts()
filterset = filtersets.ContactGroupFilterSet
filterset_form = forms.ContactGroupFilterForm
table = tables.ContactGroupTable
@@ -280,13 +274,7 @@ class ContactGroupBulkImportView(generic.BulkImportView):
@register_model_view(ContactGroup, 'bulk_edit', path='edit', detail=False)
class ContactGroupBulkEditView(generic.BulkEditView):
queryset = ContactGroup.objects.add_related_count(
ContactGroup.objects.all(),
Contact,
'groups',
'contact_count',
cumulative=True
)
queryset = ContactGroup.objects.annotate_contacts()
filterset = filtersets.ContactGroupFilterSet
table = tables.ContactGroupTable
form = forms.ContactGroupBulkEditForm
@@ -300,13 +288,7 @@ class ContactGroupBulkRenameView(generic.BulkRenameView):
@register_model_view(ContactGroup, 'bulk_delete', path='delete', detail=False)
class ContactGroupBulkDeleteView(generic.BulkDeleteView):
queryset = ContactGroup.objects.add_related_count(
ContactGroup.objects.all(),
Contact,
'groups',
'contact_count',
cumulative=True
)
queryset = ContactGroup.objects.annotate_contacts()
filterset = filtersets.ContactGroupFilterSet
table = tables.ContactGroupTable

View File

@@ -9,7 +9,7 @@ def validate_peppers(peppers):
"""
Validate the given dictionary of cryptographic peppers for type & sufficient length.
"""
if type(peppers) is not dict:
if not isinstance(peppers, dict):
raise ImproperlyConfigured("API_TOKEN_PEPPERS must be a dictionary.")
for key, pepper in peppers.items():
if type(key) is not int: