瀏覽代碼

feat(api): use ModelSerializer for RR validation, fix test

Peter Thomassen 5 年之前
父節點
當前提交
3653ccc006
共有 2 個文件被更改,包括 24 次插入32 次删除
  1. 20 25
      api/desecapi/serializers.py
  2. 4 7
      api/desecapi/tests/test_rrsets_bulk.py

+ 20 - 25
api/desecapi/serializers.py

@@ -8,7 +8,6 @@ from django.core.validators import MinValueValidator
 from django.db import IntegrityError, OperationalError
 from django.db.models import Model, Q
 from rest_framework import serializers
-from rest_framework.exceptions import ValidationError
 from rest_framework.settings import api_settings
 from rest_framework.validators import UniqueTogetherValidator, UniqueValidator, qs_filter
 
@@ -108,27 +107,6 @@ class ReadOnlyOnUpdateValidator(Validator):
             raise serializers.ValidationError(self.message, code='read-only-on-update')
 
 
-class StringField(serializers.CharField):
-
-    def to_internal_value(self, data):
-        return data
-
-    def run_validation(self, data=serializers.empty):
-        data = super().run_validation(data)
-        if not isinstance(data, str):
-            raise serializers.ValidationError('Must be a string.', code='must-be-a-string')
-        return data
-
-
-class RRsField(serializers.ListField):
-
-    def __init__(self, **kwargs):
-        super().__init__(child=StringField(), **kwargs)
-
-    def to_representation(self, data):
-        return [rr.content for rr in data.all()]
-
-
 class ConditionalExistenceModelSerializer(serializers.ModelSerializer):
     """
     Only considers data with certain condition as existing data.
@@ -206,9 +184,24 @@ class NonBulkOnlyDefault:
         return '%s(%s)' % (self.__class__.__name__, repr(self.default))
 
 
+class RRSerializer(serializers.ModelSerializer):
+
+    class Meta:
+        model = models.RR
+        fields = ('content',)
+
+    def to_internal_value(self, data):
+        if not isinstance(data, str):
+            raise serializers.ValidationError('Must be a string.', code='must-be-a-string')
+        return super().to_internal_value({'content': data})
+
+    def to_representation(self, instance):
+        return instance.content
+
+
 class RRsetSerializer(ConditionalExistenceModelSerializer):
     domain = serializers.SlugRelatedField(read_only=True, slug_field='name')
-    records = RRsField(allow_empty=True)
+    records = RRSerializer(many=True)
     ttl = serializers.IntegerField(max_value=604800)
 
     class Meta:
@@ -288,7 +281,7 @@ class RRsetSerializer(ConditionalExistenceModelSerializer):
         return instance
 
     @staticmethod
-    def _set_all_record_contents(rrset: models.RRset, record_contents):
+    def _set_all_record_contents(rrset: models.RRset, rrs):
         """
         Updates this RR set's resource records, discarding any old values.
 
@@ -297,8 +290,10 @@ class RRsetSerializer(ConditionalExistenceModelSerializer):
         Changes are saved to the database immediately.
 
         :param rrset: the RRset at which we overwrite all RRs
-        :param record_contents: set of strings
+        :param rrs: list of RR representations
         """
+        record_contents = [rr['content'] for rr in rrs]
+
         # Remove RRs that we didn't see in the new list
         removed_rrs = rrset.records.exclude(content__in=record_contents)  # one SELECT
         for rr in removed_rrs:

+ 4 - 7
api/desecapi/tests/test_rrsets_bulk.py

@@ -363,13 +363,10 @@ class AuthenticatedRRSetBulkTestCase(AuthenticatedRRSetBaseTestCase):
             [True, '1.1.1.1'],
             dict(foobar='foobar', asdf='asdf'),
         ]:
-            s = self.client.bulk_put_rr_sets(domain_name=self.my_empty_domain.name, payload=[
-                    {'subname': 'a.2', 'ttl': 50, 'type': 'MX', 'records': records}
-                ])
-            self.assertStatus(
-                s,
-                status.HTTP_400_BAD_REQUEST
-            )
+            payload = [{'subname': 'a.2', 'ttl': 3600, 'type': 'MX', 'records': records}]
+            response = self.client.bulk_put_rr_sets(domain_name=self.my_empty_domain.name, payload=payload)
+            self.assertStatus(response, status.HTTP_400_BAD_REQUEST)
+            self.assertTrue('records' in response.data[0])
 
     def test_bulk_put_empty_records(self):
         with self.assertPdnsRequests(self.requests_desec_rr_sets_update(name=self.bulk_domain.name)):