123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- from hashlib import sha1
- from rest_framework import throttling
- from rest_framework.settings import api_settings
- from desecapi import metrics
- class ScopedRatesThrottle(throttling.ScopedRateThrottle):
- """
- Like DRF's ScopedRateThrottle, but supports several rates per scope, e.g. for burst vs. sustained limit.
- """
- def parse_rate(self, rates):
- return [super(ScopedRatesThrottle, self).parse_rate(rate) for rate in rates]
- def allow_request(self, request, view):
- # We can only determine the scope once we're called by the view. Always allow request if scope not set.
- scope = getattr(view, self.scope_attr, None)
- if not scope:
- return True
- # Determine the allowed request rate as we normally would during
- # the `__init__` call.
- self.scope = scope
- self.rate = self.get_rate()
- if self.rate is None:
- return True
- # Amend scope with optional bucket
- bucket = getattr(view, self.scope_attr + "_bucket", None)
- if bucket is not None:
- self.scope += ":" + sha1(bucket.encode()).hexdigest()
- self.now = self.timer()
- self.num_requests, self.duration = zip(*self.parse_rate(self.rate))
- self.key = self.get_cache_key(request, view)
- self.history = {key: [] for key in self.key}
- self.history.update(self.cache.get_many(self.key))
- for num_requests, duration, key in zip(
- self.num_requests, self.duration, self.key
- ):
- history = self.history[key]
- # Drop any requests from the history which have now passed the
- # throttle duration
- while history and history[-1] <= self.now - duration:
- history.pop()
- if len(history) >= num_requests:
- # Prepare variables used by the Throttle's wait() method that gets called by APIView.check_throttles()
- self.num_requests, self.duration, self.key, self.history = (
- num_requests,
- duration,
- key,
- history,
- )
- response = self.throttle_failure()
- metrics.get("desecapi_throttle_failure").labels(
- request.method, scope, request.user.pk, bucket
- ).inc()
- return response
- self.history[key] = history
- return self.throttle_success()
- def throttle_success(self):
- for key in self.history:
- self.history[key].insert(0, self.now)
- self.cache.set_many(self.history, max(self.duration))
- return True
- # Override the static attribute of the parent class so that we can dynamically apply override settings for testing
- @property
- def THROTTLE_RATES(self):
- return api_settings.DEFAULT_THROTTLE_RATES
- def get_cache_key(self, request, view):
- key = super().get_cache_key(request, view)
- return [f"{key}_{duration}" for duration in self.duration]
|