瀏覽代碼

fix(api): validate CNAME exclusivity

Peter Thomassen 4 年之前
父節點
當前提交
29e9623df7

+ 21 - 0
api/desecapi/migrations/0006_cname_exclusivity.py

@@ -0,0 +1,21 @@
+# Generated by Django 3.1 on 2020-09-18 16:09
+
+import django.contrib.postgres.constraints
+from django.contrib.postgres.operations import BtreeGistExtension
+from django.db import migrations, models
+import django.db.models.expressions
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('desecapi', '0005_subname_validation'),
+    ]
+
+    operations = [
+        BtreeGistExtension(),
+        migrations.AddConstraint(
+            model_name='rrset',
+            constraint=django.contrib.postgres.constraints.ExclusionConstraint(expressions=[('domain', '='), ('subname', '='), (django.db.models.expressions.RawSQL("int4(type = 'CNAME')", ()), '<>')], name='cname_exclusivity'),
+        ),
+    ]

+ 13 - 0
api/desecapi/models.py

@@ -18,11 +18,14 @@ import rest_framework.authtoken.models
 from django.conf import settings
 from django.contrib.auth.hashers import make_password
 from django.contrib.auth.models import BaseUserManager, AbstractBaseUser
+from django.contrib.postgres.constraints import ExclusionConstraint
+from django.contrib.postgres.fields import RangeOperators
 from django.core.exceptions import ValidationError
 from django.core.mail import EmailMessage, get_connection
 from django.core.validators import RegexValidator
 from django.db import models
 from django.db.models import Manager, Q
+from django.db.models.expressions import RawSQL
 from django.template.loader import get_template
 from django.utils import timezone
 from django_prometheus.models import ExportModelOperationsMixin
@@ -490,6 +493,16 @@ class RRset(ExportModelOperationsMixin('RRset'), models.Model):
     objects = RRsetManager()
 
     class Meta:
+        constraints = [
+            ExclusionConstraint(
+                name='cname_exclusivity',
+                expressions=[
+                    ('domain', RangeOperators.EQUAL),
+                    ('subname', RangeOperators.EQUAL),
+                    (RawSQL("int4(type = 'CNAME')", ()), RangeOperators.NOT_EQUAL),
+                ],
+            ),
+        ]
         unique_together = (("domain", "subname", "type"),)
 
     @staticmethod

+ 67 - 12
api/desecapi/serializers.py

@@ -1,4 +1,5 @@
 import binascii
+import copy
 import json
 import re
 from base64 import urlsafe_b64decode, urlsafe_b64encode, b64encode
@@ -15,7 +16,7 @@ from rest_framework.settings import api_settings
 from rest_framework.validators import UniqueTogetherValidator, UniqueValidator, qs_filter
 
 from api import settings
-from desecapi import crypto, metrics, models
+from desecapi import crypto, metrics, models, validators
 
 
 class CaptchaSerializer(serializers.ModelSerializer):
@@ -237,10 +238,20 @@ class RRsetSerializer(ConditionalExistenceModelSerializer):
         return fields
 
     def get_validators(self):
-        return [UniqueTogetherValidator(
-            self.domain.rrset_set, ('subname', 'type'),
-            message='Another RRset with the same subdomain and type exists for this domain.'
-        )]
+        return [
+            UniqueTogetherValidator(
+                self.domain.rrset_set,
+                ('subname', 'type'),
+                message='Another RRset with the same subdomain and type exists for this domain.',
+            ),
+            validators.ExclusionConstraintValidator(
+                self.domain.rrset_set,
+                ('subname',),
+                exclusion_condition=('type', 'CNAME',),
+                message='RRset with conflicting type present: database ({types}).'
+                        ' (No other RRsets are allowed alongside CNAME.)',
+            ),
+        ]
 
     @staticmethod
     def validate_type(value):
@@ -338,7 +349,22 @@ class RRsetListSerializer(serializers.ListSerializer):
 
     @staticmethod
     def _key(data_item):
-        return data_item.get('subname', None), data_item.get('type', None)
+        return data_item.get('subname'), data_item.get('type')
+
+    @staticmethod
+    def _types_by_position_string(conflicting_indices_by_type):
+        types_by_position = {}
+        for type_, conflict_positions in conflicting_indices_by_type.items():
+            for position in conflict_positions:
+                types_by_position.setdefault(position, []).append(type_)
+        # Sort by position, None at the end
+        types_by_position = dict(sorted(types_by_position.items(), key=lambda x: (x[0] is None, x)))
+        db_conflicts = types_by_position.pop(None, None)
+        if db_conflicts: types_by_position['database'] = db_conflicts
+        for position, types in types_by_position.items():
+            types_by_position[position] = ', '.join(sorted(types))
+        types_by_position = [f'{position} ({types})' for position, types in types_by_position.items()]
+        return ', '.join(types_by_position)
 
     def to_internal_value(self, data):
         if not isinstance(data, list):
