Ver código fonte

feat(api): rate limiting, fixes #130

Peter Thomassen 5 anos atrás
pai
commit
96c85bfb59

+ 16 - 0
api/api/settings.py

@@ -97,6 +97,21 @@ REST_FRAMEWORK = {
     'EXCEPTION_HANDLER': 'desecapi.exception_handlers.exception_handler',
     'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.NamespaceVersioning',
     'ALLOWED_VERSIONS': ['v1', 'v2'],
+    'DEFAULT_THROTTLE_CLASSES': [
+        'desecapi.throttling.ScopedRatesThrottle',
+        'rest_framework.throttling.UserRateThrottle',
+    ],
+    'DEFAULT_THROTTLE_RATES': {
+        # ScopedRatesThrottle
+        'account_management_active': ['3/min'],  # things with side effect, e.g. sending mail or zone creation on signup
+        'account_management_passive': ['10/min'],  # things like viewing your account or creating/deleting tokens
+        'dyndns': ['1/min'],  # dynDNS updates; anything above 1/min is a client misconfiguration
+        'dns_api_read': ['5/s', '50/min'],  # DNS API requests that do not involve pdns
+        'dns_api_write': ['3/s', '50/min', '200/h'],  # DNS API requests that do involve pdns
+        # UserRateThrottle
+        'user': '1000/d',  # hard limit on requests by a) an authenticated user, b) an unauthenticated IP address
+    },
+    'NUM_PROXIES': 0,  # Do not use X-Forwarded-For header when determining IP for throttling
 }
 
 PASSWORD_HASHER_TOKEN = 'desecapi.authentication.TokenHasher'
@@ -185,3 +200,4 @@ if os.environ.get('DESECSTACK_E2E_TEST', "").upper() == "TRUE":
     LIMIT_USER_DOMAIN_COUNT_DEFAULT = 5000
     USER_ACTIVATION_REQUIRED = False
     EMAIL_BACKEND = 'django.core.mail.backends.dummy.EmailBackend'
+    REST_FRAMEWORK['DEFAULT_THROTTLE_CLASSES'] = []

+ 7 - 0
api/api/settings_quick_test.py

@@ -18,7 +18,14 @@ PASSWORD_HASHERS = [
     PASSWORD_HASHER_TOKEN,
 ]
 
+CACHES = {
+    'default': {
+        'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
+    }
+}
+
 REST_FRAMEWORK['PAGE_SIZE'] = 20
+REST_FRAMEWORK['DEFAULT_THROTTLE_CLASSES'] = []
 
 # Carry email backend connection over to test mail outbox
 CELERY_EMAIL_MESSAGE_EXTRA_ATTRIBUTES = ['connection']

+ 66 - 0
api/desecapi/tests/test_throttling.py

@@ -0,0 +1,66 @@
+from unittest import mock
+import time
+
+from django.core.cache import cache
+from django.test import TestCase, override_settings
+from rest_framework import status
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework.test import APIRequestFactory
+
+
+def override_rates(rates):
+    return override_settings(REST_FRAMEWORK={'DEFAULT_THROTTLE_CLASSES': ['desecapi.throttling.ScopedRatesThrottle'],
+                                             'DEFAULT_THROTTLE_RATES': {'test_scope': rates}})
+
+
+class MockView(APIView):
+    throttle_scope = 'test_scope'
+
+    @property
+    def throttle_classes(self):
+        # Need to import here so that the module is only loaded once the settings override is in effect
+        from desecapi.throttling import ScopedRatesThrottle
+        return (ScopedRatesThrottle,)
+
+    def get(self, request):
+        return Response('foo')
+
+
+class ThrottlingTestCase(TestCase):
+    """
+    Based on DRF's test_throttling.py.
+    """
+    def setUp(self):
+        super().setUp()
+        self.factory = APIRequestFactory()
+
+    def _test_requests_are_throttled(self, rates, counts):
+        cache.clear()
+        request = self.factory.get('/')
+        with override_rates(rates):
+            view = MockView.as_view()
+            sum_delay = 0
+            for delay, count in counts:
+                sum_delay += delay
+                with mock.patch('desecapi.throttling.ScopedRatesThrottle.timer', return_value=time.time() + sum_delay):
+                    for _ in range(count):
+                        response = view(request)
+                        self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+                    response = view(request)
+                    self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
+
+    def test_requests_are_throttled_4sec(self):
+        self._test_requests_are_throttled(['4/sec'], [(0, 4), (1, 4)])
+
+    def test_requests_are_throttled_4min(self):
+        self._test_requests_are_throttled(['4/min'], [(0, 4)])
+
+    def test_requests_are_throttled_multiple(self):
+        self._test_requests_are_throttled(['5/s', '4/day'], [(0, 4)])
+        self._test_requests_are_throttled(['4/s', '5/day'], [(0, 4)])
+
+    def test_requests_are_throttled_multiple_cascade(self):
+        # We test that we can do 4 requests in the first second and only 2 in the second second
+        self._test_requests_are_throttled(['4/s', '6/day'], [(0, 4), (1, 2)])

