Parcourir la source

feat(api): tests: check zone update body

Nils Wisiol il y a 6 ans
Parent
commit
05eb8f2d2a
1 fichiers modifiés avec 84 ajouts et 0 suppressions
  1. 84 0
      api/desecapi/tests/base.py

+ 84 - 0
api/desecapi/tests/base.py

@@ -5,8 +5,10 @@ import re
 import string
 from contextlib import nullcontext
 from functools import partial, reduce
+from typing import Union, List, Dict
 from unittest import mock
 
+from django.db import connection
 from django.utils import timezone
 from httpretty import httpretty, core as hr_core
 from rest_framework.reverse import reverse
@@ -101,6 +103,28 @@ class DesecAPIClient(APIClient):
     # TODO add and use {post,get,delete,...}_domain
 
 
+class ReadUncommitted:
+
+    def __init__(self):
+        self.read_uncommitted = None
+
+    def __enter__(self):
+        with connection.cursor() as cursor:
+            cursor.execute('PRAGMA read_uncommitted;')  # FIXME this is probably sqlite only?
+            self.read_uncommitted = True if cursor.fetchone()[0] else False
+            cursor.execute('PRAGMA read_uncommitted = true;')
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self.read_uncommitted is None:
+            return
+
+        with connection.cursor() as cursor:
+            if self.read_uncommitted:
+                cursor.execute('PRAGMA read_uncommitted = true;')
+            else:
+                cursor.execute('PRAGMA read_uncommitted = false;')
+
+
 class AssertRequestsContextManager:
     """
     Checks that in its context, certain expected requests are made.
@@ -290,6 +314,7 @@ class MockPDNSTestCase(APITestCase):
                 return [200, response_headers, '']
 
         request = cls.request_pdns_zone_create_422()
+        # noinspection PyTypeChecker
         request['body'] = request_callback
         request.pop('status')
         return request
@@ -325,10 +350,69 @@ class MockPDNSTestCase(APITestCase):
                 return [200, response_headers, None]
 
         request = cls.request_pdns_zone_update(name)
+        # noinspection PyTypeChecker
         request['body'] = request_callback
         request.pop('status')
         return request
 
+    def request_pdns_zone_update_assert_body(self, name: str = None, updated_rr_sets: Union[List[RRset], Dict] = None):
+        if updated_rr_sets is None:
+            updated_rr_sets = []
+
+        def request_callback(r, _, response_headers):
+            if not updated_rr_sets:
+                # nothing to assert
+                return [200, response_headers, '']
+
+            body = json.loads(r.parsed_body)
+            self.failIf('rrsets' not in body.keys(),
+                        'pdns zone update request malformed: did not contain a list of RR sets.')
+
+            with ReadUncommitted():  # tests are wrapped in uncommitted transactions, so we need to see inside
+                # convert updated_rr_sets into a plain data type, if Django models were given
+                if isinstance(updated_rr_sets, list):
+                    updated_rr_sets_dict = {}
+                    for rr_set in updated_rr_sets:
+                        updated_rr_sets_dict[(rr_set.type, rr_set.subname, rr_set.ttl)] = rrs = []
+                        for rr in rr_set.records.all():
+                            rrs.append(rr.content)
+                elif isinstance(updated_rr_sets, dict):
+                    updated_rr_sets_dict = updated_rr_sets
+                else:
+                    raise ValueError('updated_rr_sets must be a list of RRSets or a dict.')
+
+                # check expectations
+                self.assertEqual(len(updated_rr_sets_dict), len(body['rrsets']),
+                                 'Saw an unexpected number of RR set updates: expected %i, intercepted %i.' %
+                                 (len(updated_rr_sets_dict), len(body['rrsets'])))
+                for (expected_type, expected_subname, expected_ttl), expected_records in updated_rr_sets_dict.items():
+                    expected_name = '.'.join(filter(None, [expected_subname, name])) + '.'
+                    for seen_rr_set in body['rrsets']:
+                        if (expected_name == seen_rr_set['name'] and
+                                expected_type == seen_rr_set['type']):
+                            # TODO replace the following asserts by assertTTL, assertRecords, ... or similar
+                            if len(expected_records):
+                                self.assertEqual(expected_ttl, seen_rr_set['ttl'])
+                            self.assertEqual(
+                                set(expected_records),
+                                set([rr['content'] for rr in seen_rr_set['records']]),
+                            )
+                            break
+                    else:
+                        # we did not break out, i.e. we did not find a matching RR set in body['rrsets']
+                        self.fail('Expected to see an pdns zone update request for RR set of domain `%s` with name '
+                                  '`%s` and type `%s`, but did not see one. Seen update request on %s for RR sets:'
+                                  '\n\n%s'
+                                  % (name, expected_name, expected_type, request['uri'],
+                                     json.dumps(body['rrsets'], indent=4)))
+            return [200, response_headers, '']
+
+        request = self.request_pdns_zone_update(name)
+        request.pop('status')
+        # noinspection PyTypeChecker
+        request['body'] = request_callback
+        return request
+
     @classmethod
     def request_pdns_zone_retrieve(cls, name=None):
         return {