1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- # -*- coding: utf-8 -*-
- # !/usr/bin/env python
- from functools import wraps
- from logging import getLogger
- from time import sleep
- from .exceptions import BucketFullException
- logger = getLogger(__name__)
- class LimitContextDecorator(object):
- """A class that can be used as a:
- * decorator
- * async decorator
- * contextmanager
- * async contextmanager
- Mainly used via ``Limiter.ratelimit()``. Depending on arguments, calls that exceed the rate
- limit will either raise an exception, or sleep until space is available in the bucket.
- Args:
- limiter: Limiter object
- identities: Bucket identities
- delay: Delay until the next request instead of raising an exception
- max_delay: The maximum allowed delay time (in seconds); anything over this will raise
- an exception
- """
- def __init__(
- self,
- limiter,
- delay = False,
- max_delay = None,
- *identities
- ):
- self.delay = delay
- self.max_delay = max_delay
- self.try_acquire = lambda: limiter.try_acquire(*identities)
- def __call__(self, func):
- """Allows usage as a decorator for both normal and async functions"""
- @wraps(func)
- def wrapper(*args, **kwargs):
- self.delayed_acquire()
- return func(*args, **kwargs)
- return wrapper
- def __enter__(self):
- """Allows usage as a contextmanager"""
- self.delayed_acquire()
- def __exit__(self, *exc):
- pass
- def delayed_acquire(self):
- """Delay and retry until we can successfully acquire an available bucket item"""
- while True:
- try:
- self.try_acquire()
- except BucketFullException as err:
- delay_time = self.delay_or_reraise(err)
- sleep(delay_time)
- else:
- break
- def delay_or_reraise(self, err):
- """Determine if we should delay after exceeding a rate limit. If so, return the delay time,
- otherwise re-raise the exception.
- """
- delay_time = err.meta_info["remaining_time"]
- logger.debug(err.meta_info)
- logger.info("Rate limit reached; % seconds remaining before next request", delay_time)
- exceeded_max_delay = bool(self.max_delay) and (delay_time > self.max_delay)
- if self.delay and not exceeded_max_delay:
- return delay_time
- raise err
|