+ 58 - 0
api/desecapi/throttling.py

@@ -0,0 +1,58 @@
+from rest_framework import throttling
+from rest_framework.settings import api_settings
+
+
+class ScopedRatesThrottle(throttling.ScopedRateThrottle):
+    """
+    Like DRF's ScopedRateThrottle, but supports several rates per scope, e.g. for burst vs. sustained limit.
+    """
+    def parse_rate(self, rates):
+        return [super(ScopedRatesThrottle, self).parse_rate(rate) for rate in rates]
+
+    def allow_request(self, request, view):
+        # We can only determine the scope once we're called by the view.
+        self.scope = getattr(view, self.scope_attr, None)
+
+        # If a view does not have a `throttle_scope` always allow the request
+        if not self.scope:
+            return True
+
+        # Determine the allowed request rate as we normally would during
+        # the `__init__` call.
+        self.rate = self.get_rate()
+        if self.rate is None:
+            return True
+
+        self.now = self.timer()
+        self.num_requests, self.duration = zip(*self.parse_rate(self.rate))
+        self.key = self.get_cache_key(request, view)
+        self.history = {key: [] for key in self.key}
+        self.history.update(self.cache.get_many(self.key))
+
+        for num_requests, duration, key in zip(self.num_requests, self.duration, self.key):
+            history = self.history[key]
+            # Drop any requests from the history which have now passed the
+            # throttle duration
+            while history and history[-1] <= self.now - duration:
+                history.pop()
+            if len(history) >= num_requests:
+                # Prepare variables used by the Throttle's wait() method that gets called by APIView.check_throttles()
+                self.num_requests, self.duration, self.key, self.history = num_requests, duration, key, history
+                return self.throttle_failure()
+            self.history[key] = history
+        return self.throttle_success()
+
+    def throttle_success(self):
+        for key in self.history:
+            self.history[key].insert(0, self.now)
+        self.cache.set_many(self.history, max(self.duration))
+        return True
+
+    # Override the static attribute of the parent class so that we can dynamically apply override settings for testing
+    @property
+    def THROTTLE_RATES(self):
+        return api_settings.DEFAULT_THROTTLE_RATES
+
+    def get_cache_key(self, request, view):
+        key = super().get_cache_key(request, view)
+        return [f'{key}_{duration}' for duration in self.duration]

+ 24 - 1
api/desecapi/views.py

@@ -13,7 +13,7 @@ from rest_framework import mixins
 from rest_framework import status
 from rest_framework.authentication import get_authorization_header
 from rest_framework.exceptions import (NotAcceptable, NotFound, PermissionDenied, ValidationError)
-from rest_framework.permissions import IsAuthenticated
+from rest_framework.permissions import IsAuthenticated, SAFE_METHODS
 from rest_framework.renderers import JSONRenderer, StaticHTMLRenderer
 from rest_framework.response import Response
 from rest_framework.reverse import reverse
@@ -58,6 +58,7 @@ class TokenViewSet(IdempotentDestroy,
                    GenericViewSet):
     serializer_class = serializers.TokenSerializer
     permission_classes = (IsAuthenticated,)
+    throttle_scope = 'account_management_passive'
 
     def get_queryset(self):
         return self.request.user.auth_tokens.all()
