Ver código fonte

feat(api): bulk REST requests, closes #83

Peter Thomassen 7 anos atrás
pai
commit
0edfab3531
4 arquivos alterados com 167 adições e 66 exclusões
  1. 0 7
      api/desecapi/models.py
  2. 144 38
      api/desecapi/serializers.py
  3. 22 21
      api/desecapi/views.py
  4. 1 0
      api/requirements.txt

+ 0 - 7
api/desecapi/models.py

@@ -460,13 +460,6 @@ class RRset(models.Model, mixins.SetterMixin):
     def name(self):
         return '.'.join(filter(None, [self.subname, self.domain.name])) + '.'
 
-    @transaction.atomic
-    def set_rrs(self, contents, sync=True, notify=True):
-        self.records.all().delete()
-        self.records.set([RR(content=x) for x in contents], bulk=False)
-        if sync and not self.domain.owner.locked:
-            pdns.set_rrset(self, notify=notify)
-
     @transaction.atomic
     def delete(self, *args, **kwargs):
         # For locked users, we can't easily sync deleted RRsets to pdns later,

+ 144 - 38
api/desecapi/serializers.py

@@ -2,7 +2,9 @@ from rest_framework import serializers
 from rest_framework.exceptions import ValidationError
 from desecapi.models import Domain, Donation, User, RR, RRset
 from djoser import serializers as djoserSerializers
-from django.db import transaction
+from django.db import models, transaction
+import django.core.exceptions
+from rest_framework_bulk import BulkListSerializer, BulkSerializerMixin
 
 
 class RRSerializer(serializers.ModelSerializer):
@@ -10,54 +12,153 @@ class RRSerializer(serializers.ModelSerializer):
         model = RR
         fields = ('content',)
 
+    def to_representation(self, instance):
+        return instance.content
 
-class RRsetSerializer(serializers.ModelSerializer):
+    def to_internal_value(self, data):
+        if not isinstance(data, dict):
+            data = {'content': data}
+        return self.Meta.model(**data)
+
+
+class RRsetBulkListSerializer(BulkListSerializer):
+    @transaction.atomic
+    def update(self, queryset, validated_data):
+        q = models.Q(pk__isnull=True)
+        for data in validated_data:
+            q |= models.Q(subname=data.get('subname', ''), type=data['type'])
+        rrsets = {(obj.subname, obj.type): obj for obj in queryset.filter(q)}
+        instance = [rrsets.get((data.get('subname', ''), data['type']), None)
+                    for data in validated_data]
+        return self.child._save(instance, validated_data)
+
+    @transaction.atomic
+    def create(self, validated_data):
+        return self.child._save([None] * len(validated_data), validated_data)
+
+
+class SlugRRField(serializers.SlugRelatedField):
+    def __init__(self, *args, **kwargs):
+        kwargs['slug_field'] = 'content'
+        kwargs['queryset'] = RR.objects.all()
+        super().__init__(*args, **kwargs)
+
+    def to_internal_value(self, data):
+        return RR(**{self.slug_field: data})
+
+
+class RRsetSerializer(BulkSerializerMixin, serializers.ModelSerializer):
     domain = serializers.StringRelatedField()
     subname = serializers.CharField(allow_blank=True, required=False)
-    type = serializers.CharField(required=False)
-    records = serializers.SerializerMethodField()
+    records = SlugRRField(many=True)
 
 
     class Meta:
         model = RRset
