瀏覽代碼

fix(api): correct field requirements for rrset PATCH/PUT, fixes #91

Peter Thomassen 7 年之前
父節點
當前提交
8759c13211
共有 2 個文件被更改,包括 59 次插入18 次删除
  1. 14 4
      api/desecapi/serializers.py
  2. 45 14
      api/desecapi/tests/testrrsets.py

+ 14 - 4
api/desecapi/serializers.py

@@ -1,4 +1,5 @@
 from rest_framework import serializers
 from rest_framework import serializers
+from rest_framework.exceptions import ValidationError
 from desecapi.models import Domain, Donation, User, RR, RRset
 from desecapi.models import Domain, Donation, User, RR, RRset
 from djoser import serializers as djoserSerializers
 from djoser import serializers as djoserSerializers
 from django.db import transaction
 from django.db import transaction
@@ -22,8 +23,15 @@ class RRsetSerializer(serializers.ModelSerializer):
         fields = ('domain', 'subname', 'name', 'records', 'ttl', 'type',)
         fields = ('domain', 'subname', 'name', 'records', 'ttl', 'type',)
 
 
     def _set_records(self, instance):
     def _set_records(self, instance):
-        records_data = [{'content': x}
-                        for x in self.context['request'].data['records']]
+        # Although serializer fields have required=True by default, that
+        # setting does not work for the SerializerMethodField "records".
+        # Thus, let's wrap our read access to include the validation check.
+        records = self.context['request'].data.get('records')
+        if records is None:
+            raise ValidationError({'records': 'This field is required.'},
+                                  code='required')
+
+        records_data = [{'content': x} for x in records]
         rr_serializer = RRSerializer(data=records_data, many=True,
         rr_serializer = RRSerializer(data=records_data, many=True,
                                      allow_empty=False)
                                      allow_empty=False)
         if not rr_serializer.is_valid():
         if not rr_serializer.is_valid():
@@ -42,8 +50,10 @@ class RRsetSerializer(serializers.ModelSerializer):
     @transaction.atomic
     @transaction.atomic
     def update(self, instance, validated_data):
     def update(self, instance, validated_data):
         instance = super().update(instance, validated_data)
         instance = super().update(instance, validated_data)
-        instance.records.all().delete()
-        self._set_records(instance)
+        # Update records only if required (PUT) or provided (PATCH)
+        if not self.partial or 'records' in self.context['request'].data:
+            instance.records.all().delete()
+            self._set_records(instance)
         return instance
         return instance
 
 
     def get_records(self, obj):
     def get_records(self, obj):

+ 45 - 14
api/desecapi/tests/testrrsets.py

@@ -147,6 +147,10 @@ class AuthenticatedRRsetTests(APITestCase):
         response = self.client.post(url, json.dumps(data), content_type='application/json')
         response = self.client.post(url, json.dumps(data), content_type='application/json')
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
 
+        data = {'ttl': 60, 'type': 'A'}
+        response = self.client.post(url, json.dumps(data), content_type='application/json')
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
     def testCantPostRestrictedTypes(self):
     def testCantPostRestrictedTypes(self):
         for type_ in self.restricted_types:
         for type_ in self.restricted_types:
             url = reverse('rrsets', args=(self.ownedDomains[1].name,))
             url = reverse('rrsets', args=(self.ownedDomains[1].name,))
@@ -275,7 +279,7 @@ class AuthenticatedRRsetTests(APITestCase):
         self.assertEqual(response.status_code, status.HTTP_201_CREATED)
         self.assertEqual(response.status_code, status.HTTP_201_CREATED)
 
 
         url = reverse('rrset', args=(self.ownedDomains[1].name, '', 'A',))
         url = reverse('rrset', args=(self.ownedDomains[1].name, '', 'A',))
-        data = {'records': ['2.2.3.4'], 'ttl': 30, 'type': 'A'}
+        data = {'records': ['2.2.3.4'], 'ttl': 30}
         response = self.client.put(url, json.dumps(data), content_type='application/json')
         response = self.client.put(url, json.dumps(data), content_type='application/json')
         self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
 
@@ -284,12 +288,23 @@ class AuthenticatedRRsetTests(APITestCase):
         self.assertEqual(response.data['records'][0], '2.2.3.4')
         self.assertEqual(response.data['records'][0], '2.2.3.4')
         self.assertEqual(response.data['ttl'], 30)
         self.assertEqual(response.data['ttl'], 30)
 
 
+        url = reverse('rrset', args=(self.ownedDomains[1].name, '', 'A',))
+        data = {'records': ['3.2.3.4']}
+        response = self.client.put(url, json.dumps(data), content_type='application/json')
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
+        url = reverse('rrset', args=(self.ownedDomains[1].name, '', 'A',))
+        data = {'ttl': 37}
+        response = self.client.put(url, json.dumps(data), content_type='application/json')
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
     def testCanPatchOwnRRset(self):
     def testCanPatchOwnRRset(self):
         url = reverse('rrsets', args=(self.ownedDomains[1].name,))
         url = reverse('rrsets', args=(self.ownedDomains[1].name,))
         data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A'}
         data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A'}
         response = self.client.post(url, json.dumps(data), content_type='application/json')
         response = self.client.post(url, json.dumps(data), content_type='application/json')
         self.assertEqual(response.status_code, status.HTTP_201_CREATED)
         self.assertEqual(response.status_code, status.HTTP_201_CREATED)
 
 
