test_throttling.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from unittest import mock
  2. import time
  3. from django.core.cache import cache
  4. from django.test import TestCase, override_settings
  5. from rest_framework import status
  6. from rest_framework.response import Response
  7. from rest_framework.views import APIView
  8. from rest_framework.test import APIRequestFactory
  9. def override_rates(rates):
  10. return override_settings(REST_FRAMEWORK={'DEFAULT_THROTTLE_CLASSES': ['desecapi.throttling.ScopedRatesThrottle'],
  11. 'DEFAULT_THROTTLE_RATES': {'test_scope': rates}})
  12. class MockView(APIView):
  13. throttle_scope = 'test_scope'
  14. @property
  15. def throttle_classes(self):
  16. # Need to import here so that the module is only loaded once the settings override is in effect
  17. from desecapi.throttling import ScopedRatesThrottle
  18. return (ScopedRatesThrottle,)
  19. def get(self, request):
  20. return Response('foo')
  21. class ThrottlingTestCase(TestCase):
  22. """
  23. Based on DRF's test_throttling.py.
  24. """
  25. def setUp(self):
  26. super().setUp()
  27. self.factory = APIRequestFactory()
  28. def _test_requests_are_throttled(self, rates, counts):
  29. cache.clear()
  30. request = self.factory.get('/')
  31. with override_rates(rates):
  32. view = MockView.as_view()
  33. sum_delay = 0
  34. for delay, count in counts:
  35. sum_delay += delay
  36. with mock.patch('desecapi.throttling.ScopedRatesThrottle.timer', return_value=time.time() + sum_delay):
  37. for _ in range(count):
  38. response = view(request)
  39. self.assertEqual(response.status_code, status.HTTP_200_OK)
  40. response = view(request)
  41. self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
  42. def test_requests_are_throttled_4sec(self):
  43. self._test_requests_are_throttled(['4/sec'], [(0, 4), (1, 4)])
  44. def test_requests_are_throttled_4min(self):
  45. self._test_requests_are_throttled(['4/min'], [(0, 4)])
  46. def test_requests_are_throttled_multiple(self):
  47. self._test_requests_are_throttled(['5/s', '4/day'], [(0, 4)])
  48. self._test_requests_are_throttled(['4/s', '5/day'], [(0, 4)])
  49. def test_requests_are_throttled_multiple_cascade(self):
  50. # We test that we can do 4 requests in the first second and only 2 in the second second
  51. self._test_requests_are_throttled(['4/s', '6/day'], [(0, 4), (1, 2)])