-        fields = ('domain', 'subname', 'name', 'records', 'ttl', 'type',)
-
-    def _set_records(self, instance):
-        # Although serializer fields have required=True by default, that
-        # setting does not work for the SerializerMethodField "records".
-        # Thus, let's wrap our read access to include the validation check.
-        records = self.context['request'].data.get('records')
-        if records is None:
-            raise ValidationError({'records': 'This field is required.'},
-                                  code='required')
-
-        records_data = [{'content': x} for x in records]
-        rr_serializer = RRSerializer(data=records_data, many=True,
-                                     allow_empty=False)
-        if not rr_serializer.is_valid():
-            errors = rr_serializer.errors
-            if 'non_field_errors' in errors:
-                errors['records'] = errors.pop('non_field_errors')
-            raise serializers.ValidationError(errors)
-        instance.set_rrs([x['content'] for x in rr_serializer.validated_data])
-
-    @transaction.atomic
-    def create(self, validated_data):
-        instance = super().create(validated_data)
-        self._set_records(instance)
-        return instance
+        fields = ('id', 'domain', 'subname', 'name', 'records', 'ttl', 'type',)
+        list_serializer_class = RRsetBulkListSerializer
+
+    def _save(self, instance, validated_data):
+        bulk = isinstance(instance, list)
+        if not bulk:
+            instance = [instance]
+            validated_data = [validated_data]
+
+        name = self.context['view'].kwargs['name']
+        domain = self.context['request'].user.domains.get(name=name)
+        method = self.context['request'].method
+
+        errors = []
+        rrsets = {}
+        rrsets_seen = set()
+        for rrset, data in zip(instance, validated_data):
+            # Construct RRset
+            records = data.pop('records', None)
+            if rrset:
+                # We have a known instance (update). Update fields if given.
+                rrset.subname = data.get('subname', rrset.subname)
+                rrset.type = data.get('type', rrset.type)
+                rrset.ttl = data.get('ttl', rrset.ttl)
+            else:
+                # No known instance (creation or meaningless request)
+                if not 'ttl' in data:
+                    if records:
+                        # If we have records, this is a creation request, so we
+                        # need a TTL.
+                        errors.append({'ttl': ['This field is required for new RRsets.']})
+                        continue
+                    else:
+                        # If this request is meaningless, we still want it to
+                        # be processed by pdns for type validation. In this
+                        # case, we need some dummy TTL.
+                        data['ttl'] = data.get('ttl', 1)
+                data.pop('id', None)
+                data['domain'] = domain
+                rrset = RRset(**data)
+
+            # Verify that we have not seen this RRset before
+            if (rrset.subname, rrset.type) in rrsets_seen:
+                errors.append({'__all__': ['RRset repeated with same subname and type.']})
+                continue
+            rrsets_seen.add((rrset.subname, rrset.type))
+
+            # Validate RRset. Raises error if type or subname have been changed
+            # or if new RRset is not unique.
+            validate_unique = (method == 'POST')
+            try:
+                rrset.full_clean(exclude=['updated'],
+                                 validate_unique=validate_unique)
+            except django.core.exceptions.ValidationError as e:
+                errors.append(e.message_dict)
+                continue
+
+            # Construct dictionary of RR lists to write, indexed by their RRset
+            if records is None:
+                rrsets[rrset] = None
+            else:
+                rr_data = [{'content': x.content, 'rrset': rrset} for x in records]
+
+                # Use RRSerializer to validate records inputs
+                allow_empty = (method in ('PATCH', 'PUT'))
+                rr_serializer = RRSerializer(data=rr_data, many=True,
+                                             allow_empty=allow_empty)
+
+                if not rr_serializer.is_valid():
+                    error = rr_serializer.errors
+                    if 'non_field_errors' in error:
+                        error['records'] = error.pop('non_field_errors')
+                    errors.append(error)
+                    continue
+
+                # Blessings have been given, so add RRset to the to-write dict
+                rrsets[rrset] = [rr for rr in rr_serializer.validated_data]
+
+            errors.append({})
+
+        if any(errors):
+            raise ValidationError(errors if bulk else errors[0])
+
+        # Now try to save RRsets
+        try:
+            rrsets = domain.write_rrsets(rrsets)
+        except django.core.exceptions.ValidationError as e:
+            for attr in ['errors', 'error_dict', 'message']:
+                detail = getattr(e, attr, None)
+                if detail:
+                    raise ValidationError(detail)
+            raise ValidationError(str(e))
+        except ValueError as e:
+            raise ValidationError({'__all__': str(e)})
+
+        return rrsets if bulk else rrsets[0]
 
     @transaction.atomic
     def update(self, instance, validated_data):
-        instance = super().update(instance, validated_data)
-        # Update records only if required (PUT) or provided (PATCH)
-        if not self.partial or 'records' in self.context['request'].data:
-            instance.records.all().delete()
-            self._set_records(instance)
-        return instance
+        return self._save(instance, validated_data)
 
-    def get_records(self, obj):
-        return list(obj.records.values_list('content', flat=True))
+    @transaction.atomic
+    def create(self, validated_data):
+        return self._save(None, validated_data)
 
     def validate_type(self, value):
         if value in RRset.RESTRICTED_TYPES:
