test_throttling.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from unittest import mock
  2. import time
  3. from django.core.cache import cache
  4. from django.test import TestCase, override_settings
  5. from rest_framework import status
  6. from rest_framework.response import Response
  7. from rest_framework.views import APIView
  8. from rest_framework.test import APIRequestFactory
  9. def override_rates(rates):
  10. return override_settings(REST_FRAMEWORK={'DEFAULT_THROTTLE_CLASSES': ['desecapi.throttling.ScopedRatesThrottle'],
  11. 'DEFAULT_THROTTLE_RATES': {'test_scope': rates}})
  12. class MockView(APIView):
  13. throttle_scope = 'test_scope'
  14. @property
  15. def throttle_classes(self):
  16. # Need to import here so that the module is only loaded once the settings override is in effect
  17. from desecapi.throttling import ScopedRatesThrottle
  18. return (ScopedRatesThrottle,)
  19. def get(self, request):
  20. return Response('foo')
  21. class ThrottlingTestCase(TestCase):
  22. """
  23. Based on DRF's test_throttling.py.
  24. """
  25. def setUp(self):
  26. super().setUp()
  27. self.factory = APIRequestFactory()
  28. def _test_requests_are_throttled(self, rates, counts, buckets=None):
  29. def do_test():
  30. view = MockView.as_view()
  31. sum_delay = 0
  32. for delay, count, max_wait in counts:
  33. sum_delay += delay
  34. with mock.patch('desecapi.throttling.ScopedRatesThrottle.timer', return_value=time.time() + sum_delay):
  35. for _ in range(count):
  36. response = view(request)
  37. self.assertEqual(response.status_code, status.HTTP_200_OK)
  38. response = view(request)
  39. self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
  40. self.assertTrue(max_wait - 1 <= float(response['Retry-After']) <= max_wait)
  41. cache.clear()
  42. request = self.factory.get('/')
  43. with override_rates(rates):
  44. do_test()
  45. if buckets is not None:
  46. for bucket in buckets:
  47. MockView.throttle_scope_bucket = bucket
  48. do_test()
  49. def test_requests_are_throttled_4sec(self):
  50. self._test_requests_are_throttled(['4/sec'], [(0, 4, 1), (1, 4, 1)])
  51. def test_requests_are_throttled_4min(self):
  52. self._test_requests_are_throttled(['4/min'], [(0, 4, 60)])
  53. def test_requests_are_throttled_multiple(self):
  54. self._test_requests_are_throttled(['5/s', '4/day'], [(0, 4, 86400)])
  55. self._test_requests_are_throttled(['4/s', '5/day'], [(0, 4, 1)])
  56. def test_requests_are_throttled_multiple_cascade(self):
  57. # We test that we can do 4 requests in the first second and only 2 in the second second
  58. self._test_requests_are_throttled(['4/s', '6/day'], [(0, 4, 1), (1, 2, 86400)])
  59. def test_requests_are_throttled_multiple_cascade_with_buckets(self):
  60. # We test that we can do 4 requests in the first second and only 2 in the second second
  61. self._test_requests_are_throttled(['4/s', '6/day'], [(0, 4, 1), (1, 2, 86400)], buckets=['foo', 'bar'])