+        # Change records and TTL
         url = reverse('rrset', args=(self.ownedDomains[1].name, '', 'A',))
         url = reverse('rrset', args=(self.ownedDomains[1].name, '', 'A',))
         data = {'records': ['3.2.3.4'], 'ttl': 32}
         data = {'records': ['3.2.3.4'], 'ttl': 32}
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
@@ -300,20 +315,28 @@ class AuthenticatedRRsetTests(APITestCase):
         self.assertEqual(response.data['records'][0], '3.2.3.4')
         self.assertEqual(response.data['records'][0], '3.2.3.4')
         self.assertEqual(response.data['ttl'], 32)
         self.assertEqual(response.data['ttl'], 32)
 
 
-    def testCantPatchOForeignRRset(self):
-        self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.otherToken)
-        url = reverse('rrsets', args=(self.otherDomains[0].name,))
-        data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A'}
-        response = self.client.post(url, json.dumps(data), content_type='application/json')
-        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        # Change records alone
+        data = {'records': ['5.2.3.4']}
+        response = self.client.patch(url, json.dumps(data), content_type='application/json')
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(response.data['records'][0], '5.2.3.4')
+        self.assertEqual(response.data['ttl'], 32)
 
 
-        self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.token)
-        url = reverse('rrset', args=(self.otherDomains[0].name, '', 'A',))
-        data = {'records': ['3.2.3.4'], 'ttl': 32}
+        # Change TTL alone
+        data = {'ttl': 37}
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
-        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(response.data['records'][0], '5.2.3.4')
+        self.assertEqual(response.data['ttl'], 37)
+
+        # Change nothing
+        data = {}
+        response = self.client.patch(url, json.dumps(data), content_type='application/json')
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(response.data['records'][0], '5.2.3.4')
+        self.assertEqual(response.data['ttl'], 37)
 
 
-    def testCantPutForeignRRset(self):
+    def testCantChangeForeignRRset(self):
         self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.otherToken)
         self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.otherToken)
         url = reverse('rrsets', args=(self.otherDomains[0].name,))
         url = reverse('rrsets', args=(self.otherDomains[0].name,))
         data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A'}
         data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A'}
@@ -323,25 +346,33 @@ class AuthenticatedRRsetTests(APITestCase):
         self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.token)
         self.client.credentials(HTTP_AUTHORIZATION='Token ' + self.token)
         url = reverse('rrset', args=(self.otherDomains[0].name, '', 'A',))
         url = reverse('rrset', args=(self.otherDomains[0].name, '', 'A',))
         data = {'records': ['3.2.3.4'], 'ttl': 30, 'type': 'A'}
         data = {'records': ['3.2.3.4'], 'ttl': 30, 'type': 'A'}
+
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
         self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
 
 
+        response = self.client.put(url, json.dumps(data), content_type='application/json')
+        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
     def testCantChangeEssentialProperties(self):
     def testCantChangeEssentialProperties(self):
         url = reverse('rrsets', args=(self.ownedDomains[1].name,))
         url = reverse('rrsets', args=(self.ownedDomains[1].name,))
         data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A', 'subname': 'test1'}
         data = {'records': ['1.2.3.4'], 'ttl': 60, 'type': 'A', 'subname': 'test1'}
         response = self.client.post(url, json.dumps(data), content_type='application/json')
         response = self.client.post(url, json.dumps(data), content_type='application/json')
         self.assertEqual(response.status_code, status.HTTP_201_CREATED)
         self.assertEqual(response.status_code, status.HTTP_201_CREATED)
 
 
-        # Changing the type is expected to cause an error
+        # Changing the subname is expected to cause an error
         url = reverse('rrset', args=(self.ownedDomains[1].name, 'test1', 'A',))
         url = reverse('rrset', args=(self.ownedDomains[1].name, 'test1', 'A',))
         data = {'records': ['3.2.3.4'], 'ttl': 120, 'subname': 'test2'}
         data = {'records': ['3.2.3.4'], 'ttl': 120, 'subname': 'test2'}
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
         self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
         self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
+        response = self.client.put(url, json.dumps(data), content_type='application/json')
+        self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
 
 
-        # Changing the subname is expected to cause an error
+        # Changing the type is expected to cause an error
         data = {'records': ['3.2.3.4'], 'ttl': 120, 'type': 'TXT'}
         data = {'records': ['3.2.3.4'], 'ttl': 120, 'type': 'TXT'}
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
         response = self.client.patch(url, json.dumps(data), content_type='application/json')
         self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
         self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
+        response = self.client.put(url, json.dumps(data), content_type='application/json')
+        self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
 
 
         # Check that nothing changed
         # Check that nothing changed
         response = self.client.get(url)
         response = self.client.get(url)