@@ -83,6 +84,10 @@ class DomainViewSet(IdempotentDestroy,
     lookup_field = 'name'
     lookup_value_regex = r'[^/]+'
 
+    @property
+    def throttle_scope(self):
+        return 'dns_api_read' if self.request.method in SAFE_METHODS else 'dns_api_write'
+
     def get_queryset(self):
         return self.request.user.domains
 
@@ -116,6 +121,10 @@ class RRsetDetail(IdempotentDestroy, DomainView, generics.RetrieveUpdateDestroyA
     serializer_class = serializers.RRsetSerializer
     permission_classes = (IsAuthenticated, IsDomainOwner,)
 
+    @property
+    def throttle_scope(self):
+        return 'dns_api_read' if self.request.method in SAFE_METHODS else 'dns_api_write'
+
     def get_queryset(self):
         return self.domain.rrset_set
 
@@ -153,6 +162,10 @@ class RRsetList(DomainView, generics.ListCreateAPIView, generics.UpdateAPIView):
     serializer_class = serializers.RRsetSerializer
     permission_classes = (IsAuthenticated, IsDomainOwner,)
 
+    @property
+    def throttle_scope(self):
+        return 'dns_api_read' if self.request.method in SAFE_METHODS else 'dns_api_write'
+
     def get_queryset(self):
         rrsets = models.RRset.objects.filter(domain=self.domain)
 
@@ -220,6 +233,7 @@ class Root(APIView):
 class DynDNS12Update(APIView):
     authentication_classes = (auth.TokenAuthentication, auth.BasicTokenAuthentication, auth.URLParamAuthentication,)
     renderer_classes = [PlainTextRenderer]
+    throttle_scope = 'dyndns'
 
     def _find_domain(self, request):
         def find_domain_name(r):
@@ -368,6 +382,7 @@ class DonationList(generics.CreateAPIView):
 
 class AccountCreateView(generics.CreateAPIView):
     serializer_class = serializers.RegisterAccountSerializer
+    throttle_scope = 'account_management_active'
 
     def create(self, request, *args, **kwargs):
         # Create user and send trigger email verification.
@@ -407,6 +422,7 @@ class AccountCreateView(generics.CreateAPIView):
 class AccountView(generics.RetrieveAPIView):
     permission_classes = (IsAuthenticated,)
     serializer_class = serializers.UserSerializer
+    throttle_scope = 'account_management_passive'
 
     def get_object(self):
         return self.request.user
@@ -419,6 +435,7 @@ class AccountDeleteView(generics.GenericAPIView):
         data={'detail': 'To delete your user account, first delete all of your domains.'},
         status=status.HTTP_409_CONFLICT,
     )
+    throttle_scope = 'account_management_active'
 
     def post(self, request, *args, **kwargs):
         if self.request.user.domains.exists():
@@ -436,6 +453,7 @@ class AccountDeleteView(generics.GenericAPIView):
 class AccountLoginView(generics.GenericAPIView):
     authentication_classes = (auth.EmailPasswordPayloadAuthentication,)
     permission_classes = (IsAuthenticated,)
+    throttle_scope = 'account_management_passive'
 
     def post(self, request, *args, **kwargs):
         user = self.request.user
@@ -450,6 +468,7 @@ class AccountLoginView(generics.GenericAPIView):
 class AccountLogoutView(generics.GenericAPIView, mixins.DestroyModelMixin):
     authentication_classes = (auth.TokenAuthentication,)
     permission_classes = (IsAuthenticated,)
+    throttle_classes = []  # always allow people to log out
 
     def get_object(self):
         # self.request.auth contains the hashed key as it is stored in the database
@@ -463,6 +482,7 @@ class AccountChangeEmailView(generics.GenericAPIView):
     authentication_classes = (auth.EmailPasswordPayloadAuthentication,)
     permission_classes = (IsAuthenticated,)
     serializer_class = serializers.ChangeEmailSerializer
+    throttle_scope = 'account_management_active'
 
     def post(self, request, *args, **kwargs):
         # Check password and extract email
@@ -485,6 +505,7 @@ class AccountChangeEmailView(generics.GenericAPIView):
 
 class AccountResetPasswordView(generics.GenericAPIView):
     serializer_class = serializers.ResetPasswordSerializer
+    throttle_scope = 'account_management_active'
 
     def post(self, request, *args, **kwargs):
         serializer = self.get_serializer(data=request.data)
@@ -521,6 +542,7 @@ class AuthenticatedActionView(generics.GenericAPIView):
     html_url = None
     http_method_names = ['get', 'post']  # GET is for redirect only
     renderer_classes = [JSONRenderer, StaticHTMLRenderer]
+    throttle_scope = 'account_management_active'
 
     def get_serializer_context(self):
         return {**super().get_serializer_context(), 'code': self.kwargs['code']}
@@ -647,3 +669,4 @@ class AuthenticatedDeleteUserActionView(AuthenticatedActionView):
 
 class CaptchaView(generics.CreateAPIView):
     serializer_class = serializers.CaptchaSerializer
+    throttle_scope = 'account_management_passive'