Browse Source

feat(tests): generalize RRset assertions to lists of RRsets

Peter Thomassen 6 years ago
parent
commit
e82ebb7637
1 changed files with 25 additions and 22 deletions
  1. 25 22
      api/desecapi/tests/testrrsets.py

+ 25 - 22
api/desecapi/tests/testrrsets.py

@@ -108,24 +108,25 @@ class AuthenticatedRRSetTestCase(DomainOwnerTestCase):
                 )
                 )
 
 
     @staticmethod
     @staticmethod
-    def _filter_rr_sets(rr_sets, **kwargs):
-        return [
-            rr_set for rr_set in rr_sets
-            if reduce(operator.and_, [rr_set.get(key, None) == value for key, value in kwargs.items()])
-        ]
+    def _count_occurrences_by_mask(rr_sets, masks):
+        def _filter_rr_sets_by_mask(rr_sets_, mask):
+            return [rr_set for rr_set in rr_sets_
+                    if reduce(operator.and_, [rr_set.get(key, None) == value for key, value in mask.items()])
+            ]
+
+        return [len(_filter_rr_sets_by_mask(rr_sets, mask)) for mask in masks]
 
 
-    def assertRRSetCount(self, rr_sets, count, **kwargs):
-        filtered_rr_sets = self._filter_rr_sets(rr_sets, **kwargs)
-        if len(filtered_rr_sets) != count:
-            self.fail('Expected to find %i RR set(s) with %s, but only found %i in %s.' % (
-                count, kwargs, len(filtered_rr_sets), rr_sets
+    def assertRRSetsCount(self, rr_sets, masks, count=1):
+        counts = self._count_occurrences_by_mask(rr_sets, masks)
+        if not all([x == count for x in counts]):
+            self.fail('Expected to find %i RR set(s) for each of %s, but distribution is %s in %s.' % (
+                count, masks, counts, rr_sets
             ))
             ))
 
 
-    def assertContainsRRSet(self, rr_sets, **kwargs):
-        filtered_rr_sets = self._filter_rr_sets(rr_sets, **kwargs)
-        if not filtered_rr_sets:
-            self.fail('Expected to find RR set with %s, but only found %s.' % (
-                kwargs, rr_sets
+    def assertContainsRRSets(self, rr_sets_haystack, rr_sets_needle):
+        if not all(self._count_occurrences_by_mask(rr_sets_haystack, rr_sets_needle)):
+            self.fail('Expected to find RR sets with %s, but only got %s.' % (
+                rr_sets_needle, rr_sets_haystack
             ))
             ))
 
 
     def test_subname_validity(self):
     def test_subname_validity(self):
@@ -145,7 +146,7 @@ class AuthenticatedRRSetTestCase(DomainOwnerTestCase):
         ]:
         ]:
             self.assertStatus(response, status.HTTP_200_OK)
             self.assertStatus(response, status.HTTP_200_OK)
             self.assertEqual(len(response.data), 2, response.data)
             self.assertEqual(len(response.data), 2, response.data)
-            self.assertContainsRRSet(response.data, subname='', records=settings.DEFAULT_NS, type='NS')
+            self.assertContainsRRSets(response.data, [dict(subname='', records=settings.DEFAULT_NS, type='NS')])
 
 
     def test_retrieve_other_rr_sets(self):
     def test_retrieve_other_rr_sets(self):
         self.assertStatus(self.client.get_rr_sets(self.other_domain.name), status.HTTP_404_NOT_FOUND)
         self.assertStatus(self.client.get_rr_sets(self.other_domain.name), status.HTTP_404_NOT_FOUND)
@@ -160,13 +161,15 @@ class AuthenticatedRRSetTestCase(DomainOwnerTestCase):
         for subname in self.SUBNAMES:
         for subname in self.SUBNAMES:
             response = self.client.get_rr_sets(self.my_rr_set_domain.name, subname=subname)
             response = self.client.get_rr_sets(self.my_rr_set_domain.name, subname=subname)
             self.assertStatus(response, status.HTTP_200_OK)
             self.assertStatus(response, status.HTTP_200_OK)