@@ -366,22 +392,48 @@ class RRsetListSerializer(serializers.ListSerializer):
             if not isinstance(item, dict):
                 self.fail('invalid', datatype=type(item).__name__)
             s, t = self._key(item)  # subname, type
-            items = indices.setdefault(s, {}).setdefault(t, set())
+            # Construct an index of the RRsets in `data` by `s` and `t`. As (subname, type) may be given multiple times
+            # (although invalid), we make indices[s][t] a set to properly keep track. We also check and record RRsets
+            # which are known in the database (once per subname), using index `None` (for checking CNAME exclusivity).
+            if s not in indices:
+                types = self.child.domain.rrset_set.filter(subname=s).values_list('type', flat=True)
+                indices[s] = {type_: {None} for type_ in types}
+            items = indices[s].setdefault(t, set())
             items.add(idx)
 
+        collapsed_indices = copy.deepcopy(indices)
+        for idx, item in enumerate(data):
+            if item.get('records') == []:
+                s, t = self._key(item)
+                collapsed_indices[s][t] -= {idx, None}
+
         # Iterate over all rows in the data given
         for idx, item in enumerate(data):
             try:
                 # see if other rows have the same key
                 s, t = self._key(item)
-                if len(indices[s][t]) > 1:
+                data_indices = indices[s][t] - {None}
+                if len(data_indices) > 1:
                     raise serializers.ValidationError({
                         'non_field_errors': [
                             'Same subname and type as in position(s) %s, but must be unique.' %
-                            ', '.join(map(str, indices[s][t] - {idx}))
+                            ', '.join(map(str, data_indices - {idx}))
                         ]
                     })
 
+                # see if other rows violate CNAME exclusivity
+                if item.get('records') != []:
+                    conflicting_indices_by_type = {k: v for k, v in collapsed_indices[s].items()
+                                                   if (k == 'CNAME') != (t == 'CNAME')}
+                    if any(conflicting_indices_by_type.values()):
+                        types_by_position = self._types_by_position_string(conflicting_indices_by_type)
+                        raise serializers.ValidationError({
+                            'non_field_errors': [
+                                f'RRset with conflicting type present: {types_by_position}.'
+                                ' (No other RRsets are allowed alongside CNAME.)'
+                            ]
+                        })
+
                 # determine if this is a partial update (i.e. PATCH):
                 # we allow partial update if a partial update method (i.e. PATCH) is used, as indicated by self.partial,
                 # and if this is not actually a create request because it is unknown and nonempty
@@ -466,6 +518,12 @@ class RRsetListSerializer(serializers.ListSerializer):
 
         ret = []
 
+        # The above algorithm makes sure that created, updated, and deleted are disjoint. Thus, no "override cases"
+        # (such as: an RRset should be updated and delete, what should be applied last?) need to be considered.
+        # We apply deletion first to get any possible CNAME exclusivity collisions out of the way.
+        for subname, type_ in deleted:
+            instance_index[(subname, type_)].delete()
+
         for subname, type_ in created:
             ret.append(self.child.create(
                 validated_data=data_index[(subname, type_)]
@@ -477,9 +535,6 @@ class RRsetListSerializer(serializers.ListSerializer):
                 validated_data=data_index[(subname, type_)]
             ))
 
-        for subname, type_ in deleted:
-            instance_index[(subname, type_)].delete()
-
         return ret
 
     def save(self, **kwargs):

+ 16 - 1
api/desecapi/tests/test_rrsets.py

@@ -135,6 +135,7 @@ class AuthenticatedRRSetTestCase(AuthenticatedRRSetBaseTestCase):
                 {'subname': subname, 'records': ['1.2.3.4'], 'ttl': 3660, 'type': 'A'},
                 {'subname': '' if subname is None else subname, 'records': ['desec.io.'], 'ttl': 36900, 'type': 'PTR'},
                 {'subname': '' if subname is None else subname, 'ttl': 3650, 'type': 'TXT', 'records': ['"foo"']},
+                {'subname': f'{subname}.cname'.lower(), 'ttl': 3600, 'type': 'CNAME', 'records': ['example.com.']},
             ]:
                 # Try POST with missing subname
                 if data['subname'] is None:
@@ -142,11 +143,11 @@ class AuthenticatedRRSetTestCase(AuthenticatedRRSetBaseTestCase):
 
                 with self.assertPdnsRequests(self.requests_desec_rr_sets_update(name=self.my_empty_domain.name)):
                     response = self.client.post_rr_set(domain_name=self.my_empty_domain.name, **data)
