123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- from unittest import mock
- import time
- from django.core.cache import cache
- from django.test import TestCase, override_settings
- from rest_framework import status
- from rest_framework.response import Response
- from rest_framework.views import APIView
- from rest_framework.test import APIRequestFactory
- def override_rates(rates):
- return override_settings(
- REST_FRAMEWORK={
- "DEFAULT_THROTTLE_CLASSES": ["desecapi.throttling.ScopedRatesThrottle"],
- "DEFAULT_THROTTLE_RATES": {"test_scope": rates},
- }
- )
- class MockView(APIView):
- throttle_scope = "test_scope"
- @property
- def throttle_classes(self):
- # Need to import here so that the module is only loaded once the settings override is in effect
- from desecapi.throttling import ScopedRatesThrottle
- return (ScopedRatesThrottle,)
- def get(self, request):
- return Response("foo")
- class ThrottlingTestCase(TestCase):
- """
- Based on DRF's test_throttling.py.
- """
- def setUp(self):
- super().setUp()
- self.factory = APIRequestFactory()
- def _test_requests_are_throttled(self, rates, counts, buckets=None):
- def do_test():
- view = MockView.as_view()
- sum_delay = 0
- for delay, count, max_wait in counts:
- sum_delay += delay
- with mock.patch(
- "desecapi.throttling.ScopedRatesThrottle.timer",
- return_value=time.time() + sum_delay,
- ):
- for _ in range(count):
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = view(request)
- self.assertEqual(
- response.status_code, status.HTTP_429_TOO_MANY_REQUESTS
- )
- self.assertTrue(
- max_wait - 1 <= float(response["Retry-After"]) <= max_wait
- )
- cache.clear()
- request = self.factory.get("/")
- with override_rates(rates):
- do_test()
- if buckets is not None:
- for bucket in buckets:
- MockView.throttle_scope_bucket = bucket
- do_test()
- def test_requests_are_throttled_4sec(self):
- self._test_requests_are_throttled(["4/sec"], [(0, 4, 1), (1, 4, 1)])
- def test_requests_are_throttled_4min(self):
- self._test_requests_are_throttled(["4/min"], [(0, 4, 60)])
- def test_requests_are_throttled_multiple(self):
- self._test_requests_are_throttled(["5/s", "4/day"], [(0, 4, 86400)])
- self._test_requests_are_throttled(["4/s", "5/day"], [(0, 4, 1)])
- def test_requests_are_throttled_multiple_cascade(self):
- # We test that we can do 4 requests in the first second and only 2 in the second second
- self._test_requests_are_throttled(["4/s", "6/day"], [(0, 4, 1), (1, 2, 86400)])
- def test_requests_are_throttled_multiple_cascade_with_buckets(self):
- # We test that we can do 4 requests in the first second and only 2 in the second second
- self._test_requests_are_throttled(
- ["4/s", "6/day"], [(0, 4, 1), (1, 2, 86400)], buckets=["foo", "bar"]
- )
|