test_throttling.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from unittest import mock
  2. import time
  3. from django.core.cache import cache
  4. from django.test import TestCase, override_settings
  5. from rest_framework import status
  6. from rest_framework.response import Response
  7. from rest_framework.views import APIView
  8. from rest_framework.test import APIRequestFactory
  9. def override_rates(rates):
  10. return override_settings(REST_FRAMEWORK={'DEFAULT_THROTTLE_CLASSES': ['desecapi.throttling.ScopedRatesThrottle'],
  11. 'DEFAULT_THROTTLE_RATES': {'test_scope': rates}})
  12. class MockView(APIView):
  13. throttle_scope = 'test_scope'
  14. @property
  15. def throttle_classes(self):
  16. # Need to import here so that the module is only loaded once the settings override is in effect
  17. from desecapi.throttling import ScopedRatesThrottle
  18. return (ScopedRatesThrottle,)
  19. def get(self, request):
  20. return Response('foo')
  21. class ThrottlingTestCase(TestCase):
  22. """
  23. Based on DRF's test_throttling.py.
  24. """
  25. def setUp(self):
  26. super().setUp()
  27. self.factory = APIRequestFactory()
  28. def _test_requests_are_throttled(self, rates, counts, buckets=None):
  29. def do_test():
  30. view = MockView.as_view()
  31. sum_delay = 0
  32. for delay, count in counts:
  33. sum_delay += delay
  34. with mock.patch('desecapi.throttling.ScopedRatesThrottle.timer', return_value=time.time() + sum_delay):
  35. for _ in range(count):
  36. response = view(request)
  37. self.assertEqual(response.status_code, status.HTTP_200_OK)
  38. response = view(request)
  39. self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
  40. cache.clear()
  41. request = self.factory.get('/')
  42. with override_rates(rates):
  43. do_test()
  44. if buckets is not None:
  45. for bucket in buckets:
  46. MockView.throttle_scope_bucket = bucket
  47. do_test()
  48. def test_requests_are_throttled_4sec(self):
  49. self._test_requests_are_throttled(['4/sec'], [(0, 4), (1, 4)])
  50. def test_requests_are_throttled_4min(self):
  51. self._test_requests_are_throttled(['4/min'], [(0, 4)])
  52. def test_requests_are_throttled_multiple(self):
  53. self._test_requests_are_throttled(['5/s', '4/day'], [(0, 4)])
  54. self._test_requests_are_throttled(['4/s', '5/day'], [(0, 4)])
  55. def test_requests_are_throttled_multiple_cascade(self):
  56. # We test that we can do 4 requests in the first second and only 2 in the second second
  57. self._test_requests_are_throttled(['4/s', '6/day'], [(0, 4), (1, 2)])
  58. def test_requests_are_throttled_multiple_cascade_with_buckets(self):
  59. # We test that we can do 4 requests in the first second and only 2 in the second second
  60. self._test_requests_are_throttled(['4/s', '6/day'], [(0, 4), (1, 2)], buckets=['foo', 'bar'])