-            self.assertRRSetCount(response.data, count=len(self._test_rr_sets(subname=subname)), subname=subname)
+            self.assertRRSetsCount(response.data, [dict(subname=subname)],
+                                   count=len(self._test_rr_sets(subname=subname)))
 
 
         for type_ in self.ALLOWED_TYPES:
         for type_ in self.ALLOWED_TYPES:
             response = self.client.get_rr_sets(self.my_rr_set_domain.name, type=type_)
             response = self.client.get_rr_sets(self.my_rr_set_domain.name, type=type_)
             self.assertStatus(response, status.HTTP_200_OK)
             self.assertStatus(response, status.HTTP_200_OK)
             if type_ != 'NS':  # count does not match for NS, that's okay
             if type_ != 'NS':  # count does not match for NS, that's okay
-                self.assertRRSetCount(response.data, count=len(self._test_rr_sets(type_=type_)), type=type_)
+                self.assertRRSetsCount(response.data, [dict(type=type_)],
+                                       count=len(self._test_rr_sets(type_=type_)))
 
 
     def test_create_my_rr_sets(self):
     def test_create_my_rr_sets(self):
         for subname in ['', 'create-my-rr-sets', 'foo.create-my-rr-sets', 'bar.baz.foo.create-my-rr-sets']:
         for subname in ['', 'create-my-rr-sets', 'foo.create-my-rr-sets', 'bar.baz.foo.create-my-rr-sets']:
@@ -180,7 +183,7 @@ class AuthenticatedRRSetTestCase(DomainOwnerTestCase):
 
 
                 response = self.client.get_rr_sets(self.my_empty_domain.name)
                 response = self.client.get_rr_sets(self.my_empty_domain.name)
                 self.assertStatus(response, status.HTTP_200_OK)
                 self.assertStatus(response, status.HTTP_200_OK)
-                self.assertRRSetCount(response.data, count=1, **data)
+                self.assertRRSetsCount(response.data, [data])
 
 
                 response = self.client.get_rr_set(self.my_empty_domain.name, data['subname'], data['type'])
                 response = self.client.get_rr_set(self.my_empty_domain.name, data['subname'], data['type'])
                 self.assertStatus(response, status.HTTP_200_OK)
                 self.assertStatus(response, status.HTTP_200_OK)
@@ -204,7 +207,7 @@ class AuthenticatedRRSetTestCase(DomainOwnerTestCase):
 
 
                 response = self.client.get_rr_sets(self.my_domain.name)
                 response = self.client.get_rr_sets(self.my_domain.name)
                 self.assertStatus(response, status.HTTP_200_OK)
                 self.assertStatus(response, status.HTTP_200_OK)
-                self.assertRRSetCount(response.data, count=0, **data)
+                self.assertRRSetsCount(response.data, [data], count=0)
 
 
     def test_create_my_rr_sets_without_records(self):
     def test_create_my_rr_sets_without_records(self):
         for subname in ['', 'create-my-rr-sets', 'foo.create-my-rr-sets', 'bar.baz.foo.create-my-rr-sets']:
         for subname in ['', 'create-my-rr-sets', 'foo.create-my-rr-sets', 'bar.baz.foo.create-my-rr-sets']:
@@ -217,7 +220,7 @@ class AuthenticatedRRSetTestCase(DomainOwnerTestCase):
 
 
                 response = self.client.get_rr_sets(self.my_empty_domain.name)
                 response = self.client.get_rr_sets(self.my_empty_domain.name)
                 self.assertStatus(response, status.HTTP_200_OK)
                 self.assertStatus(response, status.HTTP_200_OK)
-                self.assertRRSetCount(response.data, count=0, **data)
+                self.assertRRSetsCount(response.data, [data], count=0)
 
 
     def test_create_other_rr_sets(self):
     def test_create_other_rr_sets(self):
         data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A'}
         data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A'}
@@ -389,4 +392,4 @@ class AuthenticatedRRSetTestCase(DomainOwnerTestCase):
         ]:
         ]:
             self.assertStatus(response, status.HTTP_200_OK)
             self.assertStatus(response, status.HTTP_200_OK)
             self.assertEqual(len(response.data), 1, response.data)
             self.assertEqual(len(response.data), 1, response.data)
-            self.assertContainsRRSet(response.data, subname='', records=settings.DEFAULT_NS, type='NS')
+            self.assertContainsRRSets(response.data, [dict(subname='', records=settings.DEFAULT_NS, type='NS')])