+                    self.assertStatus(response, status.HTTP_201_CREATED)
                     self.assertTrue(all(field in response.data for field in
                                         ['created', 'domain', 'subname', 'name', 'records', 'ttl', 'type', 'touched']))
                     self.assertEqual(self.my_empty_domain.touched,
                                      max(rrset.touched for rrset in self.my_empty_domain.rrset_set.all()))
-                    self.assertStatus(response, status.HTTP_201_CREATED)
 
                 # Check for uniqueness on second attempt
                 response = self.client.post_rr_set(domain_name=self.my_empty_domain.name, **data)
@@ -186,6 +187,20 @@ class AuthenticatedRRSetTestCase(AuthenticatedRRSetBaseTestCase):
         response = self.client.post_rr_set(self.my_empty_domain.name, **data)
         self.assertContains(response, 'CNAME RRset cannot have empty subname', status_code=status.HTTP_400_BAD_REQUEST)
 
+    def test_create_my_rr_sets_cname_exclusivity(self):
+        self.create_rr_set(self.my_domain, ['1.2.3.4'], type='A', ttl=3600, subname='a')
+        self.create_rr_set(self.my_domain, ['example.com.'], type='CNAME', ttl=3600, subname='cname')
+
+        # Can't add a CNAME where something else is
+        data = {'subname': 'a', 'ttl': 3600, 'type': 'CNAME', 'records': ['foobar.com.']}
+        response = self.client.post_rr_set(self.my_domain.name, **data)
+        self.assertStatus(response, status.HTTP_400_BAD_REQUEST)
+
+        # Can't add something else where a CNAME is
+        data = {'subname': 'cname', 'ttl': 3600, 'type': 'A', 'records': ['4.3.2.1']}
+        response = self.client.post_rr_set(self.my_domain.name, **data)
+        self.assertStatus(response, status.HTTP_400_BAD_REQUEST)
+
     def test_create_my_rr_sets_without_records(self):
         for subname in ['', 'create-my-rr-sets', 'foo.create-my-rr-sets', 'bar.baz.foo.create-my-rr-sets']:
             for data in [

+ 47 - 2
api/desecapi/tests/test_rrsets_bulk.py

@@ -13,7 +13,7 @@ class AuthenticatedRRSetBulkTestCase(AuthenticatedRRSetBaseTestCase):
         super().setUpTestDataWithPdns()
 
         cls.data = [
-            {'subname': 'my-bulk', 'records': ['1.2.3.4'], 'ttl': 3600, 'type': 'A'},
+            {'subname': 'my-cname', 'records': ['example.com.'], 'ttl': 3600, 'type': 'CNAME'},
             {'subname': 'my-bulk', 'records': ['desec.io.', 'foobar.example.'], 'ttl': 3600, 'type': 'PTR'},
         ]
 
@@ -142,6 +142,20 @@ class AuthenticatedRRSetBulkTestCase(AuthenticatedRRSetBaseTestCase):
             ]
         )
 
