浏览代码

refactor(api): reduce code duplication in RRset views, simplify

Peter Thomassen 3 年之前
父节点
当前提交
6de5e10d78
共有 1 个文件被更改,包括 33 次插入39 次删除
  1. 33 39
      api/desecapi/views.py

+ 33 - 39
api/desecapi/views.py

@@ -67,31 +67,6 @@ class IdempotentDestroyMixin:
         return Response(status=status.HTTP_204_NO_CONTENT)
         return Response(status=status.HTTP_204_NO_CONTENT)
 
 
 
 
-class DomainViewMixin:
-
-    @property
-    def throttle_scope(self):
-        return 'dns_api_read' if self.request.method in SAFE_METHODS else 'dns_api_write_rrsets'
-
-    @property
-    def throttle_scope_bucket(self):
-        # Note: bucket should remain constant even when domain is recreated
-        return None if self.request.method in SAFE_METHODS else self.kwargs['name']
-
-    def get_serializer_context(self):
-        # noinspection PyUnresolvedReferences
-        return {**super().get_serializer_context(), 'domain': self.domain}
-
-    def initial(self, request, *args, **kwargs):
-        # noinspection PyUnresolvedReferences
-        super().initial(request, *args, **kwargs)
-        try:
-            # noinspection PyAttributeOutsideInit, PyUnresolvedReferences
-            self.domain = self.request.user.domains.get(name=self.kwargs['name'])
-        except models.Domain.DoesNotExist:
-            raise Http404
-
-
 class TokenViewSet(IdempotentDestroyMixin, viewsets.ModelViewSet):
 class TokenViewSet(IdempotentDestroyMixin, viewsets.ModelViewSet):
     serializer_class = serializers.TokenSerializer
     serializer_class = serializers.TokenSerializer
     permission_classes = (IsAuthenticated, ManageTokensPermission,)
     permission_classes = (IsAuthenticated, ManageTokensPermission,)
@@ -182,13 +157,42 @@ class SerialListView(APIView):
         return Response(serials)
         return Response(serials)
 
 
 
 
-class RRsetDetail(IdempotentDestroyMixin, DomainViewMixin, generics.RetrieveUpdateDestroyAPIView):
+class RRsetView:
     serializer_class = serializers.RRsetSerializer
     serializer_class = serializers.RRsetSerializer
     permission_classes = (IsAuthenticated, IsDomainOwner,)
     permission_classes = (IsAuthenticated, IsDomainOwner,)
 
 
+    @property
+    def throttle_scope(self):
+        return 'dns_api_read' if self.request.method in SAFE_METHODS else 'dns_api_write_rrsets'
+
+    @property
+    def throttle_scope_bucket(self):
+        # Note: bucket should remain constant even when domain is recreated
+        return None if self.request.method in SAFE_METHODS else self.kwargs['name']
+
     def get_queryset(self):
     def get_queryset(self):
         return self.domain.rrset_set
         return self.domain.rrset_set
 
 
+    def get_serializer_context(self):
+        # noinspection PyUnresolvedReferences
+        return {**super().get_serializer_context(), 'domain': self.domain}
+
+    def initial(self, request, *args, **kwargs):
+        # noinspection PyUnresolvedReferences
+        super().initial(request, *args, **kwargs)
+        try:
+            # noinspection PyAttributeOutsideInit, PyUnresolvedReferences
+            self.domain = self.request.user.domains.get(name=self.kwargs['name'])
+        except models.Domain.DoesNotExist:
+            raise Http404
+
+    def perform_update(self, serializer):
+        with PDNSChangeTracker():
+            super().perform_update(serializer)
+
+
+class RRsetDetail(RRsetView, IdempotentDestroyMixin, generics.RetrieveUpdateDestroyAPIView):
+
     def get_object(self):
     def get_object(self):
         queryset = self.filter_queryset(self.get_queryset())
         queryset = self.filter_queryset(self.get_queryset())
 
 
@@ -207,21 +211,15 @@ class RRsetDetail(IdempotentDestroyMixin, DomainViewMixin, generics.RetrieveUpda
             response.status_code = 204
             response.status_code = 204
         return response
         return response
 
 
-    def perform_update(self, serializer):
-        with PDNSChangeTracker():
-            super().perform_update(serializer)
-
     def perform_destroy(self, instance):
     def perform_destroy(self, instance):
         with PDNSChangeTracker():
         with PDNSChangeTracker():
             super().perform_destroy(instance)
             super().perform_destroy(instance)
 
 
 
 
-class RRsetList(EmptyPayloadMixin, DomainViewMixin, generics.ListCreateAPIView, generics.UpdateAPIView):
-    serializer_class = serializers.RRsetSerializer
-    permission_classes = (IsAuthenticated, IsDomainOwner,)
+class RRsetList(RRsetView, EmptyPayloadMixin, generics.ListCreateAPIView, generics.UpdateAPIView):
 
 
     def get_queryset(self):
     def get_queryset(self):
-        rrsets = models.RRset.objects.filter(domain=self.domain)
+        rrsets = super().get_queryset()
 
 
         for filter_field in ('subname', 'type'):
         for filter_field in ('subname', 'type'):
             value = self.request.query_params.get(filter_field)
             value = self.request.query_params.get(filter_field)
@@ -233,7 +231,7 @@ class RRsetList(EmptyPayloadMixin, DomainViewMixin, generics.ListCreateAPIView,
 
 
                 rrsets = rrsets.filter(**{'%s__exact' % filter_field: value})
                 rrsets = rrsets.filter(**{'%s__exact' % filter_field: value})
 
 
-        return rrsets
+        return rrsets.all()  # without .all(), cache is sometimes inconsistent with actual state in bulk tests. (Why?)
 
 
     def get_object(self):
     def get_object(self):
         # For this view, the object we're operating on is the queryset that one can also GET. Serializing a queryset
         # For this view, the object we're operating on is the queryset that one can also GET. Serializing a queryset
@@ -257,10 +255,6 @@ class RRsetList(EmptyPayloadMixin, DomainViewMixin, generics.ListCreateAPIView,
         with PDNSChangeTracker():
         with PDNSChangeTracker():
             super().perform_create(serializer)
             super().perform_create(serializer)
 
 
-    def perform_update(self, serializer):
-        with PDNSChangeTracker():
-            super().perform_update(serializer)
-
 
 
 class Root(APIView):
 class Root(APIView):
     def get(self, request, *args, **kwargs):
     def get(self, request, *args, **kwargs):