throttling.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from rest_framework import throttling
  2. from rest_framework.settings import api_settings
  3. class ScopedRatesThrottle(throttling.ScopedRateThrottle):
  4. """
  5. Like DRF's ScopedRateThrottle, but supports several rates per scope, e.g. for burst vs. sustained limit.
  6. """
  7. def parse_rate(self, rates):
  8. return [super(ScopedRatesThrottle, self).parse_rate(rate) for rate in rates]
  9. def allow_request(self, request, view):
  10. # We can only determine the scope once we're called by the view.
  11. self.scope = getattr(view, self.scope_attr, None)
  12. # If a view does not have a `throttle_scope` always allow the request
  13. if not self.scope:
  14. return True
  15. # Determine the allowed request rate as we normally would during
  16. # the `__init__` call.
  17. self.rate = self.get_rate()
  18. if self.rate is None:
  19. return True
  20. self.now = self.timer()
  21. self.num_requests, self.duration = zip(*self.parse_rate(self.rate))
  22. self.key = self.get_cache_key(request, view)
  23. self.history = {key: [] for key in self.key}
  24. self.history.update(self.cache.get_many(self.key))
  25. for num_requests, duration, key in zip(self.num_requests, self.duration, self.key):
  26. history = self.history[key]
  27. # Drop any requests from the history which have now passed the
  28. # throttle duration
  29. while history and history[-1] <= self.now - duration:
  30. history.pop()
  31. if len(history) >= num_requests:
  32. # Prepare variables used by the Throttle's wait() method that gets called by APIView.check_throttles()
  33. self.num_requests, self.duration, self.key, self.history = num_requests, duration, key, history
  34. return self.throttle_failure()
  35. self.history[key] = history
  36. return self.throttle_success()
  37. def throttle_success(self):
  38. for key in self.history:
  39. self.history[key].insert(0, self.now)
  40. self.cache.set_many(self.history, max(self.duration))
  41. return True
  42. # Override the static attribute of the parent class so that we can dynamically apply override settings for testing
  43. @property
  44. def THROTTLE_RATES(self):
  45. return api_settings.DEFAULT_THROTTLE_RATES
  46. def get_cache_key(self, request, view):
  47. key = super().get_cache_key(request, view)
  48. return [f'{key}_{duration}' for duration in self.duration]