throttling.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from hashlib import sha1
  2. from rest_framework import throttling
  3. from rest_framework.settings import api_settings
  4. from desecapi import metrics
  5. class ScopedRatesThrottle(throttling.ScopedRateThrottle):
  6. """
  7. Like DRF's ScopedRateThrottle, but supports several rates per scope, e.g. for burst vs. sustained limit.
  8. """
  9. def parse_rate(self, rates):
  10. return [super(ScopedRatesThrottle, self).parse_rate(rate) for rate in rates]
  11. def allow_request(self, request, view):
  12. # We can only determine the scope once we're called by the view. Always allow request if scope not set.
  13. scope = getattr(view, self.scope_attr, None)
  14. if not scope:
  15. return True
  16. # Determine the allowed request rate as we normally would during
  17. # the `__init__` call.
  18. self.scope = scope
  19. self.rate = self.get_rate()
  20. if self.rate is None:
  21. return True
  22. # Amend scope with optional bucket
  23. bucket = getattr(view, self.scope_attr + "_bucket", None)
  24. if bucket is not None:
  25. self.scope += ":" + sha1(bucket.encode()).hexdigest()
  26. self.now = self.timer()
  27. self.num_requests, self.duration = zip(*self.parse_rate(self.rate))
  28. self.key = self.get_cache_key(request, view)
  29. self.history = {key: [] for key in self.key}
  30. self.history.update(self.cache.get_many(self.key))
  31. for num_requests, duration, key in zip(
  32. self.num_requests, self.duration, self.key
  33. ):
  34. history = self.history[key]
  35. # Drop any requests from the history which have now passed the
  36. # throttle duration
  37. while history and history[-1] <= self.now - duration:
  38. history.pop()
  39. if len(history) >= num_requests:
  40. # Prepare variables used by the Throttle's wait() method that gets called by APIView.check_throttles()
  41. self.num_requests, self.duration, self.key, self.history = (
  42. num_requests,
  43. duration,
  44. key,
  45. history,
  46. )
  47. response = self.throttle_failure()
  48. metrics.get("desecapi_throttle_failure").labels(
  49. request.method, scope, request.user.pk, bucket
  50. ).inc()
  51. return response
  52. self.history[key] = history
  53. return self.throttle_success()
  54. def throttle_success(self):
  55. for key in self.history:
  56. self.history[key].insert(0, self.now)
  57. self.cache.set_many(self.history, max(self.duration))
  58. return True
  59. # Override the static attribute of the parent class so that we can dynamically apply override settings for testing
  60. @property
  61. def THROTTLE_RATES(self):
  62. return api_settings.DEFAULT_THROTTLE_RATES
  63. def get_cache_key(self, request, view):
  64. key = super().get_cache_key(request, view)
  65. return [f"{key}_{duration}" for duration in self.duration]