diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index f97c194ef..3cdfd7362 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -515,10 +515,15 @@ class APIViewTestCases: base_name = self.model._meta.verbose_name.lower().replace(' ', '_') return getattr(self, 'graphql_base_name', base_name) - def _build_query_with_filter(self, name, filter_string): + def _build_query_with_filter(self, name, filter_string, api_version='v2'): """ Called by either _build_query or _build_filtered_query - construct the actual query given a name and filter string + + Args: + name: The query field name (e.g., 'device_list') + filter_string: Filter parameters string (e.g., '(filters: {id: "1"})') + api_version: 'v1' or 'v2' to determine response format """ type_class = get_graphql_type_for_model(self.model) @@ -564,16 +569,26 @@ class APIViewTestCases: # Check if this is a list query (ends with '_list') if name.endswith('_list'): - # Wrap fields in 'results' for paginated queries - query = f""" - {{ - {name}{filter_string} {{ - results {{ + if api_version == 'v2': + # v2: Wrap fields in 'results' for paginated queries + query = f""" + {{ + {name}{filter_string} {{ + results {{ + {fields_string} + }} + }} + }} + """ + else: + # v1: Return direct array (no 'results' wrapper) + query = f""" + {{ + {name}{filter_string} {{ {fields_string} }} }} - }} - """ + """ else: # Single object query (no pagination) query = f""" @@ -586,9 +601,14 @@ class APIViewTestCases: return query - def _build_filtered_query(self, name, **filters): + def _build_filtered_query(self, name, api_version='v2', **filters): """ Create a filtered query: i.e. device_list(filters: {name: {i_contains: "akron"}}){. + + Args: + name: The query field name + api_version: 'v1' or 'v2' to determine response format + **filters: Filter parameters """ # TODO: This should be extended to support AND, OR multi-lookups if filters: @@ -604,11 +624,16 @@ class APIViewTestCases: else: filter_string = '' - return self._build_query_with_filter(name, filter_string) + return self._build_query_with_filter(name, filter_string, api_version) - def _build_query(self, name, **filters): + def _build_query(self, name, api_version='v2', **filters): """ Create a normal query - unfiltered or with a string query: i.e. site(name: "aaa"){. + + Args: + name: The query field name + api_version: 'v1' or 'v2' to determine response format + **filters: Filter parameters """ if filters: filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items()) @@ -616,7 +641,7 @@ class APIViewTestCases: else: filter_string = '' - return self._build_query_with_filter(name, filter_string) + return self._build_query_with_filter(name, filter_string, api_version) @override_settings(LOGIN_REQUIRED=True) def test_graphql_get_object(self): @@ -664,54 +689,71 @@ class APIViewTestCases: @override_settings(LOGIN_REQUIRED=True) def test_graphql_list_objects(self): - url = reverse('graphql_v2') field_name = f'{self._get_graphql_base_name()}_list' - query = self._build_query(field_name) - # Non-authenticated requests should fail - header = { - 'HTTP_ACCEPT': 'application/json', - } - with disable_warnings('django.request'): - response = self.client.post(url, data={'query': query}, format="json", **header) - self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) + # Test both GraphQL API versions + for api_version, url_name in [('v1', 'graphql_v1'), ('v2', 'graphql_v2')]: + with self.subTest(api_version=api_version): + url = reverse(url_name) + query = self._build_query(field_name, api_version=api_version) - # Add constrained permission - obj_perm = ObjectPermission( - name='Test permission', - actions=['view'], - constraints={'id': 0} # Impossible constraint - ) - obj_perm.save() - obj_perm.users.add(self.user) - obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + # Non-authenticated requests should fail + header = { + 'HTTP_ACCEPT': 'application/json', + } + with disable_warnings('django.request'): + response = self.client.post(url, data={'query': query}, format="json", **header) + self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) - # Request should succeed but return empty results list - response = self.client.post(url, data={'query': query}, format="json", **self.header) - self.assertHttpStatus(response, status.HTTP_200_OK) - data = json.loads(response.content) - self.assertNotIn('errors', data) - self.assertEqual(len(data['data'][field_name]['results']), 0) + # Add constrained permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['view'], + constraints={'id': 0} # Impossible constraint + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) - # Remove permission constraint - obj_perm.constraints = None - obj_perm.save() + # Request should succeed but return empty results list + response = self.client.post(url, data={'query': query}, format="json", **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertNotIn('errors', data) - # Request should return all objects - response = self.client.post(url, data={'query': query}, format="json", **self.header) - self.assertHttpStatus(response, status.HTTP_200_OK) - data = json.loads(response.content) - self.assertNotIn('errors', data) - self.assertEqual(len(data['data'][field_name]['results']), self.model.objects.count()) + if api_version == 'v1': + # v1 returns direct array + self.assertEqual(len(data['data'][field_name]), 0) + else: + # v2 returns paginated response with results + self.assertEqual(len(data['data'][field_name]['results']), 0) + + # Remove permission constraint + obj_perm.constraints = None + obj_perm.save() + + # Request should return all objects + response = self.client.post(url, data={'query': query}, format="json", **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertNotIn('errors', data) + + if api_version == 'v1': + # v1 returns direct array + self.assertEqual(len(data['data'][field_name]), self.model.objects.count()) + else: + # v2 returns paginated response with results + self.assertEqual(len(data['data'][field_name]['results']), self.model.objects.count()) + + # Clean up permission for next iteration + obj_perm.delete() @override_settings(LOGIN_REQUIRED=True) def test_graphql_filter_objects(self): if not hasattr(self, 'graphql_filter'): return - url = reverse('graphql_v2') field_name = f'{self._get_graphql_base_name()}_list' - query = self._build_filtered_query(field_name, **self.graphql_filter) # Add object-level permission obj_perm = ObjectPermission( @@ -722,11 +764,26 @@ class APIViewTestCases: obj_perm.users.add(self.user) obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) - response = self.client.post(url, data={'query': query}, format="json", **self.header) - self.assertHttpStatus(response, status.HTTP_200_OK) - data = json.loads(response.content) - self.assertNotIn('errors', data) - self.assertGreater(len(data['data'][field_name]['results']), 0) + # Test both GraphQL API versions + for api_version, url_name in [('v1', 'graphql_v1'), ('v2', 'graphql_v2')]: + with self.subTest(api_version=api_version): + url = reverse(url_name) + query = self._build_filtered_query(field_name, api_version=api_version, **self.graphql_filter) + + response = self.client.post(url, data={'query': query}, format="json", **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertNotIn('errors', data) + + if api_version == 'v1': + # v1 returns direct array + self.assertGreater(len(data['data'][field_name]), 0) + else: + # v2 returns paginated response with results + self.assertGreater(len(data['data'][field_name]['results']), 0) + + # Clean up permission + obj_perm.delete() class APIViewTestCase( GetObjectViewTestCase,