test_throttling.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from unittest import mock
  2. import time
  3. from django.core.cache import cache
  4. from django.test import TestCase, override_settings
  5. from rest_framework import status
  6. from rest_framework.response import Response
  7. from rest_framework.views import APIView
  8. from rest_framework.test import APIRequestFactory
  9. def override_rates(rates):
  10. return override_settings(
  11. REST_FRAMEWORK={
  12. "DEFAULT_THROTTLE_CLASSES": ["desecapi.throttling.ScopedRatesThrottle"],
  13. "DEFAULT_THROTTLE_RATES": {"test_scope": rates},
  14. }
  15. )
  16. class MockView(APIView):
  17. throttle_scope = "test_scope"
  18. @property
  19. def throttle_classes(self):
  20. # Need to import here so that the module is only loaded once the settings override is in effect
  21. from desecapi.throttling import ScopedRatesThrottle
  22. return (ScopedRatesThrottle,)
  23. def get(self, request):
  24. return Response("foo")
  25. class ThrottlingTestCase(TestCase):
  26. """
  27. Based on DRF's test_throttling.py.
  28. """
  29. def setUp(self):
  30. super().setUp()
  31. self.factory = APIRequestFactory()
  32. def _test_requests_are_throttled(self, rates, counts, buckets=None):
  33. def do_test():
  34. view = MockView.as_view()
  35. sum_delay = 0
  36. for delay, count, max_wait in counts:
  37. sum_delay += delay
  38. with mock.patch(
  39. "desecapi.throttling.ScopedRatesThrottle.timer",
  40. return_value=time.time() + sum_delay,
  41. ):
  42. for _ in range(count):
  43. response = view(request)
  44. self.assertEqual(response.status_code, status.HTTP_200_OK)
  45. response = view(request)
  46. self.assertEqual(
  47. response.status_code, status.HTTP_429_TOO_MANY_REQUESTS
  48. )
  49. self.assertTrue(
  50. max_wait - 1 <= float(response["Retry-After"]) <= max_wait
  51. )
  52. cache.clear()
  53. request = self.factory.get("/")
  54. with override_rates(rates):
  55. do_test()
  56. if buckets is not None:
  57. for bucket in buckets:
  58. MockView.throttle_scope_bucket = bucket
  59. do_test()
  60. def test_requests_are_throttled_4sec(self):
  61. self._test_requests_are_throttled(["4/sec"], [(0, 4, 1), (1, 4, 1)])
  62. def test_requests_are_throttled_4min(self):
  63. self._test_requests_are_throttled(["4/min"], [(0, 4, 60)])
  64. def test_requests_are_throttled_multiple(self):
  65. self._test_requests_are_throttled(["5/s", "4/day"], [(0, 4, 86400)])
  66. self._test_requests_are_throttled(["4/s", "5/day"], [(0, 4, 1)])
  67. def test_requests_are_throttled_multiple_cascade(self):
  68. # We test that we can do 4 requests in the first second and only 2 in the second second
  69. self._test_requests_are_throttled(["4/s", "6/day"], [(0, 4, 1), (1, 2, 86400)])
  70. def test_requests_are_throttled_multiple_cascade_with_buckets(self):
  71. # We test that we can do 4 requests in the first second and only 2 in the second second
  72. self._test_requests_are_throttled(
  73. ["4/s", "6/day"], [(0, 4, 1), (1, 2, 86400)], buckets=["foo", "bar"]
  74. )