@@ -65,6 +166,11 @@ class RRsetSerializer(serializers.ModelSerializer):
                 "You cannot tinker with the %s RRset." % value)
         return value
 
+    def to_representation(self, instance):
+        data = super().to_representation(instance)
+        data.pop('id')
+        return data
+
 
 class DomainSerializer(serializers.ModelSerializer):
     owner = serializers.ReadOnlyField(source='owner.email')

+ 22 - 21
api/desecapi/views.py

@@ -31,6 +31,7 @@ from django.db.models import Q
 from desecapi.emails import send_account_lock_email, send_token_email
 import re
 import ipaddress, os
+from rest_framework_bulk import ListBulkCreateUpdateAPIView
 
 patternDyn = re.compile(r'^[A-Za-z-][A-Za-z0-9_-]*\.dedyn\.io$')
 patternNonDyn = re.compile(r'^([A-Za-z0-9-][A-Za-z0-9_-]*\.)+[A-Za-z]+$')
@@ -162,6 +163,9 @@ class RRsetDetail(generics.RetrieveUpdateDestroyAPIView):
         if request.data.get('records') == []:
             return self.delete(request, *args, **kwargs)
 
+        for k in ('type', 'subname'):
+            request.data[k] = request.data.pop(k, self.kwargs[k])
+
         try:
             return super().update(request, *args, **kwargs)
         except django.core.exceptions.ValidationError as e:
@@ -170,7 +174,7 @@ class RRsetDetail(generics.RetrieveUpdateDestroyAPIView):
             raise ex
 
 
-class RRsetList(generics.ListCreateAPIView):
+class RRsetList(ListBulkCreateUpdateAPIView):
     authentication_classes = (TokenAuthentication, auth.IPAuthentication,)
     serializer_class = RRsetSerializer
     permission_classes = (permissions.IsAuthenticated, IsDomainOwner,)
@@ -195,30 +199,27 @@ class RRsetList(generics.ListCreateAPIView):
             return super().create(request, *args, **kwargs)
         except Domain.DoesNotExist:
             raise Http404
-        except django.core.exceptions.ValidationError as e:
-            ex = ValidationError(detail=e.message_dict)
-            all = e.message_dict.get('__all__')
-            if all is not None \
-                    and any(msg.endswith(' already exists.') for msg in all):
-                ex.status_code = status.HTTP_409_CONFLICT
-            raise ex
+        except ValidationError as e:
+            if isinstance(e.detail, dict):
+                detail = e.detail.get('__all__')
+                if isinstance(detail, list) \
+                        and any(m.endswith(' already exists.') for m in detail):
+                    e.status_code = status.HTTP_409_CONFLICT
+            raise e
 
     def perform_create(self, serializer):
+        # For new RRsets without a subname, set it empty. We don't use
+        # default='' in the serializer field definition so that during PUT, the
+        # subname value is retained if omitted.
+        if isinstance(self.request.data, list):
+            serializer._validated_data = [{**{'subname': ''}, **data}
+                                         for data in serializer.validated_data]
+        else:
+            serializer._validated_data = {**{'subname': ''}, **serializer.validated_data}
+
         # Associate RRset with proper domain
         domain = self.request.user.domains.get(name=self.kwargs['name'])
-        kwargs = {'domain': domain}
-
-        # If this RRset is new and a subname has not been given, set it empty
-        #
-        # Notes:
-        # - We don't use default='' in the serializer so that during PUT, the
-        #   subname value is retained if omitted.)
-        # - Don't use kwargs['subname'] = self.request.data.get('subname', ''),
-        #   giving preference to what's in serializer.validated_data at this point
-        if self.request.method == 'POST' and self.request.data.get('subname') is None:
-            kwargs['subname'] = ''
-
-        serializer.save(**kwargs)
+        serializer.save(domain=domain)
 
     def get(self, request, *args, **kwargs):
         name = self.kwargs['name']

+ 1 - 0
api/requirements.txt

@@ -8,3 +8,4 @@ requests==2.18.*
 uwsgi==2.0.*
 django-nocaptcha-recaptcha==0.0.19  # updated manually
 sqlparse==0.2.*
+djangorestframework-bulk==0.2.*