from unittest import mock import time from django.core.cache import cache from django.test import TestCase, override_settings from rest_framework import status from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.test import APIRequestFactory def override_rates(rates): return override_settings(REST_FRAMEWORK={'DEFAULT_THROTTLE_CLASSES': ['desecapi.throttling.ScopedRatesThrottle'], 'DEFAULT_THROTTLE_RATES': {'test_scope': rates}}) class MockView(APIView): throttle_scope = 'test_scope' @property def throttle_classes(self): # Need to import here so that the module is only loaded once the settings override is in effect from desecapi.throttling import ScopedRatesThrottle return (ScopedRatesThrottle,) def get(self, request): return Response('foo') class ThrottlingTestCase(TestCase): """ Based on DRF's test_throttling.py. """ def setUp(self): super().setUp() self.factory = APIRequestFactory() def _test_requests_are_throttled(self, rates, counts, buckets=None): def do_test(): view = MockView.as_view() sum_delay = 0 for delay, count, max_wait in counts: sum_delay += delay with mock.patch('desecapi.throttling.ScopedRatesThrottle.timer', return_value=time.time() + sum_delay): for _ in range(count): response = view(request) self.assertEqual(response.status_code, status.HTTP_200_OK) response = view(request) self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertTrue(max_wait - 1 <= float(response['Retry-After']) <= max_wait) cache.clear() request = self.factory.get('/') with override_rates(rates): do_test() if buckets is not None: for bucket in buckets: MockView.throttle_scope_bucket = bucket do_test() def test_requests_are_throttled_4sec(self): self._test_requests_are_throttled(['4/sec'], [(0, 4, 1), (1, 4, 1)]) def test_requests_are_throttled_4min(self): self._test_requests_are_throttled(['4/min'], [(0, 4, 60)]) def test_requests_are_throttled_multiple(self): self._test_requests_are_throttled(['5/s', '4/day'], [(0, 4, 86400)]) self._test_requests_are_throttled(['4/s', '5/day'], [(0, 4, 1)]) def test_requests_are_throttled_multiple_cascade(self): # We test that we can do 4 requests in the first second and only 2 in the second second self._test_requests_are_throttled(['4/s', '6/day'], [(0, 4, 1), (1, 2, 86400)]) def test_requests_are_throttled_multiple_cascade_with_buckets(self): # We test that we can do 4 requests in the first second and only 2 in the second second self._test_requests_are_throttled(['4/s', '6/day'], [(0, 4, 1), (1, 2, 86400)], buckets=['foo', 'bar'])