diff --git a/netbox/circuits/models/circuits.py b/netbox/circuits/models/circuits.py index dfdf86a3c..c4c6e0853 100644 --- a/netbox/circuits/models/circuits.py +++ b/netbox/circuits/models/circuits.py @@ -347,6 +347,13 @@ class CircuitTermination( verbose_name = _('circuit termination') verbose_name_plural = _('circuit terminations') + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Cache original values to detect changes + self._orig_circuit_id = self.__dict__.get('circuit_id') + self._orig_term_side = self.__dict__.get('term_side') + def __str__(self): return f'{self.circuit}: Termination {self.term_side}' @@ -360,11 +367,39 @@ class CircuitTermination( raise ValidationError(_("A circuit termination must attach to a terminating object.")) def save(self, *args, **kwargs): + is_new = self._state.adding + update_fields = kwargs.get('update_fields') + + # Only consider circuit/term_side changes if those fields + # are actually being persisted + if update_fields is not None: + tracking_relevant = 'circuit' in update_fields or 'term_side' in update_fields + else: + tracking_relevant = True + + circuit_changed = tracking_relevant and self._orig_circuit_id and self._orig_circuit_id != self.circuit_id + term_side_changed = tracking_relevant and self._orig_term_side and self._orig_term_side != self.term_side + # Cache objects associated with the terminating object (for filtering) self.cache_related_objects() super().save(*args, **kwargs) + # Clear the old termination reference if circuit or term_side changed + if circuit_changed or term_side_changed: + old_termination_name = f'termination_{self._orig_term_side.lower()}' + Circuit.objects.filter(pk=self._orig_circuit_id).update(**{old_termination_name: None}) + + # Update the cache if this is a new termination or circuit/term_side changed + if is_new or circuit_changed or term_side_changed: + # Update the new circuit's termination reference + termination_name = f'termination_{self.term_side.lower()}' + Circuit.objects.filter(pk=self.circuit_id).update(**{termination_name: self.pk}) + + # Update cached values for subsequent saves + self._orig_circuit_id = self.circuit_id + self._orig_term_side = self.term_side + def cache_related_objects(self): self._provider_network = self._region = self._site_group = self._site = self._location = None if self.termination_type: diff --git a/netbox/circuits/signals.py b/netbox/circuits/signals.py index 6405a380b..4db765c80 100644 --- a/netbox/circuits/signals.py +++ b/netbox/circuits/signals.py @@ -6,17 +6,6 @@ from dcim.signals import rebuild_paths from .models import CircuitTermination -@receiver(post_save, sender=CircuitTermination) -def update_circuit(instance, **kwargs): - """ - When a CircuitTermination has been modified, update its parent Circuit. - """ - termination_name = f'termination_{instance.term_side.lower()}' - instance.circuit.refresh_from_db() - setattr(instance.circuit, termination_name, instance) - instance.circuit.save() - - @receiver((post_save, post_delete), sender=CircuitTermination) def rebuild_cablepaths(instance, raw=False, **kwargs): """ diff --git a/netbox/circuits/tests/test_models.py b/netbox/circuits/tests/test_models.py new file mode 100644 index 000000000..f837cab65 --- /dev/null +++ b/netbox/circuits/tests/test_models.py @@ -0,0 +1,148 @@ +from django.test import TestCase + +from circuits.models import Circuit, CircuitTermination, CircuitType, Provider, ProviderNetwork +from dcim.models import Site + + +class CircuitTerminationTestCase(TestCase): + + @classmethod + def setUpTestData(cls): + provider = Provider.objects.create(name='Provider 1', slug='provider-1') + circuit_type = CircuitType.objects.create(name='Circuit Type 1', slug='circuit-type-1') + + cls.sites = ( + Site.objects.create(name='Site 1', slug='site-1'), + Site.objects.create(name='Site 2', slug='site-2'), + ) + + cls.circuits = ( + Circuit.objects.create(cid='Circuit 1', provider=provider, type=circuit_type), + Circuit.objects.create(cid='Circuit 2', provider=provider, type=circuit_type), + ) + + cls.provider_network = ProviderNetwork.objects.create(name='Provider Network 1', provider=provider) + + def test_circuit_termination_creation_populates_circuit_cache(self): + """ + When a CircuitTermination is created, the parent Circuit's termination_a or termination_z + cache field should be populated. + """ + # Create A termination + termination_a = CircuitTermination.objects.create( + circuit=self.circuits[0], + term_side='A', + termination=self.sites[0], + ) + self.circuits[0].refresh_from_db() + self.assertEqual(self.circuits[0].termination_a, termination_a) + self.assertIsNone(self.circuits[0].termination_z) + + # Create Z termination + termination_z = CircuitTermination.objects.create( + circuit=self.circuits[0], + term_side='Z', + termination=self.sites[1], + ) + self.circuits[0].refresh_from_db() + self.assertEqual(self.circuits[0].termination_a, termination_a) + self.assertEqual(self.circuits[0].termination_z, termination_z) + + def test_circuit_termination_circuit_change_clears_old_cache(self): + """ + When a CircuitTermination's circuit is changed, the old Circuit's cache should be cleared + and the new Circuit's cache should be populated. + """ + # Create termination on self.circuits[0] + termination = CircuitTermination.objects.create( + circuit=self.circuits[0], + term_side='A', + termination=self.sites[0], + ) + self.circuits[0].refresh_from_db() + self.assertEqual(self.circuits[0].termination_a, termination) + + # Move termination to self.circuits[1] + termination.circuit = self.circuits[1] + termination.save() + + self.circuits[0].refresh_from_db() + self.circuits[1].refresh_from_db() + + # Old circuit's cache should be cleared + self.assertIsNone(self.circuits[0].termination_a) + # New circuit's cache should be populated + self.assertEqual(self.circuits[1].termination_a, termination) + + def test_circuit_termination_term_side_change_clears_old_cache(self): + """ + When a CircuitTermination's term_side is changed, the old side's cache should be cleared + and the new side's cache should be populated. + """ + # Create A termination + termination = CircuitTermination.objects.create( + circuit=self.circuits[0], + term_side='A', + termination=self.sites[0], + ) + self.circuits[0].refresh_from_db() + self.assertEqual(self.circuits[0].termination_a, termination) + self.assertIsNone(self.circuits[0].termination_z) + + # Change from A to Z + termination.term_side = 'Z' + termination.save() + + self.circuits[0].refresh_from_db() + + # A side should be cleared, Z side should be populated + self.assertIsNone(self.circuits[0].termination_a) + self.assertEqual(self.circuits[0].termination_z, termination) + + def test_circuit_termination_circuit_and_term_side_change(self): + """ + When both circuit and term_side are changed, the old Circuit's old side cache should be + cleared and the new Circuit's new side cache should be populated. + """ + # Create A termination on self.circuits[0] + termination = CircuitTermination.objects.create( + circuit=self.circuits[0], + term_side='A', + termination=self.sites[0], + ) + self.circuits[0].refresh_from_db() + self.assertEqual(self.circuits[0].termination_a, termination) + + # Change to self.circuits[1] Z side + termination.circuit = self.circuits[1] + termination.term_side = 'Z' + termination.save() + + self.circuits[0].refresh_from_db() + self.circuits[1].refresh_from_db() + + # Old circuit's A side should be cleared + self.assertIsNone(self.circuits[0].termination_a) + self.assertIsNone(self.circuits[0].termination_z) + # New circuit's Z side should be populated + self.assertIsNone(self.circuits[1].termination_a) + self.assertEqual(self.circuits[1].termination_z, termination) + + def test_circuit_termination_deletion_clears_cache(self): + """ + When a CircuitTermination is deleted, the parent Circuit's cache should be cleared. + """ + termination = CircuitTermination.objects.create( + circuit=self.circuits[0], + term_side='A', + termination=self.sites[0], + ) + self.circuits[0].refresh_from_db() + self.assertEqual(self.circuits[0].termination_a, termination) + + # Delete the termination + termination.delete() + self.circuits[0].refresh_from_db() + + # Cache should be cleared (SET_NULL behavior) + self.assertIsNone(self.circuits[0].termination_a)