+    def test_bulk_patch_cname_exclusivity(self):
+        response = self.client.bulk_patch_rr_sets(
+            domain_name=self.my_rr_set_domain.name,
+            payload=[
+                {'subname': 'test', 'type': 'A', 'ttl': 3600, 'records': ['1.2.3.4']},
+                {'subname': 'test', 'type': 'CNAME', 'ttl': 3600, 'records': ['example.com.']},
+            ]
+        )
+        self.assertResponse(response, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.json(), [
+            {"non_field_errors":["RRset with conflicting type present: 1 (CNAME). (No other RRsets are allowed alongside CNAME.)"]},
+            {"non_field_errors":["RRset with conflicting type present: 0 (A), database (A, TXT). (No other RRsets are allowed alongside CNAME.)"]},
+        ])
+
     def test_bulk_post_accepts_empty_list(self):
         self.assertResponse(
             self.client.bulk_post_rr_sets(domain_name=self.my_empty_domain.name, payload=[]),
@@ -186,6 +200,37 @@ class AuthenticatedRRSetBulkTestCase(AuthenticatedRRSetBaseTestCase):
         response = self.client.bulk_patch_rr_sets(domain_name=self.my_empty_domain.name, payload=None)
         self.assertContains(response, 'No data provided', status_code=status.HTTP_400_BAD_REQUEST)
 
+    def test_bulk_patch_cname_exclusivity_atomic_rrset_replacement(self):
+        self.create_rr_set(self.my_empty_domain, subname='test', type='A', records=['1.2.3.4'], ttl=3600)
+
+        with self.assertPdnsRequests(self.requests_desec_rr_sets_update(self.my_empty_domain.name)):
+            response = self.client.bulk_patch_rr_sets(
+                domain_name=self.my_empty_domain.name,
+                payload=[
+                    {'subname': 'test', 'type': 'CNAME', 'ttl': 3605, 'records': ['example.com.']},
+                    {'subname': 'test', 'type': 'A', 'records': []},
+                ]
+            )
+            self.assertResponse(response, status.HTTP_200_OK)
+            self.assertEqual(len(response.data), 1)
+            self.assertEqual(response.data[0]['type'], 'CNAME')
+            self.assertEqual(response.data[0]['records'], ['example.com.'])
+            self.assertEqual(response.data[0]['ttl'], 3605)
+
+        with self.assertPdnsRequests(self.requests_desec_rr_sets_update(self.my_empty_domain.name)):
+            response = self.client.bulk_patch_rr_sets(
+                domain_name=self.my_empty_domain.name,
+                payload=[
+                    {'subname': 'test', 'type': 'CNAME', 'records': []},
+                    {'subname': 'test', 'type': 'A', 'ttl': 3600, 'records': ['5.4.2.1']},
+                ]
+            )
+            self.assertResponse(response, status.HTTP_200_OK)
+            self.assertEqual(len(response.data), 1)
+            self.assertEqual(response.data[0]['type'], 'A')
+            self.assertEqual(response.data[0]['records'], ['5.4.2.1'])
+            self.assertEqual(response.data[0]['ttl'], 3600)
+
     def test_bulk_patch_full_on_empty_domain(self):
         # Full patch always works
         with self.assertPdnsRequests(self.requests_desec_rr_sets_update(name=self.my_empty_domain.name)):
@@ -199,7 +244,7 @@ class AuthenticatedRRSetBulkTestCase(AuthenticatedRRSetBaseTestCase):
 
     def test_bulk_patch_change_records(self):
         data_no_ttl = copy.deepcopy(self.data_no_ttl)
-        data_no_ttl[0]['records'] = ['4.3.2.1', '8.8.1.2']
+        data_no_ttl[0]['records'] = ['example.org.']
         with self.assertPdnsRequests(self.requests_desec_rr_sets_update(name=self.bulk_domain.name)):
             response = self.client.bulk_patch_rr_sets(domain_name=self.bulk_domain.name, payload=data_no_ttl)
             self.assertStatus(response, status.HTTP_200_OK)

+ 56 - 0
api/desecapi/validators.py

@@ -0,0 +1,56 @@
+from django.db import DataError
+from rest_framework.exceptions import ValidationError
+from rest_framework.validators import qs_exists, qs_filter, UniqueTogetherValidator
+
+
+def qs_exclude(queryset, **kwargs):
+    try:
+        return queryset.exclude(**kwargs)
+    except (TypeError, ValueError, DataError):
+        return queryset.none()
+
+
+class ExclusionConstraintValidator(UniqueTogetherValidator):
+    """
+    Validator that implements ExclusionConstraints, currently very basic with support for one field only.
+    Should be applied to the serializer class, not to an individual field.
+    No-op if parent serializer is a list serializer (many=True). We expect the list serializer to assure exclusivity.
+    """
+    message = 'This field violates an exclusion constraint.'
+
+    def __init__(self, queryset, fields, exclusion_condition, message=None):
+        super().__init__(queryset, fields, message)
+        self.exclusion_condition = exclusion_condition
+
+    def filter_queryset(self, attrs, queryset, serializer):
+        qs = super().filter_queryset(attrs, queryset, serializer)
+
+        # Determine the exclusion filters and prepare the queryset.
+        field_name = self.exclusion_condition[0]
+        value = self.exclusion_condition[1]
+        source = serializer.fields[field_name].source
+        if serializer.instance is not None:
+            if source not in attrs:
+                attrs[source] = getattr(serializer.instance, source)
+        exclusion_method = qs_exclude if attrs[source] == value else qs_filter
+        return exclusion_method(qs, **{field_name: value})
+
+    def __call__(self, attrs, serializer, *args, **kwargs):
+        # Ignore validation if the many flag is set
+        if getattr(serializer.root, 'many', False):
+            return
+
+        self.enforce_required_fields(attrs, serializer)
+        queryset = self.queryset
+        queryset = self.filter_queryset(attrs, queryset, serializer)
+        queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
+
+        # Ignore validation if any field is None
+        checked_values = [
+            value for field, value in attrs.items() if field in self.fields
+        ]
+        if None not in checked_values and qs_exists(queryset):
+            types = queryset.values_list('type', flat=True)
+            types = ', '.join(types)
+            message = self.message.format(types=types)
+            raise ValidationError(message, code='exclusive')